diff --git a/diffsynth/pipelines/qwen_image.py b/diffsynth/pipelines/qwen_image.py index 94d98cc..09d3402 100644 --- a/diffsynth/pipelines/qwen_image.py +++ b/diffsynth/pipelines/qwen_image.py @@ -76,10 +76,63 @@ class QwenImagePipeline(BasePipeline): self.model_fn = model_fn_qwen_image - def load_lora(self, module, path, alpha=1): - loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device) - lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device) - loader.load(module, lora, alpha=alpha) + def load_lora( + self, + module: torch.nn.Module, + lora_config: Union[ModelConfig, str] = None, + alpha=1, + hotload=False, + state_dict=None, + ): + if state_dict is None: + if isinstance(lora_config, str): + lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device) + else: + lora_config.download_if_necessary() + lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device) + else: + lora = state_dict + if hotload: + for name, module in module.named_modules(): + if isinstance(module, AutoWrappedLinear): + lora_a_name = f'{name}.lora_A.default.weight' + lora_b_name = f'{name}.lora_B.default.weight' + if lora_a_name in lora and lora_b_name in lora: + module.lora_A_weights.append(lora[lora_a_name] * alpha) + module.lora_B_weights.append(lora[lora_b_name]) + else: + loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device) + loader.load(module, lora, alpha=alpha) + + + def clear_lora(self): + for name, module in self.named_modules(): + if isinstance(module, AutoWrappedLinear): + if hasattr(module, "lora_A_weights"): + module.lora_A_weights.clear() + if hasattr(module, "lora_B_weights"): + module.lora_B_weights.clear() + + + def enable_lora_magic(self): + if self.dit is not None: + if not (hasattr(self.dit, "vram_management_enabled") and self.dit.vram_management_enabled): + dtype = next(iter(self.dit.parameters())).dtype + enable_vram_management( + self.dit, + module_map = { + torch.nn.Linear: AutoWrappedLinear, + }, + module_config = dict( + offload_dtype=dtype, + offload_device=self.device, + onload_dtype=dtype, + onload_device=self.device, + computation_dtype=self.torch_dtype, + computation_device=self.device, + ), + vram_limit=None, + ) def training_loss(self, **inputs):