mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
Merge pull request #841 from modelscope/qwen-image-lora-hotload
support qwen-image lora hotload
This commit is contained in:
@@ -77,10 +77,63 @@ class QwenImagePipeline(BasePipeline):
|
|||||||
self.model_fn = model_fn_qwen_image
|
self.model_fn = model_fn_qwen_image
|
||||||
|
|
||||||
|
|
||||||
def load_lora(self, module, path, alpha=1):
|
def load_lora(
|
||||||
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
self,
|
||||||
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
|
module: torch.nn.Module,
|
||||||
loader.load(module, lora, alpha=alpha)
|
lora_config: Union[ModelConfig, str] = None,
|
||||||
|
alpha=1,
|
||||||
|
hotload=False,
|
||||||
|
state_dict=None,
|
||||||
|
):
|
||||||
|
if state_dict is None:
|
||||||
|
if isinstance(lora_config, str):
|
||||||
|
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
|
||||||
|
else:
|
||||||
|
lora_config.download_if_necessary()
|
||||||
|
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||||
|
else:
|
||||||
|
lora = state_dict
|
||||||
|
if hotload:
|
||||||
|
for name, module in module.named_modules():
|
||||||
|
if isinstance(module, AutoWrappedLinear):
|
||||||
|
lora_a_name = f'{name}.lora_A.default.weight'
|
||||||
|
lora_b_name = f'{name}.lora_B.default.weight'
|
||||||
|
if lora_a_name in lora and lora_b_name in lora:
|
||||||
|
module.lora_A_weights.append(lora[lora_a_name] * alpha)
|
||||||
|
module.lora_B_weights.append(lora[lora_b_name])
|
||||||
|
else:
|
||||||
|
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
||||||
|
loader.load(module, lora, alpha=alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_lora(self):
|
||||||
|
for name, module in self.named_modules():
|
||||||
|
if isinstance(module, AutoWrappedLinear):
|
||||||
|
if hasattr(module, "lora_A_weights"):
|
||||||
|
module.lora_A_weights.clear()
|
||||||
|
if hasattr(module, "lora_B_weights"):
|
||||||
|
module.lora_B_weights.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def enable_lora_magic(self):
|
||||||
|
if self.dit is not None:
|
||||||
|
if not (hasattr(self.dit, "vram_management_enabled") and self.dit.vram_management_enabled):
|
||||||
|
dtype = next(iter(self.dit.parameters())).dtype
|
||||||
|
enable_vram_management(
|
||||||
|
self.dit,
|
||||||
|
module_map = {
|
||||||
|
torch.nn.Linear: AutoWrappedLinear,
|
||||||
|
},
|
||||||
|
module_config = dict(
|
||||||
|
offload_dtype=dtype,
|
||||||
|
offload_device=self.device,
|
||||||
|
onload_dtype=dtype,
|
||||||
|
onload_device=self.device,
|
||||||
|
computation_dtype=self.torch_dtype,
|
||||||
|
computation_device=self.device,
|
||||||
|
),
|
||||||
|
vram_limit=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def training_loss(self, **inputs):
|
def training_loss(self, **inputs):
|
||||||
|
|||||||
Reference in New Issue
Block a user