fix wan vram bug

This commit is contained in:
Artiprocher
2025-11-18 22:43:51 +08:00
parent 416ff5df74
commit f85af085df

View File

@@ -170,6 +170,9 @@ class AutoWrappedModule(AutoTorchModule):
elif self.preparing_device != "disk": elif self.preparing_device != "disk":
self.to(dtype=self.preparing_dtype, device=self.preparing_device) self.to(dtype=self.preparing_dtype, device=self.preparing_device)
self.state = 2 self.state = 2
def cast_to(self, module, dtype, device):
return copy.deepcopy(module).to(dtype=dtype, device=device)
def computation(self): def computation(self):
# onload / preparing -> computation (temporary) # onload / preparing -> computation (temporary)
@@ -182,7 +185,7 @@ class AutoWrappedModule(AutoTorchModule):
elif self.disk_offload and device == "disk": elif self.disk_offload and device == "disk":
module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True) module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True)
else: else:
module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device) module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device)
return module return module
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
@@ -251,6 +254,11 @@ class AutoWrappedNonRecurseModule(AutoWrappedModule):
for name in self.required_params: for name in self.required_params:
getattr(self, name).to("meta") getattr(self, name).to("meta")
def cast_to(self, module, dtype, device):
for name in self.required_params:
getattr(module, name).to(dtype=dtype, device=device)
return module
def __getattr__(self, name): def __getattr__(self, name):
if name in self.__dict__ or name == "module": if name in self.__dict__ or name == "module":
return super().__getattr__(name) return super().__getattr__(name)