Support WanToDance (#1361)

* support wantodance

* update docs

* bugfix
This commit is contained in:
Zhongjie Duan
2026-03-20 16:40:35 +08:00
committed by GitHub
parent ba0626e38f
commit 52ba5d414e
22 changed files with 1210 additions and 13 deletions

View File

@@ -0,0 +1,49 @@
import torch
from PIL import Image
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="global_model.safetensors"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="Wan2.1_VAE.pth"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
)
pipe.load_lora(pipe.dit, "models/train/WanToDance-14B-global_lora/epoch-4.safetensors", alpha=1)
dataset_snapshot_download(
"DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="wanvideo/WanToDance-14B-global/*"
)
# This is a specialized model with the following constraints on its input parameters:
# * The model outputs a sequence of keyframes rather than a video; therefore, `framewise_decoding=True` must be set.
# * When the number of keyframes is $n$, `num_frames` = 4 * (n - 1) + 1.
# * Reducing `height`, `width`, `num_frames`, or `num_inference_steps` may lead to severe artifacts or generation failure.
# * The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 7.5) seconds.
# * The width and height of `wantodance_reference_image` must be multiples of 16.
# * `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 7.5 FPS, setting it to other values is not recommended.
# * The first frame of `wantodance_keyframes` is the `wantodance_reference_image`, while all subsequent frames are solid black.
# * `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`.
wantodance_keyframes = VideoData("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/keyframes.mp4")
wantodance_keyframes = [wantodance_keyframes[i] for i in range(149)]
video = pipe(
prompt="一个人正在跳舞舞蹈种类是韩舞。帧率是7.5000",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=False,
height=1280, width=720, num_frames=149,
num_inference_steps=48,
wantodance_music_path="data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/music.WAV",
wantodance_reference_image=Image.open("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/refimage.jpg"),
wantodance_fps=7.5,
wantodance_keyframes=wantodance_keyframes,
wantodance_keyframes_mask=[1] + [0] * 148,
framewise_decoding=True,
)
save_video(video, "video_WanToDance-14B-global.mp4", fps=7.5, quality=5)

View File

@@ -0,0 +1,53 @@
import torch, os
from PIL import Image
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="local_model.safetensors"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="Wan2.1_VAE.pth"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
)
pipe.load_lora(pipe.dit, "models/train/WanToDance-14B-global_lora/epoch-4.safetensors", alpha=1)
dataset_snapshot_download(
"DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="wanvideo/WanToDance-14B-local/*"
)
# This is a specialized model with the following constraints on its input parameters:
# * The model renders and outputs video based on a sequence of keyframes; therefore, `wantodance_keyframes` must be provided correctly.
# * If you need to generate a long video, please generate it in segments, and ensure that `wantodance_music_path`, `wantodance_keyframes`, and `wantodance_keyframes_mask` are properly split accordingly.
# * The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 30) seconds.
# * The width and height of `wantodance_reference_image` must be multiples of 16.
# * `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 30 FPS, setting it to other values is not recommended.
# * In `wantodance_keyframes`, frames that are not keyframes should be solid black.
# * `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`.
wantodance_keyframes = VideoData("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/keyframes.mp4")
wantodance_keyframes = [wantodance_keyframes[i] for i in range(149)]
video = pipe(
prompt="一个人正在跳舞,舞蹈种类是古典舞,图像清晰程度高,人物动作平均幅度中等,人物动作最大幅度中等。, 帧率是30fps。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=True,
height=1280, width=720, num_frames=149,
num_inference_steps=24,
wantodance_music_path="data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/music.wav",
wantodance_reference_image=Image.open("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/refimage.jpg"),
wantodance_fps=30,
wantodance_keyframes=wantodance_keyframes,
wantodance_keyframes_mask=[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1],
)
save_video(video, "video_WanToDance-14B-local.mp4", fps=30, quality=5)