diff --git a/diffsynth/models/wan_video_motion_controller.py b/diffsynth/models/wan_video_motion_controller.py new file mode 100644 index 0000000..2743719 --- /dev/null +++ b/diffsynth/models/wan_video_motion_controller.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn +from .wan_video_dit import sinusoidal_embedding_1d + + + +class WanMotionControllerModel(torch.nn.Module): + def __init__(self, freq_dim=256, dim=1536): + super().__init__() + self.freq_dim = freq_dim + self.linear = nn.Sequential( + nn.Linear(freq_dim, dim), + nn.SiLU(), + nn.Linear(dim, dim), + nn.SiLU(), + nn.Linear(dim, dim * 6), + ) + + def forward(self, motion_bucket_id): + emb = sinusoidal_embedding_1d(self.freq_dim, motion_bucket_id * 10) + emb = self.linear(emb) + return emb + + def init(self): + state_dict = self.linear[-1].state_dict() + state_dict = {i: state_dict[i] * 0 for i in state_dict} + self.linear[-1].load_state_dict(state_dict) diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 2c6f640..7f175f4 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -18,6 +18,7 @@ from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample from ..models.wan_video_controlnet import WanControlNetModel +from ..models.wan_video_motion_controller import WanMotionControllerModel @@ -32,7 +33,8 @@ class WanVideoPipeline(BasePipeline): self.dit: WanModel = None self.vae: WanVideoVAE = None self.controlnet: WanControlNetModel = None - self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'controlnet'] + self.motion_controller: WanMotionControllerModel = None + self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'controlnet', 'motion_controller'] self.height_division_factor = 16 self.width_division_factor = 16 @@ -196,6 +198,11 @@ class WanVideoPipeline(BasePipeline): def prepare_controlnet(self, controlnet_frames, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)): controlnet_conditioning = self.encode_video(controlnet_frames, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device) return {"controlnet_conditioning": controlnet_conditioning} + + + def prepare_motion_bucket_id(self, motion_bucket_id): + motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device) + return {"motion_bucket_id": motion_bucket_id} @torch.no_grad() @@ -214,6 +221,7 @@ class WanVideoPipeline(BasePipeline): cfg_scale=5.0, num_inference_steps=50, sigma_shift=5.0, + motion_bucket_id=None, tiled=True, tile_size=(30, 52), tile_stride=(15, 26), @@ -269,6 +277,12 @@ class WanVideoPipeline(BasePipeline): else: controlnet_kwargs = {} + # Motion Controller + if self.motion_controller is not None and motion_bucket_id is not None: + motion_kwargs = self.prepare_motion_bucket_id(motion_bucket_id) + else: + motion_kwargs = {} + # Extra input extra_input = self.prepare_extra_input(latents) @@ -277,23 +291,23 @@ class WanVideoPipeline(BasePipeline): tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} # Denoise - self.load_models_to_device(["dit", "controlnet"]) + self.load_models_to_device(["dit", "controlnet", "motion_controller"]) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) # Inference noise_pred_posi = model_fn_wan_video( - self.dit, controlnet=self.controlnet, + self.dit, controlnet=self.controlnet, motion_controller=self.motion_controller, x=latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, - **tea_cache_posi, **controlnet_kwargs + **tea_cache_posi, **controlnet_kwargs, **motion_kwargs, ) if cfg_scale != 1.0: noise_pred_nega = model_fn_wan_video( - self.dit, controlnet=self.controlnet, + self.dit, controlnet=self.controlnet, motion_controller=self.motion_controller, x=latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, - **tea_cache_nega, **controlnet_kwargs + **tea_cache_nega, **controlnet_kwargs, **motion_kwargs, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: @@ -368,6 +382,7 @@ class TeaCache: def model_fn_wan_video( dit: WanModel, controlnet: WanControlNetModel = None, + motion_controller: WanMotionControllerModel = None, x: torch.Tensor = None, timestep: torch.Tensor = None, context: torch.Tensor = None, @@ -375,6 +390,7 @@ def model_fn_wan_video( y: Optional[torch.Tensor] = None, tea_cache: TeaCache = None, controlnet_conditioning: Optional[torch.Tensor] = None, + motion_bucket_id: Optional[torch.Tensor] = None, use_gradient_checkpointing: bool = False, use_gradient_checkpointing_offload: bool = False, **kwargs, @@ -392,6 +408,8 @@ def model_fn_wan_video( t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) + if motion_bucket_id is not None and motion_controller is not None: + t_mod = t_mod + motion_controller(motion_bucket_id).unflatten(1, (6, dit.dim)) context = dit.text_embedding(context) if dit.has_image_input: diff --git a/examples/wanvideo/train_wan_t2v_motion.py b/examples/wanvideo/train_wan_t2v_motion.py new file mode 100644 index 0000000..b2f90bf --- /dev/null +++ b/examples/wanvideo/train_wan_t2v_motion.py @@ -0,0 +1,691 @@ +import torch, os, imageio, argparse +from torchvision.transforms import v2 +from einops import rearrange +import lightning as pl +import pandas as pd +from diffsynth import WanVideoPipeline, ModelManager, load_state_dict +from diffsynth.models.wan_video_motion_controller import WanMotionControllerModel +from diffsynth.pipelines.wan_video import model_fn_wan_video +from peft import LoraConfig, inject_adapter_in_model +import torchvision +from PIL import Image +import numpy as np +from tqdm import tqdm + + + +class TextVideoDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False, target_fps=None): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, file_name) for file_name in metadata["file_name"]] + self.text = metadata["text"].to_list() + + self.max_num_frames = max_num_frames + self.frame_interval = frame_interval + self.num_frames = num_frames + self.height = height + self.width = width + self.is_i2v = is_i2v + self.target_fps = target_fps + + self.frame_process = v2.Compose([ + v2.CenterCrop(size=(height, width)), + v2.Resize(size=(height, width), antialias=True), + v2.ToTensor(), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + + + def crop_and_resize(self, image): + width, height = image.size + scale = max(self.width / width, self.height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + return image + + + def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process): + reader = imageio.get_reader(file_path) + if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: + reader.close() + return None + + frames = [] + first_frame = None + for frame_id in range(num_frames): + frame = reader.get_data(start_frame_id + frame_id * interval) + frame = Image.fromarray(frame) + frame = self.crop_and_resize(frame) + if first_frame is None: + first_frame = np.array(frame) + frame = frame_process(frame) + frames.append(frame) + reader.close() + + frames = torch.stack(frames, dim=0) + frames = rearrange(frames, "T C H W -> C T H W") + + if self.is_i2v: + return frames, first_frame + else: + return frames + + + def load_video(self, file_path): + start_frame_id = 0 + if self.target_fps is None: + frame_interval = self.frame_interval + else: + reader = imageio.get_reader(file_path) + fps = reader.get_meta_data()["fps"] + reader.close() + frame_interval = max(round(fps / self.target_fps), 1) + frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, frame_interval, self.num_frames, self.frame_process) + return frames + + + def is_image(self, file_path): + file_ext_name = file_path.split(".")[-1] + if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]: + return True + return False + + + def load_image(self, file_path): + frame = Image.open(file_path).convert("RGB") + frame = self.crop_and_resize(frame) + first_frame = frame + frame = self.frame_process(frame) + frame = rearrange(frame, "C H W -> C 1 H W") + return frame + + + def __getitem__(self, data_id): + text = self.text[data_id] + path = self.path[data_id] + try: + if self.is_image(path): + if self.is_i2v: + raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.") + video = self.load_image(path) + else: + video = self.load_video(path) + if self.is_i2v: + video, first_frame = video + data = {"text": text, "video": video, "path": path, "first_frame": first_frame} + else: + data = {"text": text, "video": video, "path": path} + except: + data = None + return data + + + def __len__(self): + return len(self.path) + + + +class LightningModelForDataProcess(pl.LightningModule): + def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16), redirected_tensor_path=None): + super().__init__() + model_path = [text_encoder_path, vae_path] + if image_encoder_path is not None: + model_path.append(image_encoder_path) + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + model_manager.load_models(model_path) + self.pipe = WanVideoPipeline.from_model_manager(model_manager) + + self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} + self.redirected_tensor_path = redirected_tensor_path + + def test_step(self, batch, batch_idx): + data = batch[0] + if data is None or data["video"] is None: + return + text, video, path = data["text"], data["video"].unsqueeze(0), data["path"] + + self.pipe.device = self.device + if video is not None: + # prompt + prompt_emb = self.pipe.encode_prompt(text) + # video + video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0] + # image + if "first_frame" in batch: + first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy()) + _, _, num_frames, height, width = video.shape + image_emb = self.pipe.encode_image(first_frame, num_frames, height, width) + else: + image_emb = {} + data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb} + if self.redirected_tensor_path is not None: + path = path.replace("/", "_").replace("\\", "_") + path = os.path.join(self.redirected_tensor_path, path) + torch.save(data, path + ".tensors.pth") + + + +class TensorDataset(torch.utils.data.Dataset): + def __init__(self, base_path, metadata_path=None, steps_per_epoch=1000, redirected_tensor_path=None): + if os.path.exists(metadata_path): + metadata = pd.read_csv(metadata_path) + self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] + print(len(self.path), "videos in metadata.") + if redirected_tensor_path is None: + self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")] + else: + cached_path = [] + for path in self.path: + path = path.replace("/", "_").replace("\\", "_") + path = os.path.join(redirected_tensor_path, path) + if os.path.exists(path + ".tensors.pth"): + cached_path.append(path + ".tensors.pth") + self.path = cached_path + else: + print("Cannot find metadata.csv. Trying to search for tensor files.") + self.path = [os.path.join(base_path, i) for i in os.listdir(base_path) if i.endswith(".tensors.pth")] + print(len(self.path), "tensors cached in metadata.") + assert len(self.path) > 0 + + self.steps_per_epoch = steps_per_epoch + self.redirected_tensor_path = redirected_tensor_path + + + def __getitem__(self, index): + while True: + try: + data_id = torch.randint(0, len(self.path), (1,))[0] + data_id = (data_id + index) % len(self.path) # For fixed seed. + path = self.path[data_id] + data = torch.load(path, weights_only=True, map_location="cpu") + return data + except: + continue + + + def __len__(self): + return self.steps_per_epoch + + + +class LightningModelForTrain(pl.LightningModule): + def __init__( + self, + dit_path, + learning_rate=1e-5, + lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", + use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, + pretrained_lora_path=None + ): + super().__init__() + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") + if os.path.isfile(dit_path): + model_manager.load_models([dit_path]) + else: + dit_path = dit_path.split(",") + model_manager.load_models([dit_path]) + + self.pipe = WanVideoPipeline.from_model_manager(model_manager) + self.pipe.scheduler.set_timesteps(1000, training=True) + self.freeze_parameters() + + self.pipe.motion_controller = WanMotionControllerModel().to(torch.bfloat16) + self.pipe.motion_controller.init() + self.pipe.motion_controller.requires_grad_(True) + self.pipe.motion_controller.train() + self.motion_bucket_manager = MotionBucketManager() + + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + + + def freeze_parameters(self): + # Freeze parameters + self.pipe.requires_grad_(False) + self.pipe.eval() + self.pipe.dit.train() + + + def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None): + # Add LoRA to UNet + self.lora_alpha = lora_alpha + if init_lora_weights == "kaiming": + init_lora_weights = True + + lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + init_lora_weights=init_lora_weights, + target_modules=lora_target_modules.split(","), + ) + model = inject_adapter_in_model(lora_config, model) + for param in model.parameters(): + # Upcast LoRA parameters into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + + # Lora pretrained lora weights + if pretrained_lora_path is not None: + state_dict = load_state_dict(pretrained_lora_path) + if state_dict_converter is not None: + state_dict = state_dict_converter(state_dict) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + all_keys = [i for i, _ in model.named_parameters()] + num_updated_keys = len(all_keys) - len(missing_keys) + num_unexpected_keys = len(unexpected_keys) + print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.") + + + def training_step(self, batch, batch_idx): + # Data + latents = batch["latents"].to(self.device) + prompt_emb = batch["prompt_emb"] + prompt_emb["context"] = prompt_emb["context"][0].to(self.device) + image_emb = batch["image_emb"] + if "clip_feature" in image_emb: + image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) + if "y" in image_emb: + image_emb["y"] = image_emb["y"][0].to(self.device) + + # Loss + self.pipe.device = self.device + noise = torch.randn_like(latents) + timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) + timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) + extra_input = self.pipe.prepare_extra_input(latents) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + training_target = self.pipe.scheduler.training_target(latents, noise, timestep) + motion_bucket_id = self.motion_bucket_manager(latents) + motion_bucket_kwargs = self.pipe.prepare_motion_bucket_id(motion_bucket_id) + + # Compute loss + noise_pred = model_fn_wan_video( + dit=self.pipe.dit, motion_controller=self.pipe.motion_controller, + x=noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb, **motion_bucket_kwargs, + use_gradient_checkpointing=self.use_gradient_checkpointing, + use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload + ) + loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) + loss = loss * self.pipe.scheduler.training_weight(timestep) + + # Record log + self.log("train_loss", loss, prog_bar=True) + return loss + + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.motion_controller.parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + + def on_save_checkpoint(self, checkpoint): + checkpoint.clear() + trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.motion_controller.named_parameters())) + trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) + state_dict = self.pipe.motion_controller.state_dict() + lora_state_dict = {} + for name, param in state_dict.items(): + if name in trainable_param_names: + lora_state_dict[name] = param + checkpoint.update(lora_state_dict) + + + +class MotionBucketManager: + def __init__(self): + self.thresholds = [ + 0.093750000, 0.094726562, 0.100585938, 0.100585938, 0.108886719, 0.109375000, 0.118652344, 0.127929688, 0.127929688, 0.130859375, + 0.133789062, 0.137695312, 0.138671875, 0.138671875, 0.139648438, 0.143554688, 0.143554688, 0.147460938, 0.149414062, 0.149414062, + 0.152343750, 0.153320312, 0.154296875, 0.154296875, 0.157226562, 0.163085938, 0.163085938, 0.164062500, 0.165039062, 0.166992188, + 0.173828125, 0.179687500, 0.180664062, 0.184570312, 0.187500000, 0.188476562, 0.188476562, 0.189453125, 0.189453125, 0.202148438, + 0.206054688, 0.210937500, 0.210937500, 0.211914062, 0.214843750, 0.214843750, 0.216796875, 0.216796875, 0.216796875, 0.218750000, + 0.218750000, 0.221679688, 0.222656250, 0.227539062, 0.229492188, 0.230468750, 0.236328125, 0.243164062, 0.243164062, 0.245117188, + 0.253906250, 0.253906250, 0.255859375, 0.259765625, 0.275390625, 0.275390625, 0.277343750, 0.279296875, 0.279296875, 0.279296875, + 0.292968750, 0.292968750, 0.302734375, 0.306640625, 0.312500000, 0.312500000, 0.326171875, 0.330078125, 0.332031250, 0.332031250, + 0.337890625, 0.343750000, 0.343750000, 0.351562500, 0.355468750, 0.357421875, 0.361328125, 0.367187500, 0.382812500, 0.388671875, + 0.392578125, 0.392578125, 0.392578125, 0.404296875, 0.404296875, 0.425781250, 0.433593750, 0.507812500, 0.519531250, 0.539062500, + ] + + def get_motion_score(self, frames): + score = frames[:, :, 1:, :, :].std(dim=2).mean().tolist() + return score + + def get_bucket_id(self, motion_score): + for bucket_id in range(len(self.thresholds) - 1): + if self.thresholds[bucket_id + 1] > motion_score: + return bucket_id + return len(self.thresholds) + + def __call__(self, frames): + score = self.get_motion_score(frames) + bucket_id = self.get_bucket_id(score) + return bucket_id + + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--task", + type=str, + default="data_process", + required=True, + choices=["data_process", "train"], + help="Task. `data_process` or `train`.", + ) + parser.add_argument( + "--dataset_path", + type=str, + default=None, + required=True, + help="The path of the Dataset.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./", + help="Path to save the model.", + ) + parser.add_argument( + "--metadata_path", + type=str, + default=None, + help="Path to metadata.csv.", + ) + parser.add_argument( + "--redirected_tensor_path", + type=str, + default=None, + help="Path to save cached tensors.", + ) + parser.add_argument( + "--text_encoder_path", + type=str, + default=None, + help="Path of text encoder.", + ) + parser.add_argument( + "--image_encoder_path", + type=str, + default=None, + help="Path of image encoder.", + ) + parser.add_argument( + "--vae_path", + type=str, + default=None, + help="Path of VAE.", + ) + parser.add_argument( + "--dit_path", + type=str, + default=None, + help="Path of DiT.", + ) + parser.add_argument( + "--tiled", + default=False, + action="store_true", + help="Whether enable tile encode in VAE. This option can reduce VRAM required.", + ) + parser.add_argument( + "--tile_size_height", + type=int, + default=34, + help="Tile size (height) in VAE.", + ) + parser.add_argument( + "--tile_size_width", + type=int, + default=34, + help="Tile size (width) in VAE.", + ) + parser.add_argument( + "--tile_stride_height", + type=int, + default=18, + help="Tile stride (height) in VAE.", + ) + parser.add_argument( + "--tile_stride_width", + type=int, + default=16, + help="Tile stride (width) in VAE.", + ) + parser.add_argument( + "--steps_per_epoch", + type=int, + default=500, + help="Number of steps per epoch.", + ) + parser.add_argument( + "--num_frames", + type=int, + default=81, + help="Number of frames.", + ) + parser.add_argument( + "--target_fps", + type=int, + default=None, + help="Expected FPS for sampling frames.", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="Image height.", + ) + parser.add_argument( + "--width", + type=int, + default=832, + help="Image width.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=1, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Learning rate.", + ) + parser.add_argument( + "--accumulate_grad_batches", + type=int, + default=1, + help="The number of batches in gradient accumulation.", + ) + parser.add_argument( + "--max_epochs", + type=int, + default=1, + help="Number of epochs.", + ) + parser.add_argument( + "--lora_target_modules", + type=str, + default="q,k,v,o,ffn.0,ffn.2", + help="Layers with LoRA modules.", + ) + parser.add_argument( + "--init_lora_weights", + type=str, + default="kaiming", + choices=["gaussian", "kaiming"], + help="The initializing method of LoRA weight.", + ) + parser.add_argument( + "--training_strategy", + type=str, + default="auto", + choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"], + help="Training strategy", + ) + parser.add_argument( + "--lora_rank", + type=int, + default=4, + help="The dimension of the LoRA update matrices.", + ) + parser.add_argument( + "--lora_alpha", + type=float, + default=4.0, + help="The weight of the LoRA update matrices.", + ) + parser.add_argument( + "--use_gradient_checkpointing", + default=False, + action="store_true", + help="Whether to use gradient checkpointing.", + ) + parser.add_argument( + "--use_gradient_checkpointing_offload", + default=False, + action="store_true", + help="Whether to use gradient checkpointing offload.", + ) + parser.add_argument( + "--train_architecture", + type=str, + default="lora", + choices=["lora", "full"], + help="Model structure to train. LoRA training or full training.", + ) + parser.add_argument( + "--pretrained_lora_path", + type=str, + default=None, + help="Pretrained LoRA path. Required if the training is resumed.", + ) + parser.add_argument( + "--use_swanlab", + default=False, + action="store_true", + help="Whether to use SwanLab logger.", + ) + parser.add_argument( + "--swanlab_mode", + default=None, + help="SwanLab mode (cloud or local).", + ) + args = parser.parse_args() + return args + + +def data_process(args): + dataset = TextVideoDataset( + args.dataset_path, + os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path, + max_num_frames=args.num_frames, + frame_interval=1, + num_frames=args.num_frames, + height=args.height, + width=args.width, + is_i2v=args.image_encoder_path is not None, + target_fps=args.target_fps, + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=False, + batch_size=1, + num_workers=args.dataloader_num_workers, + collate_fn=lambda x: x, + ) + model = LightningModelForDataProcess( + text_encoder_path=args.text_encoder_path, + image_encoder_path=args.image_encoder_path, + vae_path=args.vae_path, + tiled=args.tiled, + tile_size=(args.tile_size_height, args.tile_size_width), + tile_stride=(args.tile_stride_height, args.tile_stride_width), + redirected_tensor_path=args.redirected_tensor_path, + ) + trainer = pl.Trainer( + accelerator="gpu", + devices="auto", + default_root_dir=args.output_path, + ) + trainer.test(model, dataloader) + + +def get_motion_thresholds(dataloader): + scores = [] + for data in tqdm(dataloader): + scores.append(data["latents"][:, :, 1:, :, :].std(dim=2).mean().tolist()) + scores = sorted(scores) + for i in range(100): + s = scores[int(i/100 * len(scores))] + print("%.9f" % s, end=", ") + + +def train(args): + dataset = TensorDataset( + args.dataset_path, + os.path.join(args.dataset_path, "metadata.csv") if args.metadata_path is None else args.metadata_path, + steps_per_epoch=args.steps_per_epoch, + redirected_tensor_path=args.redirected_tensor_path, + ) + dataloader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=1, + num_workers=args.dataloader_num_workers + ) + model = LightningModelForTrain( + dit_path=args.dit_path, + learning_rate=args.learning_rate, + train_architecture=args.train_architecture, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + lora_target_modules=args.lora_target_modules, + init_lora_weights=args.init_lora_weights, + use_gradient_checkpointing=args.use_gradient_checkpointing, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + pretrained_lora_path=args.pretrained_lora_path, + ) + if args.use_swanlab: + from swanlab.integration.pytorch_lightning import SwanLabLogger + swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} + swanlab_config.update(vars(args)) + swanlab_logger = SwanLabLogger( + project="wan", + name="wan", + config=swanlab_config, + mode=args.swanlab_mode, + logdir=os.path.join(args.output_path, "swanlog"), + ) + logger = [swanlab_logger] + else: + logger = None + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision="bf16", + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], + logger=logger, + ) + trainer.fit(model, dataloader) + + +if __name__ == '__main__': + args = parse_args() + if args.task == "data_process": + data_process(args) + elif args.task == "train": + train(args)