diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index 8358e55..065b687 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -352,6 +352,8 @@ class DiffusionTrainingModule(torch.nn.Module): if "lora_A.weight" in key or "lora_B.weight" in key: new_key = key.replace("lora_A.weight", "lora_A.default.weight").replace("lora_B.weight", "lora_B.default.weight") new_state_dict[new_key] = value + elif "lora_A.default.weight" in key or "lora_B.default.weight" in key: + new_state_dict[key] = value return new_state_dict diff --git a/examples/flux/model_training/train.py b/examples/flux/model_training/train.py index e1b66c8..46eac56 100644 --- a/examples/flux/model_training/train.py +++ b/examples/flux/model_training/train.py @@ -45,6 +45,7 @@ class FluxTrainingModule(DiffusionTrainingModule): state_dict = load_state_dict(lora_checkpoint) state_dict = self.mapping_lora_state_dict(state_dict) load_result = model.load_state_dict(state_dict, strict=False) + print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") if len(load_result[1]) > 0: print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") setattr(self.pipe, lora_base_model, model) diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index b0ba69e..31bbfda 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -49,6 +49,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule): state_dict = load_state_dict(lora_checkpoint) state_dict = self.mapping_lora_state_dict(state_dict) load_result = model.load_state_dict(state_dict, strict=False) + print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") if len(load_result[1]) > 0: print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") setattr(self.pipe, lora_base_model, model) diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 726243c..f2f437e 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -46,6 +46,7 @@ class WanTrainingModule(DiffusionTrainingModule): state_dict = load_state_dict(lora_checkpoint) state_dict = self.mapping_lora_state_dict(state_dict) load_result = model.load_state_dict(state_dict, strict=False) + print(f"LoRA checkpoint loaded: {lora_checkpoint}, total {len(state_dict)} keys") if len(load_result[1]) > 0: print(f"Warning, LoRA key mismatch! Unexpected keys in LoRA checkpoint: {load_result[1]}") setattr(self.pipe, lora_base_model, model)