mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
fix wan vram bug
This commit is contained in:
@@ -170,6 +170,9 @@ class AutoWrappedModule(AutoTorchModule):
|
||||
elif self.preparing_device != "disk":
|
||||
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
|
||||
self.state = 2
|
||||
|
||||
def cast_to(self, module, dtype, device):
|
||||
return copy.deepcopy(module).to(dtype=dtype, device=device)
|
||||
|
||||
def computation(self):
|
||||
# onload / preparing -> computation (temporary)
|
||||
@@ -182,7 +185,7 @@ class AutoWrappedModule(AutoTorchModule):
|
||||
elif self.disk_offload and device == "disk":
|
||||
module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True)
|
||||
else:
|
||||
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
|
||||
module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device)
|
||||
return module
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
@@ -251,6 +254,11 @@ class AutoWrappedNonRecurseModule(AutoWrappedModule):
|
||||
for name in self.required_params:
|
||||
getattr(self, name).to("meta")
|
||||
|
||||
def cast_to(self, module, dtype, device):
|
||||
for name in self.required_params:
|
||||
getattr(module, name).to(dtype=dtype, device=device)
|
||||
return module
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in self.__dict__ or name == "module":
|
||||
return super().__getattr__(name)
|
||||
|
||||
Reference in New Issue
Block a user