mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 18:28:10 +00:00
Merge pull request #393 from modelscope/wan-train-update
support resume training
This commit is contained in:
@@ -3,7 +3,7 @@ from torchvision.transforms import v2
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
import lightning as pl
|
import lightning as pl
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from diffsynth import WanVideoPipeline, ModelManager
|
from diffsynth import WanVideoPipeline, ModelManager, load_state_dict
|
||||||
from peft import LoraConfig, inject_adapter_in_model
|
from peft import LoraConfig, inject_adapter_in_model
|
||||||
import torchvision
|
import torchvision
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
@@ -145,7 +145,7 @@ class TensorDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
|
|
||||||
class LightningModelForTrain(pl.LightningModule):
|
class LightningModelForTrain(pl.LightningModule):
|
||||||
def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True):
|
def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True, pretrained_lora_path=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||||
model_manager.load_models([dit_path])
|
model_manager.load_models([dit_path])
|
||||||
@@ -160,6 +160,7 @@ class LightningModelForTrain(pl.LightningModule):
|
|||||||
lora_alpha=lora_alpha,
|
lora_alpha=lora_alpha,
|
||||||
lora_target_modules=lora_target_modules,
|
lora_target_modules=lora_target_modules,
|
||||||
init_lora_weights=init_lora_weights,
|
init_lora_weights=init_lora_weights,
|
||||||
|
pretrained_lora_path=pretrained_lora_path,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.pipe.denoising_model().requires_grad_(True)
|
self.pipe.denoising_model().requires_grad_(True)
|
||||||
@@ -175,7 +176,7 @@ class LightningModelForTrain(pl.LightningModule):
|
|||||||
self.pipe.denoising_model().train()
|
self.pipe.denoising_model().train()
|
||||||
|
|
||||||
|
|
||||||
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming"):
|
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
|
||||||
# Add LoRA to UNet
|
# Add LoRA to UNet
|
||||||
self.lora_alpha = lora_alpha
|
self.lora_alpha = lora_alpha
|
||||||
if init_lora_weights == "kaiming":
|
if init_lora_weights == "kaiming":
|
||||||
@@ -192,6 +193,17 @@ class LightningModelForTrain(pl.LightningModule):
|
|||||||
# Upcast LoRA parameters into fp32
|
# Upcast LoRA parameters into fp32
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
param.data = param.to(torch.float32)
|
param.data = param.to(torch.float32)
|
||||||
|
|
||||||
|
# Lora pretrained lora weights
|
||||||
|
if pretrained_lora_path is not None:
|
||||||
|
state_dict = load_state_dict(pretrained_lora_path)
|
||||||
|
if state_dict_converter is not None:
|
||||||
|
state_dict = state_dict_converter(state_dict)
|
||||||
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||||
|
all_keys = [i for i, _ in model.named_parameters()]
|
||||||
|
num_updated_keys = len(all_keys) - len(missing_keys)
|
||||||
|
num_unexpected_keys = len(unexpected_keys)
|
||||||
|
print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
|
||||||
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
@@ -405,6 +417,12 @@ def parse_args():
|
|||||||
choices=["lora", "full"],
|
choices=["lora", "full"],
|
||||||
help="Model structure to train. LoRA training or full training.",
|
help="Model structure to train. LoRA training or full training.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pretrained_lora_path",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Pretrained LoRA path. Required if the training is resumed.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@@ -460,7 +478,8 @@ def train(args):
|
|||||||
lora_alpha=args.lora_alpha,
|
lora_alpha=args.lora_alpha,
|
||||||
lora_target_modules=args.lora_target_modules,
|
lora_target_modules=args.lora_target_modules,
|
||||||
init_lora_weights=args.init_lora_weights,
|
init_lora_weights=args.init_lora_weights,
|
||||||
use_gradient_checkpointing=args.use_gradient_checkpointing
|
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||||
|
pretrained_lora_path=args.pretrained_lora_path,
|
||||||
)
|
)
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
max_epochs=args.max_epochs,
|
max_epochs=args.max_epochs,
|
||||||
|
|||||||
Reference in New Issue
Block a user