From eb4d5187d8f1001d260c3cb30433262779fee26e Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 3 Mar 2025 18:31:31 +0800 Subject: [PATCH] support resume training --- examples/wanvideo/train_wan_t2v.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index 39aa4c0..e16aaa0 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -3,7 +3,7 @@ from torchvision.transforms import v2 from einops import rearrange import lightning as pl import pandas as pd -from diffsynth import WanVideoPipeline, ModelManager +from diffsynth import WanVideoPipeline, ModelManager, load_state_dict from peft import LoraConfig, inject_adapter_in_model import torchvision from PIL import Image @@ -145,7 +145,7 @@ class TensorDataset(torch.utils.data.Dataset): class LightningModelForTrain(pl.LightningModule): - def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True): + def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True, pretrained_lora_path=None): super().__init__() model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") model_manager.load_models([dit_path]) @@ -160,6 +160,7 @@ class LightningModelForTrain(pl.LightningModule): lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights, + pretrained_lora_path=pretrained_lora_path, ) else: self.pipe.denoising_model().requires_grad_(True) @@ -175,7 +176,7 @@ class LightningModelForTrain(pl.LightningModule): self.pipe.denoising_model().train() - def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming"): + def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None): # Add LoRA to UNet self.lora_alpha = lora_alpha if init_lora_weights == "kaiming": @@ -192,6 +193,17 @@ class LightningModelForTrain(pl.LightningModule): # Upcast LoRA parameters into fp32 if param.requires_grad: param.data = param.to(torch.float32) + + # Lora pretrained lora weights + if pretrained_lora_path is not None: + state_dict = load_state_dict(pretrained_lora_path) + if state_dict_converter is not None: + state_dict = state_dict_converter(state_dict) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + all_keys = [i for i, _ in model.named_parameters()] + num_updated_keys = len(all_keys) - len(missing_keys) + num_unexpected_keys = len(unexpected_keys) + print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.") def training_step(self, batch, batch_idx): @@ -405,6 +417,12 @@ def parse_args(): choices=["lora", "full"], help="Model structure to train. LoRA training or full training.", ) + parser.add_argument( + "--pretrained_lora_path", + type=str, + default=None, + help="Pretrained LoRA path. Required if the training is resumed.", + ) args = parser.parse_args() return args @@ -460,7 +478,8 @@ def train(args): lora_alpha=args.lora_alpha, lora_target_modules=args.lora_target_modules, init_lora_weights=args.init_lora_weights, - use_gradient_checkpointing=args.use_gradient_checkpointing + use_gradient_checkpointing=args.use_gradient_checkpointing, + pretrained_lora_path=args.pretrained_lora_path, ) trainer = pl.Trainer( max_epochs=args.max_epochs,