From 62f6ca2b8a820040d79b4512924904fec59e132b Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 6 Jun 2025 14:58:41 +0800 Subject: [PATCH] new wan trainer --- diffsynth/lora/__init__.py | 45 ++ diffsynth/models/utils.py | 14 +- diffsynth/pipelines/wan_video_new.py | 23 +- diffsynth/trainers/utils.py | 34 +- diffsynth/vram_management/layers.py | 14 +- examples/wanvideo/README.md | 302 +-------- examples/wanvideo/README_zh.md | 313 +++++++++ ...trol.py => Wan2.1-1.3b-speedcontrol-v1.py} | 0 ..._14b_flf2v.py => Wan2.1-FLF2V-14B-720P.py} | 0 ..._control.py => Wan2.1-Fun-1.3B-Control.py} | 0 ...fun_1.3b_InP.py => Wan2.1-Fun-1.3B-InP.py} | 0 ...b_control.py => Wan2.1-Fun-14B-Control.py} | 0 ...n_fun_14b_InP.py => Wan2.1-Fun-14B-InP.py} | 0 ...rol.py => Wan2.1-Fun-V1.1-1.3B-Control.py} | 0 ...trol.py => Wan2.1-Fun-V1.1-14B-Control.py} | 0 ...o_video_480p.py => Wan2.1-I2V-14B-480P.py} | 0 ...o_video_720p.py => Wan2.1-I2V-14B-720P.py} | 0 ...3b_text_to_video.py => Wan2.1-T2V-1.3B.py} | 0 ...14b_text_to_video.py => Wan2.1-T2V-14B.py} | 0 ...3b_vace.py => Wan2.1-VACE-1.3B-Preview.py} | 0 .../full/Wan2.1-1.3b-speedcontrol-v1.sh | 13 + .../full/Wan2.1-FLF2V-14B-720P.sh | 14 + .../full/Wan2.1-Fun-1.3B-Control.sh | 14 + .../full/Wan2.1-Fun-1.3B-InP.sh | 14 + .../full/Wan2.1-Fun-14B-Control.sh | 14 + .../model_training/full/Wan2.1-Fun-14B-InP.sh | 14 + .../full/Wan2.1-Fun-V1.1-1.3B-Control.sh | 15 + .../full/Wan2.1-Fun-V1.1-14B-Control.sh | 15 + .../full/Wan2.1-I2V-14B-480P.sh | 13 + .../full/Wan2.1-I2V-14B-720P.sh | 13 + .../model_training/full/Wan2.1-T2V-1.3B.sh | 12 + .../model_training/full/Wan2.1-T2V-14B.sh | 12 + .../full/accelerate_config_14B.yaml | 22 + .../wanvideo/model_training/full/run_test.py | 38 ++ .../lora/Wan2.1-1.3b-speedcontrol-v1.sh | 15 + .../lora/Wan2.1-FLF2V-14B-720P.sh | 16 + .../lora/Wan2.1-Fun-1.3B-Control.sh | 16 + .../lora/Wan2.1-Fun-1.3B-InP.sh | 16 + .../lora/Wan2.1-Fun-14B-Control.sh | 16 + .../model_training/lora/Wan2.1-Fun-14B-InP.sh | 16 + .../lora/Wan2.1-Fun-V1.1-1.3B-Control.sh | 17 + .../lora/Wan2.1-Fun-V1.1-14B-Control.sh | 17 + .../lora/Wan2.1-I2V-14B-480P.sh | 15 + .../lora/Wan2.1-I2V-14B-720P.sh | 15 + .../model_training/lora/Wan2.1-T2V-1.3B.sh | 14 + .../model_training/lora/Wan2.1-T2V-14B.sh | 14 + .../wanvideo/model_training/lora/run_test.py | 25 + examples/wanvideo/model_training/train.py | 129 ++++ examples/wanvideo/model_training/train_i2v.py | 54 -- examples/wanvideo/model_training/train_t2v.py | 53 -- .../Wan2.1-1.3b-speedcontrol-v1.py | 28 + .../validate_full/Wan2.1-FLF2V-14B-720P.py | 33 + .../validate_full/Wan2.1-Fun-1.3B-Control.py | 32 + .../validate_full/Wan2.1-Fun-1.3B-InP.py | 31 + .../validate_full/Wan2.1-Fun-14B-Control.py | 32 + .../validate_full/Wan2.1-Fun-14B-InP.py | 31 + .../Wan2.1-Fun-V1.1-1.3B-Control.py | 33 + .../Wan2.1-Fun-V1.1-14B-Control.py | 33 + .../validate_full/Wan2.1-I2V-14B-480P.py | 30 + .../validate_full/Wan2.1-I2V-14B-720P.py | 30 + .../validate_full/Wan2.1-T2V-1.3B.py | 25 + .../validate_full/Wan2.1-T2V-14B.py | 25 + .../model_training/validate_full/run_test.py | 25 + .../Wan2.1-1.3b-speedcontrol-v1.py | 27 + .../validate_lora/Wan2.1-FLF2V-14B-720P.py | 32 + .../validate_lora/Wan2.1-Fun-1.3B-Control.py | 31 + .../validate_lora/Wan2.1-Fun-1.3B-InP.py | 30 + .../validate_lora/Wan2.1-Fun-14B-Control.py | 31 + .../validate_lora/Wan2.1-Fun-14B-InP.py | 30 + .../Wan2.1-Fun-V1.1-1.3B-Control.py | 32 + .../Wan2.1-Fun-V1.1-14B-Control.py | 32 + .../validate_lora/Wan2.1-I2V-14B-480P.py | 29 + .../validate_lora/Wan2.1-I2V-14B-720P.py | 29 + .../validate_lora/Wan2.1-T2V-1.3B.py | 24 + .../validate_lora/Wan2.1-T2V-14B.py | 24 + .../model_training/validate_lora/run_test.py | 25 + examples/wanvideo/train_wan_t2v.py | 593 ------------------ .../wanvideo/wan_1.3b_motion_controller.py | 41 -- examples/wanvideo/wan_1.3b_text_to_video.py | 40 -- examples/wanvideo/wan_1.3b_vace.py | 63 -- examples/wanvideo/wan_14B_flf2v.py | 52 -- examples/wanvideo/wan_14b_image_to_video.py | 51 -- examples/wanvideo/wan_14b_text_to_video.py | 36 -- .../wan_14b_text_to_video_tensor_parallel.py | 149 ----- examples/wanvideo/wan_fun_InP.py | 42 -- examples/wanvideo/wan_fun_control.py | 40 -- .../wanvideo/wan_fun_reference_control.py | 35 -- 87 files changed, 1779 insertions(+), 1543 deletions(-) create mode 100644 diffsynth/lora/__init__.py create mode 100644 examples/wanvideo/README_zh.md rename examples/wanvideo/model_inference/{wan_1.3b_speed_control.py => Wan2.1-1.3b-speedcontrol-v1.py} (100%) rename examples/wanvideo/model_inference/{wan_14b_flf2v.py => Wan2.1-FLF2V-14B-720P.py} (100%) rename examples/wanvideo/model_inference/{wan_fun_1.3b_control.py => Wan2.1-Fun-1.3B-Control.py} (100%) rename examples/wanvideo/model_inference/{wan_fun_1.3b_InP.py => Wan2.1-Fun-1.3B-InP.py} (100%) rename examples/wanvideo/model_inference/{wan_fun_14b_control.py => Wan2.1-Fun-14B-Control.py} (100%) rename examples/wanvideo/model_inference/{wan_fun_14b_InP.py => Wan2.1-Fun-14B-InP.py} (100%) rename examples/wanvideo/model_inference/{wan_fun_v1.1_1.3b_reference_control.py => Wan2.1-Fun-V1.1-1.3B-Control.py} (100%) rename examples/wanvideo/model_inference/{wan_fun_v1.1_14b_reference_control.py => Wan2.1-Fun-V1.1-14B-Control.py} (100%) rename examples/wanvideo/model_inference/{wan_14b_image_to_video_480p.py => Wan2.1-I2V-14B-480P.py} (100%) rename examples/wanvideo/model_inference/{wan_14b_image_to_video_720p.py => Wan2.1-I2V-14B-720P.py} (100%) rename examples/wanvideo/model_inference/{wan_1.3b_text_to_video.py => Wan2.1-T2V-1.3B.py} (100%) rename examples/wanvideo/model_inference/{wan_14b_text_to_video.py => Wan2.1-T2V-14B.py} (100%) rename examples/wanvideo/model_inference/{wan_1.3b_vace.py => Wan2.1-VACE-1.3B-Preview.py} (100%) create mode 100644 examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh create mode 100644 examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh create mode 100644 examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh create mode 100644 examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh create mode 100644 examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh create mode 100644 examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh create mode 100644 examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh create mode 100644 examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh create mode 100644 examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh create mode 100644 examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh create mode 100644 examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh create mode 100644 examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh create mode 100644 examples/wanvideo/model_training/full/accelerate_config_14B.yaml create mode 100644 examples/wanvideo/model_training/full/run_test.py create mode 100644 examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh create mode 100644 examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh create mode 100644 examples/wanvideo/model_training/lora/run_test.py create mode 100644 examples/wanvideo/model_training/train.py delete mode 100644 examples/wanvideo/model_training/train_i2v.py delete mode 100644 examples/wanvideo/model_training/train_t2v.py create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py create mode 100644 examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py create mode 100644 examples/wanvideo/model_training/validate_full/run_test.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py create mode 100644 examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py create mode 100644 examples/wanvideo/model_training/validate_lora/run_test.py delete mode 100644 examples/wanvideo/train_wan_t2v.py delete mode 100644 examples/wanvideo/wan_1.3b_motion_controller.py delete mode 100644 examples/wanvideo/wan_1.3b_text_to_video.py delete mode 100644 examples/wanvideo/wan_1.3b_vace.py delete mode 100644 examples/wanvideo/wan_14B_flf2v.py delete mode 100644 examples/wanvideo/wan_14b_image_to_video.py delete mode 100644 examples/wanvideo/wan_14b_text_to_video.py delete mode 100644 examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py delete mode 100644 examples/wanvideo/wan_fun_InP.py delete mode 100644 examples/wanvideo/wan_fun_control.py delete mode 100644 examples/wanvideo/wan_fun_reference_control.py diff --git a/diffsynth/lora/__init__.py b/diffsynth/lora/__init__.py new file mode 100644 index 0000000..33bd89c --- /dev/null +++ b/diffsynth/lora/__init__.py @@ -0,0 +1,45 @@ +import torch + + + +class GeneralLoRALoader: + def __init__(self, device="cpu", torch_dtype=torch.float32): + self.device = device + self.torch_dtype = torch_dtype + + + def get_name_dict(self, lora_state_dict): + lora_name_dict = {} + for key in lora_state_dict: + if ".lora_B." not in key: + continue + keys = key.split(".") + if len(keys) > keys.index("lora_B") + 2: + keys.pop(keys.index("lora_B") + 1) + keys.pop(keys.index("lora_B")) + if keys[0] == "diffusion_model": + keys.pop(0) + keys.pop(-1) + target_name = ".".join(keys) + lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A.")) + return lora_name_dict + + + def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): + updated_num = 0 + lora_name_dict = self.get_name_dict(state_dict_lora) + for name, module in model.named_modules(): + if name in lora_name_dict: + weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype) + weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype) + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2) + weight_down = weight_down.squeeze(3).squeeze(2) + weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + weight_lora = alpha * torch.mm(weight_up, weight_down) + state_dict = module.state_dict() + state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora + module.load_state_dict(state_dict) + updated_num += 1 + print(f"{updated_num} tensors are updated by LoRA.") diff --git a/diffsynth/models/utils.py b/diffsynth/models/utils.py index 99f5dee..0d58e4e 100644 --- a/diffsynth/models/utils.py +++ b/diffsynth/models/utils.py @@ -62,16 +62,16 @@ def load_state_dict_from_folder(file_path, torch_dtype=None): return state_dict -def load_state_dict(file_path, torch_dtype=None): +def load_state_dict(file_path, torch_dtype=None, device="cpu"): if file_path.endswith(".safetensors"): - return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype) + return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device) else: - return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype) + return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device) -def load_state_dict_from_safetensors(file_path, torch_dtype=None): +def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"): state_dict = {} - with safe_open(file_path, framework="pt", device="cpu") as f: + with safe_open(file_path, framework="pt", device=device) as f: for k in f.keys(): state_dict[k] = f.get_tensor(k) if torch_dtype is not None: @@ -79,8 +79,8 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None): return state_dict -def load_state_dict_from_bin(file_path, torch_dtype=None): - state_dict = torch.load(file_path, map_location="cpu", weights_only=True) +def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"): + state_dict = torch.load(file_path, map_location=device, weights_only=True) if torch_dtype is not None: for i in state_dict: if isinstance(state_dict[i], torch.Tensor): diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 02c0fec..eb9dfba 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -11,7 +11,7 @@ from PIL import Image from tqdm import tqdm from typing import Optional -from ..models import ModelManager +from ..models import ModelManager, load_state_dict from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d from ..models.wan_video_text_encoder import WanTextEncoder, T5RelativeEmbedding, T5LayerNorm from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample @@ -21,6 +21,7 @@ from ..models.wan_video_motion_controller import WanMotionControllerModel from ..schedulers.flow_match import FlowMatchScheduler from ..prompters import WanPrompter from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm +from ..lora import GeneralLoRALoader @@ -137,7 +138,8 @@ class BasePipeline(torch.nn.Module): def enable_cpu_offload(self): - warnings.warn("`enable_cpu_offload` is deprecated. Please use `enable_vram_management`.") + warnings.warn("`enable_cpu_offload` will be deprecated. Please use `enable_vram_management`.") + self.vram_management_enabled = True def get_free_vram(self): @@ -183,7 +185,6 @@ class ModelConfig: self.path = self.path[0] - class WanVideoPipeline(BasePipeline): def __init__(self, device="cuda", torch_dtype=torch.bfloat16, tokenizer_path=None): @@ -216,6 +217,12 @@ class WanVideoPipeline(BasePipeline): ] self.model_fn = model_fn_wan_video + + def load_lora(self, module, path, alpha=1): + loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device) + lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device) + loader.load(module, lora, alpha=alpha) + def training_loss(self, **inputs): timestep_id = torch.randint(0, self.scheduler.num_train_timesteps, (1,)) @@ -946,6 +953,7 @@ def model_fn_wan_video( sliding_window_stride: Optional[int] = None, cfg_merge: bool = False, use_gradient_checkpointing: bool = False, + use_gradient_checkpointing_offload: bool = False, **kwargs, ): if sliding_window_size is not None and sliding_window_stride is not None: @@ -1036,7 +1044,14 @@ def model_fn_wan_video( return custom_forward for block_id, block in enumerate(dit.blocks): - if use_gradient_checkpointing: + if use_gradient_checkpointing_offload: + with torch.autograd.graph.save_on_cpu(): + x = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + x, context, t_mod, freqs, + use_reentrant=False, + ) + elif use_gradient_checkpointing: x = torch.utils.checkpoint.checkpoint( create_custom_forward(block), x, context, t_mod, freqs, diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index d306049..0a056a6 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -25,6 +25,7 @@ class VideoDataset(torch.utils.data.Dataset): metadata_path = args.dataset_metadata_path height = args.height width = args.width + num_frames = args.num_frames data_file_keys = args.data_file_keys.split(",") repeat = args.dataset_repeat @@ -205,27 +206,52 @@ def launch_training_task(model: DiffusionTrainingModule, dataset, learning_rate= accelerator.wait_for_everyone() if accelerator.is_main_process: state_dict = accelerator.get_state_dict(model) - state_dict = model.export_trainable_state_dict(state_dict, remove_prefix=remove_prefix_in_ckpt) + state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=remove_prefix_in_ckpt) + os.makedirs(output_path, exist_ok=True) path = os.path.join(output_path, f"epoch-{epoch}.safetensors") accelerator.save(state_dict, path, safe_serialization=True) +def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"): + dataloader = torch.utils.data.DataLoader(dataset, shuffle=False, collate_fn=lambda x: x[0]) + accelerator = Accelerator() + model, dataloader = accelerator.prepare(model, dataloader) + os.makedirs(os.path.join(output_path, "data_cache"), exist_ok=True) + for data_id, data in enumerate(tqdm(dataloader)): + with torch.no_grad(): + inputs = model.forward_preprocess(data) + inputs = {key: inputs[key] for key in model.model_input_keys if key in inputs} + torch.save(inputs, os.path.join(output_path, "data_cache", f"{data_id}.pth")) + + + def wan_parser(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument("--dataset_base_path", type=str, default="", help="Base path of the Dataset.") parser.add_argument("--dataset_metadata_path", type=str, default="", required=True, help="Metadata path of the Dataset.") parser.add_argument("--height", type=int, default=None, help="Image or video height. Leave `height` and `width` None to enable dynamic resolution.") parser.add_argument("--width", type=int, default=None, help="Image or video width. Leave `height` and `width` None to enable dynamic resolution.") + parser.add_argument("--num_frames", type=int, default=81, help="Number of frames in each video. The frames are sampled from the prefix.") parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in metadata. Separated by commas.") parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times the dataset is repeated in each epoch.") - parser.add_argument("--model_paths", type=str, default="", help="Model paths to be loaded. JSON format.") + parser.add_argument("--model_paths", type=str, default=None, help="Model paths to be loaded. JSON format.") + parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin path, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Separated by commas.") parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.") parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.") parser.add_argument("--output_path", type=str, default="./models", help="Save path.") parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.") - parser.add_argument("--task", type=str, default="train_lora", choices=["train_lora", "train_full"], help="Task.") - parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Layers with LoRA modules.") + parser.add_argument("--trainable_models", type=str, default=None, help="Trainable models, e.g., dit, vae, text_encoder.") + parser.add_argument("--lora_base_model", type=str, default=None, help="Add LoRA on which model.") + parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Add LoRA on which layer.") parser.add_argument("--lora_rank", type=int, default=32, help="LoRA rank.") + parser.add_argument("--input_contains_input_image", default=False, action="store_true", help="Model input contains 'input_image'.") + parser.add_argument("--input_contains_end_image", default=False, action="store_true", help="Model input contains 'end_image'.") + parser.add_argument("--input_contains_control_video", default=False, action="store_true", help="Model input contains 'control_video'.") + parser.add_argument("--input_contains_reference_image", default=False, action="store_true", help="Model input contains 'reference_image'.") + parser.add_argument("--input_contains_vace_video", default=False, action="store_true", help="Model input contains 'vace_video'.") + parser.add_argument("--input_contains_vace_reference_image", default=False, action="store_true", help="Model input contains 'vace_reference_image'.") + parser.add_argument("--input_contains_motion_bucket_id", default=False, action="store_true", help="Model input contains 'motion_bucket_id'.") + parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Offload gradient checkpointing to RAM.") return parser diff --git a/diffsynth/vram_management/layers.py b/diffsynth/vram_management/layers.py index 45e7433..dd4a245 100644 --- a/diffsynth/vram_management/layers.py +++ b/diffsynth/vram_management/layers.py @@ -33,7 +33,7 @@ class AutoTorchModule(torch.nn.Module): class AutoWrappedModule(AutoTorchModule): - def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit): + def __init__(self, module: torch.nn.Module, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, **kwargs): super().__init__() self.module = module.to(dtype=offload_dtype, device=offload_device) self.offload_dtype = offload_dtype @@ -60,7 +60,7 @@ class AutoWrappedModule(AutoTorchModule): class WanAutoCastLayerNorm(torch.nn.LayerNorm, AutoTorchModule): - def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit): + def __init__(self, module: torch.nn.LayerNorm, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, **kwargs): with init_weights_on_device(device=torch.device("meta")): super().__init__(module.normalized_shape, eps=module.eps, elementwise_affine=module.elementwise_affine, bias=module.bias is not None, dtype=offload_dtype, device=offload_device) self.weight = module.weight @@ -92,7 +92,7 @@ class WanAutoCastLayerNorm(torch.nn.LayerNorm, AutoTorchModule): class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): - def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit): + def __init__(self, module: torch.nn.Linear, offload_dtype, offload_device, onload_dtype, onload_device, computation_dtype, computation_device, vram_limit, name="", **kwargs): with init_weights_on_device(device=torch.device("meta")): super().__init__(in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, dtype=offload_dtype, device=offload_device) self.weight = module.weight @@ -105,6 +105,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): self.computation_device = computation_device self.vram_limit = vram_limit self.state = 0 + self.name = name def forward(self, x, *args, **kwargs): if self.state == 2: @@ -121,8 +122,9 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): return torch.nn.functional.linear(x, weight, bias) -def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None): +def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0, vram_limit=None, name_prefix=""): for name, module in model.named_children(): + layer_name = name if name_prefix == "" else name_prefix + "." + name for source_module, target_module in module_map.items(): if isinstance(module, source_module): num_param = sum(p.numel() for p in module.parameters()) @@ -130,12 +132,12 @@ def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config_ = overflow_module_config else: module_config_ = module_config - module_ = target_module(module, **module_config_, vram_limit=vram_limit) + module_ = target_module(module, **module_config_, vram_limit=vram_limit, name=layer_name) setattr(model, name, module_) total_num_param += num_param break else: - total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param, vram_limit=vram_limit) + total_num_param = enable_vram_management_recursively(module, module_map, module_config, max_num_param, overflow_module_config, total_num_param, vram_limit=vram_limit, name_prefix=layer_name) return total_num_param diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 92c3c59..46c9670 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -1,276 +1,32 @@ -# Wan-Video -Wan-Video is a collection of video synthesis models open-sourced by Alibaba. -Before using this model, please install DiffSynth-Studio from **source code**. +* dataset + * `--dataset_base_path`: Base path of the Dataset. + * `--dataset_metadata_path`: Metadata path of the Dataset. + * `--height`: Image or video height. Leave `height` and `width` None to enable dynamic resolution. + * `--width`: Image or video width. Leave `height` and `width` None to enable dynamic resolution. + * `--num_frames`: Number of frames in each video. The frames are sampled from the prefix. + * `--data_file_keys`: Data file keys in metadata. Separated by commas. + * `--dataset_repeat`: Number of times the dataset is repeated in each epoch. +* Model + * `--model_paths`: Model paths to be loaded. JSON format. + * `--model_id_with_origin_paths`: Model ID with original path, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Separated by commas. +* Training + * `--learning_rate`: Learning rate. + * `--num_epochs`: Number of epochs. + * `--output_path`: Save path. + * `--remove_prefix_in_ckpt`: Remove prefix in ckpt. +* Trainable module + * `--trainable_models`: Trainable models, e.g., dit, vae, text_encoder. + * `--lora_base_model`: Add LoRA on which model. + * `--lora_target_modules`: Add LoRA on which layer. + * `--lora_rank`: LoRA rank. +* Extra model input + * `--input_contains_input_image`: Model input contains `input_image` + * `--input_contains_end_image`: Model input contains `end_image`. + * `--input_contains_control_video`: Model input contains `control_video`. + * `--input_contains_reference_image`: Model input contains `reference_image`. + * `--input_contains_vace_video`: Model input contains `vace_video`. + * `--input_contains_vace_reference_image`: Model input contains `vace_reference_image`. + * `--input_contains_motion_bucket_id`: Model input contains `motion_bucket_id`. -```shell -git clone https://github.com/modelscope/DiffSynth-Studio.git -cd DiffSynth-Studio -pip install -e . -``` - -## Model Zoo - -|Developer|Name|Link|Scripts| -|-|-|-|-| -|Wan Team|1.3B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)|[wan_1.3b_text_to_video.py](./wan_1.3b_text_to_video.py)| -|Wan Team|14B text-to-video|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)|[wan_14b_text_to_video.py](./wan_14b_text_to_video.py)| -|Wan Team|14B image-to-video 480P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)| -|Wan Team|14B image-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|[wan_14b_image_to_video.py](./wan_14b_image_to_video.py)| -|Wan Team|14B first-last-frame-to-video 720P|[Link](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|[wan_14B_flf2v.py](./wan_14B_flf2v.py)| -|DiffSynth-Studio Team|1.3B aesthetics LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-aesthetics-v1).| -|DiffSynth-Studio Team|1.3B Highres-fix LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-highresfix-v1).| -|DiffSynth-Studio Team|1.3B ExVideo LoRA|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1)|Please see the [model card](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-lora-exvideo-v1).| -|DiffSynth-Studio Team|1.3B Speed Control adapter|[Link](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|[wan_1.3b_motion_controller.py](./wan_1.3b_motion_controller.py)| -|PAI Team|1.3B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)| -|PAI Team|14B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)| -|PAI Team|1.3B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|[wan_fun_control.py](./wan_fun_control.py)| -|PAI Team|14B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|[wan_fun_control.py](./wan_fun_control.py)| -|IIC Team|1.3B VACE|[Link](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|[wan_1.3b_vace.py](./wan_1.3b_vace.py)| - -Base model features - -||Text-to-video|Image-to-video|End frame|Control|Reference image| -|-|-|-|-|-|-| -|1.3B text-to-video|✅||||| -|14B text-to-video|✅||||| -|14B image-to-video 480P||✅|||| -|14B image-to-video 720P||✅|||| -|14B first-last-frame-to-video 720P||✅|✅||| -|1.3B InP||✅|✅||| -|14B InP||✅|✅||| -|1.3B Control||||✅|| -|14B Control||||✅|| -|1.3B VACE||||✅|✅| - -Adapter model compatibility - -||1.3B text-to-video|1.3B InP|1.3B VACE| -|-|-|-|-| -|1.3B aesthetics LoRA|✅||✅| -|1.3B Highres-fix LoRA|✅||✅| -|1.3B ExVideo LoRA|✅||✅| -|1.3B Speed Control adapter|✅|✅|✅| - -## VRAM Usage - -* Fine-grained offload: We recommend that users adjust the `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py). - -* FP8 Quantization: You only need to adjust the `torch_dtype` in the `ModelManager` (not the pipeline!). - -We present a detailed table here. The model (14B text-to-video) is tested on a single A100. - -|`torch_dtype`|`num_persistent_param_in_dit`|Speed|Required VRAM|Default Setting| -|-|-|-|-|-| -|torch.bfloat16|None (unlimited)|18.5s/it|48G|| -|torch.bfloat16|7*10**9 (7B)|20.8s/it|24G|| -|torch.bfloat16|0|23.4s/it|10G|| -|torch.float8_e4m3fn|None (unlimited)|18.3s/it|24G|yes| -|torch.float8_e4m3fn|0|24.0s/it|10G|| - -**We found that 14B image-to-video model is more sensitive to precision, so when the generated video content experiences issues such as artifacts, please switch to bfloat16 precision and use the `num_persistent_param_in_dit` parameter to control VRAM usage.** - -## Efficient Attention Implementation - -DiffSynth-Studio supports multiple Attention implementations. If you have installed any of the following Attention implementations, they will be enabled based on priority. However, we recommend to use the default torch SDPA. - -* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention) -* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention) -* [Sage Attention](https://github.com/thu-ml/SageAttention) -* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default. `torch>=2.5.0` is recommended.) - -## Acceleration - -We support multiple acceleration solutions: -* [TeaCache](https://github.com/ali-vilab/TeaCache): See [wan_1.3b_text_to_video_accelerate.py](./wan_1.3b_text_to_video_accelerate.py). - -* [Unified Sequence Parallel](https://github.com/xdit-project/xDiT): See [wan_14b_text_to_video_usp.py](./wan_14b_text_to_video_usp.py) - -```bash -pip install xfuser>=0.4.3 -torchrun --standalone --nproc_per_node=8 examples/wanvideo/wan_14b_text_to_video_usp.py -``` - -* Tensor Parallel: See [wan_14b_text_to_video_tensor_parallel.py](./wan_14b_text_to_video_tensor_parallel.py). - -## Gallery - -1.3B text-to-video. - -https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8 - -Put sunglasses on the dog. - -https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb - -14B text-to-video. - -https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f - -14B image-to-video. - -https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75 - -14B first-last-frame-to-video - -|First frame|Last frame|Video| -|-|-|-| -|![Image](https://github.com/user-attachments/assets/b0d8225b-aee0-4129-b8e5-58c8523221a6)|![Image](https://github.com/user-attachments/assets/2f0c9bc5-07e2-45fa-8320-53d63a4fd203)|https://github.com/user-attachments/assets/2a6a2681-622c-4512-b852-5f22e73830b1| - -## Train - -We support Wan-Video LoRA training and full training. Here is a tutorial. This is an experimental feature. Below is a video sample generated from the character Keqing LoRA: - -https://github.com/user-attachments/assets/9bd8e30b-97e8-44f9-bb6f-da004ba376a9 - -Step 1: Install additional packages - -``` -pip install peft lightning pandas -``` - -Step 2: Prepare your dataset - -You need to manage the training videos as follows: - -``` -data/example_dataset/ -├── metadata.csv -└── train - ├── video_00001.mp4 - └── image_00002.jpg -``` - -`metadata.csv`: - -``` -file_name,text -video_00001.mp4,"video description" -image_00002.jpg,"video description" -``` - -We support both images and videos. An image is treated as a single frame of video. - -Step 3: Data process - -```shell -CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \ - --task data_process \ - --dataset_path data/example_dataset \ - --output_path ./models \ - --text_encoder_path "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth" \ - --vae_path "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth" \ - --tiled \ - --num_frames 81 \ - --height 480 \ - --width 832 -``` - -After that, some cached files will be stored in the dataset folder. - -``` -data/example_dataset/ -├── metadata.csv -└── train - ├── video_00001.mp4 - ├── video_00001.mp4.tensors.pth - ├── video_00002.mp4 - └── video_00002.mp4.tensors.pth -``` - -Step 4: Train - -LoRA training: - -```shell -CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \ - --task train \ - --train_architecture lora \ - --dataset_path data/example_dataset \ - --output_path ./models \ - --dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" \ - --steps_per_epoch 500 \ - --max_epochs 10 \ - --learning_rate 1e-4 \ - --lora_rank 16 \ - --lora_alpha 16 \ - --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ - --accumulate_grad_batches 1 \ - --use_gradient_checkpointing -``` - -Full training: - -```shell -CUDA_VISIBLE_DEVICES="0" python examples/wanvideo/train_wan_t2v.py \ - --task train \ - --train_architecture full \ - --dataset_path data/example_dataset \ - --output_path ./models \ - --dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" \ - --steps_per_epoch 500 \ - --max_epochs 10 \ - --learning_rate 1e-4 \ - --accumulate_grad_batches 1 \ - --use_gradient_checkpointing -``` - -If you wish to train the 14B model, please separate the safetensor files with a comma. For example: `models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors,models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors`. - -If you wish to train the image-to-video model, please add an extra parameter `--image_encoder_path "models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"`. - -For LoRA training, the Wan-1.3B-T2V model requires 16G of VRAM for processing 81 frames at 480P, while the Wan-14B-T2V model requires 60G of VRAM for the same configuration. To further reduce VRAM requirements by 20%-30%, you can include the parameter `--use_gradient_checkpointing_offload`. - -Step 5: Test - -Test LoRA: - -```python -import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData - - -model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") -model_manager.load_models([ - "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", - "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", - "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", -]) -model_manager.load_lora("models/lightning_logs/version_1/checkpoints/epoch=0-step=500.ckpt", lora_alpha=1.0) -pipe = WanVideoPipeline.from_model_manager(model_manager, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=None) - -video = pipe( - prompt="...", - negative_prompt="...", - num_inference_steps=50, - seed=0, tiled=True -) -save_video(video, "video.mp4", fps=30, quality=5) -``` - -Test fine-tuned base model: - -```python -import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData - - -model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") -model_manager.load_models([ - "models/lightning_logs/version_1/checkpoints/epoch=0-step=500.ckpt", - "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", - "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", -]) -pipe = WanVideoPipeline.from_model_manager(model_manager, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=None) - -video = pipe( - prompt="...", - negative_prompt="...", - num_inference_steps=50, - seed=0, tiled=True -) -save_video(video, "video.mp4", fps=30, quality=5) -``` diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md new file mode 100644 index 0000000..4504f8f --- /dev/null +++ b/examples/wanvideo/README_zh.md @@ -0,0 +1,313 @@ +# 通义万相 2.1(Wan 2.1) + +|模型 ID|类型|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| +|-|-|-|-|-|-|-|-| +|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)|基础模型||[code](./model_inference/Wan2.1-T2V-1.3B.py)|[code](./model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-1.3B.py)| +|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)|基础模型||[code](./model_inference/Wan2.1-T2V-14B.py)|[code](./model_training/full/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./model_training/lora/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-14B.py)| +|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|基础模型|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-480P.py)|[code](./model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-480P.py)| +|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|基础模型|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-720P.py)|[code](./model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-720P.py)| +|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|基础模型|`input_image`, `end_image`|[code](./model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](./model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](./model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](./model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](./model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)| +|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|基础模型|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](./model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)| +|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|基础模型|`control_video`|[code](./model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](./model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)| +|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|基础模型|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-14B-InP.py)|[code](./model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-14B-InP.py)| +|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|基础模型|`control_video`|[code](./model_inference/Wan2.1-Fun-14B-Control.py)|[code](./model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-14B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|基础模型|`control_video`, `reference_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|基础模型|`control_video`, `reference_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)| +|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|基础模型|`input_image`, `end_image`|||||| +|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|基础模型|`input_image`, `end_image`|||||| +|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|基础模型||||||| +|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|基础模型||||||| +|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|适配器|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](./model_training/full/VACE-Wan2.1-1.3B-Preview.sh)|[code](./model_training/validate_full/VACE-Wan2.1-1.3B-Preview.py)|[code](./model_training/lora/VACE-Wan2.1-1.3B-Preview.sh)|[code](./model_training/validate_lora/VACE-Wan2.1-1.3B-Preview.py)| +|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|适配器|`vace_control_video`, `vace_reference_image`|||||| +|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|适配器|`vace_control_video`, `vace_reference_image`|||||| +|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|适配器|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)| + +## 模型推理 + +### 加载模型 + +模型通过 `from_pretrained` 加载: + +```python +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"), + ], +) +``` + +其中 `torch_dtype` 和 `device` 是计算精度和计算设备。`model_configs` 可通过多种方式配置模型路径: + +* 从[魔搭社区](https://modelscope.cn/)下载模型并加载。此时需要填写 `model_id` 和 `origin_file_pattern`,例如 + +```python +ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors") +``` + +* 从本地文件路径加载模型。此时需要填写 `path`,例如 + +```python +ModelConfig(path="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors") +``` + +对于从多个文件加载的单一模型,使用列表即可,例如 + +```python +ModelConfig(path=[ + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", +]) +``` + +`from_pretrained` 还提供了额外的参数用于控制模型加载时的行为: + +* `tokenizer_config`: Wan 模型的 tokenizer 路径,默认值为 `ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*")`。 +* `local_model_path`: 用于保存下载模型的路径,默认值为 `"./models"`。 +* `skip_download`: 是否跳过下载,默认值为 `False`。当您的网络无法访问[魔搭社区](https://modelscope.cn/)时,请手动下载必要的文件,并将其设置为 `True`。 +* `redirect_common_files`: 是否重定向重复模型文件,默认值为 `True`。由于 Wan 系列模型包括多个基础模型,每个基础模型的 text encoder 等模块都是相同的,为避免重复下载,我们会对模型路径进行重定向。 + +### 显存管理 + +DiffSynth-Studio 为 Wan 模型提供了细粒度的显存管理,让模型能够在低显存设备上进行推理,可通过以下代码开启 offload 功能,在显存有限的设备上将部分模块 offload 到内存中。 + +```python +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() +``` + +FP8 量化功能也是支持的: + +```python +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_dtype=torch.float8_e4m3fn), + ], +) +pipe.enable_vram_management() +``` + +FP8 量化和 offload 可同时开启: + +```python +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ], +) +pipe.enable_vram_management() +``` + +FP8 量化能够大幅度减少显存占用,但不会加速,部分模型在 FP8 量化下会出现精度不足导致的画面模糊、撕裂、失真问题,请谨慎使用 FP8 量化。 + +`enable_vram_management` 函数提供了以下参数,用于控制显存使用情况: + +* `vram_limit`: 显存占用量(GB),默认占用设备上的剩余显存。注意这不是一个绝对限制,当设置的显存不足以支持模型进行推理,但实际可用显存足够时,将会以最小化显存占用的形式进行推理。 +* `vram_buffer`: 显存缓冲区大小(GB),默认为 0.5GB。由于部分较大的神经网络层在 onload 阶段会不可控地占用更多显存,因此一个显存缓冲区是必要的,理论上的最优值为模型中最大的层所占的显存。 +* `num_persistent_param_in_dit`: DiT 模型中常驻显存的参数数量(个),默认为无限制。我们将会在未来删除这个参数,请不要依赖这个参数。 + +### 输入参数 + +Pipeline 在推理阶段能够接收以下输入参数: + +* `prompt`: 提示词,描述画面中出现的内容。 +* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。 +* `input_image`: 输入图片,适用于图生视频模型,例如 [`Wan-AI/Wan2.1-I2V-14B-480P`](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)、[`PAI/Wan2.1-Fun-1.3B-InP`](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP),以及首尾帧模型,例如 [`Wan-AI/Wan2.1-FLF2V-14B-720P`](Wan-AI/Wan2.1-FLF2V-14B-720P)。 +* `end_image`: 结尾帧,适用于首尾帧模型,例如 [`Wan-AI/Wan2.1-FLF2V-14B-720P`](Wan-AI/Wan2.1-FLF2V-14B-720P)。 +* `input_video`: 输入视频,用于视频生视频,适用于任意 Wan 系列模型,需与参数 `denoising_strength` 配合使用。 +* `denoising_strength`: 去噪强度,范围为 [0, 1]。数值越小,生成的视频越接近 `input_video`。 +* `control_video`: 控制视频,适用于带控制能力的 Wan 系列模型,例如 [`PAI/Wan2.1-Fun-1.3B-Control`](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)。 +* `reference_image`: 参考图片,适用于带参考图能力的 Wan 系列模型,例如 [`PAI/Wan2.1-Fun-V1.1-1.3B-Control`](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)。 +* `vace_video`: VACE 模型的输入视频,适用于 VACE 系列模型,例如 [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)。 +* `vace_video_mask`: VACE 模型的 mask 视频,适用于 VACE 系列模型,例如 [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)。 +* `vace_reference_image`: VACE 模型的参考图片,适用于 VACE 系列模型,例如 [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)。 +* `vace_scale`: VACE 模型对基础模型的影响程度,默认为1。数值越大,控制强度越高,但画面崩坏概率越大。 +* `seed`: 随机种子。默认为 `None`,即完全随机。 +* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。 +* `height`: 帧高度,默认为 480。需设置为 16 的倍数,不满足时向上取整。 +* `width`: 帧宽度,默认为 832。需设置为 16 的倍数,不满足时向上取整。 +* `num_frames`: 帧数,默认为 81。需设置为 4 的倍数 + 1,不满足时向上取整,最小值为 1。 +* `cfg_scale`: Classifier-free guidance 机制的数值,默认为 5。数值越大,提示词的控制效果越强,但画面崩坏的概率越大。 +* `cfg_merge`: 是否合并 Classifier-free guidance 的两侧进行统一推理,默认为 `False`。该参数目前仅在基础的文生视频和图生视频模型上生效。 +* `num_inference_steps`: 推理次数,默认值为 50。 +* `sigma_shift`: Rectified Flow 理论中的参数,默认为 5。数值越大,模型在去噪的开始阶段停留的步骤数越多,可适当调大这个参数来提高画面质量,但会因生成过程与训练过程不一致导致生成的视频内容与训练数据存在差异。 +* `motion_bucket_id`: 运动幅度,范围为 [0, 100]。适用于速度控制模块,例如 [`DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1`](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1),数值越大,运动幅度越大。 +* `tiled`: 是否启用 VAE 分块推理,默认为 `False`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用,会产生少许误差,以及少量推理时间延长。 +* `tile_size`: VAE 编解码阶段的分块大小,默认为 (30, 52),仅在 `tiled=True` 时生效。 +* `tile_stride`: VAE 编解码阶段的分块步长,默认为 (15, 26),仅在 `tiled=True` 时生效,需保证其数值小于或等于 `tile_size`。 +* `sliding_window_size`: DiT 部分的滑动窗口大小。实验性功能,效果不稳定。 +* `sliding_window_stride`: DiT 部分的滑动窗口步长。实验性功能,效果不稳定。 +* `tea_cache_l1_thresh`: TeaCache 的阈值,数值越大,速度越快,画面质量越差。请注意,开启 TeaCache 后推理速度并非均匀,因此进度条上显示的剩余时间将会变得不准确。 +* `tea_cache_model_id`: TeaCache 的参数模板,可选 `"Wan2.1-T2V-1.3B"`、`Wan2.1-T2V-14B`、`Wan2.1-I2V-14B-480P`、`Wan2.1-I2V-14B-720P` 之一。 +* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。 + +## 模型训练 + +Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_training/train.py) 脚本进行。 + +脚本包含以下参数: + +* 数据集 + * `--dataset_base_path`: 数据集的根路径。 + * `--dataset_metadata_path`: 数据集的元数据文件路径。 + * `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。 + * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。 + * `--num_frames`: 每个视频中的帧数。帧从视频前缀中采样。 + * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 + * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 +* 模型 + * `--model_paths`: 要加载的模型路径。JSON 格式。 + * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors。用逗号分隔。 +* 训练 + * `--learning_rate`: 学习率。 + * `--num_epochs`: 轮数(Epoch)数量。 + * `--output_path`: 保存路径。 + * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 +* 可训练模块 + * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 + * `--lora_base_model`: LoRA 添加到哪个模型上。 + * `--lora_target_modules`: LoRA 添加到哪一层上。 + * `--lora_rank`: LoRA 的秩(Rank)。 +* 额外模型输入 + * `--input_contains_input_image`: 模型输入包含 `input_image` + * `--input_contains_end_image`: 模型输入包含 `end_image`。 + * `--input_contains_control_video`: 模型输入包含 `control_video`。 + * `--input_contains_reference_image`: 模型输入包含 `reference_image`。 + * `--input_contains_vace_video`: 模型输入包含 `vace_video`。 + * `--input_contains_vace_reference_image`: 模型输入包含 `vace_reference_image`。 + * `--input_contains_motion_bucket_id`: 模型输入包含 `motion_bucket_id`。 +* 显存管理 + * `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。 + +### Step 1: 准备数据集 + +数据集包含一系列文件,我们建议您这样组织数据集文件: + +``` +data/example_video_dataset/ +├── metadata.csv +├── video1.mp4 +└── video2.mp4 +``` + +其中 `video1.mp4`、`video2.mp4` 为训练用视频数据,`metadata.csv` 为元数据列表,例如 + +``` +video,prompt +video1.mp4,"from sunset to night, a small town, light, house, river" +video2.mp4,"a dog is running" +``` + +数据集支持视频和图片混合训练,支持的视频文件格式包括 `"mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"`,支持的图片格式包括 `"jpg", "jpeg", "png", "webp"`。 + +视频的尺寸可通过脚本参数 `--height`、`--width`、`--num_frames` 控制。在每个视频中,前 `num_frames` 帧会被用于训练,因此当视频长度不足 `num_frames` 帧时会报错,图片文件会被视为单帧视频。当 `--height` 和 `--width` 为空时将会开启动态分辨率,按照数据集中每个视频或图片的实际宽高训练。 + +**我们强烈建议使用固定分辨率训练,并避免图像和视频混合训练,因为在多卡训练中存在负载均衡问题。** + +当模型需要额外输入时,例如具备控制能力的模型 [`PAI/Wan2.1-Fun-1.3B-Control`](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) 所需的 `control_video`,请在数据集中补充相应的列,例如: + +``` +video,prompt,control_video +video1.mp4,"from sunset to night, a small town, light, house, river",video1_softedge.mp4 +``` + +额外输入若包含视频和图像文件,则需要在 `--data_file_keys` 参数中指定要解析的列名。该参数的默认值为 `"image,video"`,即解析列名为 `image` 和 `video` 的列。可根据额外输入增加相应的列名,例如 `--data_file_keys "image,video,control_video"`,同时启用 `--input_contains_control_video`。 + +### Step 2: 加载模型 + +类似于推理时的模型加载逻辑,可直接通过模型 ID 配置要加载的模型。例如,推理时我们通过以下设置加载模型 + +```python +model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"), +] +``` + +那么在训练时,填入以下参数即可加载对应的模型。 + +```shell +--model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth" +``` + +如果您希望从本地文件加载模型,例如推理时 + +```python +model_configs=[ + ModelConfig(path=[ + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", + ]), + ModelConfig(path="models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(path="models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth"), +] +``` + +那么训练时需设置为 + +```shell +--model_paths '[ + [ + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", + "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors" + ], + "models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth" +]' \ +``` + +### 设置可训练模块 + +训练框架支持训练基础模型,或 LoRA 模型。以下是几个例子: + +* 全量训练 DiT 部分:`--trainable_models dit` +* 训练 DiT 部分的 LoRA 模型:`--lora_base_model dit --lora_target_modules "q,k,v,o,ffn.0,ffn.2" --lora_rank 32` +* 训练 DiT 部分的 LoRA 和 Motion Controller 部分(是的,可以训练这种花里胡哨的结构):`--trainable_models motion_controller --lora_base_model dit --lora_target_modules "q,k,v,o,ffn.0,ffn.2" --lora_rank 32` + +此外,由于训练脚本中加载了多个模块(text encoder、dit、vae),保存模型文件时需要移除前缀,例如在全量训练 DiT 部分或者训练 DiT 部分的 LoRA 模型时,请设置 `--remove_prefix_in_ckpt pipe.dit.` + +### 启动训练程序 + +我们构建了一个样例视频数据集,以方便您进行测试,通过以下命令可以下载这个数据集: + +```shell +modelscope download --dataset DiffSynth-Studio/example_video_dataset README.md --local_dir ./data/example_video_dataset +``` + +我们为每一个模型编写了训练命令,请参考本文档开头的表格。 + +请注意,14B 模型全量训练需要8个GPU,每个GPU的显存至少为80G。全量训练这些14B模型时需要安装 `deepspeed`(`pip install deepspeed`),我们编写了建议的[配置文件](./model_training/full/accelerate_config_14B.yaml),这个配置文件会在对应的训练脚本中被加载,这些脚本已在 8*A100 上测试过。 + +训练脚本的默认视频尺寸为 `480*832*81`,提升分辨率将可能导致显存不足,请添加参数 `--use_gradient_checkpointing_offload` 降低显存占用。 diff --git a/examples/wanvideo/model_inference/wan_1.3b_speed_control.py b/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py similarity index 100% rename from examples/wanvideo/model_inference/wan_1.3b_speed_control.py rename to examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py diff --git a/examples/wanvideo/model_inference/wan_14b_flf2v.py b/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py similarity index 100% rename from examples/wanvideo/model_inference/wan_14b_flf2v.py rename to examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py diff --git a/examples/wanvideo/model_inference/wan_fun_1.3b_control.py b/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py similarity index 100% rename from examples/wanvideo/model_inference/wan_fun_1.3b_control.py rename to examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py diff --git a/examples/wanvideo/model_inference/wan_fun_1.3b_InP.py b/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py similarity index 100% rename from examples/wanvideo/model_inference/wan_fun_1.3b_InP.py rename to examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py diff --git a/examples/wanvideo/model_inference/wan_fun_14b_control.py b/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py similarity index 100% rename from examples/wanvideo/model_inference/wan_fun_14b_control.py rename to examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py diff --git a/examples/wanvideo/model_inference/wan_fun_14b_InP.py b/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py similarity index 100% rename from examples/wanvideo/model_inference/wan_fun_14b_InP.py rename to examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py diff --git a/examples/wanvideo/model_inference/wan_fun_v1.1_1.3b_reference_control.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py similarity index 100% rename from examples/wanvideo/model_inference/wan_fun_v1.1_1.3b_reference_control.py rename to examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py diff --git a/examples/wanvideo/model_inference/wan_fun_v1.1_14b_reference_control.py b/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py similarity index 100% rename from examples/wanvideo/model_inference/wan_fun_v1.1_14b_reference_control.py rename to examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py diff --git a/examples/wanvideo/model_inference/wan_14b_image_to_video_480p.py b/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py similarity index 100% rename from examples/wanvideo/model_inference/wan_14b_image_to_video_480p.py rename to examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py diff --git a/examples/wanvideo/model_inference/wan_14b_image_to_video_720p.py b/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py similarity index 100% rename from examples/wanvideo/model_inference/wan_14b_image_to_video_720p.py rename to examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py diff --git a/examples/wanvideo/model_inference/wan_1.3b_text_to_video.py b/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py similarity index 100% rename from examples/wanvideo/model_inference/wan_1.3b_text_to_video.py rename to examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py diff --git a/examples/wanvideo/model_inference/wan_14b_text_to_video.py b/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py similarity index 100% rename from examples/wanvideo/model_inference/wan_14b_text_to_video.py rename to examples/wanvideo/model_inference/Wan2.1-T2V-14B.py diff --git a/examples/wanvideo/model_inference/wan_1.3b_vace.py b/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py similarity index 100% rename from examples/wanvideo/model_inference/wan_1.3b_vace.py rename to examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py diff --git a/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh b/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh new file mode 100644 index 0000000..e70fd13 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh @@ -0,0 +1,13 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_motion_bucket_id.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth,DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1:model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.motion_controller." \ + --output_path "./models/train/Wan2.1-1.3b-speedcontrol-v1_full" \ + --trainable_models "motion_controller" \ + --input_contains_motion_bucket_id \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh b/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh new file mode 100644 index 0000000..c0591ca --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-FLF2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-FLF2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-FLF2V-14B-720P_full" \ + --trainable_models "dit" \ + --input_contains_input_image \ + --input_contains_end_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh new file mode 100644 index 0000000..499c787 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh @@ -0,0 +1,14 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_control.csv \ + --data_file_keys "video,control_video" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-1.3B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-1.3B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-1.3B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-1.3B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-1.3B-Control_full" \ + --trainable_models "dit" \ + --input_contains_control_video \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh new file mode 100644 index 0000000..1fec876 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh @@ -0,0 +1,14 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-1.3B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-1.3B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-1.3B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-1.3B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-1.3B-InP_full" \ + --trainable_models "dit" \ + --input_contains_input_image \ + --input_contains_end_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh new file mode 100644 index 0000000..2d7272d --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_control.csv \ + --data_file_keys "video,control_video" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-14B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-14B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-14B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-14B-Control_full" \ + --trainable_models "dit" \ + --input_contains_control_video \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh new file mode 100644 index 0000000..3463670 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh @@ -0,0 +1,14 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-14B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-14B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-14B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-14B-InP_full" \ + --trainable_models "dit" \ + --input_contains_input_image \ + --input_contains_end_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh new file mode 100644 index 0000000..5acda18 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \ + --data_file_keys "video,control_video,reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-Control_full" \ + --trainable_models "dit" \ + --input_contains_control_video \ + --input_contains_reference_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh new file mode 100644 index 0000000..2a63311 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh @@ -0,0 +1,15 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \ + --data_file_keys "video,control_video,reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-14B-Control_full" \ + --trainable_models "dit" \ + --input_contains_control_video \ + --input_contains_reference_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh b/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh new file mode 100644 index 0000000..5cea09b --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh @@ -0,0 +1,13 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-480P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-480P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-480P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-480P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-I2V-14B-480P_full" \ + --trainable_models "dit" \ + --input_contains_input_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh b/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh new file mode 100644 index 0000000..4b0ed11 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh @@ -0,0 +1,13 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-I2V-14B-720P_full" \ + --trainable_models "dit" \ + --input_contains_input_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh b/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh new file mode 100644 index 0000000..e0d6e84 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh @@ -0,0 +1,12 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-T2V-1.3B_full" \ + --trainable_models "dit" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh b/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh new file mode 100644 index 0000000..ae804b0 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh @@ -0,0 +1,12 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-T2V-14B_full" \ + --trainable_models "dit" \ No newline at end of file diff --git a/examples/wanvideo/model_training/full/accelerate_config_14B.yaml b/examples/wanvideo/model_training/full/accelerate_config_14B.yaml new file mode 100644 index 0000000..3875a9d --- /dev/null +++ b/examples/wanvideo/model_training/full/accelerate_config_14B.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +deepspeed_config: + gradient_accumulation_steps: 1 + offload_optimizer_device: cpu + offload_param_device: cpu + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +enable_cpu_affinity: false +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 +num_machines: 1 +num_processes: 8 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false diff --git a/examples/wanvideo/model_training/full/run_test.py b/examples/wanvideo/model_training/full/run_test.py new file mode 100644 index 0000000..093becd --- /dev/null +++ b/examples/wanvideo/model_training/full/run_test.py @@ -0,0 +1,38 @@ +import multiprocessing, os + + +def run_task(scripts, thread_id, thread_num): + for script_id, script in enumerate(scripts): + if script_id % thread_num == thread_id: + log_file_name = script.replace("/", "_") + ".txt" + cmd = f"CUDA_VISIBLE_DEVICES={thread_id} bash {script} > data/log/{log_file_name} 2>&1" + os.makedirs("data/log", exist_ok=True) + print(cmd, flush=True) + os.system(cmd) + + +if __name__ == "__main__": + # 1.3B + scripts = [] + for file_name in os.listdir("examples/wanvideo/model_training/full"): + if file_name != "run_test.py" and "14B" not in file_name: + scripts.append(os.path.join("examples/wanvideo/model_training/full", file_name)) + + processes = [multiprocessing.Process(target=run_task, args=(scripts, i, 8)) for i in range(8)] + for p in processes: + p.start() + for p in processes: + p.join() + + # 14B + scripts = [] + for file_name in os.listdir("examples/wanvideo/model_training/full"): + if file_name != "run_test.py" and "14B" in file_name: + scripts.append(os.path.join("examples/wanvideo/model_training/full", file_name)) + for script in scripts: + log_file_name = script.replace("/", "_") + ".txt" + cmd = f"bash {script} > data/log/{log_file_name} 2>&1" + print(cmd, flush=True) + os.system(cmd) + + print("Done!") \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh b/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh new file mode 100644 index 0000000..4fb08bd --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_motion_bucket_id.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth,DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1:model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-1.3b-speedcontrol-v1_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --input_contains_motion_bucket_id \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh b/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh new file mode 100644 index 0000000..8b98631 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-FLF2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-FLF2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-FLF2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-FLF2V-14B-720P_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --input_contains_input_image \ + --input_contains_end_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh new file mode 100644 index 0000000..72522f2 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_control.csv \ + --data_file_keys "video,control_video" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-1.3B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-1.3B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-1.3B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-1.3B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-1.3B-Control_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --input_contains_control_video \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh new file mode 100644 index 0000000..182fccc --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-1.3B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-1.3B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-1.3B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-1.3B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-1.3B-InP_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --input_contains_input_image \ + --input_contains_end_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh new file mode 100644 index 0000000..a45203c --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_control.csv \ + --data_file_keys "video,control_video" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-14B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-14B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-14B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-14B-Control_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --input_contains_control_video \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh new file mode 100644 index 0000000..5392658 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh @@ -0,0 +1,16 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-14B-InP:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-14B-InP:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-14B-InP:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-14B-InP:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-14B-InP_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --input_contains_input_image \ + --input_contains_end_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh new file mode 100644 index 0000000..a342981 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh @@ -0,0 +1,17 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \ + --data_file_keys "video,control_video,reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-Control_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --input_contains_control_video \ + --input_contains_reference_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh new file mode 100644 index 0000000..a902522 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh @@ -0,0 +1,17 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_reference_control.csv \ + --data_file_keys "video,control_video,reference_image" \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-Control:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-Control:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-Control:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-Control:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-Fun-V1.1-14B-Control_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --input_contains_control_video \ + --input_contains_reference_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh b/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh new file mode 100644 index 0000000..3c085fa --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-480P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-480P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-480P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-480P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-I2V-14B-480P_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --input_contains_input_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh b/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh new file mode 100644 index 0000000..6193df7 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh @@ -0,0 +1,15 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-I2V-14B-720P:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-I2V-14B-720P:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-I2V-14B-720P:Wan2.1_VAE.pth,Wan-AI/Wan2.1-I2V-14B-720P:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-I2V-14B-720P_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --input_contains_input_image \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh b/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh new file mode 100644 index 0000000..d16a287 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh @@ -0,0 +1,14 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-T2V-1.3B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh b/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh new file mode 100644 index 0000000..1fb55ac --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh @@ -0,0 +1,14 @@ +accelerate launch examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --height 480 \ + --width 832 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-14B:Wan2.1_VAE.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.1-T2V-14B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/run_test.py b/examples/wanvideo/model_training/lora/run_test.py new file mode 100644 index 0000000..ec0f9e2 --- /dev/null +++ b/examples/wanvideo/model_training/lora/run_test.py @@ -0,0 +1,25 @@ +import multiprocessing, os + + +def run_task(scripts, thread_id, thread_num): + for script_id, script in enumerate(scripts): + if script_id % thread_num == thread_id: + log_file_name = script.replace("/", "_") + ".txt" + cmd = f"CUDA_VISIBLE_DEVICES={thread_id} bash {script} > data/log/{log_file_name} 2>&1" + os.makedirs("data/log", exist_ok=True) + print(cmd, flush=True) + os.system(cmd) + + +if __name__ == "__main__": + scripts = [] + for file_name in os.listdir("examples/wanvideo/model_training/lora"): + if file_name != "run_test.py": + scripts.append(os.path.join("examples/wanvideo/model_training/lora", file_name)) + + processes = [multiprocessing.Process(target=run_task, args=(scripts, i, 8)) for i in range(8)] + for p in processes: + p.start() + for p in processes: + p.join() + print("Done!") \ No newline at end of file diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py new file mode 100644 index 0000000..cbace5a --- /dev/null +++ b/examples/wanvideo/model_training/train.py @@ -0,0 +1,129 @@ +import torch, os, json +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, launch_training_task, wan_parser +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + + +class WanTrainingModule(DiffusionTrainingModule): + def __init__( + self, + model_paths=None, model_id_with_origin_paths=None, + trainable_models=None, + lora_base_model=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + # Extra inputs + input_contains_input_image=False, + input_contains_end_image=False, + input_contains_control_video=False, + input_contains_reference_image=False, + input_contains_vace_video=False, + input_contains_vace_reference_image=False, + input_contains_motion_bucket_id=False, + ): + super().__init__() + # Load models + model_configs = [] + if model_paths is not None: + model_paths = json.loads(model_paths) + model_configs += [ModelConfig(path=path) for path in model_paths] + if model_id_with_origin_paths is not None: + model_id_with_origin_paths = model_id_with_origin_paths.split(",") + model_configs += [ModelConfig(model_id=i.split(":")[0], origin_file_pattern=i.split(":")[1]) for i in model_id_with_origin_paths] + self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs) + + # Reset training scheduler + self.pipe.scheduler.set_timesteps(1000, training=True) + + # Freeze untrainable models + self.pipe.freeze_except([] if trainable_models is None else trainable_models.split(",")) + + # Add LoRA to the base models + if lora_base_model is not None: + model = self.add_lora_to_model( + getattr(self.pipe, lora_base_model), + target_modules=lora_target_modules.split(","), + lora_rank=lora_rank + ) + setattr(self.pipe, lora_base_model, model) + + # Store other configs + self.use_gradient_checkpointing = use_gradient_checkpointing + self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload + self.input_contains_input_image = input_contains_input_image + self.input_contains_end_image = input_contains_end_image + self.input_contains_control_video = input_contains_control_video + self.input_contains_reference_image = input_contains_reference_image + self.input_contains_vace_video = input_contains_vace_video + self.input_contains_vace_reference_image = input_contains_vace_reference_image + self.input_contains_motion_bucket_id = input_contains_motion_bucket_id + + + def forward_preprocess(self, data): + # CFG-sensitive parameters + inputs_posi = {"prompt": data["prompt"]} + inputs_nega = {} + + # CFG-unsensitive parameters + inputs_shared = { + # Assume you are using this pipeline for inference, + # please fill in the input parameters. + "input_video": data["video"], + "height": data["video"][0].size[1], + "width": data["video"][0].size[0], + "num_frames": len(data["video"]), + # Please do not modify the following parameters + # unless you clearly know what this will cause. + "cfg_scale": 1, + "tiled": False, + "rand_device": self.pipe.device, + "use_gradient_checkpointing": self.use_gradient_checkpointing, + "use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload, + "cfg_merge": False, + "vace_scale": 1, + } + + # Extra inputs + if self.input_contains_input_image: inputs_shared["input_image"] = data["video"][0] + if self.input_contains_end_image: inputs_shared["end_image"] = data["video"][-1] + if self.input_contains_control_video: inputs_shared["control_video"] = data["control_video"] + if self.input_contains_reference_image: inputs_shared["reference_image"] = data["reference_image"] + if self.input_contains_vace_video: inputs_shared["vace_video"] = data["vace_video"] + if self.input_contains_vace_reference_image: inputs_shared["vace_reference_image"] = data["vace_reference_image"] + if self.input_contains_motion_bucket_id: inputs_shared["motion_bucket_id"] = data["motion_bucket_id"] + + # Pipeline units will automatically process the input parameters. + for unit in self.pipe.units: + inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) + return {**inputs_shared, **inputs_posi} + + + def forward(self, data, inputs=None): + if inputs is None: inputs = self.forward_preprocess(data) + models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} + loss = self.pipe.training_loss(**models, **inputs) + return loss + + +if __name__ == "__main__": + parser = wan_parser() + args = parser.parse_args() + dataset = VideoDataset(args=args) + model = WanTrainingModule( + model_paths=args.model_paths, + model_id_with_origin_paths=args.model_id_with_origin_paths, + trainable_models=args.trainable_models, + lora_base_model=args.lora_base_model, + lora_target_modules=args.lora_target_modules, + lora_rank=args.lora_rank, + use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, + input_contains_input_image=args.input_contains_input_image, + input_contains_end_image=args.input_contains_end_image, + input_contains_control_video=args.input_contains_control_video, + input_contains_reference_image=args.input_contains_reference_image, + input_contains_vace_video=args.input_contains_vace_video, + input_contains_vace_reference_image=args.input_contains_vace_reference_image, + input_contains_motion_bucket_id=args.input_contains_motion_bucket_id, + ) + launch_training_task(model, dataset, args=args) diff --git a/examples/wanvideo/model_training/train_i2v.py b/examples/wanvideo/model_training/train_i2v.py deleted file mode 100644 index 1c5c757..0000000 --- a/examples/wanvideo/model_training/train_i2v.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch, os, json -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, launch_training_task, wan_parser -os.environ["TOKENIZERS_PARALLELISM"] = "false" - - -class WanTrainingModule(DiffusionTrainingModule): - def __init__(self, model_paths, task="train_lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32): - super().__init__() - self.pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cpu", - model_configs=[ModelConfig(path=path) for path in model_paths], - ) - self.pipe.scheduler.set_timesteps(1000, training=True) - if task == "train_lora": - self.pipe.freeze_except([]) - self.pipe.dit = self.add_lora_to_model(self.pipe.dit, target_modules=lora_target_modules.split(","), lora_rank=lora_rank) - else: - self.pipe.freeze_except(["dit"]) - - def forward_preprocess(self, data): - inputs_posi = {"prompt": data["prompt"]} - inputs_nega = {} - inputs_shared = { - "input_image": data["video"][0], - "input_video": data["video"], - "height": data["video"][0].size[1], - "width": data["video"][0].size[0], - "num_frames": len(data["video"]), - # Please do not modify the following parameters. - "cfg_scale": 1, - "tiled": False, - "rand_device": self.pipe.device, - "use_gradient_checkpointing": True, - "cfg_merge": False, - } - for unit in self.pipe.units: - inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) - return {**inputs_shared, **inputs_posi} - - def forward(self, data): - inputs = self.forward_preprocess(data) - models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} - loss = self.pipe.training_loss(**models, **inputs) - return loss - - -if __name__ == "__main__": - parser = wan_parser() - args = parser.parse_args() - dataset = VideoDataset(args=args) - model = WanTrainingModule(json.loads(args.model_paths), task=args.task, lora_target_modules=args.lora_target_modules, lora_rank=args.lora_rank) - launch_training_task(model, dataset, args=args) diff --git a/examples/wanvideo/model_training/train_t2v.py b/examples/wanvideo/model_training/train_t2v.py deleted file mode 100644 index 50b49ef..0000000 --- a/examples/wanvideo/model_training/train_t2v.py +++ /dev/null @@ -1,53 +0,0 @@ -import torch, os, json -from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, launch_training_task, wan_parser -os.environ["TOKENIZERS_PARALLELISM"] = "false" - - -class WanTrainingModule(DiffusionTrainingModule): - def __init__(self, model_paths, task="train_lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32): - super().__init__() - self.pipe = WanVideoPipeline.from_pretrained( - torch_dtype=torch.bfloat16, - device="cpu", - model_configs=[ModelConfig(path=path) for path in model_paths], - ) - self.pipe.scheduler.set_timesteps(1000, training=True) - if task == "train_lora": - self.pipe.freeze_except([]) - self.pipe.dit = self.add_lora_to_model(self.pipe.dit, target_modules=lora_target_modules.split(","), lora_rank=lora_rank) - else: - self.pipe.freeze_except(["dit"]) - - def forward_preprocess(self, data): - inputs_posi = {"prompt": data["prompt"]} - inputs_nega = {} - inputs_shared = { - "input_video": data["video"], - "height": data["video"][0].size[1], - "width": data["video"][0].size[0], - "num_frames": len(data["video"]), - # Please do not modify the following parameters. - "cfg_scale": 1, - "tiled": False, - "rand_device": self.pipe.device, - "use_gradient_checkpointing": True, - "cfg_merge": False, - } - for unit in self.pipe.units: - inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega) - return {**inputs_shared, **inputs_posi} - - def forward(self, data): - inputs = self.forward_preprocess(data) - models = {name: getattr(self.pipe, name) for name in self.pipe.in_iteration_models} - loss = self.pipe.training_loss(**models, **inputs) - return loss - - -if __name__ == "__main__": - parser = wan_parser() - args = parser.parse_args() - dataset = VideoDataset(args=args) - model = WanTrainingModule(json.loads(args.model_paths), task=args.task, lora_target_modules=args.lora_target_modules, lora_rank=args.lora_rank) - launch_training_task(model, dataset, args=args) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py b/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py new file mode 100644 index 0000000..124749a --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py @@ -0,0 +1,28 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-1.3b-speedcontrol-v1_full/epoch-1.safetensors") +pipe.motion_controller.load_state_dict(state_dict) +pipe.enable_vram_management() + +# Text-to-video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True, + motion_bucket_id=50 +) +save_video(video, "video_Wan2.1-1.3b-speedcontrol-v1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py b/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py new file mode 100644 index 0000000..41a67ed --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py @@ -0,0 +1,33 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-FLF2V-14B-720P_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) + +# First and last frame to video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=video[0], + end_image=video[80], + seed=0, tiled=True, + sigma_shift=16, +) +save_video(video, "video_Wan2.1-FLF2V-14B-720P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py new file mode 100644 index 0000000..6726e9c --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py @@ -0,0 +1,32 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-Fun-1.3B-Control_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(81)] + +# Control video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=video, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-1.3B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py new file mode 100644 index 0000000..3e1e6f3 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py @@ -0,0 +1,31 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-Fun-1.3B-InP_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) + +# First and last frame to video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=video[0], end_image=video[80], + seed=0, tiled=True +) +save_video(video, "video_Wan2.1-Fun-1.3B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py new file mode 100644 index 0000000..08b0acb --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py @@ -0,0 +1,32 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-Fun-14B-Control_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(81)] + +# Control video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=video, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-14B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py new file mode 100644 index 0000000..d7e39d7 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py @@ -0,0 +1,31 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-Fun-14B-InP_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) + +# First and last frame to video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=video[0], end_image=video[80], + seed=0, tiled=True +) +save_video(video, "video_Wan2.1-Fun-14B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py new file mode 100644 index 0000000..6497e1b --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py @@ -0,0 +1,33 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-Fun-V1.1-1.3B-Control_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(81)] +reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +# Control video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=video, reference_image=reference_image, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py new file mode 100644 index 0000000..0dd2516 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py @@ -0,0 +1,33 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-Fun-V1.1-14B-Control_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(81)] +reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +# Control video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=video, reference_image=reference_image, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-V1.1-14B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py b/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py new file mode 100644 index 0000000..c1c8615 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py @@ -0,0 +1,30 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-I2V-14B-480P_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-I2V-14B-480P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py b/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py new file mode 100644 index 0000000..a8610f3 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py @@ -0,0 +1,30 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-I2V-14B-720P_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-I2V-14B-720P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py b/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py new file mode 100644 index 0000000..1420514 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py @@ -0,0 +1,25 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-T2V-1.3B_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-T2V-1.3B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py b/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py new file mode 100644 index 0000000..a0107ae --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py @@ -0,0 +1,25 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.1-T2V-14B_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +pipe.enable_vram_management() + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-T2V-14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/run_test.py b/examples/wanvideo/model_training/validate_full/run_test.py new file mode 100644 index 0000000..a4e3203 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/run_test.py @@ -0,0 +1,25 @@ +import multiprocessing, os + + +def run_task(scripts, thread_id, thread_num): + for script_id, script in enumerate(scripts): + if script_id % thread_num == thread_id: + log_file_name = script.replace("/", "_") + ".txt" + cmd = f"CUDA_VISIBLE_DEVICES={thread_id} python -u {script} > data/log/{log_file_name} 2>&1" + os.makedirs("data/log", exist_ok=True) + print(cmd, flush=True) + os.system(cmd) + + +if __name__ == "__main__": + scripts = [] + for file_name in os.listdir("examples/wanvideo/model_training/validate_full"): + if file_name != "run_test.py": + scripts.append(os.path.join("examples/wanvideo/model_training/validate_full", file_name)) + + processes = [multiprocessing.Process(target=run_task, args=(scripts, i, 8)) for i in range(8)] + for p in processes: + p.start() + for p in processes: + p.join() + print("Done!") \ No newline at end of file diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py new file mode 100644 index 0000000..167b871 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py @@ -0,0 +1,27 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", origin_file_pattern="model.safetensors", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-1.3b-speedcontrol-v1_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +# Text-to-video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True, + motion_bucket_id=50 +) +save_video(video, "video_Wan2.1-1.3b-speedcontrol-v1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py new file mode 100644 index 0000000..cd68f0e --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py @@ -0,0 +1,32 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-FLF2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-FLF2V-14B-720P_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) + +# First and last frame to video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=video[0], + end_image=video[80], + seed=0, tiled=True, + sigma_shift=16, +) +save_video(video, "video_Wan2.1-FLF2V-14B-720P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py new file mode 100644 index 0000000..7270c38 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py @@ -0,0 +1,31 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-1.3B-Control_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(81)] + +# Control video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=video, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-1.3B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py new file mode 100644 index 0000000..c904dfa --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py @@ -0,0 +1,30 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-1.3B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-1.3B-InP_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) + +# First and last frame to video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=video[0], end_image=video[80], + seed=0, tiled=True +) +save_video(video, "video_Wan2.1-Fun-1.3B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py new file mode 100644 index 0000000..8631d05 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py @@ -0,0 +1,31 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-14B-Control_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(81)] + +# Control video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=video, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-14B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py new file mode 100644 index 0000000..e020aac --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py @@ -0,0 +1,30 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-14B-InP", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-14B-InP_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832) + +# First and last frame to video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=video[0], end_image=video[80], + seed=0, tiled=True +) +save_video(video, "video_Wan2.1-Fun-14B-InP.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py new file mode 100644 index 0000000..ebcfd2f --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py @@ -0,0 +1,32 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-V1.1-1.3B-Control_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(81)] +reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +# Control video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=video, reference_image=reference_image, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py new file mode 100644 index 0000000..6b11098 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py @@ -0,0 +1,32 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-V1.1-14B-Control_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = VideoData("data/example_video_dataset/video1_softedge.mp4", height=480, width=832) +video = [video[i] for i in range(81)] +reference_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +# Control video +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + control_video=video, reference_image=reference_image, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-Fun-V1.1-14B-Control.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py new file mode 100644 index 0000000..1687e36 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py @@ -0,0 +1,29 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-480P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-I2V-14B-480P_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-I2V-14B-480P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py new file mode 100644 index 0000000..9893e26 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py @@ -0,0 +1,29 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-I2V-14B-720P", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-I2V-14B-720P_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)[0] + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + input_image=input_image, + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-I2V-14B-720P.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py new file mode 100644 index 0000000..7cb6c02 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py @@ -0,0 +1,24 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-T2V-1.3B_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-T2V-1.3B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py b/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py new file mode 100644 index 0000000..3b66a49 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py @@ -0,0 +1,24 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.1-T2V-14B_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +video = pipe( + prompt="from sunset to night, a small town, light, house, river", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=1, tiled=True +) +save_video(video, "video_Wan2.1-T2V-14B.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/run_test.py b/examples/wanvideo/model_training/validate_lora/run_test.py new file mode 100644 index 0000000..367ee9d --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/run_test.py @@ -0,0 +1,25 @@ +import multiprocessing, os + + +def run_task(scripts, thread_id, thread_num): + for script_id, script in enumerate(scripts): + if script_id % thread_num == thread_id: + log_file_name = script.replace("/", "_") + ".txt" + cmd = f"CUDA_VISIBLE_DEVICES={thread_id} python -u {script} > data/log/{log_file_name} 2>&1" + os.makedirs("data/log", exist_ok=True) + print(cmd, flush=True) + os.system(cmd) + + +if __name__ == "__main__": + scripts = [] + for file_name in os.listdir("examples/wanvideo/model_training/validate_lora"): + if file_name != "run_test.py": + scripts.append(os.path.join("examples/wanvideo/model_training/validate_lora", file_name)) + + processes = [multiprocessing.Process(target=run_task, args=(scripts, i, 8)) for i in range(8)] + for p in processes: + p.start() + for p in processes: + p.join() + print("Done!") \ No newline at end of file diff --git a/examples/wanvideo/train_wan_t2v.py b/examples/wanvideo/train_wan_t2v.py deleted file mode 100644 index cd10096..0000000 --- a/examples/wanvideo/train_wan_t2v.py +++ /dev/null @@ -1,593 +0,0 @@ -import torch, os, imageio, argparse -from torchvision.transforms import v2 -from einops import rearrange -import lightning as pl -import pandas as pd -from diffsynth import WanVideoPipeline, ModelManager, load_state_dict -from peft import LoraConfig, inject_adapter_in_model -import torchvision -from PIL import Image -import numpy as np - - - -class TextVideoDataset(torch.utils.data.Dataset): - def __init__(self, base_path, metadata_path, max_num_frames=81, frame_interval=1, num_frames=81, height=480, width=832, is_i2v=False): - metadata = pd.read_csv(metadata_path) - self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] - self.text = metadata["text"].to_list() - - self.max_num_frames = max_num_frames - self.frame_interval = frame_interval - self.num_frames = num_frames - self.height = height - self.width = width - self.is_i2v = is_i2v - - self.frame_process = v2.Compose([ - v2.CenterCrop(size=(height, width)), - v2.Resize(size=(height, width), antialias=True), - v2.ToTensor(), - v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), - ]) - - - def crop_and_resize(self, image): - width, height = image.size - scale = max(self.width / width, self.height / height) - image = torchvision.transforms.functional.resize( - image, - (round(height*scale), round(width*scale)), - interpolation=torchvision.transforms.InterpolationMode.BILINEAR - ) - return image - - - def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process): - reader = imageio.get_reader(file_path) - if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: - reader.close() - return None - - frames = [] - first_frame = None - for frame_id in range(num_frames): - frame = reader.get_data(start_frame_id + frame_id * interval) - frame = Image.fromarray(frame) - frame = self.crop_and_resize(frame) - if first_frame is None: - first_frame = frame - frame = frame_process(frame) - frames.append(frame) - reader.close() - - frames = torch.stack(frames, dim=0) - frames = rearrange(frames, "T C H W -> C T H W") - - first_frame = v2.functional.center_crop(first_frame, output_size=(self.height, self.width)) - first_frame = np.array(first_frame) - - if self.is_i2v: - return frames, first_frame - else: - return frames - - - def load_video(self, file_path): - start_frame_id = torch.randint(0, self.max_num_frames - (self.num_frames - 1) * self.frame_interval, (1,))[0] - frames = self.load_frames_using_imageio(file_path, self.max_num_frames, start_frame_id, self.frame_interval, self.num_frames, self.frame_process) - return frames - - - def is_image(self, file_path): - file_ext_name = file_path.split(".")[-1] - if file_ext_name.lower() in ["jpg", "jpeg", "png", "webp"]: - return True - return False - - - def load_image(self, file_path): - frame = Image.open(file_path).convert("RGB") - frame = self.crop_and_resize(frame) - first_frame = frame - frame = self.frame_process(frame) - frame = rearrange(frame, "C H W -> C 1 H W") - return frame - - - def __getitem__(self, data_id): - text = self.text[data_id] - path = self.path[data_id] - if self.is_image(path): - if self.is_i2v: - raise ValueError(f"{path} is not a video. I2V model doesn't support image-to-image training.") - video = self.load_image(path) - else: - video = self.load_video(path) - if self.is_i2v: - video, first_frame = video - data = {"text": text, "video": video, "path": path, "first_frame": first_frame} - else: - data = {"text": text, "video": video, "path": path} - return data - - - def __len__(self): - return len(self.path) - - - -class LightningModelForDataProcess(pl.LightningModule): - def __init__(self, text_encoder_path, vae_path, image_encoder_path=None, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): - super().__init__() - model_path = [text_encoder_path, vae_path] - if image_encoder_path is not None: - model_path.append(image_encoder_path) - model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") - model_manager.load_models(model_path) - self.pipe = WanVideoPipeline.from_model_manager(model_manager) - - self.tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride} - - def test_step(self, batch, batch_idx): - text, video, path = batch["text"][0], batch["video"], batch["path"][0] - - self.pipe.device = self.device - if video is not None: - # prompt - prompt_emb = self.pipe.encode_prompt(text) - # video - video = video.to(dtype=self.pipe.torch_dtype, device=self.pipe.device) - latents = self.pipe.encode_video(video, **self.tiler_kwargs)[0] - # image - if "first_frame" in batch: - first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy()) - _, _, num_frames, height, width = video.shape - image_emb = self.pipe.encode_image(first_frame, None, num_frames, height, width) - else: - image_emb = {} - data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb} - torch.save(data, path + ".tensors.pth") - - - -class TensorDataset(torch.utils.data.Dataset): - def __init__(self, base_path, metadata_path, steps_per_epoch): - metadata = pd.read_csv(metadata_path) - self.path = [os.path.join(base_path, "train", file_name) for file_name in metadata["file_name"]] - print(len(self.path), "videos in metadata.") - self.path = [i + ".tensors.pth" for i in self.path if os.path.exists(i + ".tensors.pth")] - print(len(self.path), "tensors cached in metadata.") - assert len(self.path) > 0 - - self.steps_per_epoch = steps_per_epoch - - - def __getitem__(self, index): - data_id = torch.randint(0, len(self.path), (1,))[0] - data_id = (data_id + index) % len(self.path) # For fixed seed. - path = self.path[data_id] - data = torch.load(path, weights_only=True, map_location="cpu") - return data - - - def __len__(self): - return self.steps_per_epoch - - - -class LightningModelForTrain(pl.LightningModule): - def __init__( - self, - dit_path, - learning_rate=1e-5, - lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", - use_gradient_checkpointing=True, use_gradient_checkpointing_offload=False, - pretrained_lora_path=None - ): - super().__init__() - model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu") - if os.path.isfile(dit_path): - model_manager.load_models([dit_path]) - else: - dit_path = dit_path.split(",") - model_manager.load_models([dit_path]) - - self.pipe = WanVideoPipeline.from_model_manager(model_manager) - self.pipe.scheduler.set_timesteps(1000, training=True) - self.freeze_parameters() - if train_architecture == "lora": - self.add_lora_to_model( - self.pipe.denoising_model(), - lora_rank=lora_rank, - lora_alpha=lora_alpha, - lora_target_modules=lora_target_modules, - init_lora_weights=init_lora_weights, - pretrained_lora_path=pretrained_lora_path, - ) - else: - self.pipe.denoising_model().requires_grad_(True) - - self.learning_rate = learning_rate - self.use_gradient_checkpointing = use_gradient_checkpointing - self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload - - - def freeze_parameters(self): - # Freeze parameters - self.pipe.requires_grad_(False) - self.pipe.eval() - self.pipe.denoising_model().train() - - - def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None): - # Add LoRA to UNet - self.lora_alpha = lora_alpha - if init_lora_weights == "kaiming": - init_lora_weights = True - - lora_config = LoraConfig( - r=lora_rank, - lora_alpha=lora_alpha, - init_lora_weights=init_lora_weights, - target_modules=lora_target_modules.split(","), - ) - model = inject_adapter_in_model(lora_config, model) - for param in model.parameters(): - # Upcast LoRA parameters into fp32 - if param.requires_grad: - param.data = param.to(torch.float32) - - # Lora pretrained lora weights - if pretrained_lora_path is not None: - state_dict = load_state_dict(pretrained_lora_path) - if state_dict_converter is not None: - state_dict = state_dict_converter(state_dict) - missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) - all_keys = [i for i, _ in model.named_parameters()] - num_updated_keys = len(all_keys) - len(missing_keys) - num_unexpected_keys = len(unexpected_keys) - print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.") - - - def training_step(self, batch, batch_idx): - # Data - latents = batch["latents"].to(self.device) - prompt_emb = batch["prompt_emb"] - prompt_emb["context"] = prompt_emb["context"][0].to(self.device) - image_emb = batch["image_emb"] - if "clip_feature" in image_emb: - image_emb["clip_feature"] = image_emb["clip_feature"][0].to(self.device) - if "y" in image_emb: - image_emb["y"] = image_emb["y"][0].to(self.device) - - # Loss - self.pipe.device = self.device - noise = torch.randn_like(latents) - timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,)) - timestep = self.pipe.scheduler.timesteps[timestep_id].to(dtype=self.pipe.torch_dtype, device=self.pipe.device) - extra_input = self.pipe.prepare_extra_input(latents) - noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) - training_target = self.pipe.scheduler.training_target(latents, noise, timestep) - - # Compute loss - noise_pred = self.pipe.denoising_model()( - noisy_latents, timestep=timestep, **prompt_emb, **extra_input, **image_emb, - use_gradient_checkpointing=self.use_gradient_checkpointing, - use_gradient_checkpointing_offload=self.use_gradient_checkpointing_offload - ) - loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) - loss = loss * self.pipe.scheduler.training_weight(timestep) - - # Record log - self.log("train_loss", loss, prog_bar=True) - return loss - - - def configure_optimizers(self): - trainable_modules = filter(lambda p: p.requires_grad, self.pipe.denoising_model().parameters()) - optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) - return optimizer - - - def on_save_checkpoint(self, checkpoint): - checkpoint.clear() - trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.denoising_model().named_parameters())) - trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) - state_dict = self.pipe.denoising_model().state_dict() - lora_state_dict = {} - for name, param in state_dict.items(): - if name in trainable_param_names: - lora_state_dict[name] = param - checkpoint.update(lora_state_dict) - - - -def parse_args(): - parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument( - "--task", - type=str, - default="data_process", - required=True, - choices=["data_process", "train"], - help="Task. `data_process` or `train`.", - ) - parser.add_argument( - "--dataset_path", - type=str, - default=None, - required=True, - help="The path of the Dataset.", - ) - parser.add_argument( - "--output_path", - type=str, - default="./", - help="Path to save the model.", - ) - parser.add_argument( - "--text_encoder_path", - type=str, - default=None, - help="Path of text encoder.", - ) - parser.add_argument( - "--image_encoder_path", - type=str, - default=None, - help="Path of image encoder.", - ) - parser.add_argument( - "--vae_path", - type=str, - default=None, - help="Path of VAE.", - ) - parser.add_argument( - "--dit_path", - type=str, - default=None, - help="Path of DiT.", - ) - parser.add_argument( - "--tiled", - default=False, - action="store_true", - help="Whether enable tile encode in VAE. This option can reduce VRAM required.", - ) - parser.add_argument( - "--tile_size_height", - type=int, - default=34, - help="Tile size (height) in VAE.", - ) - parser.add_argument( - "--tile_size_width", - type=int, - default=34, - help="Tile size (width) in VAE.", - ) - parser.add_argument( - "--tile_stride_height", - type=int, - default=18, - help="Tile stride (height) in VAE.", - ) - parser.add_argument( - "--tile_stride_width", - type=int, - default=16, - help="Tile stride (width) in VAE.", - ) - parser.add_argument( - "--steps_per_epoch", - type=int, - default=500, - help="Number of steps per epoch.", - ) - parser.add_argument( - "--num_frames", - type=int, - default=81, - help="Number of frames.", - ) - parser.add_argument( - "--height", - type=int, - default=480, - help="Image height.", - ) - parser.add_argument( - "--width", - type=int, - default=832, - help="Image width.", - ) - parser.add_argument( - "--dataloader_num_workers", - type=int, - default=1, - help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", - ) - parser.add_argument( - "--learning_rate", - type=float, - default=1e-5, - help="Learning rate.", - ) - parser.add_argument( - "--accumulate_grad_batches", - type=int, - default=1, - help="The number of batches in gradient accumulation.", - ) - parser.add_argument( - "--max_epochs", - type=int, - default=1, - help="Number of epochs.", - ) - parser.add_argument( - "--lora_target_modules", - type=str, - default="q,k,v,o,ffn.0,ffn.2", - help="Layers with LoRA modules.", - ) - parser.add_argument( - "--init_lora_weights", - type=str, - default="kaiming", - choices=["gaussian", "kaiming"], - help="The initializing method of LoRA weight.", - ) - parser.add_argument( - "--training_strategy", - type=str, - default="auto", - choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"], - help="Training strategy", - ) - parser.add_argument( - "--lora_rank", - type=int, - default=4, - help="The dimension of the LoRA update matrices.", - ) - parser.add_argument( - "--lora_alpha", - type=float, - default=4.0, - help="The weight of the LoRA update matrices.", - ) - parser.add_argument( - "--use_gradient_checkpointing", - default=False, - action="store_true", - help="Whether to use gradient checkpointing.", - ) - parser.add_argument( - "--use_gradient_checkpointing_offload", - default=False, - action="store_true", - help="Whether to use gradient checkpointing offload.", - ) - parser.add_argument( - "--train_architecture", - type=str, - default="lora", - choices=["lora", "full"], - help="Model structure to train. LoRA training or full training.", - ) - parser.add_argument( - "--pretrained_lora_path", - type=str, - default=None, - help="Pretrained LoRA path. Required if the training is resumed.", - ) - parser.add_argument( - "--use_swanlab", - default=False, - action="store_true", - help="Whether to use SwanLab logger.", - ) - parser.add_argument( - "--swanlab_mode", - default=None, - help="SwanLab mode (cloud or local).", - ) - args = parser.parse_args() - return args - - -def data_process(args): - dataset = TextVideoDataset( - args.dataset_path, - os.path.join(args.dataset_path, "metadata.csv"), - max_num_frames=args.num_frames, - frame_interval=1, - num_frames=args.num_frames, - height=args.height, - width=args.width, - is_i2v=args.image_encoder_path is not None - ) - dataloader = torch.utils.data.DataLoader( - dataset, - shuffle=False, - batch_size=1, - num_workers=args.dataloader_num_workers - ) - model = LightningModelForDataProcess( - text_encoder_path=args.text_encoder_path, - image_encoder_path=args.image_encoder_path, - vae_path=args.vae_path, - tiled=args.tiled, - tile_size=(args.tile_size_height, args.tile_size_width), - tile_stride=(args.tile_stride_height, args.tile_stride_width), - ) - trainer = pl.Trainer( - accelerator="gpu", - devices="auto", - default_root_dir=args.output_path, - ) - trainer.test(model, dataloader) - - -def train(args): - dataset = TensorDataset( - args.dataset_path, - os.path.join(args.dataset_path, "metadata.csv"), - steps_per_epoch=args.steps_per_epoch, - ) - dataloader = torch.utils.data.DataLoader( - dataset, - shuffle=True, - batch_size=1, - num_workers=args.dataloader_num_workers - ) - model = LightningModelForTrain( - dit_path=args.dit_path, - learning_rate=args.learning_rate, - train_architecture=args.train_architecture, - lora_rank=args.lora_rank, - lora_alpha=args.lora_alpha, - lora_target_modules=args.lora_target_modules, - init_lora_weights=args.init_lora_weights, - use_gradient_checkpointing=args.use_gradient_checkpointing, - use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload, - pretrained_lora_path=args.pretrained_lora_path, - ) - if args.use_swanlab: - from swanlab.integration.pytorch_lightning import SwanLabLogger - swanlab_config = {"UPPERFRAMEWORK": "DiffSynth-Studio"} - swanlab_config.update(vars(args)) - swanlab_logger = SwanLabLogger( - project="wan", - name="wan", - config=swanlab_config, - mode=args.swanlab_mode, - logdir=os.path.join(args.output_path, "swanlog"), - ) - logger = [swanlab_logger] - else: - logger = None - trainer = pl.Trainer( - max_epochs=args.max_epochs, - accelerator="gpu", - devices="auto", - precision="bf16", - strategy=args.training_strategy, - default_root_dir=args.output_path, - accumulate_grad_batches=args.accumulate_grad_batches, - callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)], - logger=logger, - ) - trainer.fit(model, dataloader) - - -if __name__ == '__main__': - args = parse_args() - if args.task == "data_process": - data_process(args) - elif args.task == "train": - train(args) diff --git a/examples/wanvideo/wan_1.3b_motion_controller.py b/examples/wanvideo/wan_1.3b_motion_controller.py deleted file mode 100644 index 8036819..0000000 --- a/examples/wanvideo/wan_1.3b_motion_controller.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData -from modelscope import snapshot_download - - -# Download models -snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B") -snapshot_download("DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1", local_dir="models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1") - -# Load models -model_manager = ModelManager(device="cpu") -model_manager.load_models( - [ - "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", - "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", - "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", - "models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1/model.safetensors", - ], - torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. -) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=None) - -# Text-to-video -video = pipe( - prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, - seed=1, tiled=True, - motion_bucket_id=0 -) -save_video(video, "video_slow.mp4", fps=15, quality=5) - -video = pipe( - prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, - seed=1, tiled=True, - motion_bucket_id=100 -) -save_video(video, "video_fast.mp4", fps=15, quality=5) \ No newline at end of file diff --git a/examples/wanvideo/wan_1.3b_text_to_video.py b/examples/wanvideo/wan_1.3b_text_to_video.py deleted file mode 100644 index e444cd2..0000000 --- a/examples/wanvideo/wan_1.3b_text_to_video.py +++ /dev/null @@ -1,40 +0,0 @@ -import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData -from modelscope import snapshot_download - - -# Download models -snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B") - -# Load models -model_manager = ModelManager(device="cpu") -model_manager.load_models( - [ - "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", - "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", - "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", - ], - torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. -) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=None) - -# Text-to-video -video = pipe( - prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, - seed=0, tiled=True -) -save_video(video, "video1.mp4", fps=15, quality=5) - -# Video-to-video -video = VideoData("video1.mp4", height=480, width=832) -video = pipe( - prompt="纪实摄影风格画面,一只活泼的小狗戴着黑色墨镜在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,戴着黑色墨镜,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_video=video, denoising_strength=0.7, - num_inference_steps=50, - seed=1, tiled=True -) -save_video(video, "video2.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/wan_1.3b_vace.py b/examples/wanvideo/wan_1.3b_vace.py deleted file mode 100644 index fb987a7..0000000 --- a/examples/wanvideo/wan_1.3b_vace.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData -from modelscope import snapshot_download, dataset_snapshot_download -from PIL import Image - - -# Download models -snapshot_download("iic/VACE-Wan2.1-1.3B-Preview", local_dir="models/iic/VACE-Wan2.1-1.3B-Preview") - -# Load models -model_manager = ModelManager(device="cpu") -model_manager.load_models( - [ - "models/iic/VACE-Wan2.1-1.3B-Preview/diffusion_pytorch_model.safetensors", - "models/iic/VACE-Wan2.1-1.3B-Preview/models_t5_umt5-xxl-enc-bf16.pth", - "models/iic/VACE-Wan2.1-1.3B-Preview/Wan2.1_VAE.pth", - ], - torch_dtype=torch.bfloat16, -) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=None) - -# Download example video -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"] -) - -# Depth video -> Video -control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832) -video = pipe( - prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, - height=480, width=832, num_frames=81, - vace_video=control_video, - seed=1, tiled=True -) -save_video(video, "video1.mp4", fps=15, quality=5) - -# Reference image -> Video -video = pipe( - prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, - height=480, width=832, num_frames=81, - vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), - seed=1, tiled=True -) -save_video(video, "video2.mp4", fps=15, quality=5) - -# Depth video + Reference image -> Video -video = pipe( - prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, - height=480, width=832, num_frames=81, - vace_video=control_video, - vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), - seed=1, tiled=True -) -save_video(video, "video3.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/wan_14B_flf2v.py b/examples/wanvideo/wan_14B_flf2v.py deleted file mode 100644 index 23109df..0000000 --- a/examples/wanvideo/wan_14B_flf2v.py +++ /dev/null @@ -1,52 +0,0 @@ -import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData -from modelscope import snapshot_download, dataset_snapshot_download -from PIL import Image - - -# Download models -snapshot_download("Wan-AI/Wan2.1-FLF2V-14B-720P", local_dir="models/Wan-AI/Wan2.1-FLF2V-14B-720P") - -# Load models -model_manager = ModelManager(device="cpu") -model_manager.load_models( - ["models/Wan-AI/Wan2.1-FLF2V-14B-720P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"], - torch_dtype=torch.float32, # Image Encoder is loaded with float32 -) -model_manager.load_models( - [ - [ - "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00001-of-00007.safetensors", - "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00002-of-00007.safetensors", - "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00003-of-00007.safetensors", - "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00004-of-00007.safetensors", - "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00005-of-00007.safetensors", - "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00006-of-00007.safetensors", - "models/Wan-AI/Wan2.1-FLF2V-14B-720P/diffusion_pytorch_model-00007-of-00007.safetensors", - ], - "models/Wan-AI/Wan2.1-FLF2V-14B-720P/models_t5_umt5-xxl-enc-bf16.pth", - "models/Wan-AI/Wan2.1-FLF2V-14B-720P/Wan2.1_VAE.pth", - ], - torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. -) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=None) - -# Download example image -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=["data/examples/wan/first_frame.jpeg", "data/examples/wan/last_frame.jpeg"] -) - -# First and last frame to video -video = pipe( - prompt="写实风格,一个女生手持枯萎的花站在花园中,镜头逐渐拉远,记录下花园的全貌。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=30, - input_image=Image.open("data/examples/wan/first_frame.jpeg").resize((960, 960)), - end_image=Image.open("data/examples/wan/last_frame.jpeg").resize((960, 960)), - height=960, width=960, - seed=1, tiled=True -) -save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/wan_14b_image_to_video.py b/examples/wanvideo/wan_14b_image_to_video.py deleted file mode 100644 index 91894ae..0000000 --- a/examples/wanvideo/wan_14b_image_to_video.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData -from modelscope import snapshot_download, dataset_snapshot_download -from PIL import Image - - -# Download models -snapshot_download("Wan-AI/Wan2.1-I2V-14B-480P", local_dir="models/Wan-AI/Wan2.1-I2V-14B-480P") - -# Load models -model_manager = ModelManager(device="cpu") -model_manager.load_models( - ["models/Wan-AI/Wan2.1-I2V-14B-480P/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"], - torch_dtype=torch.float32, # Image Encoder is loaded with float32 -) -model_manager.load_models( - [ - [ - "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00001-of-00007.safetensors", - "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00002-of-00007.safetensors", - "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00003-of-00007.safetensors", - "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00004-of-00007.safetensors", - "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00005-of-00007.safetensors", - "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00006-of-00007.safetensors", - "models/Wan-AI/Wan2.1-I2V-14B-480P/diffusion_pytorch_model-00007-of-00007.safetensors", - ], - "models/Wan-AI/Wan2.1-I2V-14B-480P/models_t5_umt5-xxl-enc-bf16.pth", - "models/Wan-AI/Wan2.1-I2V-14B-480P/Wan2.1_VAE.pth", - ], - torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. -) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=6*10**9) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required. - -# Download example image -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=f"data/examples/wan/input_image.jpg" -) -image = Image.open("data/examples/wan/input_image.jpg") - -# Image-to-video -video = pipe( - prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - input_image=image, - num_inference_steps=50, - seed=0, tiled=True -) -save_video(video, "video.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/wan_14b_text_to_video.py b/examples/wanvideo/wan_14b_text_to_video.py deleted file mode 100644 index 654565d..0000000 --- a/examples/wanvideo/wan_14b_text_to_video.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData -from modelscope import snapshot_download - - -# Download models -snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B") - -# Load models -model_manager = ModelManager(device="cpu") -model_manager.load_models( - [ - [ - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", - ], - "models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", - "models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth", - ], - torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization. -) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=None) # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required. - -# Text-to-video -video = pipe( - prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, - seed=0, tiled=True -) -save_video(video, "video1.mp4", fps=25, quality=5) diff --git a/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py b/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py deleted file mode 100644 index 77c230c..0000000 --- a/examples/wanvideo/wan_14b_text_to_video_tensor_parallel.py +++ /dev/null @@ -1,149 +0,0 @@ -import torch -import lightning as pl -from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, SequenceParallel, PrepareModuleInput, PrepareModuleOutput -from torch.distributed._tensor import Replicate, Shard -from torch.distributed.tensor.parallel import parallelize_module -from lightning.pytorch.strategies import ModelParallelStrategy -from diffsynth import ModelManager, WanVideoPipeline, save_video -from tqdm import tqdm -from modelscope import snapshot_download - - - -class ToyDataset(torch.utils.data.Dataset): - def __init__(self, tasks=[]): - self.tasks = tasks - - def __getitem__(self, data_id): - return self.tasks[data_id] - - def __len__(self): - return len(self.tasks) - - -class LitModel(pl.LightningModule): - def __init__(self): - super().__init__() - model_manager = ModelManager(device="cpu") - model_manager.load_models( - [ - [ - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors", - "models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors", - ], - "models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth", - "models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth", - ], - torch_dtype=torch.bfloat16, - ) - self.pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") - - def configure_model(self): - tp_mesh = self.device_mesh["tensor_parallel"] - plan = { - "text_embedding.0": ColwiseParallel(), - "text_embedding.2": RowwiseParallel(), - "time_projection.1": ColwiseParallel(output_layouts=Replicate()), - "text_embedding.0": ColwiseParallel(), - "text_embedding.2": RowwiseParallel(), - "blocks.0": PrepareModuleInput( - input_layouts=(Replicate(), None, None, None), - desired_input_layouts=(Replicate(), None, None, None), - ), - "head": PrepareModuleInput( - input_layouts=(Replicate(), None), - desired_input_layouts=(Replicate(), None), - use_local_output=True, - ) - } - self.pipe.dit = parallelize_module(self.pipe.dit, tp_mesh, plan) - for block_id, block in enumerate(self.pipe.dit.blocks): - layer_tp_plan = { - "self_attn": PrepareModuleInput( - input_layouts=(Shard(1), Replicate()), - desired_input_layouts=(Shard(1), Shard(0)), - ), - "self_attn.q": SequenceParallel(), - "self_attn.k": SequenceParallel(), - "self_attn.v": SequenceParallel(), - "self_attn.norm_q": SequenceParallel(), - "self_attn.norm_k": SequenceParallel(), - "self_attn.attn": PrepareModuleInput( - input_layouts=(Shard(1), Shard(1), Shard(1)), - desired_input_layouts=(Shard(2), Shard(2), Shard(2)), - ), - "self_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate()), - - "cross_attn": PrepareModuleInput( - input_layouts=(Shard(1), Replicate()), - desired_input_layouts=(Shard(1), Replicate()), - ), - "cross_attn.q": SequenceParallel(), - "cross_attn.k": SequenceParallel(), - "cross_attn.v": SequenceParallel(), - "cross_attn.norm_q": SequenceParallel(), - "cross_attn.norm_k": SequenceParallel(), - "cross_attn.attn": PrepareModuleInput( - input_layouts=(Shard(1), Shard(1), Shard(1)), - desired_input_layouts=(Shard(2), Shard(2), Shard(2)), - ), - "cross_attn.o": RowwiseParallel(input_layouts=Shard(2), output_layouts=Replicate(), use_local_output=False), - - "ffn.0": ColwiseParallel(input_layouts=Shard(1)), - "ffn.2": RowwiseParallel(output_layouts=Replicate()), - - "norm1": SequenceParallel(use_local_output=True), - "norm2": SequenceParallel(use_local_output=True), - "norm3": SequenceParallel(use_local_output=True), - "gate": PrepareModuleInput( - input_layouts=(Shard(1), Replicate(), Replicate()), - desired_input_layouts=(Replicate(), Replicate(), Replicate()), - ) - } - parallelize_module( - module=block, - device_mesh=tp_mesh, - parallelize_plan=layer_tp_plan, - ) - - - def test_step(self, batch): - data = batch[0] - data["progress_bar_cmd"] = tqdm if self.local_rank == 0 else lambda x: x - output_path = data.pop("output_path") - with torch.no_grad(), torch.inference_mode(False): - video = self.pipe(**data) - if self.local_rank == 0: - save_video(video, output_path, fps=15, quality=5) - - -if __name__ == "__main__": - snapshot_download("Wan-AI/Wan2.1-T2V-14B", local_dir="models/Wan-AI/Wan2.1-T2V-14B") - dataloader = torch.utils.data.DataLoader( - ToyDataset([ - { - "prompt": "一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", - "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - "num_inference_steps": 50, - "seed": 0, - "tiled": False, - "output_path": "video1.mp4", - }, - { - "prompt": "一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", - "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - "num_inference_steps": 50, - "seed": 1, - "tiled": False, - "output_path": "video2.mp4", - }, - ]), - collate_fn=lambda x: x - ) - model = LitModel() - trainer = pl.Trainer(accelerator="gpu", devices=torch.cuda.device_count(), strategy=ModelParallelStrategy()) - trainer.test(model, dataloader) \ No newline at end of file diff --git a/examples/wanvideo/wan_fun_InP.py b/examples/wanvideo/wan_fun_InP.py deleted file mode 100644 index ae23ee0..0000000 --- a/examples/wanvideo/wan_fun_InP.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData -from modelscope import snapshot_download, dataset_snapshot_download -from PIL import Image - - -# Download models -snapshot_download("PAI/Wan2.1-Fun-1.3B-InP", local_dir="models/PAI/Wan2.1-Fun-1.3B-InP") - -# Load models -model_manager = ModelManager(device="cpu") -model_manager.load_models( - [ - "models/PAI/Wan2.1-Fun-1.3B-InP/diffusion_pytorch_model.safetensors", - "models/PAI/Wan2.1-Fun-1.3B-InP/models_t5_umt5-xxl-enc-bf16.pth", - "models/PAI/Wan2.1-Fun-1.3B-InP/Wan2.1_VAE.pth", - "models/PAI/Wan2.1-Fun-1.3B-InP/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", - ], - torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. -) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=None) - -# Download example image -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=f"data/examples/wan/input_image.jpg" -) -image = Image.open("data/examples/wan/input_image.jpg") - -# Image-to-video -video = pipe( - prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, - input_image=image, - # You can input `end_image=xxx` to control the last frame of the video. - # The model will automatically generate the dynamic content between `input_image` and `end_image`. - seed=1, tiled=True -) -save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/wan_fun_control.py b/examples/wanvideo/wan_fun_control.py deleted file mode 100644 index e2c4d0c..0000000 --- a/examples/wanvideo/wan_fun_control.py +++ /dev/null @@ -1,40 +0,0 @@ -import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData -from modelscope import snapshot_download, dataset_snapshot_download -from PIL import Image - - -# Download models -snapshot_download("PAI/Wan2.1-Fun-1.3B-Control", local_dir="models/PAI/Wan2.1-Fun-1.3B-Control") - -# Load models -model_manager = ModelManager(device="cpu") -model_manager.load_models( - [ - "models/PAI/Wan2.1-Fun-1.3B-Control/diffusion_pytorch_model.safetensors", - "models/PAI/Wan2.1-Fun-1.3B-Control/models_t5_umt5-xxl-enc-bf16.pth", - "models/PAI/Wan2.1-Fun-1.3B-Control/Wan2.1_VAE.pth", - "models/PAI/Wan2.1-Fun-1.3B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", - ], - torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. -) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=None) - -# Download example video -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=f"data/examples/wan/control_video.mp4" -) - -# Control-to-video -control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576) -video = pipe( - prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, - control_video=control_video, height=832, width=576, num_frames=49, - seed=1, tiled=True -) -save_video(video, "video1.mp4", fps=15, quality=5) diff --git a/examples/wanvideo/wan_fun_reference_control.py b/examples/wanvideo/wan_fun_reference_control.py deleted file mode 100644 index bc82157..0000000 --- a/examples/wanvideo/wan_fun_reference_control.py +++ /dev/null @@ -1,35 +0,0 @@ -import torch -from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData -from modelscope import snapshot_download, dataset_snapshot_download -from PIL import Image - - -# Download models -# snapshot_download("PAI/Wan2.1-Fun-1.3B-Control", local_dir="models/PAI/Wan2.1-Fun-V1.1-1.3B-Control") - -# Load models -model_manager = ModelManager(device="cpu") -model_manager.load_models( - [ - "models/PAI/Wan2.1-Fun-V1.1-14B-Control/diffusion_pytorch_model.safetensors", - "models/PAI/Wan2.1-Fun-V1.1-14B-Control/models_t5_umt5-xxl-enc-bf16.pth", - "models/PAI/Wan2.1-Fun-V1.1-14B-Control/Wan2.1_VAE.pth", - "models/PAI/Wan2.1-Fun-V1.1-14B-Control/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", - ], - torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. -) -pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") -pipe.enable_vram_management(num_persistent_param_in_dit=None) - -# Control-to-video -control_video = VideoData("xxx/pose.mp4", height=832, width=480) -control_video = [control_video[i] for i in range(49)] -video = pipe( - prompt="一位年轻女性穿着一件粉色的连衣裙,裙子上有白色的装饰和粉色的纽扣。她的头发是紫色的,头上戴着一个红色的大蝴蝶结,显得非常可爱和精致。她还戴着一个红色的领结,整体造型充满了少女感和活力。她的表情温柔,双手轻轻交叉放在身前,姿态优雅。背景是简单的灰色,没有任何多余的装饰,使得人物更加突出。她的妆容清淡自然,突显了她的清新气质。整体画面给人一种甜美、梦幻的感觉,仿佛置身于童话世界中。", - negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - num_inference_steps=50, - reference_image=Image.open("xxx/6.png").convert("RGB").resize((480, 832)), - control_video=control_video, height=832, width=480, num_frames=49, - seed=1, tiled=True -) -save_video(video, "video1.mp4", fps=15, quality=5)