Support WanToDance (#1361)

* support wantodance

* update docs

* bugfix
This commit is contained in:
Zhongjie Duan
2026-03-20 16:40:35 +08:00
committed by GitHub
parent ba0626e38f
commit 52ba5d414e
22 changed files with 1210 additions and 13 deletions

View File

@@ -870,6 +870,8 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/)
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
| [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) | `input_image` | [code](/examples/mova/model_inference/MOVA-360p-I2AV.py) | [code](/examples/mova/model_training/full/MOVA-360P-I2AV.sh) | [code](/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py) | [code](/examples/mova/model_training/lora/MOVA-360P-I2AV.sh) | [code](/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py) |
| [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) | `input_image` | [code](/examples/mova/model_inference/MOVA-720p-I2AV.py) | [code](/examples/mova/model_training/full/MOVA-720P-I2AV.sh) | [code](/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py) | [code](/examples/mova/model_training/lora/MOVA-720P-I2AV.sh) | [code](/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py) |
| [Wan-AI/WanToDance-14B (global model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B) | `wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask` | [code](/examples/wanvideo/model_inference/WanToDance-14B-global.py) | [code](/examples/wanvideo/model_training/full/WanToDance-14B-global.sh) | [code](/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py) | [code](/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh) | [code](/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py) |
| [Wan-AI/WanToDance-14B (local model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B) | `wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask` | [code](/examples/wanvideo/model_inference/WanToDance-14B-global.py) | [code](/examples/wanvideo/model_training/full/WanToDance-14B-global.sh) | [code](/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py) | [code](/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh) | [code](/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py) |
</details>

View File

@@ -870,6 +870,8 @@ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/)
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
| [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) | `input_image` | [code](/examples/mova/model_inference/MOVA-360p-I2AV.py) | [code](/examples/mova/model_training/full/MOVA-360P-I2AV.sh) | [code](/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py) | [code](/examples/mova/model_training/lora/MOVA-360P-I2AV.sh) | [code](/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py) |
| [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) | `input_image` | [code](/examples/mova/model_inference/MOVA-720p-I2AV.py) | [code](/examples/mova/model_training/full/MOVA-720P-I2AV.sh) | [code](/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py) | [code](/examples/mova/model_training/lora/MOVA-720P-I2AV.sh) | [code](/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py) |
| [Wan-AI/WanToDance-14B (global model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B) | `wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask` | [code](/examples/wanvideo/model_inference/WanToDance-14B-global.py) | [code](/examples/wanvideo/model_training/full/WanToDance-14B-global.sh) | [code](/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py) | [code](/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh) | [code](/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py) |
| [Wan-AI/WanToDance-14B (local model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B) | `wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask` | [code](/examples/wanvideo/model_inference/WanToDance-14B-global.py) | [code](/examples/wanvideo/model_training/full/WanToDance-14B-global.sh) | [code](/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py) | [code](/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh) | [code](/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py) |
</details>

View File

@@ -307,6 +307,13 @@ wan_series = [
"model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="global_model.safetensors")
"model_hash": "eb18873fc0ba77b541eb7b62dbcd2059",
"model_name": "wan_video_dit",
"model_class": "diffsynth.models.wan_video_dit.WanModel",
"extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'wantodance_enable_music_inject': True, 'wantodance_music_inject_layers': [0, 4, 8, 12, 16, 20, 24, 27], 'wantodance_enable_refimage': True, 'has_ref_conv': True, 'wantodance_enable_refface': False, 'wantodance_enable_global': True, 'wantodance_enable_dynamicfps': True, 'wantodance_enable_unimodel': True}
},
]
flux_series = [

View File

@@ -6,6 +6,7 @@ from typing import Tuple, Optional
from einops import rearrange
from .wan_video_camera_controller import SimpleAdapter
from ..core.gradient import gradient_checkpoint_forward
from .wantodance import WanToDanceRotaryEmbedding, WanToDanceMusicEncoderLayer
try:
import flash_attn_interface
@@ -283,6 +284,57 @@ class Head(nn.Module):
return x
def wantodance_torch_dfs(model: nn.Module, parent_name='root'):
module_names, modules = [], []
current_name = parent_name if parent_name else 'root'
module_names.append(current_name)
modules.append(model)
for name, child in model.named_children():
if parent_name:
child_name = f'{parent_name}.{name}'
else:
child_name = name
child_modules, child_names = wantodance_torch_dfs(child, child_name)
module_names += child_names
modules += child_modules
return modules, module_names
class WanToDanceInjector(nn.Module):
def __init__(self, all_modules, all_modules_names, dim=2048, num_heads=32, inject_layer=[0, 27]):
super().__init__()
self.injected_block_id = {}
injector_id = 0
for mod_name, mod in zip(all_modules_names, all_modules):
if isinstance(mod, DiTBlock):
for inject_id in inject_layer:
if f'root.transformer_blocks.{inject_id}' == mod_name:
self.injected_block_id[inject_id] = injector_id
injector_id += 1
self.injector = nn.ModuleList(
[
CrossAttention(
dim=dim,
num_heads=num_heads,
)
for _ in range(injector_id)
]
)
self.injector_pre_norm_feat = nn.ModuleList(
[
nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6,)
for _ in range(injector_id)
]
)
self.injector_pre_norm_vec = nn.ModuleList(
[
nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6,)
for _ in range(injector_id)
]
)
class WanModel(torch.nn.Module):
def __init__(
self,
@@ -305,6 +357,13 @@ class WanModel(torch.nn.Module):
require_vae_embedding: bool = True,
require_clip_embedding: bool = True,
fuse_vae_embedding_in_latents: bool = False,
wantodance_enable_music_inject: bool = False,
wantodance_music_inject_layers = [0, 4, 8, 12, 16, 20, 24, 27],
wantodance_enable_refimage: bool = False,
wantodance_enable_refface: bool = False,
wantodance_enable_global: bool = False,
wantodance_enable_dynamicfps: bool = False,
wantodance_enable_unimodel: bool = False,
):
super().__init__()
self.dim = dim
@@ -337,7 +396,12 @@ class WanModel(torch.nn.Module):
])
self.head = Head(dim, out_dim, patch_size, eps)
head_dim = dim // num_heads
self.freqs = precompute_freqs_cis_3d(head_dim)
if wantodance_enable_dynamicfps or wantodance_enable_unimodel:
end = int(22350 / 8 + 0.5) # 149f * 30fps * 5s = 22350
self.freqs = precompute_freqs_cis_3d(head_dim, end=end)
else:
self.freqs = precompute_freqs_cis_3d(head_dim)
if has_image_input:
self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
@@ -350,8 +414,83 @@ class WanModel(torch.nn.Module):
else:
self.control_adapter = None
def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None):
x = self.patch_embedding(x)
self.prepare_wantodance(in_dim, dim, num_heads, has_image_pos_emb, out_dim, patch_size, eps,
wantodance_enable_music_inject, wantodance_music_inject_layers, wantodance_enable_refimage, wantodance_enable_refface,
wantodance_enable_global, wantodance_enable_dynamicfps, wantodance_enable_unimodel)
def prepare_wantodance(
self,
in_dim, dim, num_heads, has_image_pos_emb, out_dim, patch_size, eps,
wantodance_enable_music_inject: bool = False,
wantodance_music_inject_layers = [0, 4, 8, 12, 16, 20, 24, 27],
wantodance_enable_refimage: bool = False,
wantodance_enable_refface: bool = False,
wantodance_enable_global: bool = False,
wantodance_enable_dynamicfps: bool = False,
wantodance_enable_unimodel: bool = False,
):
if wantodance_enable_music_inject:
all_modules, all_modules_names = wantodance_torch_dfs(self.blocks, parent_name="root.transformer_blocks")
self.music_injector = WanToDanceInjector(all_modules, all_modules_names, dim=dim, num_heads=num_heads, inject_layer=wantodance_music_inject_layers)
if wantodance_enable_refimage:
self.img_emb_refimage = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
if wantodance_enable_refface:
self.img_emb_refface = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280
if wantodance_enable_global or wantodance_enable_dynamicfps or wantodance_enable_unimodel:
music_feature_dim = 35
ff_size = 1024
dropout = 0.1
latent_dim = 256
nhead = 4
activation = F.gelu
rotary = WanToDanceRotaryEmbedding(dim=latent_dim)
self.music_projection = nn.Linear(music_feature_dim, latent_dim)
self.music_encoder = nn.Sequential()
for _ in range(2):
self.music_encoder.append(
WanToDanceMusicEncoderLayer(
d_model=latent_dim,
nhead=nhead,
dim_feedforward=ff_size,
dropout=dropout,
activation=activation,
batch_first=True,
rotary=rotary,
device='cuda',
)
)
if wantodance_enable_unimodel:
self.patch_embedding_global = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size)
if wantodance_enable_unimodel:
self.head_global = Head(dim, out_dim, patch_size, eps)
self.wantodance_enable_music_inject = wantodance_enable_music_inject
self.wantodance_enable_refimage = wantodance_enable_refimage
self.wantodance_enable_refface = wantodance_enable_refface
self.wantodance_enable_global = wantodance_enable_global
self.wantodance_enable_dynamicfps = wantodance_enable_dynamicfps
self.wantodance_enable_unimodel = wantodance_enable_unimodel
def wantodance_after_transformer_block(self, block_idx, hidden_states):
if self.wantodance_enable_music_inject:
if block_idx in self.music_injector.injected_block_id.keys():
audio_attn_id = self.music_injector.injected_block_id[block_idx]
audio_emb = self.merged_audio_emb # b f n c
num_frames = audio_emb.shape[1]
input_hidden_states = hidden_states.clone() # b (f h w) c
input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames)
attn_hidden_states = self.music_injector.injector_pre_norm_feat[audio_attn_id](input_hidden_states)
audio_emb = rearrange(audio_emb, "b t c -> (b t) 1 c", t=num_frames)
attn_audio_emb = audio_emb
residual_out = self.music_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb)
residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
hidden_states = hidden_states + residual_out
return hidden_states
def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None, enable_wantodance_global=False):
if enable_wantodance_global:
x = self.patch_embedding_global(x)
else:
x = self.patch_embedding(x)
if self.control_adapter is not None and control_camera_latents_input is not None:
y_camera = self.control_adapter(control_camera_latents_input)
x = [u + v for u, v in zip(x, y_camera)]

