diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 656d7c9..58822c8 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -134,6 +134,48 @@ CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \ --use_gradient_checkpointing ``` +Step 4-1: I2V LoRA-training +```shell +# cache latents +CUDA_VISIBLE_DEVICES="0" python train_wan_i2v.py \ + --task data_process \ + --dataset_path data/fps24_V6 \ + --output_path ./output \ + --text_encoder_path "./models/Wan-AI/Wan2.1-I2V-14B-720P/models_t5_umt5-xxl-enc-bf16.pth" \ + --vae_path "./models/Wan-AI/Wan2.1-I2V-14B-720P/Wan2.1_VAE.pth" \ + --image_encoder_path "./models/Wan-AI/Wan2.1-I2V-14B-720P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --tiled \ + --num_frames 121 \ + --height 309 \ + --width 186 +``` + +```shell +# run I2V training +CUDA_VISIBLE_DEVICES="0" python train_wan_i2v.py \ + --task train \ + --train_architecture lora \ + --dataset_path data/kling_hips_fps24_V6 \ + --output_path ./output \ + --dit_path "[ + \"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00001-of-00007.safetensors\", + \"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00002-of-00007.safetensors\", + \"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00003-of-00007.safetensors\", + \"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00004-of-00007.safetensors\", + \"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00005-of-00007.safetensors\", + \"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors\", + \"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors\" + ]" \ + --steps_per_epoch 500 \ + --max_epochs 10 \ + --learning_rate 1e-4 \ + --lora_rank 4 \ + --lora_alpha 4 \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --accumulate_grad_batches 1 \ + --use_gradient_checkpointing +``` + Step 5: Test Test LoRA: diff --git a/examples/wanvideo/train_wan_i2v.py b/examples/wanvideo/train_wan_i2v.py new file mode 100644 index 0000000..f10ff9f --- /dev/null +++ b/examples/wanvideo/train_wan_i2v.py @@ -0,0 +1,494 @@ +import torch, os, imageio, argparse +from torchvision.transforms import v2 +from einops import rearrange +import lightning as pl +import pandas as pd +from diffsynth import WanVideoPipeline, ModelManager +from peft import LoraConfig, inject_adapter_in_model +import torchvision +from PIL import Image +import numpy as np +import json + + +class I2VDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + self.text = metadata["text"].to_list() + + self.max_num_frames = max_num_frames + self.frame_interval = frame_interval + self.num_frames = num_frames + self.height = height + self.width = width + + self.frame_process = v2.Compose([ + v2.Resize(size=(height, width), antialias=True), + v2.CenterCrop(size=(height, width)), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + + def crop_and_resize(self, image): + width, height = image.size + scale = max(self.width / width, self.height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + return image + + + def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process): + reader = imageio.get_reader(file_path) + if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: + reader.close() + return None + + frames = [] + first_frame_image = None + for frame_id in range(num_frames): + frame = reader.get_data(start_frame_id + frame_id * interval) + frame = Image.fromarray(frame) + if first_frame_image is None: + first_frame_image = frame + frame = self.crop_and_resize(frame) + frame = frame_process(frame) + frames.append(frame) + reader.close() + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + + return frames, first_frame_image + + + def load_video(self, file_path): + start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0] + frames, first_frame_image = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process) + return frames, first_frame_image + + + def load_text_video_raw_data(self, data_id): + text = self.path[data_id] + video = self.load_video(self.path[data_id]) + data = {"text": text, "video": video} + return data + + + def __getitem__(self, data_id): + text = self.path[data_id] + path = self.path[data_id] + video, first_frame_image = self.load_video(path) + data = {"text": text, "video": video, "first_frame_image":np.array(first_frame_image), "path": path} + return data + + def __len__(self): + return len(self.path) + + + +class LightningModelForDataProcess(pl.LightningModule): + def __init__(self, text_encoder_path, image_encoder_path, vae_path, num_frames, height, width, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models([text_encoder_path, image_encoder_path, vae_path]) + self.pipe = WanVideoPipeline.from_model_manager(model_manager) + + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + self.num_frames = num_frames + self.height = height + self.width = width + + def test_step(self, batch, batch_idx): + text, video, first_frame_image_tensor, path = batch["text"][0], batch["video"], batch["first_frame_image"][0], batch["path"][0] + self.pipe.device = self.device + if video is not None: + prompt_emb = self.pipe.encode_prompt(text) + latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0] + first_frame_image = Image.fromarray(np.array(first_frame_image_tensor.cpu())) + cond_data_dict = self.pipe.encode_image(first_frame_image, num_frames=self.num_frames, height=self.height, width=self.width) + data = {"latents": latents, "prompt_emb": prompt_emb, "clip_fea": cond_data_dict["clip_fea"][0], "y": cond_data_dict["y"][0]} + torch.save(data, path + ".tensors.pth") + + +class TensorDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, steps_per_epoch): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + print(len(self.path), "videos in metadata.") + self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")] + print(len(self.path), "tensors cached in metadata.") + assert len(self.path) > 0 + + self.steps_per_epoch = steps_per_epoch + + + def __getitem__(self, index): + data_id = torch.randint(0, len(self.path), (1,))[0] + data_id = (data_id + index) % len(self.path) # For fixed seed. + path = self.path[data_id] + data = torch.load(path, weights_only=True, map_location="cpu") + return data + + + def __len__(self): + return self.steps_per_epoch + + + +class LightningModelForTrain(pl.LightningModule): + def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + # 将 dit_path 从字符串解析为 Python 列表 + dit_path = json.loads(dit_path) + model_manager.load_models([dit_path]) + + self.pipe = WanVideoPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + self.freeze_parameters() + if train_architecture == "lora": + self.add_lora_to_model( + self.pipe.denoising_model(), + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_target_modules=lora_target_modules, + init_lora_weights=init_lora_weights, + ) + else: + self.pipe.denoising_model().requires_grad_(True) + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + + + def freeze_parameters(self): + # Freeze parameters + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + + def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming"): + # Add LoRA to UNet + self.lora_alpha = lora_alpha + if init_lora_weights == "kaiming": + init_lora_weights = True + + lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + init_lora_weights=init_lora_weights, + target_modules=lora_target_modules.split(","), + ) + model = inject_adapter_in_model(lora_config, model) + for param in model.parameters(): + # Upcast LoRA parameters into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + + + def training_step(self, batch, batch_idx): + # Data + latents = batch["latents"].to(self.device) + + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = [prompt_emb["context"][0][0].to(self.device)] + clip_fea = batch["clip_fea"].to(self.device) + y = batch["y"].to(self.device) + + # Loss + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device) + extra_input = self.pipe.prepare_extra_input(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # Compute loss + with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): + noise_pred = self.pipe.denoising_model()( + noisy_latents, timestep=timestep, **prompt_emb, **extra_input, + use_gradient_checkpointing=self.use_gradient_checkpointing, + clip_fea=clip_fea, y=y + ) + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target[..., 1:].float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + + # Record log + self.log("train_loss", loss, prog_bar=True) + return loss + + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + + def on_save_checkpoint(self, checkpoint): + checkpoint.clear() + trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters())) + trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) + state_dict = self.pipe.denoising_model().state_dict() + lora_state_dict = {} + for name, param in state_dict.items(): + if name in trainable_param_names: + lora_state_dict[name] = param + checkpoint.update(lora_state_dict) + + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--task", + type=str, + default="data_process", + required=True, + choices=["data_process", "train"], + help="Task. `data_process` or `train`.", + ) + parser.add_argument( + "--dataset_path", + type=str, + default=None, + required=True, + help="The path of the Dataset.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./", + help="Path to save the model.", + ) + parser.add_argument( + "--text_encoder_path", + type=str, + default=None, + help="Path of text encoder.", + ) + parser.add_argument( + "--image_encoder_path", + type=str, + default=None, + help="Path of image encoder.", + ) + parser.add_argument( + "--vae_path", + type=str, + default=None, + help="Path of VAE.", + ) + parser.add_argument( + "--dit_path", + type=str, + default=None, + help="Path of DiT.", + ) + parser.add_argument( + "--tiled", + default=False, + action="store_true", + help="Whether enable tile encode in VAE. This option can reduce VRAM required.", + ) + parser.add_argument( + "--tile_size_height", + type=int, + default=34, + help="Tile size (height) in VAE.", + ) + parser.add_argument( + "--tile_size_width", + type=int, + default=34, + help="Tile size (width) in VAE.", + ) + parser.add_argument( + "--tile_stride_height", + type=int, + default=18, + help="Tile stride (height) in VAE.", + ) + parser.add_argument( + "--tile_stride_width", + type=int, + default=16, + help="Tile stride (width) in VAE.", + ) + parser.add_argument( + "--steps_per_epoch", + type=int, + default=500, + help="Number of steps per epoch.", + ) + parser.add_argument( + "--num_frames", + type=int, + default=81, + help="Number of frames.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="Image height.", + ) + parser.add_argument( + "--width", + type=int, + default=832, + help="Image width.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=1, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Learning rate.", + ) + parser.add_argument( + "--accumulate_grad_batches", + type=int, + default=1, + help="The number of batches in gradient accumulation.", + ) + parser.add_argument( + "--max_epochs", + type=int, + default=1, + help="Number of epochs.", + ) + parser.add_argument( + "--lora_target_modules", + type=str, + default="q,k,v,o,ffn.0,ffn.2", + help="Layers with LoRA modules.", + ) + parser.add_argument( + "--init_lora_weights", + type=str, + default="kaiming", + choices=["gaussian", "kaiming"], + help="The initializing method of LoRA weight.", + ) + parser.add_argument( + "--training_strategy", + type=str, + default="auto", + choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"], + help="Training strategy", + ) + parser.add_argument( + "--lora_rank", + type=int, + default=4, + help="The dimension of the LoRA update matrices.", + ) + parser.add_argument( + "--lora_alpha", + type=float, + default=4.0, + help="The weight of the LoRA update matrices.", + ) + parser.add_argument( + "--use_gradient_checkpointing", + default=False, + action="store_true", + help="Whether to use gradient checkpointing.", + ) + parser.add_argument( + "--train_architecture", + type=str, + default="lora", + choices=["lora", "full"], + help="Model structure to train. LoRA training or full training.", + ) + args = parser.parse_args() + return args + + +def data_process(args): + dataset = I2VDataset( + args.dataset_path, + os.path.join(args.dataset_path, "metadata.csv"), + max_num_frames=args.num_frames, + frame_interval=1, + num_frames=args.num_frames, + height=args.height, + width=args.width + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=False, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + model = LightningModelForDataProcess( + text_encoder_path=args.text_encoder_path, + image_encoder_path=args.image_encoder_path, + vae_path=args.vae_path, + num_frames=args.num_frames, + height=args.height, + width=args.width, + tiled=args.tiled, + tile_size=(args.tile_size_height, args.tile_size_width), + tile_stride=(args.tile_stride_height, args.tile_stride_width) + ) + trainer = pl.Trainer( + accelerator="gpu", + devices="auto", + default_root_dir=args.output_path, + ) + trainer.test(model, dataloader) + + +def train(args): + dataset = TensorDataset( + args.dataset_path, + os.path.join(args.dataset_path, "metadata.csv"), + steps_per_epoch=args.steps_per_epoch, + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + model = LightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + train_architecture=args.train_architecture, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_target_modules=args.lora_target_modules, + init_lora_weights=args.init_lora_weights, + use_gradient_checkpointing=args.use_gradient_checkpointing + ) + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)] + ) + trainer.fit(model, dataloader) + + +if __name__ == '__main__': + args = parse_args() + if args.task == "data_process": + data_process(args) + elif args.task == "train": + train(args)