mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
bug fix
This commit is contained in:
@@ -220,7 +220,7 @@ class BasePipeline(torch.nn.Module):
|
||||
module: torch.nn.Module,
|
||||
lora_config: Union[ModelConfig, str] = None,
|
||||
alpha=1,
|
||||
hotload=False,
|
||||
hotload=None,
|
||||
state_dict=None,
|
||||
):
|
||||
if state_dict is None:
|
||||
@@ -233,12 +233,15 @@ class BasePipeline(torch.nn.Module):
|
||||
lora = state_dict
|
||||
lora_loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
lora = lora_loader.convert_state_dict(lora)
|
||||
if hotload is None:
|
||||
hotload = hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")
|
||||
if hotload:
|
||||
if not (hasattr(module, "vram_management_enabled") and getattr(module, "vram_management_enabled")):
|
||||
raise ValueError("VRAM Management is not enabled. LoRA hotloading is not supported.")
|
||||
updated_num = 0
|
||||
for name, module in module.named_modules():
|
||||
for _, module in module.named_modules():
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
name = module.name
|
||||
lora_a_name = f'{name}.lora_A.weight'
|
||||
lora_b_name = f'{name}.lora_B.weight'
|
||||
if lora_a_name in lora and lora_b_name in lora:
|
||||
|
||||
Reference in New Issue
Block a user