mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
fix wan vram bug
This commit is contained in:
@@ -170,6 +170,9 @@ class AutoWrappedModule(AutoTorchModule):
|
|||||||
elif self.preparing_device != "disk":
|
elif self.preparing_device != "disk":
|
||||||
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
|
self.to(dtype=self.preparing_dtype, device=self.preparing_device)
|
||||||
self.state = 2
|
self.state = 2
|
||||||
|
|
||||||
|
def cast_to(self, module, dtype, device):
|
||||||
|
return copy.deepcopy(module).to(dtype=dtype, device=device)
|
||||||
|
|
||||||
def computation(self):
|
def computation(self):
|
||||||
# onload / preparing -> computation (temporary)
|
# onload / preparing -> computation (temporary)
|
||||||
@@ -182,7 +185,7 @@ class AutoWrappedModule(AutoTorchModule):
|
|||||||
elif self.disk_offload and device == "disk":
|
elif self.disk_offload and device == "disk":
|
||||||
module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True)
|
module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True)
|
||||||
else:
|
else:
|
||||||
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device)
|
module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device)
|
||||||
return module
|
return module
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
@@ -251,6 +254,11 @@ class AutoWrappedNonRecurseModule(AutoWrappedModule):
|
|||||||
for name in self.required_params:
|
for name in self.required_params:
|
||||||
getattr(self, name).to("meta")
|
getattr(self, name).to("meta")
|
||||||
|
|
||||||
|
def cast_to(self, module, dtype, device):
|
||||||
|
for name in self.required_params:
|
||||||
|
getattr(module, name).to(dtype=dtype, device=device)
|
||||||
|
return module
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
if name in self.__dict__ or name == "module":
|
if name in self.__dict__ or name == "module":
|
||||||
return super().__getattr__(name)
|
return super().__getattr__(name)
|
||||||
|
|||||||
Reference in New Issue
Block a user