From 105eaf0f4944e735816bc3df5c769ffcc7fb74e3 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 21 Mar 2025 11:09:56 +0800 Subject: [PATCH] controlnet --- diffsynth/models/wan_video_controlnet.py | 204 ++++++ diffsynth/pipelines/wan_video.py | 82 ++- examples/wanvideo/train_wan_t2v.py | 15 +- examples/wanvideo/train_wan_t2v_controlnet.py | 626 ++++++++++++++++++ 4 files changed, 915 insertions(+), 12 deletions(-) create mode 100644 diffsynth/models/wan_video_controlnet.py create mode 100644 examples/wanvideo/train_wan_t2v_controlnet.py diff --git a/diffsynth/models/wan_video_controlnet.py b/diffsynth/models/wan_video_controlnet.py new file mode 100644 index 0000000..9294b0e --- /dev/null +++ b/diffsynth/models/wan_video_controlnet.py @@ -0,0 +1,204 @@ +import torch +import torch.nn as nn +from typing import Tuple, Optional +from einops import rearrange +from .wan_video_dit import DiTBlock, precompute_freqs_cis_3d, MLP, sinusoidal_embedding_1d +from .utils import hash_state_dict_keys + + + +class WanControlNetModel(torch.nn.Module): + def __init__( + self, + dim: int, + in_dim: int, + ffn_dim: int, + out_dim: int, + text_dim: int, + freq_dim: int, + eps: float, + patch_size: Tuple[int, int, int], + num_heads: int, + num_layers: int, + has_image_input: bool, + ): + super().__init__() + self.dim = dim + self.freq_dim = freq_dim + self.has_image_input = has_image_input + self.patch_size = patch_size + + self.patch_embedding = nn.Conv3d( + in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential( + nn.Linear(text_dim, dim), + nn.GELU(approximate='tanh'), + nn.Linear(dim, dim) + ) + self.time_embedding = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim) + ) + self.time_projection = nn.Sequential( + nn.SiLU(), nn.Linear(dim, dim * 6)) + self.blocks = nn.ModuleList([ + DiTBlock(has_image_input, dim, num_heads, ffn_dim, eps) + for _ in range(num_layers) + ]) + head_dim = dim // num_heads + self.freqs = precompute_freqs_cis_3d(head_dim) + + if has_image_input: + self.img_emb = MLP(1280, dim) # clip_feature_dim = 1280 + + self.controlnet_conv_in = torch.nn.Conv3d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) + self.controlnet_blocks = torch.nn.ModuleList([ + torch.nn.Linear(dim, dim, bias=False) + for _ in range(num_layers) + ]) + + def patchify(self, x: torch.Tensor): + x = self.patch_embedding(x) + grid_size = x.shape[2:] + x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() + return x, grid_size # x, grid_size: (f, h, w) + + def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): + return rearrange( + x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)', + f=grid_size[0], h=grid_size[1], w=grid_size[2], + x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2] + ) + + def forward(self, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + controlnet_conditioning: Optional[torch.Tensor] = None, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, + **kwargs, + ): + t = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, timestep)) + t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) + context = self.text_embedding(context) + + if self.has_image_input: + x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + clip_embdding = self.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + x = x + self.controlnet_conv_in(controlnet_conditioning) + x, (f, h, w) = self.patchify(x) + + freqs = torch.cat([ + self.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + self.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + self.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + + res_stack = [] + for block in self.blocks: + if self.training and use_gradient_checkpointing: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, freqs) + res_stack.append(x) + + controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)] + return controlnet_res_stack + + @staticmethod + def state_dict_converter(): + return WanControlNetModelStateDictConverter() + + +class WanControlNetModelStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + return state_dict + + def from_base_model(self, state_dict): + if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814": + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 16, + "dim": 1536, + "ffn_dim": 8960, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 12, + "num_layers": 30, + "eps": 1e-6 + } + elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70": + config = { + "has_image_input": False, + "patch_size": [1, 2, 2], + "in_dim": 16, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6 + } + elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e": + config = { + "has_image_input": True, + "patch_size": [1, 2, 2], + "in_dim": 36, + "dim": 5120, + "ffn_dim": 13824, + "freq_dim": 256, + "text_dim": 4096, + "out_dim": 16, + "num_heads": 40, + "num_layers": 40, + "eps": 1e-6 + } + else: + config = {} + state_dict_ = {} + dtype, device = None, None + for name, param in state_dict.items(): + if name.startswith("head."): + continue + state_dict_[name] = param + dtype, device = param.dtype, param.device + for block_id in range(config["num_layers"]): + zeros = torch.zeros((config["dim"], config["dim"]), dtype=dtype, device=device) + state_dict_[f"controlnet_blocks.{block_id}.weight"] = zeros.clone() + state_dict_["controlnet_conv_in.weight"] = torch.zeros((config["in_dim"], config["in_dim"], 1, 1, 1), dtype=dtype, device=device) + state_dict_["controlnet_conv_in.bias"] = torch.zeros((config["in_dim"],), dtype=dtype, device=device) + return state_dict_, config \ No newline at end of file diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 439d311..2c6f640 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -17,6 +17,7 @@ from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWra from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample +from ..models.wan_video_controlnet import WanControlNetModel @@ -30,7 +31,8 @@ class WanVideoPipeline(BasePipeline): self.image_encoder: WanImageEncoder = None self.dit: WanModel = None self.vae: WanVideoVAE = None - self.model_names = ['text_encoder', 'dit', 'vae'] + self.controlnet: WanControlNetModel = None + self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'controlnet'] self.height_division_factor = 16 self.width_division_factor = 16 @@ -189,6 +191,11 @@ class WanVideoPipeline(BasePipeline): def decode_video(self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): frames = self.vae.decode(latents, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) return frames + + + def prepare_controlnet(self, controlnet_frames, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): + controlnet_conditioning = self.encode_video(controlnet_frames, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device) + return {"controlnet_conditioning": controlnet_conditioning} @torch.no_grad() @@ -212,6 +219,7 @@ class WanVideoPipeline(BasePipeline): tile_stride=(15, 26), tea_cache_l1_thresh=None, tea_cache_model_id="", + controlnet_frames=None, progress_bar_cmd=tqdm, progress_bar_st=None, ): @@ -252,6 +260,15 @@ class WanVideoPipeline(BasePipeline): else: image_emb = {} + # ControlNet + if self.controlnet is not None and controlnet_frames is not None: + self.load_models_to_device(['vae', 'controlnet']) + controlnet_frames = self.preprocess_images(controlnet_frames) + controlnet_frames = torch.stack(controlnet_frames, dim=2).to(dtype=self.torch_dtype, device=self.device) + controlnet_kwargs = self.prepare_controlnet(controlnet_frames) + else: + controlnet_kwargs = {} + # Extra input extra_input = self.prepare_extra_input(latents) @@ -260,14 +277,24 @@ class WanVideoPipeline(BasePipeline): tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} # Denoise - self.load_models_to_device(["dit"]) + self.load_models_to_device(["dit", "controlnet"]) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) # Inference - noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi) + noise_pred_posi = model_fn_wan_video( + self.dit, controlnet=self.controlnet, + x=latents, timestep=timestep, + **prompt_emb_posi, **image_emb, **extra_input, + **tea_cache_posi, **controlnet_kwargs + ) if cfg_scale != 1.0: - noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega) + noise_pred_nega = model_fn_wan_video( + self.dit, controlnet=self.controlnet, + x=latents, timestep=timestep, + **prompt_emb_nega, **image_emb, **extra_input, + **tea_cache_nega, **controlnet_kwargs + ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: noise_pred = noise_pred_posi @@ -340,14 +367,29 @@ class TeaCache: def model_fn_wan_video( dit: WanModel, - x: torch.Tensor, - timestep: torch.Tensor, - context: torch.Tensor, + controlnet: WanControlNetModel = None, + x: torch.Tensor = None, + timestep: torch.Tensor = None, + context: torch.Tensor = None, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, tea_cache: TeaCache = None, + controlnet_conditioning: Optional[torch.Tensor] = None, + use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, **kwargs, ): + # ControlNet + if controlnet is not None and controlnet_conditioning is not None: + controlnet_res_stack = controlnet( + x, timestep=timestep, context=context, clip_feature=clip_feature, y=y, + controlnet_conditioning=controlnet_conditioning, + use_gradient_checkpointing=use_gradient_checkpointing, + use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, + ) + else: + controlnet_res_stack = None + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) context = dit.text_embedding(context) @@ -370,13 +412,35 @@ def model_fn_wan_video( tea_cache_update = tea_cache.check(dit, x, t_mod) else: tea_cache_update = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward if tea_cache_update: x = tea_cache.update(x) else: # blocks - for block in dit.blocks: - x = block(x, context, t_mod, freqs) + for block_id, block in enumerate(dit.blocks): + if dit.training and use_gradient_checkpointing: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + else: + x = block(x, context, t_mod, freqs) + if controlnet_res_stack is not None: + x = x + controlnet_res_stack[block_id] if tea_cache is not None: tea_cache.store(x) diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py index e7b1495..ed58890 100644 --- a/examples/wanvideo/train_wan_t2v.py +++ b/examples/wanvideo/train_wan_t2v.py @@ -14,7 +14,10 @@ import numpy as np class TextVideoDataset(torch.utils.data.Dataset): def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, target_fps=None): metadata = pd.read_csv(metadata_path) - self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + if os.path.exists(os.path.join(base_path, "train")): + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + else: + self.path = [os.path.join(base_path, file_name) for file_name in metadata["file_name"]] self.text = metadata["text"].to_list() self.max_num_frames = max_num_frames @@ -359,6 +362,12 @@ def parse_args(): default="./", help="Path to save the model.", ) + parser.add_argument( + "--metadata_path", + type=str, + default=None, + help="Path to metadata.csv.", + ) parser.add_argument( "--redirected_tensor_path", type=str, @@ -548,7 +557,7 @@ def parse_args(): def data_process(args): dataset = TextVideoDataset( args.dataset_path, - os.path.join(args.dataset_path, "metadata.csv"), + os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path, max_num_frames=args.num_frames, frame_interval=1, num_frames=args.num_frames, @@ -584,7 +593,7 @@ def data_process(args): def train(args): dataset = TensorDataset( args.dataset_path, - os.path.join(args.dataset_path, "metadata.csv"), + os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path, steps_per_epoch=args.steps_per_epoch, redirected_tensor_path=args.redirected_tensor_path, ) diff --git a/examples/wanvideo/train_wan_t2v_controlnet.py b/examples/wanvideo/train_wan_t2v_controlnet.py new file mode 100644 index 0000000..b8eb903 --- /dev/null +++ b/examples/wanvideo/train_wan_t2v_controlnet.py @@ -0,0 +1,626 @@ +import torch, os, imageio, argparse +from torchvision.transforms import v2 +from einops import rearrange +import lightning as pl +import pandas as pd +from diffsynth import WanVideoPipeline, ModelManager, load_state_dict +from peft import LoraConfig, inject_adapter_in_model +import torchvision +from PIL import Image +import numpy as np +from diffsynth.models.wan_video_controlnet import WanControlNetModel +from diffsynth.pipelines.wan_video import model_fn_wan_video + + + +class TextVideoDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, target_fps=None): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + self.controlnet_path = [os.path.join(base_path, file_name) for file_name in metadata["controlnet_file_name"]] + self.text = metadata["text"].to_list() + + self.max_num_frames = max_num_frames + self.frame_interval = frame_interval + self.num_frames = num_frames + self.height = height + self.width = width + self.is_i2v = is_i2v + self.target_fps = target_fps + + self.frame_process = v2.Compose([ + v2.CenterCrop(size=(height, width)), + v2.Resize(size=(height, width), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + + def crop_and_resize(self, image): + width, height = image.size + scale = max(self.width / width, self.height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + return image + + + def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process): + reader = imageio.get_reader(file_path) + if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: + reader.close() + return None + + frames = [] + first_frame = None + for frame_id in range(num_frames): + frame = reader.get_data(start_frame_id + frame_id * interval) + frame = Image.fromarray(frame) + frame = self.crop_and_resize(frame) + if first_frame is None: + first_frame = np.array(frame) + frame = frame_process(frame) + frames.append(frame) + reader.close() + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + + if self.is_i2v: + return frames, first_frame + else: + return frames + + + def load_video(self, file_path): + start_frame_id = 0 + if self.target_fps is None: + frame_interval = self.frame_interval + else: + reader = imageio.get_reader(file_path) + fps = reader.get_meta_data()["fps"] + reader.close() + frame_interval = max(round(fps / self.target_fps), 1) + frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, frame_interval, self.num_frames, self.frame_process) + return frames + + + def is_image(self, file_path): + file_ext_name = file_path.split(".")[-1] + if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]: + return True + return False + + + def load_image(self, file_path): + frame = Image.open(file_path).convert("RGB") + frame = self.crop_and_resize(frame) + frame = self.frame_process(frame) + frame = rearrange(frame, "C H W -> C 1 H W") + return frame + + + def __getitem__(self, data_id): + text = self.text[data_id] + path = self.path[data_id] + controlnet_path = self.controlnet_path[data_id] + try: + if self.is_image(path): + if self.is_i2v: + raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.") + video = self.load_image(path) + else: + video = self.load_video(path) + controlnet_frames = self.load_video(controlnet_path) + if self.is_i2v: + video, first_frame = video + data = {"text": text, "video": video, "path": path, "first_frame": first_frame} + else: + data = {"text": text, "video": video, "path": path, "controlnet_frames": controlnet_frames} + except: + data = None + return data + + + def __len__(self): + return len(self.path) + + + +class LightningModelForDataProcess(pl.LightningModule): + def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16), redirected_tensor_path=None): + super().__init__() + model_path = [text_encoder_path, vae_path] + if image_encoder_path is not None: + model_path.append(image_encoder_path) + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models(model_path) + self.pipe = WanVideoPipeline.from_model_manager(model_manager) + + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + self.redirected_tensor_path = redirected_tensor_path + + def test_step(self, batch, batch_idx): + data = batch[0] + if data is None or data["video"] is None: + return + text, video, path = data["text"], data["video"].unsqueeze(0), data["path"] + controlnet_frames = data["controlnet_frames"].unsqueeze(0) + + self.pipe.device = self.device + if video is not None: + # prompt + prompt_emb = self.pipe.encode_prompt(text) + # video + video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0] + # ControlNet video + controlnet_frames = controlnet_frames.to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + controlnet_kwargs = self.pipe.prepare_controlnet(controlnet_frames, **self.tiler_kwargs) + controlnet_kwargs["controlnet_conditioning"] = controlnet_kwargs["controlnet_conditioning"][0] + # image + if "first_frame" in batch: + first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy()) + _, _, num_frames, height, width = video.shape + image_emb = self.pipe.encode_image(first_frame, num_frames, height, width) + else: + image_emb = {} + data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb, "controlnet_kwargs": controlnet_kwargs} + if self.redirected_tensor_path is not None: + path = path.replace("/", "_").replace("\\", "_") + path = os.path.join(self.redirected_tensor_path, path) + torch.save(data, path + ".tensors.pth") + + + +class TensorDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path=None, steps_per_epoch=1000, redirected_tensor_path=None): + if os.path.exists(metadata_path): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + print(len(self.path), "videos in metadata.") + if redirected_tensor_path is None: + self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")] + else: + cached_path = [] + for path in self.path: + path = path.replace("/", "_").replace("\\", "_") + path = os.path.join(redirected_tensor_path, path) + if os.path.exists(path + ".tensors.pth"): + cached_path.append(path + ".tensors.pth") + self.path = cached_path + else: + print("Cannot find metadata.csv. Trying to search for tensor files.") + self.path = [os.path.join(base_path, i) for i in os.listdir(base_path) if i.endswith(".tensors.pth")] + print(len(self.path), "tensors cached in metadata.") + assert len(self.path) > 0 + + self.steps_per_epoch = steps_per_epoch + self.redirected_tensor_path = redirected_tensor_path + + + def __getitem__(self, index): + while True: + try: + data_id = torch.randint(0, len(self.path), (1,))[0] + data_id = (data_id + index) % len(self.path) # For fixed seed. + path = self.path[data_id] + data = torch.load(path, weights_only=True, map_location="cpu") + return data + except: + continue + + + def __len__(self): + return self.steps_per_epoch + + + +class LightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", + use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, + pretrained_lora_path=None + ): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + + self.pipe = WanVideoPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + self.freeze_parameters() + + state_dict = load_state_dict(dit_path, torch_dtype=torch.bfloat16) + state_dict, config = WanControlNetModel.state_dict_converter().from_base_model(state_dict) + self.pipe.controlnet = WanControlNetModel(**config).to(torch.bfloat16) + self.pipe.controlnet.load_state_dict(state_dict) + self.pipe.controlnet.train() + self.pipe.controlnet.requires_grad_(True) + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + + def freeze_parameters(self): + # Freeze parameters + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.denoising_model().train() + + + def training_step(self, batch, batch_idx): + # Data + latents = batch["latents"].to(self.device) + controlnet_kwargs = batch["controlnet_kwargs"] + controlnet_kwargs["controlnet_conditioning"] = controlnet_kwargs["controlnet_conditioning"].to(self.device) + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + # Loss + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + extra_input = self.pipe.prepare_extra_input(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + + # Compute loss + noise_pred = model_fn_wan_video( + dit=self.pipe.dit, controlnet=self.pipe.controlnet, + x=noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb, **controlnet_kwargs, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + + # Record log + self.log("train_loss", loss, prog_bar=True) + return loss + + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.controlnet.parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + + def on_save_checkpoint(self, checkpoint): + checkpoint.clear() + trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.controlnet.named_parameters())) + trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) + state_dict = self.pipe.controlnet.state_dict() + lora_state_dict = {} + for name, param in state_dict.items(): + if name in trainable_param_names: + lora_state_dict[name] = param + checkpoint.update(lora_state_dict) + + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--task", + type=str, + default="data_process", + required=True, + choices=["data_process", "train"], + help="Task. `data_process` or `train`.", + ) + parser.add_argument( + "--dataset_path", + type=str, + default=None, + required=True, + help="The path of the Dataset.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./", + help="Path to save the model.", + ) + parser.add_argument( + "--metadata_path", + type=str, + default=None, + help="Path to metadata.csv.", + ) + parser.add_argument( + "--redirected_tensor_path", + type=str, + default=None, + help="Path to save cached tensors.", + ) + parser.add_argument( + "--text_encoder_path", + type=str, + default=None, + help="Path of text encoder.", + ) + parser.add_argument( + "--image_encoder_path", + type=str, + default=None, + help="Path of image encoder.", + ) + parser.add_argument( + "--vae_path", + type=str, + default=None, + help="Path of VAE.", + ) + parser.add_argument( + "--dit_path", + type=str, + default=None, + help="Path of DiT.", + ) + parser.add_argument( + "--tiled", + default=False, + action="store_true", + help="Whether enable tile encode in VAE. This option can reduce VRAM required.", + ) + parser.add_argument( + "--tile_size_height", + type=int, + default=34, + help="Tile size (height) in VAE.", + ) + parser.add_argument( + "--tile_size_width", + type=int, + default=34, + help="Tile size (width) in VAE.", + ) + parser.add_argument( + "--tile_stride_height", + type=int, + default=18, + help="Tile stride (height) in VAE.", + ) + parser.add_argument( + "--tile_stride_width", + type=int, + default=16, + help="Tile stride (width) in VAE.", + ) + parser.add_argument( + "--steps_per_epoch", + type=int, + default=500, + help="Number of steps per epoch.", + ) + parser.add_argument( + "--num_frames", + type=int, + default=81, + help="Number of frames.", + ) + parser.add_argument( + "--target_fps", + type=int, + default=None, + help="Expected FPS for sampling frames.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="Image height.", + ) + parser.add_argument( + "--width", + type=int, + default=832, + help="Image width.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=1, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Learning rate.", + ) + parser.add_argument( + "--accumulate_grad_batches", + type=int, + default=1, + help="The number of batches in gradient accumulation.", + ) + parser.add_argument( + "--max_epochs", + type=int, + default=1, + help="Number of epochs.", + ) + parser.add_argument( + "--lora_target_modules", + type=str, + default="q,k,v,o,ffn.0,ffn.2", + help="Layers with LoRA modules.", + ) + parser.add_argument( + "--init_lora_weights", + type=str, + default="kaiming", + choices=["gaussian", "kaiming"], + help="The initializing method of LoRA weight.", + ) + parser.add_argument( + "--training_strategy", + type=str, + default="auto", + choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"], + help="Training strategy", + ) + parser.add_argument( + "--lora_rank", + type=int, + default=4, + help="The dimension of the LoRA update matrices.", + ) + parser.add_argument( + "--lora_alpha", + type=float, + default=4.0, + help="The weight of the LoRA update matrices.", + ) + parser.add_argument( + "--use_gradient_checkpointing", + default=False, + action="store_true", + help="Whether to use gradient checkpointing.", + ) + parser.add_argument( + "--use_gradient_checkpointing_offload", + default=False, + action="store_true", + help="Whether to use gradient checkpointing offload.", + ) + parser.add_argument( + "--train_architecture", + type=str, + default="lora", + choices=["lora", "full"], + help="Model structure to train. LoRA training or full training.", + ) + parser.add_argument( + "--pretrained_lora_path", + type=str, + default=None, + help="Pretrained LoRA path. Required if the training is resumed.", + ) + parser.add_argument( + "--use_swanlab", + default=False, + action="store_true", + help="Whether to use SwanLab logger.", + ) + parser.add_argument( + "--swanlab_mode", + default=None, + help="SwanLab mode (cloud or local).", + ) + args = parser.parse_args() + return args + + +def data_process(args): + dataset = TextVideoDataset( + args.dataset_path, + os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path, + max_num_frames=args.num_frames, + frame_interval=1, + num_frames=args.num_frames, + height=args.height, + width=args.width, + is_i2v=args.image_encoder_path is not None, + target_fps=args.target_fps, + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=False, + batch_size=1, + num_workers=args.dataloader_num_workers, + collate_fn=lambda x: x, + ) + model = LightningModelForDataProcess( + text_encoder_path=args.text_encoder_path, + image_encoder_path=args.image_encoder_path, + vae_path=args.vae_path, + tiled=args.tiled, + tile_size=(args.tile_size_height, args.tile_size_width), + tile_stride=(args.tile_stride_height, args.tile_stride_width), + redirected_tensor_path=args.redirected_tensor_path, + ) + trainer = pl.Trainer( + accelerator="gpu", + devices="auto", + default_root_dir=args.output_path, + ) + trainer.test(model, dataloader) + + +def train(args): + dataset = TensorDataset( + args.dataset_path, + os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path, + steps_per_epoch=args.steps_per_epoch, + redirected_tensor_path=args.redirected_tensor_path, + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + model = LightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + train_architecture=args.train_architecture, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_target_modules=args.lora_target_modules, + init_lora_weights=args.init_lora_weights, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + pretrained_lora_path=args.pretrained_lora_path, + ) + if args.use_swanlab: + from swanlab.integration.pytorch_lightning import SwanLabLogger + swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} + swanlab_config.update(vars(args)) + swanlab_logger = SwanLabLogger( + project="wan", + name="wan", + config=swanlab_config, + mode=args.swanlab_mode, + logdir=os.path.join(args.output_path, "swanlog"), + ) + logger = [swanlab_logger] + else: + logger = None + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], + logger=logger, + ) + trainer.fit(model, dataloader) + + +if __name__ == '__main__': + args = parse_args() + if args.task == "data_process": + data_process(args) + elif args.task == "train": + train(args)