Merge pull request #841 from modelscope/qwen-image-lora-hotload

support qwen-image lora hotload
This commit is contained in:
Zhongjie Duan
2025-08-21 15:14:46 +08:00
committed by GitHub

View File

@@ -77,10 +77,63 @@ class QwenImagePipeline(BasePipeline):
self.model_fn = model_fn_qwen_image
def load_lora(self, module, path, alpha=1):
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
loader.load(module, lora, alpha=alpha)
def load_lora(
self,
module: torch.nn.Module,
lora_config: Union[ModelConfig, str] = None,
alpha=1,
hotload=False,
state_dict=None,
):
if state_dict is None:
if isinstance(lora_config, str):
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
else:
lora_config.download_if_necessary()
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
else:
lora = state_dict
if hotload:
for name, module in module.named_modules():
if isinstance(module, AutoWrappedLinear):
lora_a_name = f'{name}.lora_A.default.weight'
lora_b_name = f'{name}.lora_B.default.weight'
if lora_a_name in lora and lora_b_name in lora:
module.lora_A_weights.append(lora[lora_a_name] * alpha)
module.lora_B_weights.append(lora[lora_b_name])
else:
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
loader.load(module, lora, alpha=alpha)
def clear_lora(self):
for name, module in self.named_modules():
if isinstance(module, AutoWrappedLinear):
if hasattr(module, "lora_A_weights"):
module.lora_A_weights.clear()
if hasattr(module, "lora_B_weights"):
module.lora_B_weights.clear()
def enable_lora_magic(self):
if self.dit is not None:
if not (hasattr(self.dit, "vram_management_enabled") and self.dit.vram_management_enabled):
dtype = next(iter(self.dit.parameters())).dtype
enable_vram_management(
self.dit,
module_map = {
torch.nn.Linear: AutoWrappedLinear,
},
module_config = dict(
offload_dtype=dtype,
offload_device=self.device,
onload_dtype=dtype,
onload_device=self.device,
computation_dtype=self.torch_dtype,
computation_device=self.device,
),
vram_limit=None,
)
def training_loss(self, **inputs):