mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
wan-refactor
This commit is contained in:
@@ -2,8 +2,9 @@ import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
from packaging import version as pver
|
||||
import os
|
||||
from typing_extensions import Literal
|
||||
|
||||
class SimpleAdapter(nn.Module):
|
||||
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1):
|
||||
super(SimpleAdapter, self).__init__()
|
||||
@@ -42,6 +43,22 @@ class SimpleAdapter(nn.Module):
|
||||
|
||||
return out
|
||||
|
||||
def process_camera_coordinates(
|
||||
self,
|
||||
direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"],
|
||||
length: int,
|
||||
height: int,
|
||||
width: int,
|
||||
speed: float = 1/54,
|
||||
origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
|
||||
):
|
||||
if origin is None:
|
||||
origin = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
|
||||
coordinates = generate_camera_coordinates(direction, length, speed, origin)
|
||||
plucker_embedding = process_pose_file(coordinates, width, height)
|
||||
return plucker_embedding
|
||||
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, dim):
|
||||
@@ -90,13 +107,8 @@ def get_relative_pose(cam_params):
|
||||
return ret_poses
|
||||
|
||||
def custom_meshgrid(*args):
|
||||
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
|
||||
if pver.parse(torch.__version__) < pver.parse('1.10'):
|
||||
return torch.meshgrid(*args)
|
||||
else:
|
||||
return torch.meshgrid(*args, indexing='ij')
|
||||
# torch>=2.0.0 only
|
||||
return torch.meshgrid(*args, indexing='ij')
|
||||
|
||||
|
||||
def ray_condition(K, c2w, H, W, device):
|
||||
@@ -128,23 +140,14 @@ def ray_condition(K, c2w, H, W, device):
|
||||
rays_o = c2w[..., :3, 3] # B, V, 3
|
||||
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
|
||||
# c2w @ dirctions
|
||||
rays_dxo = torch.cross(rays_o, rays_d)
|
||||
rays_dxo = torch.linalg.cross(rays_o, rays_d)
|
||||
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
|
||||
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
|
||||
# plucker = plucker.permute(0, 1, 4, 2, 3)
|
||||
return plucker
|
||||
|
||||
def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
|
||||
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
|
||||
"""
|
||||
if os.path.isfile(pose_file_path):
|
||||
with open(pose_file_path, 'r') as f:
|
||||
poses = f.readlines()
|
||||
else:
|
||||
poses = pose_file_path.splitlines()
|
||||
|
||||
poses = [pose.strip().split(' ') for pose in poses[1:]]
|
||||
cam_params = [[float(x) for x in pose] for pose in poses]
|
||||
def process_pose_file(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
|
||||
if return_poses:
|
||||
return cam_params
|
||||
else:
|
||||
@@ -175,3 +178,25 @@ def process_pose_file(pose_file_path, width=672, height=384, original_pose_width
|
||||
plucker_embedding = plucker_embedding[None]
|
||||
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
|
||||
return plucker_embedding
|
||||
|
||||
|
||||
|
||||
def generate_camera_coordinates(
|
||||
direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"],
|
||||
length: int,
|
||||
speed: float = 1/54,
|
||||
origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
|
||||
):
|
||||
coordinates = [list(origin)]
|
||||
while len(coordinates) < length:
|
||||
coor = coordinates[-1].copy()
|
||||
if "Left" in direction:
|
||||
coor[9] += speed
|
||||
if "Right" in direction:
|
||||
coor[9] -= speed
|
||||
if "Up" in direction:
|
||||
coor[13] += speed
|
||||
if "Down" in direction:
|
||||
coor[13] -= speed
|
||||
coordinates.append(coor)
|
||||
return coordinates
|
||||
|
||||
@@ -10,6 +10,7 @@ import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
from typing_extensions import Literal
|
||||
|
||||
from ..models import ModelManager, load_state_dict
|
||||
from ..models.wan_video_dit import WanModel, RMSNorm, sinusoidal_embedding_1d
|
||||
@@ -208,9 +209,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
WanVideoUnit_InputVideoEmbedder(),
|
||||
WanVideoUnit_PromptEmbedder(),
|
||||
WanVideoUnit_ImageEmbedder(),
|
||||
WanVideoUnit_FunCamera(),
|
||||
WanVideoUnit_FunControl(),
|
||||
WanVideoUnit_FunReference(),
|
||||
WanVideoUnit_FunCameraControl(),
|
||||
WanVideoUnit_SpeedControl(),
|
||||
WanVideoUnit_VACE(),
|
||||
WanVideoUnit_UnifiedSequenceParallel(),
|
||||
@@ -472,6 +473,10 @@ class WanVideoPipeline(BasePipeline):
|
||||
# ControlNet
|
||||
control_video: Optional[list[Image.Image]] = None,
|
||||
reference_image: Optional[Image.Image] = None,
|
||||
# Camera control
|
||||
camera_control_direction: Optional[Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"]] = None,
|
||||
camera_control_speed: Optional[float] = 1/54,
|
||||
camera_control_origin: Optional[tuple] = (0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0),
|
||||
# VACE
|
||||
vace_video: Optional[list[Image.Image]] = None,
|
||||
vace_video_mask: Optional[Image.Image] = None,
|
||||
@@ -504,8 +509,6 @@ class WanVideoPipeline(BasePipeline):
|
||||
tea_cache_model_id: Optional[str] = "",
|
||||
# progress_bar
|
||||
progress_bar_cmd=tqdm,
|
||||
# Camera control
|
||||
control_camera_video: Optional[torch.Tensor] = None
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, shift=sigma_shift)
|
||||
@@ -524,7 +527,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
"end_image": end_image,
|
||||
"input_video": input_video, "denoising_strength": denoising_strength,
|
||||
"control_video": control_video, "reference_image": reference_image,
|
||||
"control_camera_video": control_camera_video,
|
||||
"camera_control_direction": camera_control_direction, "camera_control_speed": camera_control_speed, "camera_control_origin": camera_control_origin,
|
||||
"vace_video": vace_video, "vace_video_mask": vace_video_mask, "vace_reference_image": vace_reference_image, "vace_scale": vace_scale,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"height": height, "width": width, "num_frames": num_frames,
|
||||
@@ -724,37 +727,7 @@ class WanVideoUnit_ImageEmbedder(PipelineUnit):
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"clip_feature": clip_context, "y": y}
|
||||
|
||||
class WanVideoUnit_FunCamera(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("control_camera_video", "cfg_merge", "num_frames", "height", "width", "input_image", "latents"),
|
||||
onload_model_names=("vae")
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, control_camera_video, cfg_merge, num_frames, height, width, input_image, latents):
|
||||
if control_camera_video is None:
|
||||
return {}
|
||||
control_camera_video = control_camera_video[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0)
|
||||
control_camera_latents = torch.concat(
|
||||
[
|
||||
torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
|
||||
control_camera_video[:, :, 1:]
|
||||
], dim=2
|
||||
).transpose(1, 2)
|
||||
b, f, c, h, w = control_camera_latents.shape
|
||||
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
|
||||
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
|
||||
control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
|
||||
input_image = input_image.resize((width, height))
|
||||
input_latents = pipe.preprocess_video([input_image])
|
||||
input_latents = pipe.vae.encode(input_latents, device=pipe.device)
|
||||
y = torch.zeros_like(latents).to(pipe.device)
|
||||
if latents.size()[2] != 1:
|
||||
y[:, :, :1] = input_latents
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
return {"control_camera_latents": control_camera_latents, "control_camera_latents_input": control_camera_latents_input, "y":y}
|
||||
|
||||
|
||||
class WanVideoUnit_FunControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
@@ -800,6 +773,40 @@ class WanVideoUnit_FunReference(PipelineUnit):
|
||||
|
||||
|
||||
|
||||
class WanVideoUnit_FunCameraControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image")
|
||||
)
|
||||
|
||||
def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image):
|
||||
if camera_control_direction is None:
|
||||
return {}
|
||||
camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates(
|
||||
camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin)
|
||||
|
||||
control_camera_video = camera_control_plucker_embedding[:num_frames].permute([3, 0, 1, 2]).unsqueeze(0)
|
||||
control_camera_latents = torch.concat(
|
||||
[
|
||||
torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
|
||||
control_camera_video[:, :, 1:]
|
||||
], dim=2
|
||||
).transpose(1, 2)
|
||||
b, f, c, h, w = control_camera_latents.shape
|
||||
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
|
||||
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
|
||||
control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
|
||||
input_image = input_image.resize((width, height))
|
||||
input_latents = pipe.preprocess_video([input_image])
|
||||
input_latents = pipe.vae.encode(input_latents, device=pipe.device)
|
||||
y = torch.zeros_like(latents).to(pipe.device)
|
||||
y[:, :, :1] = input_latents
|
||||
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
return {"control_camera_latents_input": control_camera_latents_input, "y": y}
|
||||
|
||||
|
||||
|
||||
class WanVideoUnit_SpeedControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(input_params=("motion_bucket_id",))
|
||||
|
||||
@@ -228,30 +228,24 @@ def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_pat
|
||||
|
||||
def wan_parser():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument("--dataset_base_path", type=str, default="", help="Base path of the Dataset.")
|
||||
parser.add_argument("--dataset_metadata_path", type=str, default="", required=True, help="Metadata path of the Dataset.")
|
||||
parser.add_argument("--height", type=int, default=None, help="Image or video height. Leave `height` and `width` None to enable dynamic resolution.")
|
||||
parser.add_argument("--width", type=int, default=None, help="Image or video width. Leave `height` and `width` None to enable dynamic resolution.")
|
||||
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames in each video. The frames are sampled from the prefix.")
|
||||
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in metadata. Separated by commas.")
|
||||
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times the dataset is repeated in each epoch.")
|
||||
parser.add_argument("--model_paths", type=str, default=None, help="Model paths to be loaded. JSON format.")
|
||||
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin path, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Separated by commas.")
|
||||
parser.add_argument("--dataset_base_path", type=str, default="", help="Base path of the dataset.")
|
||||
parser.add_argument("--dataset_metadata_path", type=str, default="", required=True, help="Path to the metadata file of the dataset.")
|
||||
parser.add_argument("--height", type=int, default=None, help="Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--width", type=int, default=None, help="Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.")
|
||||
parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.")
|
||||
parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.")
|
||||
parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.")
|
||||
parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.")
|
||||
parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.")
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
|
||||
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs.")
|
||||
parser.add_argument("--output_path", type=str, default="./models", help="Save path.")
|
||||
parser.add_argument("--output_path", type=str, default="./models", help="Output save path.")
|
||||
parser.add_argument("--remove_prefix_in_ckpt", type=str, default="pipe.dit.", help="Remove prefix in ckpt.")
|
||||
parser.add_argument("--trainable_models", type=str, default=None, help="Trainable models, e.g., dit, vae, text_encoder.")
|
||||
parser.add_argument("--lora_base_model", type=str, default=None, help="Add LoRA on which model.")
|
||||
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Add LoRA on which layer.")
|
||||
parser.add_argument("--lora_rank", type=int, default=32, help="LoRA rank.")
|
||||
parser.add_argument("--input_contains_input_image", default=False, action="store_true", help="Model input contains 'input_image'.")
|
||||
parser.add_argument("--input_contains_end_image", default=False, action="store_true", help="Model input contains 'end_image'.")
|
||||
parser.add_argument("--input_contains_control_video", default=False, action="store_true", help="Model input contains 'control_video'.")
|
||||
parser.add_argument("--input_contains_reference_image", default=False, action="store_true", help="Model input contains 'reference_image'.")
|
||||
parser.add_argument("--input_contains_vace_video", default=False, action="store_true", help="Model input contains 'vace_video'.")
|
||||
parser.add_argument("--input_contains_vace_reference_image", default=False, action="store_true", help="Model input contains 'vace_reference_image'.")
|
||||
parser.add_argument("--input_contains_motion_bucket_id", default=False, action="store_true", help="Model input contains 'motion_bucket_id'.")
|
||||
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Offload gradient checkpointing to RAM.")
|
||||
parser.add_argument("--trainable_models", type=str, default=None, help="Models to train, e.g., dit, vae, text_encoder.")
|
||||
parser.add_argument("--lora_base_model", type=str, default=None, help="Which model LoRA is added to.")
|
||||
parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.")
|
||||
parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.")
|
||||
parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.")
|
||||
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
|
||||
return parser
|
||||
|
||||
|
||||
@@ -1,32 +1,389 @@
|
||||
# Wan 2.1
|
||||
|
||||
[切换到中文](./README_zh.md)
|
||||
|
||||
Wan 2.1 is a collection of video synthesis models open-sourced by Alibaba.
|
||||
|
||||
Before using this model, please install DiffSynth-Studio from **source code**.
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
| Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./model_inference/Wan2.1-T2V-1.3B.py)|[code](./model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./model_inference/Wan2.1-T2V-14B.py)|[code](./model_training/full/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./model_training/lora/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-480P.py)|[code](./model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-720P.py)|[code](./model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|
||||
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](./model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](./model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](./model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](./model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|
||||
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](./model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](./model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](./model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-14B-InP.py)|[code](./model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](./model_inference/Wan2.1-Fun-14B-Control.py)|[code](./model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|
||||
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](./model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||
|
||||
|
||||
* dataset
|
||||
* `--dataset_base_path`: Base path of the Dataset.
|
||||
* `--dataset_metadata_path`: Metadata path of the Dataset.
|
||||
* `--height`: Image or video height. Leave `height` and `width` None to enable dynamic resolution.
|
||||
* `--width`: Image or video width. Leave `height` and `width` None to enable dynamic resolution.
|
||||
* `--num_frames`: Number of frames in each video. The frames are sampled from the prefix.
|
||||
* `--data_file_keys`: Data file keys in metadata. Separated by commas.
|
||||
* `--dataset_repeat`: Number of times the dataset is repeated in each epoch.
|
||||
* Model
|
||||
* `--model_paths`: Model paths to be loaded. JSON format.
|
||||
* `--model_id_with_origin_paths`: Model ID with original path, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Separated by commas.
|
||||
## Model Inference
|
||||
|
||||
The following sections will help you understand our functionalities and write inference code.
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Loading the Model</summary>
|
||||
|
||||
The model is loaded using `from_pretrained`:
|
||||
|
||||
```python
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
Here, `torch_dtype` and `device` specify the computation precision and device respectively. The `model_configs` can be used to configure model paths in various ways:
|
||||
|
||||
* Downloading the model from [ModelScope](https://modelscope.cn/) and loading it. In this case, both `model_id` and `origin_file_pattern` need to be specified, for example:
|
||||
|
||||
```python
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
|
||||
```
|
||||
|
||||
* Loading the model from a local file path. In this case, the `path` parameter needs to be specified, for example:
|
||||
|
||||
```python
|
||||
ModelConfig(path="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors")
|
||||
```
|
||||
|
||||
For models that are loaded from multiple files, simply use a list, for example:
|
||||
|
||||
```python
|
||||
ModelConfig(path=[
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
|
||||
])
|
||||
```
|
||||
|
||||
The `from_pretrained` function also provides additional parameters to control the behavior during model loading:
|
||||
|
||||
* `tokenizer_config`: Path to the tokenizer of the Wan model. Default value is `ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*")`.
|
||||
* `local_model_path`: Path where downloaded models are saved. Default value is `"./models"`.
|
||||
* `skip_download`: Whether to skip downloading models. Default value is `False`. When your network cannot access [ModelScope](https://modelscope.cn/), manually download the necessary files and set this to `True`.
|
||||
* `redirect_common_files`: Whether to redirect duplicate model files. Default value is `True`. Since the Wan series models include multiple base models, some modules like text encoder are shared across these models. To avoid redundant downloads, we redirect the model paths.
|
||||
* `use_usp`: Whether to enable Unified Sequence Parallel. Default value is `False`. Used for multi-GPU parallel inference.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>VRAM Management</summary>
|
||||
|
||||
DiffSynth-Studio provides fine-grained VRAM management for the Wan model, allowing it to run on devices with limited VRAM. You can enable offloading functionality via the following code, which moves parts of the model to system memory on devices with limited VRAM:
|
||||
|
||||
```python
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
```
|
||||
|
||||
FP8 quantization is also supported:
|
||||
|
||||
```python
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
```
|
||||
|
||||
Both FP8 quantization and offloading can be enabled simultaneously:
|
||||
|
||||
```python
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
```
|
||||
|
||||
FP8 quantization significantly reduces VRAM usage but does not accelerate computations. Some models may experience issues such as blurry, torn, or distorted outputs due to insufficient precision when using FP8 quantization. Use FP8 quantization with caution.
|
||||
|
||||
The `enable_vram_management` function provides the following parameters to control VRAM usage:
|
||||
|
||||
* `vram_limit`: VRAM usage limit (in GB). By default, it uses all available VRAM on the device. Note that this is not an absolute limit; if the specified VRAM is insufficient but more VRAM is actually available, inference will proceed using the minimum required VRAM.
|
||||
* `vram_buffer`: Size of the VRAM buffer (in GB). Default is 0.5GB. Since certain large neural network layers may consume more VRAM unpredictably during their execution phase, a VRAM buffer is necessary. Ideally, this should match the maximum VRAM consumed by any single layer in the model.
|
||||
* `num_persistent_param_in_dit`: Number of persistent parameters in DiT models. By default, there is no limit. We plan to remove this parameter in the future, so please avoid relying on it.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Inference Acceleration</summary>
|
||||
|
||||
Wan supports multiple acceleration techniques, including:
|
||||
|
||||
* **Efficient attention implementations**: If any of these attention implementations are installed in your Python environment, they will be automatically enabled in the following priority:
|
||||
* [Flash Attention 3](https://github.com/Dao-AILab/flash-attention)
|
||||
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
|
||||
* [Sage Attention](https://github.com/thu-ml/SageAttention)
|
||||
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (default setting; we recommend installing `torch>=2.5.0`)
|
||||
* **Unified Sequence Parallel**: Sequence parallelism based on [xDiT](https://github.com/xdit-project/xDiT). Please refer to [this example](./acceleration/unified_sequence_parallel.py), and run it using the command:
|
||||
|
||||
```shell
|
||||
pip install xfuser>=0.4.3
|
||||
torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py
|
||||
```
|
||||
|
||||
* **TeaCache**: Acceleration technique [TeaCache](https://github.com/ali-vilab/TeaCache). Please refer to [this example](./acceleration/teacache.py).
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Input Parameters</summary>
|
||||
|
||||
The pipeline accepts the following input parameters during inference:
|
||||
|
||||
* `prompt`: Prompt describing the content to appear in the video.
|
||||
* `negative_prompt`: Negative prompt describing content that should not appear in the video. Default is `""`.
|
||||
* `input_image`: Input image, applicable for image-to-video models such as [`Wan-AI/Wan2.1-I2V-14B-480P`](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) and [`PAI/Wan2.1-Fun-1.3B-InP`](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP), as well as first-and-last-frame models like [`Wan-AI/Wan2.1-FLF2V-14B-720P`](Wan-AI/Wan2.1-FLF2V-14B-720P).
|
||||
* `end_image`: End frame, applicable for first-and-last-frame models such as [`Wan-AI/Wan2.1-FLF2V-14B-720P`](Wan-AI/Wan2.1-FLF2V-14B-720P).
|
||||
* `input_video`: Input video used for video-to-video generation. Applicable to any Wan series model and must be used together with `denoising_strength`.
|
||||
* `denoising_strength`: Denoising strength in range [0, 1]. A smaller value results in a video closer to `input_video`.
|
||||
* `control_video`: Control video, applicable to Wan models with control capabilities such as [`PAI/Wan2.1-Fun-1.3B-Control`](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control).
|
||||
* `reference_image`: Reference image, applicable to Wan models supporting reference images such as [`PAI/Wan2.1-Fun-V1.1-1.3B-Control`](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control).
|
||||
* `camera_control_direction`: Camera control direction, optional values are "Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown". Applicable to Camera-Control models, such as [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://www.modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera).
|
||||
* `camera_control_speed`: Camera control speed. Applicable to Camera-Control models, such as [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://www.modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera).
|
||||
* `camera_control_origin`: Origin coordinate of the camera control sequence. Please refer to the [original paper](https://arxiv.org/pdf/2404.02101) for proper configuration. Applicable to Camera-Control models, such as [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://www.modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera).
|
||||
* `vace_video`: Input video for VACE models, applicable to the VACE series such as [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview).
|
||||
* `vace_video_mask`: Mask video for VACE models, applicable to the VACE series such as [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview).
|
||||
* `vace_reference_image`: Reference image for VACE models, applicable to the VACE series such as [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview).
|
||||
* `vace_scale`: Influence of the VACE model on the base model, default is 1. Higher values increase control strength but may lead to visual artifacts or breakdowns.
|
||||
* `seed`: Random seed. Default is `None`, meaning fully random.
|
||||
* `rand_device`: Device used to generate random Gaussian noise matrix. Default is `"cpu"`. When set to `"cuda"`, different GPUs may produce different generation results.
|
||||
* `height`: Frame height, default is 480. Must be a multiple of 16; if not, it will be rounded up.
|
||||
* `width`: Frame width, default is 832. Must be a multiple of 16; if not, it will be rounded up.
|
||||
* `num_frames`: Number of frames, default is 81. Must be a multiple of 4 plus 1; if not, it will be rounded up, minimum is 1.
|
||||
* `cfg_scale`: Classifier-free guidance scale, default is 5. Higher values increase adherence to the prompt but may cause visual artifacts.
|
||||
* `cfg_merge`: Whether to merge both sides of classifier-free guidance for unified inference. Default is `False`. This parameter currently only works for basic text-to-video and image-to-video models.
|
||||
* `num_inference_steps`: Number of inference steps, default is 50.
|
||||
* `sigma_shift`: Parameter from Rectified Flow theory, default is 5. Higher values make the model stay longer at the initial denoising stage. Increasing this may improve video quality but may also cause inconsistency between generated videos and training data due to deviation from training behavior.
|
||||
* `motion_bucket_id`: Motion intensity, range [0, 100], applicable to motion control modules such as [`DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1`](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1). Larger values indicate more intense motion.
|
||||
* `tiled`: Whether to enable tiled VAE inference, default is `False`. Setting to `True` significantly reduces VRAM usage during VAE encoding/decoding but introduces small errors and slightly increases inference time.
|
||||
* `tile_size`: Tile size during VAE encoding/decoding, default is (30, 52), only effective when `tiled=True`.
|
||||
* `tile_stride`: Stride of tiles during VAE encoding/decoding, default is (15, 26), only effective when `tiled=True`. Must be less than or equal to `tile_size`.
|
||||
* `sliding_window_size`: Sliding window size for DiT part. Experimental feature, effects are unstable.
|
||||
* `sliding_window_stride`: Sliding window stride for DiT part. Experimental feature, effects are unstable.
|
||||
* `tea_cache_l1_thresh`: Threshold for TeaCache. Larger values result in faster speed but lower quality. Note that after enabling TeaCache, the inference speed is not uniform, so the remaining time shown on the progress bar becomes inaccurate.
|
||||
* `tea_cache_model_id`: TeaCache parameter template, options include `"Wan2.1-T2V-1.3B"`, `"Wan2.1-T2V-14B"`, `"Wan2.1-I2V-14B-480P"`, `"Wan2.1-I2V-14B-720P"`.
|
||||
* `progress_bar_cmd`: Progress bar implementation, default is `tqdm.tqdm`. You can set it to `lambda x:x` to disable the progress bar.
|
||||
|
||||
</details>
|
||||
|
||||
## Model Training
|
||||
|
||||
Wan series models are trained using a unified script located at [`./model_training/train.py`](./model_training/train.py).
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Script Parameters</summary>
|
||||
|
||||
The script includes the following parameters:
|
||||
|
||||
* Dataset
|
||||
* `--dataset_base_path`: Base path of the dataset.
|
||||
* `--dataset_metadata_path`: Path to the metadata file of the dataset.
|
||||
* `--height`: Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.
|
||||
* `--width`: Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.
|
||||
* `--num_frames`: Number of frames per video. Frames are sampled from the video prefix.
|
||||
* `--data_file_keys`: Data file keys in the metadata. Comma-separated.
|
||||
* `--dataset_repeat`: Number of times to repeat the dataset per epoch.
|
||||
* Models
|
||||
* `--model_paths`: Paths to load models. In JSON format.
|
||||
* `--model_id_with_origin_paths`: Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.
|
||||
* Training
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--output_path`: Save path.
|
||||
* `--output_path`: Output save path.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in ckpt.
|
||||
* Trainable module
|
||||
* `--trainable_models`: Trainable models, e.g., dit, vae, text_encoder.
|
||||
* `--lora_base_model`: Add LoRA on which model.
|
||||
* `--lora_target_modules`: Add LoRA on which layer.
|
||||
* `--lora_rank`: LoRA rank.
|
||||
* Extra model input
|
||||
* `--input_contains_input_image`: Model input contains `input_image`
|
||||
* `--input_contains_end_image`: Model input contains `end_image`.
|
||||
* `--input_contains_control_video`: Model input contains `control_video`.
|
||||
* `--input_contains_reference_image`: Model input contains `reference_image`.
|
||||
* `--input_contains_vace_video`: Model input contains `vace_video`.
|
||||
* `--input_contains_vace_reference_image`: Model input contains `vace_reference_image`.
|
||||
* `--input_contains_motion_bucket_id`: Model input contains `motion_bucket_id`.
|
||||
* Trainable Modules
|
||||
* `--trainable_models`: Models to train, e.g., dit, vae, text_encoder.
|
||||
* `--lora_base_model`: Which model LoRA is added to.
|
||||
* `--lora_target_modules`: Which layers LoRA is added to.
|
||||
* `--lora_rank`: Rank of LoRA.
|
||||
* Extra Inputs
|
||||
* `--extra_inputs`: Additional model inputs, comma-separated.
|
||||
* VRAM Management
|
||||
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to CPU memory.
|
||||
|
||||
Additionally, the training framework is built upon [`accelerate`](https://huggingface.co/docs/accelerate/index). Before starting training, run `accelerate config` to configure GPU-related parameters. For certain training scripts (e.g., full fine-tuning of 14B models), we provide recommended `accelerate` configuration files, which can be found in the corresponding training scripts.
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 1: Prepare the Dataset</summary>
|
||||
|
||||
The dataset consists of a series of files. We recommend organizing your dataset as follows:
|
||||
|
||||
```
|
||||
data/example_video_dataset/
|
||||
├── metadata.csv
|
||||
├── video1.mp4
|
||||
└── video2.mp4
|
||||
```
|
||||
|
||||
Here, `video1.mp4` and `video2.mp4` are training video files, and `metadata.csv` is the metadata list, for example:
|
||||
|
||||
```
|
||||
video,prompt
|
||||
video1.mp4,"from sunset to night, a small town, light, house, river"
|
||||
video2.mp4,"a dog is running"
|
||||
```
|
||||
|
||||
We have prepared a sample video dataset to help you test. You can download it using the following command:
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/example_video_dataset README.md --local_dir ./data/example_video_dataset
|
||||
```
|
||||
|
||||
The dataset supports mixed training of videos and images. Supported video formats include `"mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"`, and supported image formats include `"jpg", "jpeg", "png", "webp"`.
|
||||
|
||||
The resolution of videos can be controlled via script parameters `--height`, `--width`, and `--num_frames`. For each video, the first `num_frames` frames will be used for training; therefore, an error will occur if the video length is less than `num_frames`. Image files will be treated as single-frame videos. When both `--height` and `--width` are left empty, dynamic resolution will be enabled, meaning training will use the actual resolution of each video or image in the dataset.
|
||||
|
||||
**We strongly recommend using fixed-resolution training and avoiding mixing images and videos in the same dataset due to load balancing issues in multi-GPU training.**
|
||||
|
||||
When the model requires additional inputs, such as the `control_video` needed by control-capable models like [`PAI/Wan2.1-Fun-1.3B-Control`](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control), please add corresponding columns in the metadata file, for example:
|
||||
|
||||
```
|
||||
video,prompt,control_video
|
||||
video1.mp4,"from sunset to night, a small town, light, house, river",video1_softedge.mp4
|
||||
```
|
||||
|
||||
If additional inputs contain video or image files, their column names need to be specified in the `--data_file_keys` parameter. The default value of this parameter is `"image,video"`, meaning it parses columns named `image` and `video`. You can extend this list based on the additional input requirements, for example: `--data_file_keys "image,video,control_video"`, and also enable `--input_contains_control_video`.
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 2: Load the Model</summary>
|
||||
|
||||
Similar to the model loading logic during inference, you can configure the model to be loaded directly via its model ID. For instance, during inference we load the model using:
|
||||
|
||||
```python
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
]
|
||||
```
|
||||
|
||||
During training, simply use the following parameter to load the corresponding model:
|
||||
|
||||
```shell
|
||||
--model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth"
|
||||
```
|
||||
|
||||
If you want to load the model from local files, for example during inference:
|
||||
|
||||
```python
|
||||
model_configs=[
|
||||
ModelConfig(path=[
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
|
||||
]),
|
||||
ModelConfig(path="models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(path="models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth"),
|
||||
]
|
||||
```
|
||||
|
||||
Then during training, set the parameter as:
|
||||
|
||||
```shell
|
||||
--model_paths '[
|
||||
[
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors"
|
||||
],
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth"
|
||||
]' \
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 3: Configure Trainable Modules</summary>
|
||||
|
||||
The training framework supports full fine-tuning of base models or LoRA-based training. Here are some examples:
|
||||
|
||||
* Full fine-tuning of the DiT module: `--trainable_models dit`
|
||||
* Training a LoRA model for the DiT module: `--lora_base_model dit --lora_target_modules "q,k,v,o,ffn.0,ffn.2" --lora_rank 32`
|
||||
* Training both a LoRA model for DiT and the Motion Controller (yes, you can train such advanced structures): `--trainable_models motion_controller --lora_base_model dit --lora_target_modules "q,k,v,o,ffn.0,ffn.2" --lora_rank 32`
|
||||
|
||||
Additionally, since multiple modules (text encoder, dit, vae) are loaded in the training script, you need to remove prefixes when saving model files. For example, when fully fine-tuning the DiT module or training a LoRA version of DiT, please set `--remove_prefix_in_ckpt pipe.dit.`
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Step 4: Launch the Training Process</summary>
|
||||
|
||||
We have prepared training commands for each model. Please refer to the table at the beginning of this document.
|
||||
|
||||
Note that full fine-tuning of the 14B model requires 8 GPUs, each with at least 80GB VRAM. During full fine-tuning of these 14B models, you must install `deepspeed` (`pip install deepspeed`). We have provided recommended [configuration files](./model_training/full/accelerate_config_14B.yaml), which will be loaded automatically in the corresponding training scripts. These scripts have been tested on 8*A100.
|
||||
|
||||
The default video resolution in the training script is `480*832*81`. Increasing the resolution may cause out-of-memory errors. To reduce VRAM usage, add the parameter `--use_gradient_checkpointing_offload`.
|
||||
|
||||
</details>
|
||||
|
||||
@@ -1,26 +1,38 @@
|
||||
# 通义万相 2.1(Wan 2.1)
|
||||
|
||||
|模型 ID|类型|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)|基础模型||[code](./model_inference/Wan2.1-T2V-1.3B.py)|[code](./model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)|基础模型||[code](./model_inference/Wan2.1-T2V-14B.py)|[code](./model_training/full/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./model_training/lora/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|基础模型|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-480P.py)|[code](./model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|基础模型|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-720P.py)|[code](./model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|
||||
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|基础模型|`input_image`, `end_image`|[code](./model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](./model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](./model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](./model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](./model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|
||||
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|基础模型|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](./model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|基础模型|`control_video`|[code](./model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](./model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|基础模型|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-14B-InP.py)|[code](./model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|基础模型|`control_video`|[code](./model_inference/Wan2.1-Fun-14B-Control.py)|[code](./model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|基础模型|`control_video`, `reference_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|基础模型|`control_video`, `reference_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|基础模型|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|基础模型|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|基础模型|`control_camera_video`, `input_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|||||
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|基础模型|||||||
|
||||
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|适配器|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](./model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|适配器|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|适配器|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|适配器|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||
[Switch to English](./README.md)
|
||||
|
||||
Wan 2.1 是由阿里巴巴通义实验室开源的一系列视频生成模型。
|
||||
|
||||
在使用本系列模型之前,请通过源码安装 DiffSynth-Studio。
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](./model_inference/Wan2.1-T2V-1.3B.py)|[code](./model_training/full/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](./model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](./model_inference/Wan2.1-T2V-14B.py)|[code](./model_training/full/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_full/Wan2.1-T2V-14B.py)|[code](./model_training/lora/Wan2.1-T2V-14B.sh)|[code](./model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-480P.py)|[code](./model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](./model_inference/Wan2.1-I2V-14B-720P.py)|[code](./model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](./model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](./model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](./model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|
||||
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](./model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](./model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](./model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](./model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|
||||
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](./model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](./model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](./model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-14B-InP.py)|[code](./model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](./model_inference/Wan2.1-Fun-14B-Control.py)|[code](./model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](./model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](./model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](./model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](./model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|
||||
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](./model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-1.3B.py)|[code](./model_training/full/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](./model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](./model_inference/Wan2.1-VACE-14B.py)|[code](./model_training/full/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_full/Wan2.1-VACE-14B.py)|[code](./model_training/lora/Wan2.1-VACE-14B.sh)|[code](./model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](./model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](./model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](./model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||
|
||||
## 模型推理
|
||||
|
||||
@@ -154,7 +166,13 @@ Wan 支持多种加速方案,包括
|
||||
* [Flash Attention 2](https://github.com/Dao-AILab/flash-attention)
|
||||
* [Sage Attention](https://github.com/thu-ml/SageAttention)
|
||||
* [torch SDPA](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) (默认设置,建议安装 `torch>=2.5.0`)
|
||||
* 统一序列并行:基于 [xDiT](https://github.com/xdit-project/xDiT) 实现的序列并行,请参考[示例代码](./acceleration/unified_sequence_parallel.py),使用命令 `torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py` 运行。
|
||||
* 统一序列并行:基于 [xDiT](https://github.com/xdit-project/xDiT) 实现的序列并行,请参考[示例代码](./acceleration/unified_sequence_parallel.py),使用以下命令运行:
|
||||
|
||||
```shell
|
||||
pip install xfuser>=0.4.3
|
||||
torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py
|
||||
```
|
||||
|
||||
* TeaCache:加速技术 [TeaCache](https://github.com/ali-vilab/TeaCache),请参考[示例代码](./acceleration/teacache.py)。
|
||||
|
||||
</details>
|
||||
@@ -174,6 +192,9 @@ Pipeline 在推理阶段能够接收以下输入参数:
|
||||
* `denoising_strength`: 去噪强度,范围为 [0, 1]。数值越小,生成的视频越接近 `input_video`。
|
||||
* `control_video`: 控制视频,适用于带控制能力的 Wan 系列模型,例如 [`PAI/Wan2.1-Fun-1.3B-Control`](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)。
|
||||
* `reference_image`: 参考图片,适用于带参考图能力的 Wan 系列模型,例如 [`PAI/Wan2.1-Fun-V1.1-1.3B-Control`](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)。
|
||||
* `camera_control_direction`: 镜头控制方向,可选 "Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown" 之一,适用于 Camera-Control 模型,例如 [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://www.modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)。
|
||||
* `camera_control_speed`: 镜头控制速度,适用于 Camera-Control 模型,例如 [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://www.modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)。
|
||||
* `camera_control_origin`: 镜头控制序列的原点坐标,请参考[原论文](https://arxiv.org/pdf/2404.02101)进行设置,适用于 Camera-Control 模型,例如 [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://www.modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)。
|
||||
* `vace_video`: VACE 模型的输入视频,适用于 VACE 系列模型,例如 [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)。
|
||||
* `vace_video_mask`: VACE 模型的 mask 视频,适用于 VACE 系列模型,例如 [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)。
|
||||
* `vace_reference_image`: VACE 模型的参考图片,适用于 VACE 系列模型,例如 [`iic/VACE-Wan2.1-1.3B-Preview`](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)。
|
||||
@@ -232,13 +253,7 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
|
||||
* `--lora_target_modules`: LoRA 添加到哪一层上。
|
||||
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||
* 额外模型输入
|
||||
* `--input_contains_input_image`: 模型输入包含 `input_image`
|
||||
* `--input_contains_end_image`: 模型输入包含 `end_image`。
|
||||
* `--input_contains_control_video`: 模型输入包含 `control_video`。
|
||||
* `--input_contains_reference_image`: 模型输入包含 `reference_image`。
|
||||
* `--input_contains_vace_video`: 模型输入包含 `vace_video`。
|
||||
* `--input_contains_vace_reference_image`: 模型输入包含 `vace_reference_image`。
|
||||
* `--input_contains_motion_bucket_id`: 模型输入包含 `motion_bucket_id`。
|
||||
* `--extra_inputs`: 额外的模型输入,以逗号分隔。
|
||||
* 显存管理
|
||||
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from PIL import Image
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
from diffsynth.models.wan_video_camera_controller import process_pose_file
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
@@ -18,92 +18,6 @@ pipe = WanVideoPipeline.from_pretrained(
|
||||
pipe.enable_vram_management()
|
||||
|
||||
|
||||
# Control camera video
|
||||
# text or txt file path, e.g. control_camera_text = "Pan_Left.txt"
|
||||
control_camera_text = '''
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.018518518518518517 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.037037037037037035 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.05555555555555555 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.07407407407407407 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.09259259259259259 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.1111111111111111 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.12962962962962962 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.14814814814814814 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.16666666666666666 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.18518518518518517 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.2037037037037037 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.2222222222222222 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.24074074074074073 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.25925925925925924 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.2777777777777778 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.2962962962962963 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.31481481481481477 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.3333333333333333 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.35185185185185186 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.37037037037037035 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.38888888888888884 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.4074074074074074 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.42592592592592593 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.4444444444444444 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.4629629629629629 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.48148148148148145 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.5 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.5185185185185185 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.537037037037037 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.5555555555555556 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.5740740740740741 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.5925925925925926 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.611111111111111 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.6296296296296295 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.6481481481481481 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.6666666666666666 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.6851851851851851 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7037037037037037 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7222222222222222 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7407407407407407 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7592592592592593 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7777777777777777 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.7962962962962963 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.8148148148148148 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.8333333333333334 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.8518518518518519 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.8703703703703705 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.8888888888888888 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.9074074074074074 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.9259259259259258 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.9444444444444444 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.9629629629629629 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 0.9814814814814815 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.0185185185185186 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.037037037037037 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.0555555555555556 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.074074074074074 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.0925925925925926 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.1111111111111112 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.1296296296296298 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.1481481481481481 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.1666666666666667 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.1851851851851851 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.2037037037037037 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.222222222222222 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.2407407407407407 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.259259259259259 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.2777777777777777 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.2962962962962963 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.3148148148148149 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.3333333333333333 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.3518518518518519 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.3703703703703702 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.3888888888888888 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.4074074074074074 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.425925925925926 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.4444444444444444 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.462962962962963 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
0 0.532139961 0.946026558 0.5 0.5 0 0 1.0 0.0 0.0 1.4814814814814814 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0
|
||||
'''
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
@@ -111,19 +25,20 @@ dataset_snapshot_download(
|
||||
)
|
||||
input_image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
height = 480
|
||||
width = 832
|
||||
|
||||
control_camera_video = process_pose_file(control_camera_text, width, height)
|
||||
|
||||
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
camera_control_direction="Left", camera_control_speed=0.01,
|
||||
)
|
||||
save_video(video, "video_left.mp4", fps=15, quality=5)
|
||||
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
height=height, width=width, num_frames=81,
|
||||
seed=43, tiled=True,
|
||||
control_camera_video = control_camera_video,
|
||||
input_image = input_image,
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
camera_control_direction="Up", camera_control_speed=0.01,
|
||||
)
|
||||
save_video(video, "video.mp4", fps=15, quality=5)
|
||||
save_video(video, "video_up.mp4", fps=15, quality=5)
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=f"data/examples/wan/input_image.jpg"
|
||||
)
|
||||
input_image = Image.open("data/examples/wan/input_image.jpg")
|
||||
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
camera_control_direction="Left", camera_control_speed=0.01,
|
||||
)
|
||||
save_video(video, "video_left.mp4", fps=15, quality=5)
|
||||
|
||||
video = pipe(
|
||||
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
input_image=input_image,
|
||||
camera_control_direction="Up", camera_control_speed=0.01,
|
||||
)
|
||||
save_video(video, "video_up.mp4", fps=15, quality=5)
|
||||
@@ -31,7 +31,7 @@ video = pipe(
|
||||
vace_video=control_video,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video1_14b.mp4", fps=15, quality=5)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
# Reference image -> Video
|
||||
video = pipe(
|
||||
@@ -40,7 +40,7 @@ video = pipe(
|
||||
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video2_14b.mp4", fps=15, quality=5)
|
||||
save_video(video, "video2.mp4", fps=15, quality=5)
|
||||
|
||||
# Depth video + Reference image -> Video
|
||||
video = pipe(
|
||||
@@ -50,4 +50,4 @@ video = pipe(
|
||||
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video3_14b.mp4", fps=15, quality=5)
|
||||
save_video(video, "video3.mp4", fps=15, quality=5)
|
||||
|
||||
@@ -10,4 +10,4 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--remove_prefix_in_ckpt "pipe.motion_controller." \
|
||||
--output_path "./models/train/Wan2.1-1.3b-speedcontrol-v1_full" \
|
||||
--trainable_models "motion_controller" \
|
||||
--input_contains_motion_bucket_id
|
||||
--extra_inputs "motion_bucket_id"
|
||||
@@ -10,5 +10,4 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.1-FLF2V-14B-720P_full" \
|
||||
--trainable_models "dit" \
|
||||
--input_contains_input_image \
|
||||
--input_contains_end_image
|
||||
--extra_inputs "input_image,end_image"
|
||||
@@ -11,4 +11,4 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.1-Fun-1.3B-Control_full" \
|
||||
--trainable_models "dit" \
|
||||
--input_contains_control_video
|
||||
--extra_inputs "control_video"
|
||||
@@ -10,5 +10,4 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.1-Fun-1.3B-InP_full" \
|
||||
--trainable_models "dit" \
|
||||
--input_contains_input_image \
|
||||
--input_contains_end_image
|
||||
--extra_inputs "input_image,end_image"
|
||||
@@ -11,4 +11,4 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.1-Fun-14B-Control_full" \
|
||||
--trainable_models "dit" \
|
||||
--input_contains_control_video
|
||||
--extra_inputs "control_video"
|
||||
@@ -10,5 +10,4 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.1-Fun-14B-InP_full" \
|
||||
--trainable_models "dit" \
|
||||
--input_contains_input_image \
|
||||
--input_contains_end_image
|
||||
--extra_inputs "input_image,end_image"
|
||||
@@ -0,0 +1,13 @@
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata_camera_control.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_full" \
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "input_image,camera_control_direction,camera_control_speed"
|
||||
@@ -11,5 +11,4 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-Control_full" \
|
||||
--trainable_models "dit" \
|
||||
--input_contains_control_video \
|
||||
--input_contains_reference_image
|
||||
--extra_inputs "control_video,reference_image"
|
||||
@@ -10,5 +10,4 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-InP_full" \
|
||||
--trainable_models "dit" \
|
||||
--input_contains_input_image \
|
||||
--input_contains_end_image
|
||||
--extra_inputs "input_image,end_image"
|
||||
@@ -0,0 +1,13 @@
|
||||
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata_camera_control.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.1-Fun-V1.1-14B-Control-Camera_full" \
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "input_image,camera_control_direction,camera_control_speed"
|
||||
@@ -11,5 +11,4 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.1-Fun-V1.1-14B-Control_full" \
|
||||
--trainable_models "dit" \
|
||||
--input_contains_control_video \
|
||||
--input_contains_reference_image
|
||||
--extra_inputs "control_video,reference_image"
|
||||
@@ -10,5 +10,4 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.1-Fun-V1.1-14B-InP_full" \
|
||||
--trainable_models "dit" \
|
||||
--input_contains_input_image \
|
||||
--input_contains_end_image
|
||||
--extra_inputs "input_image,end_image"
|
||||
@@ -10,4 +10,4 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.1-I2V-14B-480P_full" \
|
||||
--trainable_models "dit" \
|
||||
--input_contains_input_image
|
||||
--extra_inputs "input_image"
|
||||
@@ -10,4 +10,4 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.1-I2V-14B-720P_full" \
|
||||
--trainable_models "dit" \
|
||||
--input_contains_input_image
|
||||
--extra_inputs "input_image"
|
||||
@@ -12,6 +12,5 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--remove_prefix_in_ckpt "pipe.vace." \
|
||||
--output_path "./models/train/Wan2.1-VACE-1.3B-Preview_full" \
|
||||
--trainable_models "vace" \
|
||||
--input_contains_vace_video \
|
||||
--input_contains_vace_reference_image \
|
||||
--extra_inputs "vace_video,vace_reference_image" \
|
||||
--use_gradient_checkpointing_offload
|
||||
@@ -12,6 +12,5 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--remove_prefix_in_ckpt "pipe.vace." \
|
||||
--output_path "./models/train/Wan2.1-VACE-1.3B_full" \
|
||||
--trainable_models "vace" \
|
||||
--input_contains_vace_video \
|
||||
--input_contains_vace_reference_image \
|
||||
--extra_inputs "vace_video,vace_reference_image" \
|
||||
--use_gradient_checkpointing_offload
|
||||
@@ -12,6 +12,5 @@ accelerate launch --config_file examples/wanvideo/model_training/full/accelerate
|
||||
--remove_prefix_in_ckpt "pipe.vace." \
|
||||
--output_path "./models/train/Wan2.1-VACE-14B_full" \
|
||||
--trainable_models "vace" \
|
||||
--input_contains_vace_video \
|
||||
--input_contains_vace_reference_image \
|
||||
--extra_inputs "vace_video,vace_reference_image" \
|
||||
--use_gradient_checkpointing_offload
|
||||
@@ -12,4 +12,4 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--input_contains_motion_bucket_id
|
||||
--extra_inputs "motion_bucket_id"
|
||||
@@ -12,5 +12,4 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--input_contains_input_image \
|
||||
--input_contains_end_image
|
||||
--extra_inputs "input_image,end_image"
|
||||
@@ -13,4 +13,4 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--input_contains_control_video
|
||||
--extra_inputs "control_video"
|
||||
@@ -12,5 +12,4 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--input_contains_input_image \
|
||||
--input_contains_end_image
|
||||
--extra_inputs "input_image,end_image"
|
||||
@@ -13,4 +13,4 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--input_contains_control_video
|
||||
--extra_inputs "control_video"
|
||||
@@ -12,5 +12,4 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--input_contains_input_image \
|
||||
--input_contains_end_image
|
||||
--extra_inputs "input_image,end_image"
|
||||
@@ -0,0 +1,15 @@
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata_camera_control.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_full" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "input_image,camera_control_direction,camera_control_speed"
|
||||
@@ -13,5 +13,4 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--input_contains_control_video \
|
||||
--input_contains_reference_image
|
||||
--extra_inputs "control_video,reference_image"
|
||||
@@ -12,5 +12,4 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--input_contains_input_image \
|
||||
--input_contains_end_image
|
||||
--extra_inputs "input_image,end_image"
|
||||
@@ -0,0 +1,15 @@
|
||||
accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset \
|
||||
--dataset_metadata_path data/example_video_dataset/metadata_camera_control.csv \
|
||||
--height 480 \
|
||||
--width 832 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:diffusion_pytorch_model*.safetensors,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:models_t5_umt5-xxl-enc-bf16.pth,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:Wan2.1_VAE.pth,PAI/Wan2.1-Fun-V1.1-14B-Control-Camera:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Wan2.1-Fun-V1.1-14B-Control-Camera_full" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--extra_inputs "input_image,camera_control_direction,camera_control_speed"
|
||||
@@ -13,5 +13,4 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--input_contains_control_video \
|
||||
--input_contains_reference_image
|
||||
--extra_inputs "control_video,reference_image"
|
||||
@@ -12,5 +12,4 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--input_contains_input_image \
|
||||
--input_contains_end_image
|
||||
--extra_inputs "input_image,end_image"
|
||||
@@ -12,4 +12,4 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--input_contains_input_image
|
||||
--extra_inputs "input_image"
|
||||
@@ -12,4 +12,4 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--input_contains_input_image
|
||||
--extra_inputs "input_image"
|
||||
@@ -13,6 +13,5 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--lora_base_model "vace" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--input_contains_vace_video \
|
||||
--input_contains_vace_reference_image \
|
||||
--extra_inputs "vace_video,vace_reference_image" \
|
||||
--use_gradient_checkpointing_offload
|
||||
@@ -13,6 +13,5 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--lora_base_model "vace" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--input_contains_vace_video \
|
||||
--input_contains_vace_reference_image \
|
||||
--extra_inputs "vace_video,vace_reference_image" \
|
||||
--use_gradient_checkpointing_offload
|
||||
@@ -14,6 +14,5 @@ accelerate launch examples/wanvideo/model_training/train.py \
|
||||
--lora_base_model "vace" \
|
||||
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||
--lora_rank 32 \
|
||||
--input_contains_vace_video \
|
||||
--input_contains_vace_reference_image \
|
||||
--extra_inputs "vace_video,vace_reference_image" \
|
||||
--use_gradient_checkpointing_offload
|
||||
@@ -13,14 +13,7 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
lora_base_model=None, lora_target_modules="q,k,v,o,ffn.0,ffn.2", lora_rank=32,
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
# Extra inputs
|
||||
input_contains_input_image=False,
|
||||
input_contains_end_image=False,
|
||||
input_contains_control_video=False,
|
||||
input_contains_reference_image=False,
|
||||
input_contains_vace_video=False,
|
||||
input_contains_vace_reference_image=False,
|
||||
input_contains_motion_bucket_id=False,
|
||||
extra_inputs=None,
|
||||
):
|
||||
super().__init__()
|
||||
# Load models
|
||||
@@ -51,13 +44,7 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
# Store other configs
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
self.input_contains_input_image = input_contains_input_image
|
||||
self.input_contains_end_image = input_contains_end_image
|
||||
self.input_contains_control_video = input_contains_control_video
|
||||
self.input_contains_reference_image = input_contains_reference_image
|
||||
self.input_contains_vace_video = input_contains_vace_video
|
||||
self.input_contains_vace_reference_image = input_contains_vace_reference_image
|
||||
self.input_contains_motion_bucket_id = input_contains_motion_bucket_id
|
||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||
|
||||
|
||||
def forward_preprocess(self, data):
|
||||
@@ -85,13 +72,13 @@ class WanTrainingModule(DiffusionTrainingModule):
|
||||
}
|
||||
|
||||
# Extra inputs
|
||||
if self.input_contains_input_image: inputs_shared["input_image"] = data["video"][0]
|
||||
if self.input_contains_end_image: inputs_shared["end_image"] = data["video"][-1]
|
||||
if self.input_contains_control_video: inputs_shared["control_video"] = data["control_video"]
|
||||
if self.input_contains_reference_image: inputs_shared["reference_image"] = data["reference_image"]
|
||||
if self.input_contains_vace_video: inputs_shared["vace_video"] = data["vace_video"]
|
||||
if self.input_contains_vace_reference_image: inputs_shared["vace_reference_image"] = data["vace_reference_image"]
|
||||
if self.input_contains_motion_bucket_id: inputs_shared["motion_bucket_id"] = data["motion_bucket_id"]
|
||||
for extra_input in self.extra_inputs:
|
||||
if extra_input == "input_image":
|
||||
inputs_shared["input_image"] = data["video"][0]
|
||||
elif extra_input == "end_image":
|
||||
inputs_shared["end_image"] = data["video"][-1]
|
||||
else:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
|
||||
# Pipeline units will automatically process the input parameters.
|
||||
for unit in self.pipe.units:
|
||||
@@ -118,12 +105,6 @@ if __name__ == "__main__":
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
lora_rank=args.lora_rank,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
input_contains_input_image=args.input_contains_input_image,
|
||||
input_contains_end_image=args.input_contains_end_image,
|
||||
input_contains_control_video=args.input_contains_control_video,
|
||||
input_contains_reference_image=args.input_contains_reference_image,
|
||||
input_contains_vace_video=args.input_contains_vace_video,
|
||||
input_contains_vace_reference_image=args.input_contains_vace_reference_image,
|
||||
input_contains_motion_bucket_id=args.input_contains_motion_bucket_id,
|
||||
extra_inputs=args.extra_inputs,
|
||||
)
|
||||
launch_training_task(model, dataset, args=args)
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData, load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_full/epoch-1.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)
|
||||
|
||||
# First and last frame to video
|
||||
video = pipe(
|
||||
prompt="from sunset to night, a small town, light, house, river",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=video[0],
|
||||
camera_control_direction="Left", camera_control_speed=0.0,
|
||||
seed=0, tiled=True
|
||||
)
|
||||
save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-Control-Camera.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,32 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData, load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
state_dict = load_state_dict("models/train/Wan2.1-Fun-V1.1-14B-Control-Camera_full/epoch-1.safetensors")
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)
|
||||
|
||||
# First and last frame to video
|
||||
video = pipe(
|
||||
prompt="from sunset to night, a small town, light, house, river",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=video[0],
|
||||
camera_control_direction="Left", camera_control_speed=0.0,
|
||||
seed=0, tiled=True
|
||||
)
|
||||
save_video(video, "video_Wan2.1-Fun-V1.1-14B-Control-Camera.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData, load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_lora/epoch-4.safetensors", alpha=1)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)
|
||||
|
||||
# First and last frame to video
|
||||
video = pipe(
|
||||
prompt="from sunset to night, a small town, light, house, river",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=video[0],
|
||||
camera_control_direction="Left", camera_control_speed=0.0,
|
||||
seed=0, tiled=True
|
||||
)
|
||||
save_video(video, "video_Wan2.1-Fun-V1.1-1.3B-Control-Camera.mp4", fps=15, quality=5)
|
||||
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffsynth import save_video, VideoData, load_state_dict
|
||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||
ModelConfig(model_id="PAI/Wan2.1-Fun-V1.1-14B-Control-Camera", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||
],
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/Wan2.1-Fun-V1.1-1.3B-Control-Camera_lora/epoch-4.safetensors", alpha=1)
|
||||
pipe.enable_vram_management()
|
||||
|
||||
video = VideoData("data/example_video_dataset/video1.mp4", height=480, width=832)
|
||||
|
||||
# First and last frame to video
|
||||
video = pipe(
|
||||
prompt="from sunset to night, a small town, light, house, river",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
input_image=video[0],
|
||||
camera_control_direction="Left", camera_control_speed=0.0,
|
||||
seed=0, tiled=True
|
||||
)
|
||||
save_video(video, "video_Wan2.1-Fun-V1.1-14B-Control-Camera.mp4", fps=15, quality=5)
|
||||
Reference in New Issue
Block a user