diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 7c7b144..22c5d98 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -10,6 +10,13 @@ cd DiffSynth-Studio pip install -e . ``` +Wan-Video supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority. + +* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention) +* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention) +* [Sage Attention](https://github.com/thu-ml/SageAttention) +* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.) + ## Inference ### Wan-Video-1.3B-T2V @@ -44,13 +51,17 @@ https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f Wan-Video-14B-I2V adds the functionality of image-to-video based on Wan-Video-14B-T2V. The model size remains the same, therefore the speed and VRAM requirements are also consistent. See [`./wan_14b_image_to_video.py`](./wan_14b_image_to_video.py). +**In the sample code, we use the same settings as the T2V 14B model, with FP8 quantization enabled by default. However, we found that this model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.** + ![Image](https://github.com/user-attachments/assets/adf8047f-7943-4aaa-a555-2b32dc415f39) https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75 ## Train -We support Wan-Video LoRA training and full training. Here is a tutorial. +We support Wan-Video LoRA training and full training. Here is a tutorial. This is an experimental feature. Below is a video sample generated from the character Keqing LoRA: + +https://github.com/user-attachments/assets/9bd8e30b-97e8-44f9-bb6f-da004ba376a9 Step 1: Install additional packages @@ -67,7 +78,7 @@ data/example_dataset/ ├── metadata.csv └── train ├── video_00001.mp4 - └── video_00002.mp4 + └── image_00002.jpg ``` `metadata.csv`: @@ -75,9 +86,11 @@ data/example_dataset/ ``` file_name,text video_00001.mp4,"video description" -video_00001.mp4,"video description" +image_00002.jpg,"video description" ``` +We support both images and videos. An image is treated as a single frame of video. + Step 3: Data process ```shell @@ -119,8 +132,8 @@ CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \ --steps_per_epoch 500 \ --max_epochs 10 \ --learning_rate 1e-4 \ - --lora_rank 4 \ - --lora_alpha 4 \ + --lora_rank 16 \ + --lora_alpha 16 \ --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ --accumulate_grad_batches 1 \ --use_gradient_checkpointing @@ -142,48 +155,12 @@ CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \ --use_gradient_checkpointing ``` -Step 4-1: I2V LoRA-training -```shell -# cache latents -CUDA_VISIBLE_DEVICES="0" python train_wan_i2v.py \ - --task data_process \ - --dataset_path data/fps24_V6 \ - --output_path ./output \ - --text_encoder_path "./models/Wan-AI/Wan2.1-I2V-14B-720P/models_t5_umt5-xxl-enc-bf16.pth" \ - --vae_path "./models/Wan-AI/Wan2.1-I2V-14B-720P/Wan2.1_VAE.pth" \ - --image_encoder_path "./models/Wan-AI/Wan2.1-I2V-14B-720P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ - --tiled \ - --num_frames 121 \ - --height 309 \ - --width 186 -``` +If you wish to train the 14B model, please separate the safetensor files with a comma. For example: `models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors`. + +If you wish to train the image-to-video model, please add an extra parameter `--image_encoder_path "models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"`. + +For LoRA training, the Wan-1.3B-T2V model requires 16G of VRAM for processing 81 frames at 480P, while the Wan-14B-T2V model requires 60G of VRAM for the same configuration. To further reduce VRAM requirements by 20%-30%, you can include the parameter `--use_gradient_checkpointing_offload`. -```shell -# run I2V training -CUDA_VISIBLE_DEVICES="0" python train_wan_i2v.py \ - --task train \ - --train_architecture lora \ - --dataset_path data/kling_hips_fps24_V6 \ - --output_path ./output \ - --dit_path "[ - \"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00001-of-00007.safetensors\", - \"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00002-of-00007.safetensors\", - \"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00003-of-00007.safetensors\", - \"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00004-of-00007.safetensors\", - \"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00005-of-00007.safetensors\", - \"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors\", - \"./models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors\" - ]" \ - --steps_per_epoch 500 \ - --max_epochs 10 \ - --learning_rate 1e-4 \ - --lora_rank 4 \ - --lora_alpha 4 \ - --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ - --accumulate_grad_batches 1 \ - --use_gradient_checkpointing -``` - Step 5: Test Test LoRA: diff --git a/examples/wanvideo/train_wan_i2v.py b/examples/wanvideo/train_wan_i2v.py deleted file mode 100644 index 9a49b34..0000000 --- a/examples/wanvideo/train_wan_i2v.py +++ /dev/null @@ -1,494 +0,0 @@ -import torch, os, imageio, argparse -from torchvision.transforms import v2 -from einops import rearrange -import lightning as pl -import pandas as pd -from diffsynth import WanVideoPipeline, ModelManager -from peft import LoraConfig, inject_adapter_in_model -import torchvision -from PIL import Image -import numpy as np -import json - - -class I2VDataset(torch.utils.data.Dataset): - def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832): - metadata = pd.read_csv(metadata_path) - self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] - self.text = metadata["text"].to_list() - - self.max_num_frames = max_num_frames - self.frame_interval = frame_interval - self.num_frames = num_frames - self.height = height - self.width = width - - self.frame_process = v2.Compose([ - v2.CenterCrop(size=(height, width)), - v2.Resize(size=(height, width), antialias=True), - v2.ToTensor(), - v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - ]) - - - def crop_and_resize(self, image): - width, height = image.size - scale = max(self.width / width, self.height / height) - image = torchvision.transforms.functional.resize( - image, - (round(height*scale), round(width*scale)), - interpolation=torchvision.transforms.InterpolationMode.BILINEAR - ) - return image - - - def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process): - reader = imageio.get_reader(file_path) - if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: - reader.close() - return None - - frames = [] - first_frame_image = None - for frame_id in range(num_frames): - frame = reader.get_data(start_frame_id + frame_id * interval) - frame = Image.fromarray(frame) - if first_frame_image is None: - first_frame_image = frame - frame = self.crop_and_resize(frame) - frame = frame_process(frame) - frames.append(frame) - reader.close() - - frames = torch.stack(frames, dim=0) - frames = rearrange(frames, "T C H W -> C T H W") - - return frames, first_frame_image - - - def load_video(self, file_path): - start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0] - frames, first_frame_image = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process) - return frames, first_frame_image - - - def load_text_video_raw_data(self, data_id): - text = self.path[data_id] - video = self.load_video(self.path[data_id]) - data = {"text": text, "video": video} - return data - - - def __getitem__(self, data_id): - text = self.text[data_id] - path = self.path[data_id] - video, first_frame_image = self.load_video(path) - data = {"text": text, "video": video, "first_frame_image":np.array(first_frame_image), "path": path} - return data - - def __len__(self): - return len(self.path) - - - -class LightningModelForDataProcess(pl.LightningModule): - def __init__(self, text_encoder_path, image_encoder_path, vae_path, num_frames, height, width, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): - super().__init__() - model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") - model_manager.load_models([text_encoder_path, image_encoder_path, vae_path]) - self.pipe = WanVideoPipeline.from_model_manager(model_manager) - - self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} - self.num_frames = num_frames - self.height = height - self.width = width - - def test_step(self, batch, batch_idx): - text, video, first_frame_image_tensor, path = batch["text"][0], batch["video"], batch["first_frame_image"][0], batch["path"][0] - self.pipe.device = self.device - if video is not None: - prompt_emb = self.pipe.encode_prompt(text) - latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0] - first_frame_image = Image.fromarray(np.array(first_frame_image_tensor.cpu())) - cond_data_dict = self.pipe.encode_image(first_frame_image, num_frames=self.num_frames, height=self.height, width=self.width) - data = {"latents": latents, "prompt_emb": prompt_emb, "clip_fea": cond_data_dict["clip_fea"][0], "y": cond_data_dict["y"][0]} - torch.save(data, path + ".tensors.pth") - - -class TensorDataset(torch.utils.data.Dataset): - def __init__(self, base_path, metadata_path, steps_per_epoch): - metadata = pd.read_csv(metadata_path) - self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] - print(len(self.path), "videos in metadata.") - self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")] - print(len(self.path), "tensors cached in metadata.") - assert len(self.path) > 0 - - self.steps_per_epoch = steps_per_epoch - - - def __getitem__(self, index): - data_id = torch.randint(0, len(self.path), (1,))[0] - data_id = (data_id + index) % len(self.path) # For fixed seed. - path = self.path[data_id] - data = torch.load(path, weights_only=True, map_location="cpu") - return data - - - def __len__(self): - return self.steps_per_epoch - - - -class LightningModelForTrain(pl.LightningModule): - def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True): - super().__init__() - model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") - # 将 dit_path 从字符串解析为 Python 列表 - dit_path = json.loads(dit_path) - model_manager.load_models([dit_path]) - - self.pipe = WanVideoPipeline.from_model_manager(model_manager) - self.pipe.scheduler.set_timesteps(1000, training=True) - self.freeze_parameters() - if train_architecture == "lora": - self.add_lora_to_model( - self.pipe.denoising_model(), - lora_rank=lora_rank, - lora_alpha=lora_alpha, - lora_target_modules=lora_target_modules, - init_lora_weights=init_lora_weights, - ) - else: - self.pipe.denoising_model().requires_grad_(True) - - self.learning_rate = learning_rate - self.use_gradient_checkpointing = use_gradient_checkpointing - - - def freeze_parameters(self): - # Freeze parameters - self.pipe.requires_grad_(False) - self.pipe.eval() - self.pipe.denoising_model().train() - - - def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming"): - # Add LoRA to UNet - self.lora_alpha = lora_alpha - if init_lora_weights == "kaiming": - init_lora_weights = True - - lora_config = LoraConfig( - r=lora_rank, - lora_alpha=lora_alpha, - init_lora_weights=init_lora_weights, - target_modules=lora_target_modules.split(","), - ) - model = inject_adapter_in_model(lora_config, model) - for param in model.parameters(): - # Upcast LoRA parameters into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) - - - def training_step(self, batch, batch_idx): - # Data - latents = batch["latents"].to(self.device) - - prompt_emb = batch["prompt_emb"] - prompt_emb["context"] = [prompt_emb["context"][0][0].to(self.device)] - clip_fea = batch["clip_fea"].to(self.device) - y = batch["y"].to(self.device) - - # Loss - noise = torch.randn_like(latents) - timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) - timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device) - extra_input = self.pipe.prepare_extra_input(latents) - noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) - training_target = self.pipe.scheduler.training_target(latents, noise, timestep) - - # Compute loss - with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): - noise_pred = self.pipe.denoising_model()( - noisy_latents, timestep=timestep, **prompt_emb, **extra_input, - use_gradient_checkpointing=self.use_gradient_checkpointing, - clip_fea=clip_fea, y=y - ) - loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target[..., 1:].float()) - loss = loss * self.pipe.scheduler.training_weight(timestep) - - # Record log - self.log("train_loss", loss, prog_bar=True) - return loss - - - def configure_optimizers(self): - trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) - optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) - return optimizer - - - def on_save_checkpoint(self, checkpoint): - checkpoint.clear() - trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters())) - trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) - state_dict = self.pipe.denoising_model().state_dict() - lora_state_dict = {} - for name, param in state_dict.items(): - if name in trainable_param_names: - lora_state_dict[name] = param - checkpoint.update(lora_state_dict) - - - -def parse_args(): - parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument( - "--task", - type=str, - default="data_process", - required=True, - choices=["data_process", "train"], - help="Task. `data_process` or `train`.", - ) - parser.add_argument( - "--dataset_path", - type=str, - default=None, - required=True, - help="The path of the Dataset.", - ) - parser.add_argument( - "--output_path", - type=str, - default="./", - help="Path to save the model.", - ) - parser.add_argument( - "--text_encoder_path", - type=str, - default=None, - help="Path of text encoder.", - ) - parser.add_argument( - "--image_encoder_path", - type=str, - default=None, - help="Path of image encoder.", - ) - parser.add_argument( - "--vae_path", - type=str, - default=None, - help="Path of VAE.", - ) - parser.add_argument( - "--dit_path", - type=str, - default=None, - help="Path of DiT.", - ) - parser.add_argument( - "--tiled", - default=False, - action="store_true", - help="Whether enable tile encode in VAE. This option can reduce VRAM required.", - ) - parser.add_argument( - "--tile_size_height", - type=int, - default=34, - help="Tile size (height) in VAE.", - ) - parser.add_argument( - "--tile_size_width", - type=int, - default=34, - help="Tile size (width) in VAE.", - ) - parser.add_argument( - "--tile_stride_height", - type=int, - default=18, - help="Tile stride (height) in VAE.", - ) - parser.add_argument( - "--tile_stride_width", - type=int, - default=16, - help="Tile stride (width) in VAE.", - ) - parser.add_argument( - "--steps_per_epoch", - type=int, - default=500, - help="Number of steps per epoch.", - ) - parser.add_argument( - "--num_frames", - type=int, - default=81, - help="Number of frames.", - ) - parser.add_argument( - "--height", - type=int, - default=480, - help="Image height.", - ) - parser.add_argument( - "--width", - type=int, - default=832, - help="Image width.", - ) - parser.add_argument( - "--dataloader_num_workers", - type=int, - default=1, - help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=1e-5, - help="Learning rate.", - ) - parser.add_argument( - "--accumulate_grad_batches", - type=int, - default=1, - help="The number of batches in gradient accumulation.", - ) - parser.add_argument( - "--max_epochs", - type=int, - default=1, - help="Number of epochs.", - ) - parser.add_argument( - "--lora_target_modules", - type=str, - default="q,k,v,o,ffn.0,ffn.2", - help="Layers with LoRA modules.", - ) - parser.add_argument( - "--init_lora_weights", - type=str, - default="kaiming", - choices=["gaussian", "kaiming"], - help="The initializing method of LoRA weight.", - ) - parser.add_argument( - "--training_strategy", - type=str, - default="auto", - choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"], - help="Training strategy", - ) - parser.add_argument( - "--lora_rank", - type=int, - default=4, - help="The dimension of the LoRA update matrices.", - ) - parser.add_argument( - "--lora_alpha", - type=float, - default=4.0, - help="The weight of the LoRA update matrices.", - ) - parser.add_argument( - "--use_gradient_checkpointing", - default=False, - action="store_true", - help="Whether to use gradient checkpointing.", - ) - parser.add_argument( - "--train_architecture", - type=str, - default="lora", - choices=["lora", "full"], - help="Model structure to train. LoRA training or full training.", - ) - args = parser.parse_args() - return args - - -def data_process(args): - dataset = I2VDataset( - args.dataset_path, - os.path.join(args.dataset_path, "metadata.csv"), - max_num_frames=args.num_frames, - frame_interval=1, - num_frames=args.num_frames, - height=args.height, - width=args.width - ) - dataloader = torch.utils.data.DataLoader( - dataset, - shuffle=False, - batch_size=1, - num_workers=args.dataloader_num_workers - ) - model = LightningModelForDataProcess( - text_encoder_path=args.text_encoder_path, - image_encoder_path=args.image_encoder_path, - vae_path=args.vae_path, - num_frames=args.num_frames, - height=args.height, - width=args.width, - tiled=args.tiled, - tile_size=(args.tile_size_height, args.tile_size_width), - tile_stride=(args.tile_stride_height, args.tile_stride_width) - ) - trainer = pl.Trainer( - accelerator="gpu", - devices="auto", - default_root_dir=args.output_path, - ) - trainer.test(model, dataloader) - - -def train(args): - dataset = TensorDataset( - args.dataset_path, - os.path.join(args.dataset_path, "metadata.csv"), - steps_per_epoch=args.steps_per_epoch, - ) - dataloader = torch.utils.data.DataLoader( - dataset, - shuffle=True, - batch_size=1, - num_workers=args.dataloader_num_workers - ) - model = LightningModelForTrain( - dit_path=args.dit_path, - learning_rate=args.learning_rate, - train_architecture=args.train_architecture, - lora_rank=args.lora_rank, - lora_alpha=args.lora_alpha, - lora_target_modules=args.lora_target_modules, - init_lora_weights=args.init_lora_weights, - use_gradient_checkpointing=args.use_gradient_checkpointing - ) - trainer = pl.Trainer( - max_epochs=args.max_epochs, - accelerator="gpu", - devices="auto", - strategy=args.training_strategy, - default_root_dir=args.output_path, - accumulate_grad_batches=args.accumulate_grad_batches, - callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)] - ) - trainer.fit(model, dataloader) - - -if __name__ == '__main__': - args = parse_args() - if args.task == "data_process": - data_process(args) - elif args.task == "train": - train(args) diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index 817fd5c..8e8b370 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -3,15 +3,16 @@ from torchvision.transforms import v2 from einops import rearrange import lightning as pl import pandas as pd -from diffsynth import WanVideoPipeline, ModelManager +from diffsynth import WanVideoPipeline, ModelManager, load_state_dict from peft import LoraConfig, inject_adapter_in_model import torchvision from PIL import Image +import numpy as np class TextVideoDataset(torch.utils.data.Dataset): - def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832): + def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False): metadata = pd.read_csv(metadata_path) self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] self.text = metadata["text"].to_list() @@ -21,6 +22,7 @@ class TextVideoDataset(torch.utils.data.Dataset): self.num_frames = num_frames self.height = height self.width = width + self.is_i2v = is_i2v self.frame_process = v2.Compose([ v2.CenterCrop(size=(height, width)), @@ -48,10 +50,13 @@ class TextVideoDataset(torch.utils.data.Dataset): return None frames = [] + first_frame = None for frame_id in range(num_frames): frame = reader.get_data(start_frame_id + frame_id * interval) frame = Image.fromarray(frame) frame = self.crop_and_resize(frame) + if first_frame is None: + first_frame = np.array(frame) frame = frame_process(frame) frames.append(frame) reader.close() @@ -59,7 +64,10 @@ class TextVideoDataset(torch.utils.data.Dataset): frames = torch.stack(frames, dim=0) frames = rearrange(frames, "T C H W -> C T H W") - return frames + if self.is_i2v: + return frames, first_frame + else: + return frames def load_video(self, file_path): @@ -70,7 +78,7 @@ class TextVideoDataset(torch.utils.data.Dataset): def is_image(self, file_path): file_ext_name = file_path.split(".")[-1] - if file_ext_name.lower() in ["jpg", "png", "webp"]: + if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]: return True return False @@ -78,6 +86,7 @@ class TextVideoDataset(torch.utils.data.Dataset): def load_image(self, file_path): frame = Image.open(file_path).convert("RGB") frame = self.crop_and_resize(frame) + first_frame = frame frame = self.frame_process(frame) frame = rearrange(frame, "C H W -> C 1 H W") return frame @@ -87,10 +96,16 @@ class TextVideoDataset(torch.utils.data.Dataset): text = self.text[data_id] path = self.path[data_id] if self.is_image(path): + if self.is_i2v: + raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.") video = self.load_image(path) else: video = self.load_video(path) - data = {"text": text, "video": video, "path": path} + if self.is_i2v: + video, first_frame = video + data = {"text": text, "video": video, "path": path, "first_frame": first_frame} + else: + data = {"text": text, "video": video, "path": path} return data @@ -100,21 +115,35 @@ class TextVideoDataset(torch.utils.data.Dataset): class LightningModelForDataProcess(pl.LightningModule): - def __init__(self, text_encoder_path, vae_path, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): super().__init__() + model_path = [text_encoder_path, vae_path] + if image_encoder_path is not None: + model_path.append(image_encoder_path) model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") - model_manager.load_models([text_encoder_path, vae_path]) + model_manager.load_models(model_path) self.pipe = WanVideoPipeline.from_model_manager(model_manager) self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} def test_step(self, batch, batch_idx): text, video, path = batch["text"][0], batch["video"], batch["path"][0] + self.pipe.device = self.device if video is not None: + # prompt prompt_emb = self.pipe.encode_prompt(text) + # video + video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device) latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0] - data = {"latents": latents, "prompt_emb": prompt_emb} + # image + if "first_frame" in batch: + first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy()) + _, _, num_frames, height, width = video.shape + image_emb = self.pipe.encode_image(first_frame, num_frames, height, width) + else: + image_emb = {} + data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb} torch.save(data, path + ".tensors.pth") @@ -145,10 +174,21 @@ class TensorDataset(torch.utils.data.Dataset): class LightningModelForTrain(pl.LightningModule): - def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True): + def __init__( + self, + dit_path, + learning_rate=1e-5, + lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", + use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, + pretrained_lora_path=None + ): super().__init__() model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") - model_manager.load_models([dit_path]) + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) self.pipe = WanVideoPipeline.from_model_manager(model_manager) self.pipe.scheduler.set_timesteps(1000, training=True) @@ -160,12 +200,14 @@ class LightningModelForTrain(pl.LightningModule): lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights, + pretrained_lora_path=pretrained_lora_path, ) else: self.pipe.denoising_model().requires_grad_(True) self.learning_rate = learning_rate self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload def freeze_parameters(self): @@ -175,7 +217,7 @@ class LightningModelForTrain(pl.LightningModule): self.pipe.denoising_model().train() - def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming"): + def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None): # Add LoRA to UNet self.lora_alpha = lora_alpha if init_lora_weights == "kaiming": @@ -192,30 +234,47 @@ class LightningModelForTrain(pl.LightningModule): # Upcast LoRA parameters into fp32 if param.requires_grad: param.data = param.to(torch.float32) + + # Lora pretrained lora weights + if pretrained_lora_path is not None: + state_dict = load_state_dict(pretrained_lora_path) + if state_dict_converter is not None: + state_dict = state_dict_converter(state_dict) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + all_keys = [i for i, _ in model.named_parameters()] + num_updated_keys = len(all_keys) - len(missing_keys) + num_unexpected_keys = len(unexpected_keys) + print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.") def training_step(self, batch, batch_idx): # Data latents = batch["latents"].to(self.device) prompt_emb = batch["prompt_emb"] - prompt_emb["context"] = [prompt_emb["context"][0][0].to(self.device)] - + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + # Loss + self.pipe.device = self.device noise = torch.randn_like(latents) timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) - timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) extra_input = self.pipe.prepare_extra_input(latents) noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) training_target = self.pipe.scheduler.training_target(latents, noise, timestep) # Compute loss - with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): - noise_pred = self.pipe.denoising_model()( - noisy_latents, timestep=timestep, **prompt_emb, **extra_input, - use_gradient_checkpointing=self.use_gradient_checkpointing - ) - loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) - loss = loss * self.pipe.scheduler.training_weight(timestep) + noise_pred = self.pipe.denoising_model()( + noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) # Record log self.log("train_loss", loss, prog_bar=True) @@ -270,6 +329,12 @@ def parse_args(): default=None, help="Path of text encoder.", ) + parser.add_argument( + "--image_encoder_path", + type=str, + default=None, + help="Path of image encoder.", + ) parser.add_argument( "--vae_path", type=str, @@ -398,6 +463,12 @@ def parse_args(): action="store_true", help="Whether to use gradient checkpointing.", ) + parser.add_argument( + "--use_gradient_checkpointing_offload", + default=False, + action="store_true", + help="Whether to use gradient checkpointing offload.", + ) parser.add_argument( "--train_architecture", type=str, @@ -405,6 +476,23 @@ def parse_args(): choices=["lora", "full"], help="Model structure to train. LoRA training or full training.", ) + parser.add_argument( + "--pretrained_lora_path", + type=str, + default=None, + help="Pretrained LoRA path. Required if the training is resumed.", + ) + parser.add_argument( + "--use_swanlab", + default=False, + action="store_true", + help="Whether to use SwanLab logger.", + ) + parser.add_argument( + "--swanlab_mode", + default=None, + help="SwanLab mode (cloud or local).", + ) args = parser.parse_args() return args @@ -417,7 +505,8 @@ def data_process(args): frame_interval=1, num_frames=args.num_frames, height=args.height, - width=args.width + width=args.width, + is_i2v=args.image_encoder_path is not None ) dataloader = torch.utils.data.DataLoader( dataset, @@ -427,6 +516,7 @@ def data_process(args): ) model = LightningModelForDataProcess( text_encoder_path=args.text_encoder_path, + image_encoder_path=args.image_encoder_path, vae_path=args.vae_path, tiled=args.tiled, tile_size=(args.tile_size_height, args.tile_size_width), @@ -460,16 +550,34 @@ def train(args): lora_alpha=args.lora_alpha, lora_target_modules=args.lora_target_modules, init_lora_weights=args.init_lora_weights, - use_gradient_checkpointing=args.use_gradient_checkpointing + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + pretrained_lora_path=args.pretrained_lora_path, ) + if args.use_swanlab: + from swanlab.integration.pytorch_lightning import SwanLabLogger + swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} + swanlab_config.update(vars(args)) + swanlab_logger = SwanLabLogger( + project="wan", + name="wan", + config=swanlab_config, + mode=args.swanlab_mode, + logdir=os.path.join(args.output_path, "swanlog"), + ) + logger = [swanlab_logger] + else: + logger = None trainer = pl.Trainer( max_epochs=args.max_epochs, accelerator="gpu", devices="auto", + precision="bf16", strategy=args.training_strategy, default_root_dir=args.output_path, accumulate_grad_batches=args.accumulate_grad_batches, - callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)] + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], + logger=logger, ) trainer.fit(model, dataloader)