From f85af085df499eaab27cf9ec7122614f3beee254 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 18 Nov 2025 22:43:51 +0800 Subject: [PATCH] fix wan vram bug --- diffsynth/core/vram/layers.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/diffsynth/core/vram/layers.py b/diffsynth/core/vram/layers.py index 74b7c4d..cc2b4d9 100644 --- a/diffsynth/core/vram/layers.py +++ b/diffsynth/core/vram/layers.py @@ -170,6 +170,9 @@ class AutoWrappedModule(AutoTorchModule): elif self.preparing_device != "disk": self.to(dtype=self.preparing_dtype, device=self.preparing_device) self.state = 2 + + def cast_to(self, module, dtype, device): + return copy.deepcopy(module).to(dtype=dtype, device=device) def computation(self): # onload / preparing -> computation (temporary) @@ -182,7 +185,7 @@ class AutoWrappedModule(AutoTorchModule): elif self.disk_offload and device == "disk": module = self.load_from_disk(self.computation_dtype, self.computation_device, copy_module=True) else: - module = copy.deepcopy(self.module).to(dtype=self.computation_dtype, device=self.computation_device) + module = self.cast_to(self.module, dtype=self.computation_dtype, device=self.computation_device) return module def forward(self, *args, **kwargs): @@ -251,6 +254,11 @@ class AutoWrappedNonRecurseModule(AutoWrappedModule): for name in self.required_params: getattr(self, name).to("meta") + def cast_to(self, module, dtype, device): + for name in self.required_params: + getattr(module, name).to(dtype=dtype, device=device) + return module + def __getattr__(self, name): if name in self.__dict__ or name == "module": return super().__getattr__(name)