Wan 2.1
Wan 2.1 is a collection of video synthesis models open-sourced by Alibaba.
DiffSynth-Studio has adopted a new inference and training framework. To use the previous version, please click here.
Installation
Before using this model, please install DiffSynth-Studio from source code.
git clone https://github.com/modelscope/DiffSynth-Studio.git
cd DiffSynth-Studio
pip install -e .
Overview
Model Inference
The following sections will help you understand our functionalities and write inference code.
Loading the Model
The model is loaded using from_pretrained:
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"),
],
)
Here, torch_dtype and device specify the computation precision and device respectively. The model_configs can be used to configure model paths in various ways:
- Downloading the model from ModelScope and loading it. In this case, both
model_idandorigin_file_patternneed to be specified, for example:
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors")
- Loading the model from a local file path. In this case, the
pathparameter needs to be specified, for example:
ModelConfig(path="models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors")
For models that are loaded from multiple files, simply use a list, for example:
ModelConfig(path=[
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
])
The from_pretrained function also provides additional parameters to control the behavior during model loading:
tokenizer_config: Path to the tokenizer of the Wan model. Default value isModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/*").local_model_path: Path where downloaded models are saved. Default value is"./models".skip_download: Whether to skip downloading models. Default value isFalse. When your network cannot access ModelScope, manually download the necessary files and set this toTrue.redirect_common_files: Whether to redirect duplicate model files. Default value isTrue. Since the Wan series models include multiple base models, some modules like text encoder are shared across these models. To avoid redundant downloads, we redirect the model paths.use_usp: Whether to enable Unified Sequence Parallel. Default value isFalse. Used for multi-GPU parallel inference.
VRAM Management
DiffSynth-Studio provides fine-grained VRAM management for the Wan model, allowing it to run on devices with limited VRAM. You can enable offloading functionality via the following code, which moves parts of the model to system memory on devices with limited VRAM:
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
],
)
pipe.enable_vram_management()
FP8 quantization is also supported:
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_dtype=torch.float8_e4m3fn),
],
)
pipe.enable_vram_management()
Both FP8 quantization and offloading can be enabled simultaneously:
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
],
)
pipe.enable_vram_management()
FP8 quantization significantly reduces VRAM usage but does not accelerate computations. Some models may experience issues such as blurry, torn, or distorted outputs due to insufficient precision when using FP8 quantization. Use FP8 quantization with caution.
The enable_vram_management function provides the following parameters to control VRAM usage:
vram_limit: VRAM usage limit (in GB). By default, it uses all available VRAM on the device. Note that this is not an absolute limit; if the specified VRAM is insufficient but more VRAM is actually available, inference will proceed using the minimum required VRAM.vram_buffer: Size of the VRAM buffer (in GB). Default is 0.5GB. Since certain large neural network layers may consume more VRAM unpredictably during their execution phase, a VRAM buffer is necessary. Ideally, this should match the maximum VRAM consumed by any single layer in the model.num_persistent_param_in_dit: Number of persistent parameters in DiT models. By default, there is no limit. We plan to remove this parameter in the future, so please avoid relying on it.
Inference Acceleration
Wan supports multiple acceleration techniques, including:
- Efficient attention implementations: If any of these attention implementations are installed in your Python environment, they will be automatically enabled in the following priority:
- Flash Attention 3
- Flash Attention 2
- Sage Attention
- torch SDPA (default setting; we recommend installing
torch>=2.5.0)
- Unified Sequence Parallel: Sequence parallelism based on xDiT. Please refer to this example, and run it using the command:
pip install xfuser>=0.4.3
torchrun --standalone --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py
- TeaCache: Acceleration technique TeaCache. Please refer to this example.
Input Parameters
The pipeline accepts the following input parameters during inference:
prompt: Prompt describing the content to appear in the video.negative_prompt: Negative prompt describing content that should not appear in the video. Default is"".input_image: Input image, applicable for image-to-video models such asWan-AI/Wan2.1-I2V-14B-480PandPAI/Wan2.1-Fun-1.3B-InP, as well as first-and-last-frame models likeWan-AI/Wan2.1-FLF2V-14B-720P.end_image: End frame, applicable for first-and-last-frame models such asWan-AI/Wan2.1-FLF2V-14B-720P.input_video: Input video used for video-to-video generation. Applicable to any Wan series model and must be used together withdenoising_strength.denoising_strength: Denoising strength in range [0, 1]. A smaller value results in a video closer toinput_video.control_video: Control video, applicable to Wan models with control capabilities such asPAI/Wan2.1-Fun-1.3B-Control.reference_image: Reference image, applicable to Wan models supporting reference images such asPAI/Wan2.1-Fun-V1.1-1.3B-Control.camera_control_direction: Camera control direction, optional values are "Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown". Applicable to Camera-Control models, such as PAI/Wan2.1-Fun-V1.1-14B-Control-Camera.camera_control_speed: Camera control speed. Applicable to Camera-Control models, such as PAI/Wan2.1-Fun-V1.1-14B-Control-Camera.camera_control_origin: Origin coordinate of the camera control sequence. Please refer to the original paper for proper configuration. Applicable to Camera-Control models, such as PAI/Wan2.1-Fun-V1.1-14B-Control-Camera.vace_video: Input video for VACE models, applicable to the VACE series such asiic/VACE-Wan2.1-1.3B-Preview.vace_video_mask: Mask video for VACE models, applicable to the VACE series such asiic/VACE-Wan2.1-1.3B-Preview.vace_reference_image: Reference image for VACE models, applicable to the VACE series such asiic/VACE-Wan2.1-1.3B-Preview.vace_scale: Influence of the VACE model on the base model, default is 1. Higher values increase control strength but may lead to visual artifacts or breakdowns.seed: Random seed. Default isNone, meaning fully random.rand_device: Device used to generate random Gaussian noise matrix. Default is"cpu". When set to"cuda", different GPUs may produce different generation results.height: Frame height, default is 480. Must be a multiple of 16; if not, it will be rounded up.width: Frame width, default is 832. Must be a multiple of 16; if not, it will be rounded up.num_frames: Number of frames, default is 81. Must be a multiple of 4 plus 1; if not, it will be rounded up, minimum is 1.cfg_scale: Classifier-free guidance scale, default is 5. Higher values increase adherence to the prompt but may cause visual artifacts.cfg_merge: Whether to merge both sides of classifier-free guidance for unified inference. Default isFalse. This parameter currently only works for basic text-to-video and image-to-video models.num_inference_steps: Number of inference steps, default is 50.sigma_shift: Parameter from Rectified Flow theory, default is 5. Higher values make the model stay longer at the initial denoising stage. Increasing this may improve video quality but may also cause inconsistency between generated videos and training data due to deviation from training behavior.motion_bucket_id: Motion intensity, range [0, 100], applicable to motion control modules such asDiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1. Larger values indicate more intense motion.tiled: Whether to enable tiled VAE inference, default isFalse. Setting toTruesignificantly reduces VRAM usage during VAE encoding/decoding but introduces small errors and slightly increases inference time.tile_size: Tile size during VAE encoding/decoding, default is (30, 52), only effective whentiled=True.tile_stride: Stride of tiles during VAE encoding/decoding, default is (15, 26), only effective whentiled=True. Must be less than or equal totile_size.sliding_window_size: Sliding window size for DiT part. Experimental feature, effects are unstable.sliding_window_stride: Sliding window stride for DiT part. Experimental feature, effects are unstable.tea_cache_l1_thresh: Threshold for TeaCache. Larger values result in faster speed but lower quality. Note that after enabling TeaCache, the inference speed is not uniform, so the remaining time shown on the progress bar becomes inaccurate.tea_cache_model_id: TeaCache parameter template, options include"Wan2.1-T2V-1.3B","Wan2.1-T2V-14B","Wan2.1-I2V-14B-480P","Wan2.1-I2V-14B-720P".progress_bar_cmd: Progress bar implementation, default istqdm.tqdm. You can set it tolambda x:xto disable the progress bar.
Model Training
Wan series models are trained using a unified script located at ./model_training/train.py.
Script Parameters
The script includes the following parameters:
- Dataset
--dataset_base_path: Base path of the dataset.--dataset_metadata_path: Path to the metadata file of the dataset.--height: Height of images or videos. Leaveheightandwidthempty to enable dynamic resolution.--width: Width of images or videos. Leaveheightandwidthempty to enable dynamic resolution.--num_frames: Number of frames per video. Frames are sampled from the video prefix.--data_file_keys: Data file keys in the metadata. Comma-separated.--dataset_repeat: Number of times to repeat the dataset per epoch.
- Models
--model_paths: Paths to load models. In JSON format.--model_id_with_origin_paths: Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.
- Training
--learning_rate: Learning rate.--num_epochs: Number of epochs.--output_path: Output save path.--remove_prefix_in_ckpt: Remove prefix in ckpt.
- Trainable Modules
--trainable_models: Models to train, e.g., dit, vae, text_encoder.--lora_base_model: Which model LoRA is added to.--lora_target_modules: Which layers LoRA is added to.--lora_rank: Rank of LoRA.
- Extra Inputs
--extra_inputs: Additional model inputs, comma-separated.
- VRAM Management
--use_gradient_checkpointing_offload: Whether to offload gradient checkpointing to CPU memory.
Additionally, the training framework is built upon accelerate. Before starting training, run accelerate config to configure GPU-related parameters. For certain training scripts (e.g., full fine-tuning of 14B models), we provide recommended accelerate configuration files, which can be found in the corresponding training scripts.
Step 1: Prepare the Dataset
The dataset consists of a series of files. We recommend organizing your dataset as follows:
data/example_video_dataset/
├── metadata.csv
├── video1.mp4
└── video2.mp4
Here, video1.mp4 and video2.mp4 are training video files, and metadata.csv is the metadata list, for example:
video,prompt
video1.mp4,"from sunset to night, a small town, light, house, river"
video2.mp4,"a dog is running"
We have prepared a sample video dataset to help you test. You can download it using the following command:
modelscope download --dataset DiffSynth-Studio/example_video_dataset README.md --local_dir ./data/example_video_dataset
The dataset supports mixed training of videos and images. Supported video formats include "mp4", "avi", "mov", "wmv", "mkv", "flv", "webm", and supported image formats include "jpg", "jpeg", "png", "webp".
The resolution of videos can be controlled via script parameters --height, --width, and --num_frames. For each video, the first num_frames frames will be used for training; therefore, an error will occur if the video length is less than num_frames. Image files will be treated as single-frame videos. When both --height and --width are left empty, dynamic resolution will be enabled, meaning training will use the actual resolution of each video or image in the dataset.
We strongly recommend using fixed-resolution training and avoiding mixing images and videos in the same dataset due to load balancing issues in multi-GPU training.
When the model requires additional inputs, such as the control_video needed by control-capable models like PAI/Wan2.1-Fun-1.3B-Control, please add corresponding columns in the metadata file, for example:
video,prompt,control_video
video1.mp4,"from sunset to night, a small town, light, house, river",video1_softedge.mp4
If additional inputs contain video or image files, their column names need to be specified in the --data_file_keys parameter. The default value of this parameter is "image,video", meaning it parses columns named image and video. You can extend this list based on the additional input requirements, for example: --data_file_keys "image,video,control_video", and also enable --input_contains_control_video.
Step 2: Load the Model
Similar to the model loading logic during inference, you can configure the model to be loaded directly via its model ID. For instance, during inference we load the model using:
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
]
During training, simply use the following parameter to load the corresponding model:
--model_id_with_origin_paths "Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.1-T2V-1.3B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.1-T2V-1.3B:Wan2.1_VAE.pth"
If you want to load the model from local files, for example during inference:
model_configs=[
ModelConfig(path=[
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors",
]),
ModelConfig(path="models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth"),
ModelConfig(path="models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth"),
]
Then during training, set the parameter as:
--model_paths '[
[
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00001-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00002-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00003-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00004-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00005-of-00006.safetensors",
"models/Wan-AI/Wan2.1-T2V-14B/diffusion_pytorch_model-00006-of-00006.safetensors"
],
"models/Wan-AI/Wan2.1-T2V-14B/models_t5_umt5-xxl-enc-bf16.pth",
"models/Wan-AI/Wan2.1-T2V-14B/Wan2.1_VAE.pth"
]' \
Step 3: Configure Trainable Modules
The training framework supports full fine-tuning of base models or LoRA-based training. Here are some examples:
- Full fine-tuning of the DiT module:
--trainable_models dit - Training a LoRA model for the DiT module:
--lora_base_model dit --lora_target_modules "q,k,v,o,ffn.0,ffn.2" --lora_rank 32 - Training both a LoRA model for DiT and the Motion Controller (yes, you can train such advanced structures):
--trainable_models motion_controller --lora_base_model dit --lora_target_modules "q,k,v,o,ffn.0,ffn.2" --lora_rank 32
Additionally, since multiple modules (text encoder, dit, vae) are loaded in the training script, you need to remove prefixes when saving model files. For example, when fully fine-tuning the DiT module or training a LoRA version of DiT, please set --remove_prefix_in_ckpt pipe.dit.
Step 4: Launch the Training Process
We have prepared training commands for each model. Please refer to the table at the beginning of this document.
Note that full fine-tuning of the 14B model requires 8 GPUs, each with at least 80GB VRAM. During full fine-tuning of these 14B models, you must install deepspeed (pip install deepspeed). We have provided recommended configuration files, which will be loaded automatically in the corresponding training scripts. These scripts have been tested on 8*A100.
The default video resolution in the training script is 480*832*81. Increasing the resolution may cause out-of-memory errors. To reduce VRAM usage, add the parameter --use_gradient_checkpointing_offload.
案例展示
1.3B text-to-video:
https://github.com/user-attachments/assets/124397be-cd6a-4f29-a87c-e4c695aaabb8
Put sunglasses on the dog (1.3B video-to-video):
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
14B text-to-video:
https://github.com/user-attachments/assets/3908bc64-d451-485a-8b61-28f6d32dd92f
14B image-to-video:
https://github.com/user-attachments/assets/c0bdd5ca-292f-45ed-b9bc-afe193156e75
LoRA training:
https://github.com/user-attachments/assets/9bd8e30b-97e8-44f9-bb6f-da004ba376a9