View File

@@ -1247,6 +1247,22 @@ class WanVideoVAE(nn.Module):
return videos
def encode_framewise(self, videos, device):
hidden_states = []
for i in range(videos.shape[2]):
hidden_states.append(self.single_encode(videos[:, :, i:i+1], device))
hidden_states = torch.concat(hidden_states, dim=2)
return hidden_states
def decode_framewise(self, hidden_states, device):
video = []
for i in range(hidden_states.shape[2]):
video.append(self.single_decode(hidden_states[:, :, i:i+1], device))
video = torch.concat(video, dim=2)
return video
@staticmethod
def state_dict_converter():
return WanVideoVAEStateDictConverter()

View File

@@ -0,0 +1,209 @@
from inspect import isfunction
from math import log, pi
import torch
from einops import rearrange, repeat
from torch import einsum, nn
from typing import Any, Callable, List, Optional, Union
from torch import Tensor
import torch.nn.functional as F
# helper functions
def exists(val):
return val is not None
def broadcat(tensors, dim=-1):
num_tensors = len(tensors)
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
shape_len = list(shape_lens)[0]
dim = (dim + shape_len) if dim < 0 else dim
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
assert all(
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
), "invalid dimensions for broadcastable concatentation"
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
expanded_dims.insert(dim, (dim, dims[dim]))
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
return torch.cat(tensors, dim=dim)
# rotary embedding helper functions
def rotate_half(x):
x = rearrange(x, "... (d r) -> ... d r", r=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d r -> ... (d r)")
def apply_rotary_emb(freqs, t, start_index=0):
freqs = freqs.to(t)
rot_dim = freqs.shape[-1]
end_index = start_index + rot_dim
assert (
rot_dim <= t.shape[-1]
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
t_left, t, t_right = (
t[..., :start_index],
t[..., start_index:end_index],
t[..., end_index:],
)
t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin())
return torch.cat((t_left, t, t_right), dim=-1)
# learned rotation helpers
def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None):
if exists(freq_ranges):
rotations = einsum("..., f -> ... f", rotations, freq_ranges)
rotations = rearrange(rotations, "... r f -> ... (r f)")
rotations = repeat(rotations, "... n -> ... (n r)", r=2)
return apply_rotary_emb(rotations, t, start_index=start_index)
# classes
class WanToDanceRotaryEmbedding(nn.Module):
def __init__(
self,
dim,
custom_freqs=None,
freqs_for="lang",
theta=10000,
max_freq=10,
num_freqs=1,
learned_freq=False,
):
super().__init__()
if exists(custom_freqs):
freqs = custom_freqs
elif freqs_for == "lang":
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
)
elif freqs_for == "pixel":
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
elif freqs_for == "constant":
freqs = torch.ones(num_freqs).float()
else:
raise ValueError(f"unknown modality {freqs_for}")
self.cache = dict()
if learned_freq:
self.freqs = nn.Parameter(freqs)
else:
self.register_buffer("freqs", freqs, persistent=False)
def rotate_queries_or_keys(self, t, seq_dim=-2):
device = t.device
seq_len = t.shape[seq_dim]
freqs = self.forward(
lambda: torch.arange(seq_len, device=device), cache_key=seq_len
)
return apply_rotary_emb(freqs, t)
def forward(self, t, cache_key=None):
if exists(cache_key) and cache_key in self.cache:
return self.cache[cache_key]
if isfunction(t):
t = t()
# freqs = self.freqs
freqs = self.freqs.to(t.device)
freqs = torch.einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
if exists(cache_key):
self.cache[cache_key] = freqs
return freqs
class WanToDanceMusicEncoderLayer(nn.Module):
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
layer_norm_eps: float = 1e-5,
batch_first: bool = False,
norm_first: bool = True,
device=None,
dtype=None,
rotary=None,
) -> None:
super().__init__()
self.self_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first, device=device, dtype=dtype
)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm_first = norm_first
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = activation
self.rotary = rotary
self.use_rotary = rotary is not None
# self-attention block
def _sa_block(
self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]
) -> Tensor:
qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
x = self.self_attn(
qk,
qk,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
return self.dropout1(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
def forward(
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
x = src
if self.norm_first:
self.norm1.to(device=x.device)
self.norm2.to(device=x.device)
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
x = x + self._ff_block(self.norm2(x))
else:
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
x = self.norm2(x + self._ff_block(x))
return x

View File

@@ -75,6 +75,9 @@ class WanVideoPipeline(BasePipeline):
WanVideoUnit_TeaCache(),
WanVideoUnit_CfgMerger(),
WanVideoUnit_LongCatVideo(),
WanVideoUnit_WanToDance_ProcessInputs(),
WanVideoUnit_WanToDance_RefImageEmbedder(),
WanVideoUnit_WanToDance_ImageKeyframesEmbedder(),
]
self.post_units = [
WanVideoPostUnit_S2V(),
@@ -244,6 +247,13 @@ class WanVideoPipeline(BasePipeline):
# Teacache
tea_cache_l1_thresh: Optional[float] = None,
tea_cache_model_id: Optional[str] = "",
# WanToDance
wantodance_music_path: Optional[str] = None,
wantodance_reference_image: Optional[Image.Image] = None,
wantodance_fps: Optional[float] = 30,
wantodance_keyframes: Optional[list[Image.Image]] = None,
wantodance_keyframes_mask: Optional[list[int]] = None,
framewise_decoding: bool = False,
# progress_bar
progress_bar_cmd=tqdm,
output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized",
@@ -280,6 +290,9 @@ class WanVideoPipeline(BasePipeline):
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video,
"vap_video": vap_video,
"wantodance_music_path": wantodance_music_path, "wantodance_reference_image": wantodance_reference_image, "wantodance_fps": wantodance_fps,
"wantodance_keyframes": wantodance_keyframes, "wantodance_keyframes_mask": wantodance_keyframes_mask,
"framewise_decoding": framewise_decoding,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -325,7 +338,10 @@ class WanVideoPipeline(BasePipeline):
inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Decode
self.load_models_to_device(['vae'])
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if framewise_decoding:
video = self.vae.decode_framewise(inputs_shared["latents"], device=self.device)
else:
video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
if output_type == "quantized":
video = self.vae_output_to_video(video)
elif output_type == "floatpoint":
@@ -371,17 +387,20 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit):
class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"),
input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image", "framewise_decoding"),
output_params=("latents", "input_latents"),
onload_model_names=("vae",)
)
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image):
def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image, framewise_decoding):
if input_video is None:
return {"latents": noise}
pipe.load_models_to_device(self.onload_model_names)
input_video = pipe.preprocess_video(input_video)
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
if framewise_decoding:
input_latents = pipe.vae.encode_framewise(input_video, device=pipe.device)
else:
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
if vace_reference_image is not None:
if not isinstance(vace_reference_image, list):
vace_reference_image = [vace_reference_image]
@@ -1018,6 +1037,111 @@ class WanVideoUnit_LongCatVideo(PipelineUnit):
return {"longcat_latents": longcat_latents}
class WanVideoUnit_WanToDance_ProcessInputs(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
)
def get_music_base_feature(self, music_path, fps=30):
import librosa
hop_length = 512
sr = fps * hop_length
data, sr = librosa.load(music_path, sr=sr)
sr = 22050
envelope = librosa.onset.onset_strength(y=data, sr=sr)
mfcc = librosa.feature.mfcc(y=data, sr=sr, n_mfcc=20).T
chroma = librosa.feature.chroma_cens(
y=data, sr=sr, hop_length=hop_length, n_chroma=12
).T
peak_idxs = librosa.onset.onset_detect(
onset_envelope=envelope.flatten(), sr=sr, hop_length=hop_length
)
peak_onehot = np.zeros_like(envelope, dtype=np.float32)
peak_onehot[peak_idxs] = 1.0
start_bpm = librosa.beat.tempo(y=librosa.load(music_path)[0])[0]
_, beat_idxs = librosa.beat.beat_track(
onset_envelope=envelope,
sr=sr,
hop_length=hop_length,
start_bpm=start_bpm,
tightness=100,
)
beat_onehot = np.zeros_like(envelope, dtype=np.float32)
beat_onehot[beat_idxs] = 1.0
audio_feature = np.concatenate(
[envelope[:, None], mfcc, chroma, peak_onehot[:, None], beat_onehot[:, None]],
axis=-1,
)
return torch.from_numpy(audio_feature)
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
if pipe.dit.wantodance_enable_global:
inputs_nega["skip_9th_layer"] = True
if inputs_shared.get("wantodance_music_path", None) is not None:
inputs_shared["music_feature"] = self.get_music_base_feature(inputs_shared["wantodance_music_path"]).to(dtype=pipe.torch_dtype, device=pipe.device)
return inputs_shared, inputs_posi, inputs_nega
class WanVideoUnit_WanToDance_RefImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("wantodance_reference_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
output_params=("wantodance_refimage_feature",),
onload_model_names=("image_encoder", "vae")
)
def process(self, pipe: WanVideoPipeline, wantodance_reference_image, num_frames, height, width, tiled, tile_size, tile_stride):
if wantodance_reference_image is None:
return {}
pipe.load_models_to_device(self.onload_model_names)
if isinstance(wantodance_reference_image, list):
wantodance_reference_image = wantodance_reference_image[0]
image = pipe.preprocess_image(wantodance_reference_image.resize((width, height))).to(pipe.device) # B,C,H,W;B=1
refimage_feature = pipe.image_encoder.encode_image([image])
refimage_feature = refimage_feature.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"wantodance_refimage_feature": refimage_feature}
class WanVideoUnit_WanToDance_ImageKeyframesEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("wantodance_keyframes", "wantodance_keyframes_mask", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"),
output_params=("clip_feature", "y"),
onload_model_names=("image_encoder", "vae")
)
def process(self, pipe: WanVideoPipeline, wantodance_keyframes, wantodance_keyframes_mask, num_frames, height, width, tiled, tile_size, tile_stride):
if wantodance_keyframes is None:
return {}
wantodance_keyframes_mask = torch.tensor(wantodance_keyframes_mask)
pipe.load_models_to_device(self.onload_model_names)
images = []
for input_image in wantodance_keyframes:
input_image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
images.append(input_image)
clip_context = pipe.image_encoder.encode_image(images[:1]) # 取第一帧作为clip输入
msk = torch.zeros(1, num_frames, height//8, width//8, device=pipe.device)
msk[:, wantodance_keyframes_mask==1, :, :] = torch.ones(1, height//8, width//8, device=pipe.device) # set keyframes mask to 1
images = [image.transpose(0, 1) for image in images] # 3, num_frames, h, w
images = torch.concat(images, dim=1)
vae_input = images
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) # expand first frame mask, N to N + 3
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
y = torch.concat([msk, y])
y = y.unsqueeze(0)
clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device)
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
return {"clip_feature": clip_context, "y": y}
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
self.num_inference_steps = num_inference_steps
@@ -1123,6 +1247,22 @@ class TemporalTiler_BCTHW:
return value
def wantodance_get_single_freqs(freqs, frame_num, fps):
total_frame = int(30.0 / (fps + 1e-6) * frame_num + 0.5)
interval_frame = 30.0 / (fps + 1e-6)
freqs_0 = freqs[:total_frame]
freqs_new = torch.zeros((frame_num, freqs_0.shape[1]), device=freqs_0.device, dtype=freqs_0.dtype)
freqs_new[0] = freqs_0[0]
freqs_new[-1] = freqs_0[total_frame - 1]
for i in range(1, frame_num-1):
pos = i * interval_frame
low_idx = int(pos)
high_idx = min(low_idx + 1, total_frame - 1)
weight_high = pos - low_idx
weight_low = 1.0 - weight_high
freqs_new[i] = freqs_0[low_idx] * weight_low + freqs_0[high_idx] * weight_high
return freqs_new
def model_fn_wan_video(
dit: WanModel,
@@ -1158,6 +1298,10 @@ def model_fn_wan_video(
use_gradient_checkpointing_offload: bool = False,
control_camera_latents_input = None,
fuse_vae_embedding_in_latents: bool = False,
wantodance_refimage_feature = None,
wantodance_fps: float = 30.0,
music_feature = None,
skip_9th_layer: bool = False,
**kwargs,
):
if sliding_window_size is not None and sliding_window_stride is not None:
@@ -1255,7 +1399,10 @@ def model_fn_wan_video(
context = torch.cat([clip_embdding, context], dim=1)
# Camera control
x = dit.patchify(x, control_camera_latents_input)
if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global and int(wantodance_fps + 0.5) != 30:
x = dit.patchify(x, control_camera_latents_input, enable_wantodance_global=True)
else:
x = dit.patchify(x, control_camera_latents_input)
# Animate
if pose_latents is not None and face_pixel_values is not None:
@@ -1310,7 +1457,61 @@ def model_fn_wan_video(
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload
)
# WanToDance
if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global:
if wantodance_refimage_feature is not None:
refimage_feature_embedding = dit.img_emb_refimage(wantodance_refimage_feature)
context = torch.cat([refimage_feature_embedding, context], dim=1)
if (dit.wantodance_enable_dynamicfps or dit.wantodance_enable_unimodel) and int(wantodance_fps + 0.5) != 30:
freqs_0 = wantodance_get_single_freqs(dit.freqs[0], f, wantodance_fps)
freqs = torch.cat([
freqs_0.view(f, 1, 1, -1).expand(f, h, w, -1),
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
if dit.wantodance_enable_global or dit.wantodance_enable_dynamicfps or dit.wantodance_enable_unimodel:
if use_unified_sequence_parallel:
length = int(float(music_feature.shape[0]) / get_sequence_parallel_world_size()) * get_sequence_parallel_world_size()
music_feature = music_feature[:length]
music_feature = torch.chunk(music_feature, get_sequence_parallel_world_size(), dim=0)[get_sequence_parallel_rank()]
if not dit.training:
dit.music_encoder.to(x.device, dtype=x.dtype) # only evaluation
music_feature = music_feature.to(x.device, dtype=x.dtype)
music_feature = dit.music_projection(music_feature)
music_feature = dit.music_encoder(music_feature)
if music_feature.dim() == 2:
music_feature = music_feature.unsqueeze(0)
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
music_feature = get_sp_group().all_gather(music_feature, dim=1)
music_feature = music_feature.unsqueeze(1) # [1, 1, 149, 4800]
N = 149
M = 4800
music_feature = torch.nn.functional.interpolate(music_feature, size=(N, M), mode='bilinear')
music_feature = music_feature.squeeze(1) # shape: [1, 149, 4800]
if music_feature is not None:
if music_feature.dim() == 2:
music_feature = music_feature.unsqueeze(0)
music_feature = music_feature.to(x.device, dtype=x.dtype)
interp_mode = 'bilinear'
if interp_mode == 'bilinear':
frame_num = latents.shape[2] if len(latents.shape) == 5 else latents.shape[1] # 21
context_shape_end = context.shape[2] ## 14B 5120
music_feature = music_feature.unsqueeze(1) # shape: [1, 1, 149, 4800]
if use_unified_sequence_parallel:
N = int(float(frame_num * 8) / get_sequence_parallel_world_size()) * get_sequence_parallel_world_size()
else:
N = frame_num * 8
music_feature = torch.nn.functional.interpolate(music_feature, size=(N, context_shape_end), mode='bilinear')
music_feature = music_feature.squeeze(1) # shape: [1, N, context_shape_end]
if use_unified_sequence_parallel:
dit.merged_audio_emb = torch.chunk(music_feature, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
else:
dit.merged_audio_emb = music_feature
else:
dit.merged_audio_emb = music_feature
# blocks
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
@@ -1326,8 +1527,12 @@ def model_fn_wan_video(
return vap(block, *inputs)
return custom_forward
# Block
for block_id, block in enumerate(dit.blocks):
# Block
if skip_9th_layer:
# This is only used in WanToDance
if block_id == 9:
continue
if vap is not None and block_id in vap.mot_layers_mapping:
if use_gradient_checkpointing_offload:
with torch.autograd.graph.save_on_cpu():
@@ -1364,10 +1569,18 @@ def model_fn_wan_video(
# Animate
if pose_latents is not None and face_pixel_values is not None:
x = animate_adapter.after_transformer_block(block_id, x, motion_vec)
# WanToDance
if hasattr(dit, "wantodance_enable_music_inject") and dit.wantodance_enable_music_inject:
x = dit.wantodance_after_transformer_block(block_id, x)
if tea_cache is not None:
tea_cache.store(x)
x = dit.head(x, t)
if hasattr(dit, "wantodance_enable_unimodel") and dit.wantodance_enable_unimodel and int(wantodance_fps + 0.5) != 30:
x = dit.head_global(x, t)
else:
x = dit.head(x, t)
if use_unified_sequence_parallel:
if dist.is_initialized() and dist.get_world_size() > 1:
x = get_sp_group().all_gather(x, dim=1)

View File

@@ -139,6 +139,8 @@ graph LR;
| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) |
| [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py) |
| [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py) |
| [Wan-AI/WanToDance-14B (global model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B) | `wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-global.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-global.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py) |
| [Wan-AI/WanToDance-14B (local model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B) | `wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-local.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-local.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-local.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-local.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-local.py) |
* FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/)
* Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/)
@@ -203,6 +205,50 @@ Input parameters for `WanVideoPipeline` inference include:
If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
### Multi-GPU Parallel Acceleration
To enable multi-GPU parallel acceleration, please install `flash_attn` and `xfuser`:
```shell
pip install flash-attn --no-build-isolation
pip install xfuser
```
Please modify your code as follows ([example code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/acceleration/unified_sequence_parallel.py)):
```diff
import torch
from PIL import Image
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
+ import torch.distributed as dist
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
+ use_usp=True,
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
)
video = pipe(
prompt="An astronaut in a spacesuit rides a mechanical horse across the Martian surface, facing the camera. The red, desolate terrain stretches into the distance, dotted with massive craters and unusual rock formations. The mechanical horse moves with steady strides, kicking up faint dust, embodying a perfect fusion of futuristic technology and primal exploration. The astronaut holds a control device, with a determined gaze, as if pioneering new frontiers for humanity. Against a backdrop of the deep cosmos and the blue Earth, the scene is both sci-fi and hopeful, evoking imagination about future interstellar life.",
negative_prompt="oversaturated colors, overexposed, static, blurry details, subtitles, style, artwork, painting, still image, overall gray tone, worst quality, low quality, JPEG compression artifacts, ugly, malformed, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, malformed limbs, fused fingers, frozen frame, cluttered background, three legs, crowd in background, walking backwards",
seed=0, tiled=True,
)
+ if dist.get_rank() == 0:
+ save_video(video, "video1.mp4", fps=15, quality=5)
```
When running multi-GPU parallel inference, please use `torchrun`, where `--nproc_per_node` specifies the number of GPUs:
```shell
torchrun --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py
```
## Model Training
Wan series models are uniformly trained through [`examples/wanvideo/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/train.py), and the script parameters include:

View File

@@ -140,6 +140,8 @@ graph LR;
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
| [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py) |
| [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py) |
| [Wan-AI/WanToDance-14B (global model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B) | `wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-global.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-global.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py) |
| [Wan-AI/WanToDance-14B (local model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B) | `wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-local.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-local.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-local.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-local.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-local.py) |
* FP8 精度训练:[doc](../Training/FP8_Precision.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/)
* 两阶段拆分训练:[doc](../Training/Split_Training.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/)
@@ -204,6 +206,50 @@ DeepSpeed ZeRO 3 训练Wan 系列模型支持 DeepSpeed ZeRO 3 训练,将
如果显存不足,请开启[显存管理](../Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。
### 多卡并行加速
如需开启多卡并行加速,请先安装 `flash_attn``xfuser`
```shell
pip install flash-attn --no-build-isolation
pip install xfuser
```
对代码进行如下修改([样例代码](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/acceleration/unified_sequence_parallel.py)
```diff
import torch
from PIL import Image
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
+ import torch.distributed as dist
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
+ use_usp=True,
model_configs=[
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"),
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
)
video = pipe(
prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=True,
)
+ if dist.get_rank() == 0:
+ save_video(video, "video1.mp4", fps=15, quality=5)
```
运行多卡并行推理时,请使用 `torchrun` 运行,其中 `--nproc_per_node` 为 GPU 数量:
```shell
torchrun --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py
```
## 模型训练
Wan 系列模型统一通过 [`examples/wanvideo/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/train.py) 进行训练,脚本的参数包括:

View File

@@ -0,0 +1,48 @@
import torch
from PIL import Image
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="global_model.safetensors"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="Wan2.1_VAE.pth"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
)
dataset_snapshot_download(
"DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="wanvideo/WanToDance-14B-global/*"
)
# This is a specialized model with the following constraints on its input parameters:
# * The model outputs a sequence of keyframes rather than a video; therefore, `framewise_decoding=True` must be set.
# * When the number of keyframes is $n$, `num_frames` = 4 * (n - 1) + 1.
# * Reducing `height`, `width`, `num_frames`, or `num_inference_steps` may lead to severe artifacts or generation failure.
# * The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 7.5) seconds.
# * The width and height of `wantodance_reference_image` must be multiples of 16.
# * `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 7.5 FPS, setting it to other values is not recommended.
# * The first frame of `wantodance_keyframes` is the `wantodance_reference_image`, while all subsequent frames are solid black.
# * `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`.
wantodance_keyframes = VideoData("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/keyframes.mp4")
wantodance_keyframes = [wantodance_keyframes[i] for i in range(149)]
video = pipe(
prompt="一个人正在跳舞舞蹈种类是韩舞。帧率是7.5000",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=False,
height=1280, width=720, num_frames=149,
num_inference_steps=48,
wantodance_music_path="data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/music.WAV",
wantodance_reference_image=Image.open("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/refimage.jpg"),
wantodance_fps=7.5,
wantodance_keyframes=wantodance_keyframes,
wantodance_keyframes_mask=[1] + [0] * 148,
framewise_decoding=True,
)
save_video(video, "video_WanToDance-14B-global.mp4", fps=7.5, quality=5)

View File

@@ -0,0 +1,52 @@
import torch, os
from PIL import Image
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="local_model.safetensors"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="Wan2.1_VAE.pth"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
)
dataset_snapshot_download(
"DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="wanvideo/WanToDance-14B-local/*"
)
# This is a specialized model with the following constraints on its input parameters:
# * The model renders and outputs video based on a sequence of keyframes; therefore, `wantodance_keyframes` must be provided correctly.
# * If you need to generate a long video, please generate it in segments, and ensure that `wantodance_music_path`, `wantodance_keyframes`, and `wantodance_keyframes_mask` are properly split accordingly.
# * The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 30) seconds.
# * The width and height of `wantodance_reference_image` must be multiples of 16.
# * `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 30 FPS, setting it to other values is not recommended.
# * In `wantodance_keyframes`, frames that are not keyframes should be solid black.
# * `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`.
wantodance_keyframes = VideoData("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/keyframes.mp4")
wantodance_keyframes = [wantodance_keyframes[i] for i in range(149)]
video = pipe(
prompt="一个人正在跳舞,舞蹈种类是古典舞,图像清晰程度高,人物动作平均幅度中等,人物动作最大幅度中等。, 帧率是30fps。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=True,
height=1280, width=720, num_frames=149,
num_inference_steps=24,
wantodance_music_path="data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/music.wav",
wantodance_reference_image=Image.open("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/refimage.jpg"),
wantodance_fps=30,
wantodance_keyframes=wantodance_keyframes,
wantodance_keyframes_mask=[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1],
)
save_video(video, "video_WanToDance-14B-local.mp4", fps=30, quality=5)

View File

@@ -0,0 +1,59 @@
import torch
from PIL import Image
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="global_model.safetensors", **vram_config),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
)
dataset_snapshot_download(
"DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="wanvideo/WanToDance-14B-global/*"
)
# This is a specialized model with the following constraints on its input parameters:
# * The model outputs a sequence of keyframes rather than a video; therefore, `framewise_decoding=True` must be set.
# * When the number of keyframes is $n$, `num_frames` = 4 * (n - 1) + 1.
# * Reducing `height`, `width`, `num_frames`, or `num_inference_steps` may lead to severe artifacts or generation failure.
# * The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 7.5) seconds.
# * The width and height of `wantodance_reference_image` must be multiples of 16.
# * `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 7.5 FPS, setting it to other values is not recommended.
# * The first frame of `wantodance_keyframes` is the `wantodance_reference_image`, while all subsequent frames are solid black.
# * `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`.
wantodance_keyframes = VideoData("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/keyframes.mp4")
wantodance_keyframes = [wantodance_keyframes[i] for i in range(149)]
video = pipe(
prompt="一个人正在跳舞舞蹈种类是韩舞。帧率是7.5000",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=False,
height=1280, width=720, num_frames=149,
num_inference_steps=48,
wantodance_music_path="data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/music.WAV",
wantodance_reference_image=Image.open("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/refimage.jpg"),
wantodance_fps=7.5,
wantodance_keyframes=wantodance_keyframes,
wantodance_keyframes_mask=[1] + [0] * 148,
framewise_decoding=True,
)
save_video(video, "video_WanToDance-14B-global.mp4", fps=7.5, quality=5)

View File

@@ -0,0 +1,63 @@
import torch, os
from PIL import Image
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
vram_config = {
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
}
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="local_model.safetensors", **vram_config),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", **vram_config),
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
)
dataset_snapshot_download(
"DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="wanvideo/WanToDance-14B-local/*"
)
# This is a specialized model with the following constraints on its input parameters:
# * The model renders and outputs video based on a sequence of keyframes; therefore, `wantodance_keyframes` must be provided correctly.
# * If you need to generate a long video, please generate it in segments, and ensure that `wantodance_music_path`, `wantodance_keyframes`, and `wantodance_keyframes_mask` are properly split accordingly.
# * The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 30) seconds.
# * The width and height of `wantodance_reference_image` must be multiples of 16.
# * `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 30 FPS, setting it to other values is not recommended.
# * In `wantodance_keyframes`, frames that are not keyframes should be solid black.
# * `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`.
wantodance_keyframes = VideoData("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/keyframes.mp4")
wantodance_keyframes = [wantodance_keyframes[i] for i in range(149)]
video = pipe(
prompt="一个人正在跳舞,舞蹈种类是古典舞,图像清晰程度高,人物动作平均幅度中等,人物动作最大幅度中等。, 帧率是30fps。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=True,
height=1280, width=720, num_frames=149,
num_inference_steps=24,
wantodance_music_path="data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/music.wav",
wantodance_reference_image=Image.open("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/refimage.jpg"),
wantodance_fps=30,
wantodance_keyframes=wantodance_keyframes,
wantodance_keyframes_mask=[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1],
)
save_video(video, "video_WanToDance-14B-local.mp4", fps=30, quality=5)

View File

@@ -0,0 +1,20 @@
# 8*H200 required
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "wanvideo/WanToDance-14B-global/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global \
--dataset_metadata_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/metadata.json \
--data_file_keys "video,wantodance_reference_image,wantodance_keyframes,wantodance_music_path" \
--height 1280 \
--width 720 \
--num_frames 149 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/WanToDance-14B:global_model.safetensors,Wan-AI/WanToDance-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/WanToDance-14B:Wan2.1_VAE.pth,Wan-AI/WanToDance-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/WanToDance-14B-global_full" \
--trainable_models "dit" \
--extra_inputs "wantodance_music_path,wantodance_reference_image,wantodance_fps,wantodance_keyframes,wantodance_keyframes_mask,framewise_decoding" \
--use_gradient_checkpointing_offload \
--framewise_decoding

View File

@@ -0,0 +1,19 @@
# 8*H200 required
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "wanvideo/WanToDance-14B-local/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local \
--dataset_metadata_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/metadata.json \
--data_file_keys "video,wantodance_reference_image,wantodance_keyframes,wantodance_music_path" \
--height 1280 \
--width 720 \
--num_frames 149 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/WanToDance-14B:local_model.safetensors,Wan-AI/WanToDance-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/WanToDance-14B:Wan2.1_VAE.pth,Wan-AI/WanToDance-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/WanToDance-14B-local_full" \
--trainable_models "dit" \
--extra_inputs "wantodance_music_path,wantodance_reference_image,wantodance_fps,wantodance_keyframes,wantodance_keyframes_mask" \
--use_gradient_checkpointing_offload

View File

@@ -0,0 +1,22 @@
# 8*H200 required
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "wanvideo/WanToDance-14B-global/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global \
--dataset_metadata_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/metadata.json \
--data_file_keys "video,wantodance_reference_image,wantodance_keyframes,wantodance_music_path" \
--height 1280 \
--width 720 \
--num_frames 149 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/WanToDance-14B:global_model.safetensors,Wan-AI/WanToDance-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/WanToDance-14B:Wan2.1_VAE.pth,Wan-AI/WanToDance-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/WanToDance-14B-global_lora" \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--extra_inputs "wantodance_music_path,wantodance_reference_image,wantodance_fps,wantodance_keyframes,wantodance_keyframes_mask,framewise_decoding" \
--use_gradient_checkpointing_offload \
--framewise_decoding

View File

@@ -0,0 +1,21 @@
# 8*H200 required
modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "wanvideo/WanToDance-14B-local/*" --local_dir ./data/diffsynth_example_dataset
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
--dataset_base_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local \
--dataset_metadata_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/metadata.json \
--data_file_keys "video,wantodance_reference_image,wantodance_keyframes,wantodance_music_path" \
--height 1280 \
--width 720 \
--num_frames 149 \
--dataset_repeat 100 \
--model_id_with_origin_paths "Wan-AI/WanToDance-14B:local_model.safetensors,Wan-AI/WanToDance-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/WanToDance-14B:Wan2.1_VAE.pth,Wan-AI/WanToDance-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/WanToDance-14B-local_lora" \
--lora_base_model "dit" \
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
--lora_rank 32 \
--extra_inputs "wantodance_music_path,wantodance_reference_image,wantodance_fps,wantodance_keyframes,wantodance_keyframes_mask" \
--use_gradient_checkpointing_offload

View File

@@ -72,6 +72,9 @@ class WanTrainingModule(DiffusionTrainingModule):
inputs_shared[extra_input] = data[extra_input][0]
else:
inputs_shared[extra_input] = data[extra_input]
if inputs_shared.get("framewise_decoding", False):
# WanToDance global model
inputs_shared["num_frames"] = 4 * (len(data["video"]) - 1) + 1
return inputs_shared
def get_pipeline_inputs(self, data):
@@ -117,6 +120,7 @@ def wan_parser():
parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
parser.add_argument("--framewise_decoding", default=False, action="store_true", help="Enable it if this model is a WanToDance global model.")
return parser
@@ -140,12 +144,13 @@ if __name__ == "__main__":
height_division_factor=16,
width_division_factor=16,
num_frames=args.num_frames,
time_division_factor=4,
time_division_remainder=1,
time_division_factor=4 if not args.framewise_decoding else 1,
time_division_remainder=1 if not args.framewise_decoding else 0,
),
special_operator_map={
"animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)),
"input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudio(sr=16000),
"wantodance_music_path": ToAbsolutePath(args.dataset_base_path),
}
)
model = WanTrainingModule(

View File

@@ -0,0 +1,51 @@
import torch
from PIL import Image
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
from diffsynth.core import load_state_dict
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="global_model.safetensors"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="Wan2.1_VAE.pth"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
)
state_dict = load_state_dict("models/train/WanToDance-14B-global_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
dataset_snapshot_download(
"DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="wanvideo/WanToDance-14B-global/*"
)
# This is a specialized model with the following constraints on its input parameters:
# * The model outputs a sequence of keyframes rather than a video; therefore, `framewise_decoding=True` must be set.
# * When the number of keyframes is $n$, `num_frames` = 4 * (n - 1) + 1.
# * Reducing `height`, `width`, `num_frames`, or `num_inference_steps` may lead to severe artifacts or generation failure.
# * The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 7.5) seconds.
# * The width and height of `wantodance_reference_image` must be multiples of 16.
# * `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 7.5 FPS, setting it to other values is not recommended.
# * The first frame of `wantodance_keyframes` is the `wantodance_reference_image`, while all subsequent frames are solid black.
# * `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`.
wantodance_keyframes = VideoData("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/keyframes.mp4")
wantodance_keyframes = [wantodance_keyframes[i] for i in range(149)]
video = pipe(
prompt="一个人正在跳舞舞蹈种类是韩舞。帧率是7.5000",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=False,
height=1280, width=720, num_frames=149,
num_inference_steps=48,
wantodance_music_path="data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/music.WAV",
wantodance_reference_image=Image.open("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/refimage.jpg"),
wantodance_fps=7.5,
wantodance_keyframes=wantodance_keyframes,
wantodance_keyframes_mask=[1] + [0] * 148,
framewise_decoding=True,
)
save_video(video, "video_WanToDance-14B-global.mp4", fps=7.5, quality=5)

View File

@@ -0,0 +1,55 @@
import torch, os
from PIL import Image
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
from diffsynth.core import load_state_dict
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="local_model.safetensors"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="Wan2.1_VAE.pth"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
)
state_dict = load_state_dict("models/train/WanToDance-14B-local_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
dataset_snapshot_download(
"DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="wanvideo/WanToDance-14B-local/*"
)
# This is a specialized model with the following constraints on its input parameters:
# * The model renders and outputs video based on a sequence of keyframes; therefore, `wantodance_keyframes` must be provided correctly.
# * If you need to generate a long video, please generate it in segments, and ensure that `wantodance_music_path`, `wantodance_keyframes`, and `wantodance_keyframes_mask` are properly split accordingly.
# * The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 30) seconds.
# * The width and height of `wantodance_reference_image` must be multiples of 16.
# * `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 30 FPS, setting it to other values is not recommended.
# * In `wantodance_keyframes`, frames that are not keyframes should be solid black.
# * `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`.
wantodance_keyframes = VideoData("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/keyframes.mp4")
wantodance_keyframes = [wantodance_keyframes[i] for i in range(149)]
video = pipe(
prompt="一个人正在跳舞,舞蹈种类是古典舞,图像清晰程度高,人物动作平均幅度中等,人物动作最大幅度中等。, 帧率是30fps。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=True,
height=1280, width=720, num_frames=149,
num_inference_steps=24,
wantodance_music_path="data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/music.wav",
wantodance_reference_image=Image.open("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/refimage.jpg"),
wantodance_fps=30,
wantodance_keyframes=wantodance_keyframes,
wantodance_keyframes_mask=[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1],
)
save_video(video, "video_WanToDance-14B-local.mp4", fps=30, quality=5)

View File

@@ -0,0 +1,49 @@
import torch
from PIL import Image
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="global_model.safetensors"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="Wan2.1_VAE.pth"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
)
pipe.load_lora(pipe.dit, "models/train/WanToDance-14B-global_lora/epoch-4.safetensors", alpha=1)
dataset_snapshot_download(
"DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="wanvideo/WanToDance-14B-global/*"
)
# This is a specialized model with the following constraints on its input parameters:
# * The model outputs a sequence of keyframes rather than a video; therefore, `framewise_decoding=True` must be set.
# * When the number of keyframes is $n$, `num_frames` = 4 * (n - 1) + 1.
# * Reducing `height`, `width`, `num_frames`, or `num_inference_steps` may lead to severe artifacts or generation failure.
# * The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 7.5) seconds.
# * The width and height of `wantodance_reference_image` must be multiples of 16.
# * `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 7.5 FPS, setting it to other values is not recommended.
# * The first frame of `wantodance_keyframes` is the `wantodance_reference_image`, while all subsequent frames are solid black.
# * `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`.
wantodance_keyframes = VideoData("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/keyframes.mp4")
wantodance_keyframes = [wantodance_keyframes[i] for i in range(149)]
video = pipe(
prompt="一个人正在跳舞舞蹈种类是韩舞。帧率是7.5000",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=False,
height=1280, width=720, num_frames=149,
num_inference_steps=48,
wantodance_music_path="data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/music.WAV",
wantodance_reference_image=Image.open("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/refimage.jpg"),
wantodance_fps=7.5,
wantodance_keyframes=wantodance_keyframes,
wantodance_keyframes_mask=[1] + [0] * 148,
framewise_decoding=True,
)
save_video(video, "video_WanToDance-14B-global.mp4", fps=7.5, quality=5)

View File

@@ -0,0 +1,53 @@
import torch, os
from PIL import Image
from diffsynth.utils.data import save_video, VideoData
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
from modelscope import dataset_snapshot_download
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="local_model.safetensors"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="Wan2.1_VAE.pth"),
ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"),
],
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
)
pipe.load_lora(pipe.dit, "models/train/WanToDance-14B-global_lora/epoch-4.safetensors", alpha=1)
dataset_snapshot_download(
"DiffSynth-Studio/diffsynth_example_dataset",
local_dir="data/diffsynth_example_dataset",
allow_file_pattern="wanvideo/WanToDance-14B-local/*"
)
# This is a specialized model with the following constraints on its input parameters:
# * The model renders and outputs video based on a sequence of keyframes; therefore, `wantodance_keyframes` must be provided correctly.
# * If you need to generate a long video, please generate it in segments, and ensure that `wantodance_music_path`, `wantodance_keyframes`, and `wantodance_keyframes_mask` are properly split accordingly.
# * The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 30) seconds.
# * The width and height of `wantodance_reference_image` must be multiples of 16.
# * `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 30 FPS, setting it to other values is not recommended.
# * In `wantodance_keyframes`, frames that are not keyframes should be solid black.
# * `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`.
wantodance_keyframes = VideoData("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/keyframes.mp4")
wantodance_keyframes = [wantodance_keyframes[i] for i in range(149)]
video = pipe(
prompt="一个人正在跳舞,舞蹈种类是古典舞,图像清晰程度高,人物动作平均幅度中等,人物动作最大幅度中等。, 帧率是30fps。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
seed=0, tiled=True,
height=1280, width=720, num_frames=149,
num_inference_steps=24,
wantodance_music_path="data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/music.wav",
wantodance_reference_image=Image.open("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/refimage.jpg"),
wantodance_fps=30,
wantodance_keyframes=wantodance_keyframes,
wantodance_keyframes_mask=[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1],
)
save_video(video, "video_WanToDance-14B-local.mp4", fps=30, quality=5)