mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
@@ -352,6 +352,8 @@ class DiffusionTrainingModule(torch.nn.Module):
|
|||||||
if "lora_A.weight" in key or "lora_B.weight" in key:
|
if "lora_A.weight" in key or "lora_B.weight" in key:
|
||||||
new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight")
|
new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight")
|
||||||
new_state_dict[new_key] = value
|
new_state_dict[new_key] = value
|
||||||
|
elif "lora_A.default.weight" in key or "lora_B.default.weight" in key:
|
||||||
|
new_state_dict[key] = value
|
||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ class FluxTrainingModule(DiffusionTrainingModule):
|
|||||||
state_dict = load_state_dict(lora_checkpoint)
|
state_dict = load_state_dict(lora_checkpoint)
|
||||||
state_dict = self.mapping_lora_state_dict(state_dict)
|
state_dict = self.mapping_lora_state_dict(state_dict)
|
||||||
load_result = model.load_state_dict(state_dict, strict=False)
|
load_result = model.load_state_dict(state_dict, strict=False)
|
||||||
|
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
|
||||||
if len(load_result[1]) > 0:
|
if len(load_result[1]) > 0:
|
||||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||||
setattr(self.pipe, lora_base_model, model)
|
setattr(self.pipe, lora_base_model, model)
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
|
|||||||
state_dict = load_state_dict(lora_checkpoint)
|
state_dict = load_state_dict(lora_checkpoint)
|
||||||
state_dict = self.mapping_lora_state_dict(state_dict)
|
state_dict = self.mapping_lora_state_dict(state_dict)
|
||||||
load_result = model.load_state_dict(state_dict, strict=False)
|
load_result = model.load_state_dict(state_dict, strict=False)
|
||||||
|
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
|
||||||
if len(load_result[1]) > 0:
|
if len(load_result[1]) > 0:
|
||||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||||
setattr(self.pipe, lora_base_model, model)
|
setattr(self.pipe, lora_base_model, model)
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ class WanTrainingModule(DiffusionTrainingModule):
|
|||||||
state_dict = load_state_dict(lora_checkpoint)
|
state_dict = load_state_dict(lora_checkpoint)
|
||||||
state_dict = self.mapping_lora_state_dict(state_dict)
|
state_dict = self.mapping_lora_state_dict(state_dict)
|
||||||
load_result = model.load_state_dict(state_dict, strict=False)
|
load_result = model.load_state_dict(state_dict, strict=False)
|
||||||
|
print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys")
|
||||||
if len(load_result[1]) > 0:
|
if len(load_result[1]) > 0:
|
||||||
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}")
|
||||||
setattr(self.pipe, lora_base_model, model)
|
setattr(self.pipe, lora_base_model, model)
|
||||||
|
|||||||
Reference in New Issue
Block a user