feat: support I2V training

This commit is contained in:
wang96
2025-02-26 19:50:59 +08:00
parent cf12723c89
commit 0dbb3d333f
2 changed files with 536 additions and 0 deletions

View File

@@ -134,6 +134,48 @@ CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \
--use_gradient_checkpointing
```
Step 4-1: I2V LoRA-training
```shell
# cache latents
CUDA_VISIBLE_DEVICES="0" python train_wan_i2v.py \
--task data_process \
--dataset_path data/fps24_V6 \
--output_path ./output \
--text_encoder_path "./models/Wan-AI/Wan2.1-I2V-14B-720P/models_t5_umt5-xxl-enc-bf16.pth" \
--vae_path "./models/Wan-AI/Wan2.1-I2V-14B-720P/Wan2.1_VAE.pth" \
--image_encoder_path "./models/Wan-AI/Wan2.1-I2V-14B-720P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
--tiled \
--num_frames 121 \
--height 309 \
--width 186
```
```shell
# run I2V training
CUDA_VISIBLE_DEVICES="0" python train_wan_i2v.py \
--task train \
--train_architecture lora \
--dataset_path data/kling_hips_fps24_V6 \
--output_path ./output \
--dit_path "[
\"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00001-of-00007.safetensors\",
\"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00002-of-00007.safetensors\",
\"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00003-of-00007.safetensors\",
\"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00004-of-00007.safetensors\",
\"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00005-of-00007.safetensors\",
\"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors\",
\"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors\"
]" \
--steps_per_epoch 500 \
--max_epochs 10 \
--learning_rate 1e-4 \
--lora_rank 4 \
--lora_alpha 4 \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--accumulate_grad_batches 1 \
--use_gradient_checkpointing
```
Step 5: Test
Test LoRA:

View File

@@ -0,0 +1,494 @@
import torch, os, imageio, argparse
from torchvision.transforms import v2
from einops import rearrange
import lightning as pl
import pandas as pd
from diffsynth import WanVideoPipeline, ModelManager
from peft import LoraConfig, inject_adapter_in_model
import torchvision
from PIL import Image
import numpy as np
import json
class I2VDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832):
metadata = pd.read_csv(metadata_path)
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
self.text = metadata["text"].to_list()
self.max_num_frames = max_num_frames
self.frame_interval = frame_interval
self.num_frames = num_frames
self.height = height
self.width = width
self.frame_process = v2.Compose([
v2.Resize(size=(height, width), antialias=True),
v2.CenterCrop(size=(height, width)),
v2.ToTensor(),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
def crop_and_resize(self, image):
width, height = image.size
scale = max(self.width / width, self.height / height)
image = torchvision.transforms.functional.resize(
image,
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
return image
def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
reader = imageio.get_reader(file_path)
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
reader.close()
return None
frames = []
first_frame_image = None
for frame_id in range(num_frames):
frame = reader.get_data(start_frame_id + frame_id * interval)
frame = Image.fromarray(frame)
if first_frame_image is None:
first_frame_image = frame
frame = self.crop_and_resize(frame)
frame = frame_process(frame)
frames.append(frame)
reader.close()
frames = torch.stack(frames, dim=0)
frames = rearrange(frames, "T C H W -> C T H W")
return frames, first_frame_image
def load_video(self, file_path):
start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0]
frames, first_frame_image = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process)
return frames, first_frame_image
def load_text_video_raw_data(self, data_id):
text = self.path[data_id]
video = self.load_video(self.path[data_id])
data = {"text": text, "video": video}
return data
def __getitem__(self, data_id):
text = self.path[data_id]
path = self.path[data_id]
video, first_frame_image = self.load_video(path)
data = {"text": text, "video": video, "first_frame_image":np.array(first_frame_image), "path": path}
return data
def __len__(self):
return len(self.path)
class LightningModelForDataProcess(pl.LightningModule):
def __init__(self, text_encoder_path, image_encoder_path, vae_path, num_frames, height, width, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
super().__init__()
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
model_manager.load_models([text_encoder_path, image_encoder_path, vae_path])
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
self.num_frames = num_frames
self.height = height
self.width = width
def test_step(self, batch, batch_idx):
text, video, first_frame_image_tensor, path = batch["text"][0], batch["video"], batch["first_frame_image"][0], batch["path"][0]
self.pipe.device = self.device
if video is not None:
prompt_emb = self.pipe.encode_prompt(text)
latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0]
first_frame_image = Image.fromarray(np.array(first_frame_image_tensor.cpu()))
cond_data_dict = self.pipe.encode_image(first_frame_image, num_frames=self.num_frames, height=self.height, width=self.width)
data = {"latents": latents, "prompt_emb": prompt_emb, "clip_fea": cond_data_dict["clip_fea"][0], "y": cond_data_dict["y"][0]}
torch.save(data, path + ".tensors.pth")
class TensorDataset(torch.utils.data.Dataset):
def __init__(self, base_path, metadata_path, steps_per_epoch):
metadata = pd.read_csv(metadata_path)
self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]]
print(len(self.path), "videos in metadata.")
self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")]
print(len(self.path), "tensors cached in metadata.")
assert len(self.path) > 0
self.steps_per_epoch = steps_per_epoch
def __getitem__(self, index):
data_id = torch.randint(0, len(self.path), (1,))[0]
data_id = (data_id + index) % len(self.path) # For fixed seed.
path = self.path[data_id]
data = torch.load(path, weights_only=True, map_location="cpu")
return data
def __len__(self):
return self.steps_per_epoch
class LightningModelForTrain(pl.LightningModule):
def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True):
super().__init__()
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
# 将 dit_path 从字符串解析为 Python 列表
dit_path = json.loads(dit_path)
model_manager.load_models([dit_path])
self.pipe = WanVideoPipeline.from_model_manager(model_manager)
self.pipe.scheduler.set_timesteps(1000, training=True)
self.freeze_parameters()
if train_architecture == "lora":
self.add_lora_to_model(
self.pipe.denoising_model(),
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_target_modules=lora_target_modules,
init_lora_weights=init_lora_weights,
)
else:
self.pipe.denoising_model().requires_grad_(True)
self.learning_rate = learning_rate
self.use_gradient_checkpointing = use_gradient_checkpointing
def freeze_parameters(self):
# Freeze parameters
self.pipe.requires_grad_(False)
self.pipe.eval()
self.pipe.denoising_model().train()
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming"):
# Add LoRA to UNet
self.lora_alpha = lora_alpha
if init_lora_weights == "kaiming":
init_lora_weights = True
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
init_lora_weights=init_lora_weights,
target_modules=lora_target_modules.split(","),
)
model = inject_adapter_in_model(lora_config, model)
for param in model.parameters():
# Upcast LoRA parameters into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
def training_step(self, batch, batch_idx):
# Data
latents = batch["latents"].to(self.device)
prompt_emb = batch["prompt_emb"]
prompt_emb["context"] = [prompt_emb["context"][0][0].to(self.device)]
clip_fea = batch["clip_fea"].to(self.device)
y = batch["y"].to(self.device)
# Loss
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
# Compute loss
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
noise_pred = self.pipe.denoising_model()(
noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
use_gradient_checkpointing=self.use_gradient_checkpointing,
clip_fea=clip_fea, y=y
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target[..., 1:].float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
# Record log
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
return optimizer
def on_save_checkpoint(self, checkpoint):
checkpoint.clear()
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
state_dict = self.pipe.denoising_model().state_dict()
lora_state_dict = {}
for name, param in state_dict.items():
if name in trainable_param_names:
lora_state_dict[name] = param
checkpoint.update(lora_state_dict)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--task",
type=str,
default="data_process",
required=True,
choices=["data_process", "train"],
help="Task. `data_process` or `train`.",
)
parser.add_argument(
"--dataset_path",
type=str,
default=None,
required=True,
help="The path of the Dataset.",
)
parser.add_argument(
"--output_path",
type=str,
default="./",
help="Path to save the model.",
)
parser.add_argument(
"--text_encoder_path",
type=str,
default=None,
help="Path of text encoder.",
)
parser.add_argument(
"--image_encoder_path",
type=str,
default=None,
help="Path of image encoder.",
)
parser.add_argument(
"--vae_path",
type=str,
default=None,
help="Path of VAE.",
)
parser.add_argument(
"--dit_path",
type=str,
default=None,
help="Path of DiT.",
)
parser.add_argument(
"--tiled",
default=False,
action="store_true",
help="Whether enable tile encode in VAE. This option can reduce VRAM required.",
)
parser.add_argument(
"--tile_size_height",
type=int,
default=34,
help="Tile size (height) in VAE.",
)
parser.add_argument(
"--tile_size_width",
type=int,
default=34,
help="Tile size (width) in VAE.",
)
parser.add_argument(
"--tile_stride_height",
type=int,
default=18,
help="Tile stride (height) in VAE.",
)
parser.add_argument(
"--tile_stride_width",
type=int,
default=16,
help="Tile stride (width) in VAE.",
)
parser.add_argument(
"--steps_per_epoch",
type=int,
default=500,
help="Number of steps per epoch.",
)
parser.add_argument(
"--num_frames",
type=int,
default=81,
help="Number of frames.",
)
parser.add_argument(
"--height",
type=int,
default=480,
help="Image height.",
)
parser.add_argument(
"--width",
type=int,
default=832,
help="Image width.",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=1,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-5,
help="Learning rate.",
)
parser.add_argument(
"--accumulate_grad_batches",
type=int,
default=1,
help="The number of batches in gradient accumulation.",
)
parser.add_argument(
"--max_epochs",
type=int,
default=1,
help="Number of epochs.",
)
parser.add_argument(
"--lora_target_modules",
type=str,
default="q,k,v,o,ffn.0,ffn.2",
help="Layers with LoRA modules.",
)
parser.add_argument(
"--init_lora_weights",
type=str,
default="kaiming",
choices=["gaussian", "kaiming"],
help="The initializing method of LoRA weight.",
)
parser.add_argument(
"--training_strategy",
type=str,
default="auto",
choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"],
help="Training strategy",
)
parser.add_argument(
"--lora_rank",
type=int,
default=4,
help="The dimension of the LoRA update matrices.",
)
parser.add_argument(
"--lora_alpha",
type=float,
default=4.0,
help="The weight of the LoRA update matrices.",
)
parser.add_argument(
"--use_gradient_checkpointing",
default=False,
action="store_true",
help="Whether to use gradient checkpointing.",
)
parser.add_argument(
"--train_architecture",
type=str,
default="lora",
choices=["lora", "full"],
help="Model structure to train. LoRA training or full training.",
)
args = parser.parse_args()
return args
def data_process(args):
dataset = I2VDataset(
args.dataset_path,
os.path.join(args.dataset_path, "metadata.csv"),
max_num_frames=args.num_frames,
frame_interval=1,
num_frames=args.num_frames,
height=args.height,
width=args.width
)
dataloader = torch.utils.data.DataLoader(
dataset,
shuffle=False,
batch_size=1,
num_workers=args.dataloader_num_workers
)
model = LightningModelForDataProcess(
text_encoder_path=args.text_encoder_path,
image_encoder_path=args.image_encoder_path,
vae_path=args.vae_path,
num_frames=args.num_frames,
height=args.height,
width=args.width,
tiled=args.tiled,
tile_size=(args.tile_size_height, args.tile_size_width),
tile_stride=(args.tile_stride_height, args.tile_stride_width)
)
trainer = pl.Trainer(
accelerator="gpu",
devices="auto",
default_root_dir=args.output_path,
)
trainer.test(model, dataloader)
def train(args):
dataset = TensorDataset(
args.dataset_path,
os.path.join(args.dataset_path, "metadata.csv"),
steps_per_epoch=args.steps_per_epoch,
)
dataloader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
batch_size=1,
num_workers=args.dataloader_num_workers
)
model = LightningModelForTrain(
dit_path=args.dit_path,
learning_rate=args.learning_rate,
train_architecture=args.train_architecture,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_target_modules=args.lora_target_modules,
init_lora_weights=args.init_lora_weights,
use_gradient_checkpointing=args.use_gradient_checkpointing
)
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator="gpu",
devices="auto",
strategy=args.training_strategy,
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
)
trainer.fit(model, dataloader)
if __name__ == '__main__':
args = parse_args()
if args.task == "data_process":
data_process(args)
elif args.task == "train":
train(args)