Merge pull request #393 from modelscope/wan-train-update

support resume training
This commit is contained in:
Zhongjie Duan
2025-03-03 18:45:17 +08:00
committed by GitHub

View File

@@ -3,7 +3,7 @@ from torchvision.transforms import v2
from einops import rearrange from einops import rearrange
import lightning as pl import lightning as pl
import pandas as pd import pandas as pd
from diffsynth import WanVideoPipeline, ModelManager from diffsynth import WanVideoPipeline, ModelManager, load_state_dict
from peft import LoraConfig, inject_adapter_in_model from peft import LoraConfig, inject_adapter_in_model
import torchvision import torchvision
from PIL import Image from PIL import Image
@@ -145,7 +145,7 @@ class TensorDataset(torch.utils.data.Dataset):
class LightningModelForTrain(pl.LightningModule): class LightningModelForTrain(pl.LightningModule):
def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True): def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True, pretrained_lora_path=None):
super().__init__() super().__init__()
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models([dit_path]) model_manager.load_models([dit_path])
@@ -160,6 +160,7 @@ class LightningModelForTrain(pl.LightningModule):
lora_alpha=lora_alpha, lora_alpha=lora_alpha,
lora_target_modules=lora_target_modules, lora_target_modules=lora_target_modules,
init_lora_weights=init_lora_weights, init_lora_weights=init_lora_weights,
pretrained_lora_path=pretrained_lora_path,
) )
else: else:
self.pipe.denoising_model().requires_grad_(True) self.pipe.denoising_model().requires_grad_(True)
@@ -175,7 +176,7 @@ class LightningModelForTrain(pl.LightningModule):
self.pipe.denoising_model().train() self.pipe.denoising_model().train()
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming"): def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
# Add LoRA to UNet # Add LoRA to UNet
self.lora_alpha = lora_alpha self.lora_alpha = lora_alpha
if init_lora_weights == "kaiming": if init_lora_weights == "kaiming":
@@ -193,6 +194,17 @@ class LightningModelForTrain(pl.LightningModule):
if param.requires_grad: if param.requires_grad:
param.data = param.to(torch.float32) param.data = param.to(torch.float32)
# Lora pretrained lora weights
if pretrained_lora_path is not None:
state_dict = load_state_dict(pretrained_lora_path)
if state_dict_converter is not None:
state_dict = state_dict_converter(state_dict)
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
all_keys = [i for i, _ in model.named_parameters()]
num_updated_keys = len(all_keys) - len(missing_keys)
num_unexpected_keys = len(unexpected_keys)
print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
# Data # Data
@@ -405,6 +417,12 @@ def parse_args():
choices=["lora", "full"], choices=["lora", "full"],
help="Model structure to train. LoRA training or full training.", help="Model structure to train. LoRA training or full training.",
) )
parser.add_argument(
"--pretrained_lora_path",
type=str,
default=None,
help="Pretrained LoRA path. Required if the training is resumed.",
)
args = parser.parse_args() args = parser.parse_args()
return args return args
@@ -460,7 +478,8 @@ def train(args):
lora_alpha=args.lora_alpha, lora_alpha=args.lora_alpha,
lora_target_modules=args.lora_target_modules, lora_target_modules=args.lora_target_modules,
init_lora_weights=args.init_lora_weights, init_lora_weights=args.init_lora_weights,
use_gradient_checkpointing=args.use_gradient_checkpointing use_gradient_checkpointing=args.use_gradient_checkpointing,
pretrained_lora_path=args.pretrained_lora_path,
) )
trainer = pl.Trainer( trainer = pl.Trainer(
max_epochs=args.max_epochs, max_epochs=args.max_epochs,