mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
support wan i2v training
This commit is contained in:
@@ -3,15 +3,16 @@ from torchvision.transforms import v2
|
||||
from einops import rearrange
|
||||
import lightning as pl
|
||||
import pandas as pd
|
||||
from diffsynth import WanVideoPipeline, ModelManager
|
||||
from diffsynth import WanVideoPipeline, ModelManager, load_state_dict
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
import torchvision
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
class TextVideoDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832):
|
||||
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False):
|
||||
metadata = pd.read_csv(metadata_path)
|
||||
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
self.text = metadata["text"].to_list()
|
||||
@@ -21,6 +22,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
self.num_frames = num_frames
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.is_i2v = is_i2v
|
||||
|
||||
self.frame_process = v2.Compose([
|
||||
v2.CenterCrop(size=(height, width)),
|
||||
@@ -48,10 +50,13 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
return None
|
||||
|
||||
frames = []
|
||||
first_frame = None
|
||||
for frame_id in range(num_frames):
|
||||
frame = reader.get_data(start_frame_id + frame_id * interval)
|
||||
frame = Image.fromarray(frame)
|
||||
frame = self.crop_and_resize(frame)
|
||||
if first_frame is None:
|
||||
first_frame = np.array(frame)
|
||||
frame = frame_process(frame)
|
||||
frames.append(frame)
|
||||
reader.close()
|
||||
@@ -59,7 +64,10 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
frames = torch.stack(frames, dim=0)
|
||||
frames = rearrange(frames, "T C H W -> C T H W")
|
||||
|
||||
return frames
|
||||
if self.is_i2v:
|
||||
return frames, first_frame
|
||||
else:
|
||||
return frames
|
||||
|
||||
|
||||
def load_video(self, file_path):
|
||||
@@ -70,7 +78,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
|
||||
def is_image(self, file_path):
|
||||
file_ext_name = file_path.split(".")[-1]
|
||||
if file_ext_name.lower() in ["jpg", "png", "webp"]:
|
||||
if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -78,6 +86,7 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
def load_image(self, file_path):
|
||||
frame = Image.open(file_path).convert("RGB")
|
||||
frame = self.crop_and_resize(frame)
|
||||
first_frame = frame
|
||||
frame = self.frame_process(frame)
|
||||
frame = rearrange(frame, "C H W -> C 1 H W")
|
||||
return frame
|
||||
@@ -87,10 +96,16 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
text = self.text[data_id]
|
||||
path = self.path[data_id]
|
||||
if self.is_image(path):
|
||||
if self.is_i2v:
|
||||
raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.")
|
||||
video = self.load_image(path)
|
||||
else:
|
||||
video = self.load_video(path)
|
||||
data = {"text": text, "video": video, "path": path}
|
||||
if self.is_i2v:
|
||||
video, first_frame = video
|
||||
data = {"text": text, "video": video, "path": path, "first_frame": first_frame}
|
||||
else:
|
||||
data = {"text": text, "video": video, "path": path}
|
||||
return data
|
||||
|
||||
|
||||
@@ -100,21 +115,35 @@ class TextVideoDataset(torch.utils.data.Dataset):
|
||||
|
||||
|
||||
class LightningModelForDataProcess(pl.LightningModule):
|
||||
def __init__(self, text_encoder_path, vae_path, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
super().__init__()
|
||||
model_path = [text_encoder_path, vae_path]
|
||||
if image_encoder_path is not None:
|
||||
model_path.append(image_encoder_path)
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||
model_manager.load_models([text_encoder_path, vae_path])
|
||||
model_manager.load_models(model_path)
|
||||
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||
|
||||
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
text, video, path = batch["text"][0], batch["video"], batch["path"][0]
|
||||
|
||||
self.pipe.device = self.device
|
||||
if video is not None:
|
||||
# prompt
|
||||
prompt_emb = self.pipe.encode_prompt(text)
|
||||
# video
|
||||
video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
|
||||
data = {"latents": latents, "prompt_emb": prompt_emb}
|
||||
# image
|
||||
if "first_frame" in batch:
|
||||
first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy())
|
||||
_, _, num_frames, height, width = video.shape
|
||||
image_emb = self.pipe.encode_image(first_frame, num_frames, height, width)
|
||||
else:
|
||||
image_emb = {}
|
||||
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb}
|
||||
torch.save(data, path + ".tensors.pth")
|
||||
|
||||
|
||||
@@ -145,10 +174,21 @@ class TensorDataset(torch.utils.data.Dataset):
|
||||
|
||||
|
||||
class LightningModelForTrain(pl.LightningModule):
|
||||
def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True):
|
||||
def __init__(
|
||||
self,
|
||||
dit_path,
|
||||
learning_rate=1e-5,
|
||||
lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming",
|
||||
use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False,
|
||||
pretrained_lora_path=None
|
||||
):
|
||||
super().__init__()
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
|
||||
model_manager.load_models([dit_path])
|
||||
if os.path.isfile(dit_path):
|
||||
model_manager.load_models([dit_path])
|
||||
else:
|
||||
dit_path = dit_path.split(",")
|
||||
model_manager.load_models([dit_path])
|
||||
|
||||
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
@@ -160,12 +200,14 @@ class LightningModelForTrain(pl.LightningModule):
|
||||
lora_alpha=lora_alpha,
|
||||
lora_target_modules=lora_target_modules,
|
||||
init_lora_weights=init_lora_weights,
|
||||
pretrained_lora_path=pretrained_lora_path,
|
||||
)
|
||||
else:
|
||||
self.pipe.denoising_model().requires_grad_(True)
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
|
||||
|
||||
def freeze_parameters(self):
|
||||
@@ -175,7 +217,7 @@ class LightningModelForTrain(pl.LightningModule):
|
||||
self.pipe.denoising_model().train()
|
||||
|
||||
|
||||
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming"):
|
||||
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
|
||||
# Add LoRA to UNet
|
||||
self.lora_alpha = lora_alpha
|
||||
if init_lora_weights == "kaiming":
|
||||
@@ -192,30 +234,47 @@ class LightningModelForTrain(pl.LightningModule):
|
||||
# Upcast LoRA parameters into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
# Lora pretrained lora weights
|
||||
if pretrained_lora_path is not None:
|
||||
state_dict = load_state_dict(pretrained_lora_path)
|
||||
if state_dict_converter is not None:
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
||||
all_keys = [i for i, _ in model.named_parameters()]
|
||||
num_updated_keys = len(all_keys) - len(missing_keys)
|
||||
num_unexpected_keys = len(unexpected_keys)
|
||||
print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
|
||||
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# Data
|
||||
latents = batch["latents"].to(self.device)
|
||||
prompt_emb = batch["prompt_emb"]
|
||||
prompt_emb["context"] = [prompt_emb["context"][0][0].to(self.device)]
|
||||
|
||||
prompt_emb["context"] = prompt_emb["context"][0].to(self.device)
|
||||
image_emb = batch["image_emb"]
|
||||
if "clip_feature" in image_emb:
|
||||
image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device)
|
||||
if "y" in image_emb:
|
||||
image_emb["y"] = image_emb["y"][0].to(self.device)
|
||||
|
||||
# Loss
|
||||
self.pipe.device = self.device
|
||||
noise = torch.randn_like(latents)
|
||||
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
|
||||
timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||
extra_input = self.pipe.prepare_extra_input(latents)
|
||||
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||
|
||||
# Compute loss
|
||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
||||
noise_pred = self.pipe.denoising_model()(
|
||||
noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||
use_gradient_checkpointing=self.use_gradient_checkpointing
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||
noise_pred = self.pipe.denoising_model()(
|
||||
noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb,
|
||||
use_gradient_checkpointing=self.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||
|
||||
# Record log
|
||||
self.log("train_loss", loss, prog_bar=True)
|
||||
@@ -270,6 +329,12 @@ def parse_args():
|
||||
default=None,
|
||||
help="Path of text encoder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_encoder_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path of image encoder.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_path",
|
||||
type=str,
|
||||
@@ -398,6 +463,12 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Whether to use gradient checkpointing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_gradient_checkpointing_offload",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use gradient checkpointing offload.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_architecture",
|
||||
type=str,
|
||||
@@ -405,6 +476,23 @@ def parse_args():
|
||||
choices=["lora", "full"],
|
||||
help="Model structure to train. LoRA training or full training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_lora_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained LoRA path. Required if the training is resumed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_swanlab",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use SwanLab logger.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--swanlab_mode",
|
||||
default=None,
|
||||
help="SwanLab mode (cloud or local).",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@@ -417,7 +505,8 @@ def data_process(args):
|
||||
frame_interval=1,
|
||||
num_frames=args.num_frames,
|
||||
height=args.height,
|
||||
width=args.width
|
||||
width=args.width,
|
||||
is_i2v=args.image_encoder_path is not None
|
||||
)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
@@ -427,6 +516,7 @@ def data_process(args):
|
||||
)
|
||||
model = LightningModelForDataProcess(
|
||||
text_encoder_path=args.text_encoder_path,
|
||||
image_encoder_path=args.image_encoder_path,
|
||||
vae_path=args.vae_path,
|
||||
tiled=args.tiled,
|
||||
tile_size=(args.tile_size_height, args.tile_size_width),
|
||||
@@ -460,16 +550,34 @@ def train(args):
|
||||
lora_alpha=args.lora_alpha,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
init_lora_weights=args.init_lora_weights,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
pretrained_lora_path=args.pretrained_lora_path,
|
||||
)
|
||||
if args.use_swanlab:
|
||||
from swanlab.integration.pytorch_lightning import SwanLabLogger
|
||||
swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"}
|
||||
swanlab_config.update(vars(args))
|
||||
swanlab_logger = SwanLabLogger(
|
||||
project="wan",
|
||||
name="wan",
|
||||
config=swanlab_config,
|
||||
mode=args.swanlab_mode,
|
||||
logdir=os.path.join(args.output_path, "swanlog"),
|
||||
)
|
||||
logger = [swanlab_logger]
|
||||
else:
|
||||
logger = None
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=args.max_epochs,
|
||||
accelerator="gpu",
|
||||
devices="auto",
|
||||
precision="bf16",
|
||||
strategy=args.training_strategy,
|
||||
default_root_dir=args.output_path,
|
||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
|
||||
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
|
||||
logger=logger,
|
||||
)
|
||||
trainer.fit(model, dataloader)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user