diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index de3be03..22c5d98 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -157,6 +157,8 @@ CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \ If you wish to train the 14B model, please separate the safetensor files with a comma. For example: `models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors`. +If you wish to train the image-to-video model, please add an extra parameter `--image_encoder_path "models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"`. + For LoRA training, the Wan-1.3B-T2V model requires 16G of VRAM for processing 81 frames at 480P, while the Wan-14B-T2V model requires 60G of VRAM for the same configuration. To further reduce VRAM requirements by 20%-30%, you can include the parameter `--use_gradient_checkpointing_offload`. Step 5: Test diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index 90533f9..8e8b370 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -7,11 +7,12 @@ from diffsynth import WanVideoPipeline, ModelManager, load_state_dict from peft import LoraConfig, inject_adapter_in_model import torchvision from PIL import Image +import numpy as np class TextVideoDataset(torch.utils.data.Dataset): - def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832): + def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False): metadata = pd.read_csv(metadata_path) self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] self.text = metadata["text"].to_list() @@ -21,6 +22,7 @@ class TextVideoDataset(torch.utils.data.Dataset): self.num_frames = num_frames self.height = height self.width = width + self.is_i2v = is_i2v self.frame_process = v2.Compose([ v2.CenterCrop(size=(height, width)), @@ -48,10 +50,13 @@ class TextVideoDataset(torch.utils.data.Dataset): return None frames = [] + first_frame = None for frame_id in range(num_frames): frame = reader.get_data(start_frame_id + frame_id * interval) frame = Image.fromarray(frame) frame = self.crop_and_resize(frame) + if first_frame is None: + first_frame = np.array(frame) frame = frame_process(frame) frames.append(frame) reader.close() @@ -59,7 +64,10 @@ class TextVideoDataset(torch.utils.data.Dataset): frames = torch.stack(frames, dim=0) frames = rearrange(frames, "T C H W -> C T H W") - return frames + if self.is_i2v: + return frames, first_frame + else: + return frames def load_video(self, file_path): @@ -78,6 +86,7 @@ class TextVideoDataset(torch.utils.data.Dataset): def load_image(self, file_path): frame = Image.open(file_path).convert("RGB") frame = self.crop_and_resize(frame) + first_frame = frame frame = self.frame_process(frame) frame = rearrange(frame, "C H W -> C 1 H W") return frame @@ -87,10 +96,16 @@ class TextVideoDataset(torch.utils.data.Dataset): text = self.text[data_id] path = self.path[data_id] if self.is_image(path): + if self.is_i2v: + raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.") video = self.load_image(path) else: video = self.load_video(path) - data = {"text": text, "video": video, "path": path} + if self.is_i2v: + video, first_frame = video + data = {"text": text, "video": video, "path": path, "first_frame": first_frame} + else: + data = {"text": text, "video": video, "path": path} return data @@ -100,22 +115,35 @@ class TextVideoDataset(torch.utils.data.Dataset): class LightningModelForDataProcess(pl.LightningModule): - def __init__(self, text_encoder_path, vae_path, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): super().__init__() + model_path = [text_encoder_path, vae_path] + if image_encoder_path is not None: + model_path.append(image_encoder_path) model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") - model_manager.load_models([text_encoder_path, vae_path]) + model_manager.load_models(model_path) self.pipe = WanVideoPipeline.from_model_manager(model_manager) self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} def test_step(self, batch, batch_idx): text, video, path = batch["text"][0], batch["video"], batch["path"][0] + self.pipe.device = self.device if video is not None: + # prompt prompt_emb = self.pipe.encode_prompt(text) + # video video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device) latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0] - data = {"latents": latents, "prompt_emb": prompt_emb} + # image + if "first_frame" in batch: + first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy()) + _, _, num_frames, height, width = video.shape + image_emb = self.pipe.encode_image(first_frame, num_frames, height, width) + else: + image_emb = {} + data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb} torch.save(data, path + ".tensors.pth") @@ -224,7 +252,12 @@ class LightningModelForTrain(pl.LightningModule): latents = batch["latents"].to(self.device) prompt_emb = batch["prompt_emb"] prompt_emb["context"] = prompt_emb["context"][0].to(self.device) - + image_emb = batch["image_emb"] + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + # Loss self.pipe.device = self.device noise = torch.randn_like(latents) @@ -236,7 +269,7 @@ class LightningModelForTrain(pl.LightningModule): # Compute loss noise_pred = self.pipe.denoising_model()( - noisy_latents, timestep=timestep, **prompt_emb, **extra_input, + noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb, use_gradient_checkpointing=self.use_gradient_checkpointing, use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload ) @@ -296,6 +329,12 @@ def parse_args(): default=None, help="Path of text encoder.", ) + parser.add_argument( + "--image_encoder_path", + type=str, + default=None, + help="Path of image encoder.", + ) parser.add_argument( "--vae_path", type=str, @@ -466,7 +505,8 @@ def data_process(args): frame_interval=1, num_frames=args.num_frames, height=args.height, - width=args.width + width=args.width, + is_i2v=args.image_encoder_path is not None ) dataloader = torch.utils.data.DataLoader( dataset, @@ -476,6 +516,7 @@ def data_process(args): ) model = LightningModelForDataProcess( text_encoder_path=args.text_encoder_path, + image_encoder_path=args.image_encoder_path, vae_path=args.vae_path, tiled=args.tiled, tile_size=(args.tile_size_height, args.tile_size_width),