diff --git a/README.md b/README.md index b9f8ab0..f7ef6ce 100644 --- a/README.md +++ b/README.md @@ -870,6 +870,8 @@ Example code for Wan is available at: [/examples/wanvideo/](/examples/wanvideo/) |[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)| | [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) | `input_image` | [code](/examples/mova/model_inference/MOVA-360p-I2AV.py) | [code](/examples/mova/model_training/full/MOVA-360P-I2AV.sh) | [code](/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py) | [code](/examples/mova/model_training/lora/MOVA-360P-I2AV.sh) | [code](/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py) | | [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) | `input_image` | [code](/examples/mova/model_inference/MOVA-720p-I2AV.py) | [code](/examples/mova/model_training/full/MOVA-720P-I2AV.sh) | [code](/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py) | [code](/examples/mova/model_training/lora/MOVA-720P-I2AV.sh) | [code](/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py) | +| [Wan-AI/WanToDance-14B (global model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B) | `wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask` | [code](/examples/wanvideo/model_inference/WanToDance-14B-global.py) | [code](/examples/wanvideo/model_training/full/WanToDance-14B-global.sh) | [code](/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py) | [code](/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh) | [code](/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py) | +| [Wan-AI/WanToDance-14B (local model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B) | `wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask` | [code](/examples/wanvideo/model_inference/WanToDance-14B-global.py) | [code](/examples/wanvideo/model_training/full/WanToDance-14B-global.sh) | [code](/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py) | [code](/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh) | [code](/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py) | diff --git a/README_zh.md b/README_zh.md index 76b29d8..d000498 100644 --- a/README_zh.md +++ b/README_zh.md @@ -870,6 +870,8 @@ Wan 的示例代码位于:[/examples/wanvideo/](/examples/wanvideo/) |[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)| | [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) | `input_image` | [code](/examples/mova/model_inference/MOVA-360p-I2AV.py) | [code](/examples/mova/model_training/full/MOVA-360P-I2AV.sh) | [code](/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py) | [code](/examples/mova/model_training/lora/MOVA-360P-I2AV.sh) | [code](/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py) | | [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) | `input_image` | [code](/examples/mova/model_inference/MOVA-720p-I2AV.py) | [code](/examples/mova/model_training/full/MOVA-720P-I2AV.sh) | [code](/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py) | [code](/examples/mova/model_training/lora/MOVA-720P-I2AV.sh) | [code](/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py) | +| [Wan-AI/WanToDance-14B (global model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B) | `wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask` | [code](/examples/wanvideo/model_inference/WanToDance-14B-global.py) | [code](/examples/wanvideo/model_training/full/WanToDance-14B-global.sh) | [code](/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py) | [code](/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh) | [code](/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py) | +| [Wan-AI/WanToDance-14B (local model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B) | `wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask` | [code](/examples/wanvideo/model_inference/WanToDance-14B-global.py) | [code](/examples/wanvideo/model_training/full/WanToDance-14B-global.sh) | [code](/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py) | [code](/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh) | [code](/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py) | diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 2bb5747..9593f0b 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -307,6 +307,13 @@ wan_series = [ "model_class": "diffsynth.models.wav2vec.WanS2VAudioEncoder", "state_dict_converter": "diffsynth.utils.state_dict_converters.wans2v_audio_encoder.WanS2VAudioEncoderStateDictConverter", }, + { + # Example: ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="global_model.safetensors") + "model_hash": "eb18873fc0ba77b541eb7b62dbcd2059", + "model_name": "wan_video_dit", + "model_class": "diffsynth.models.wan_video_dit.WanModel", + "extra_kwargs": {'has_image_input': True, 'patch_size': [1, 2, 2], 'in_dim': 36, 'dim': 5120, 'ffn_dim': 13824, 'freq_dim': 256, 'text_dim': 4096, 'out_dim': 16, 'num_heads': 40, 'num_layers': 40, 'eps': 1e-06, 'wantodance_enable_music_inject': True, 'wantodance_music_inject_layers': [0, 4, 8, 12, 16, 20, 24, 27], 'wantodance_enable_refimage': True, 'has_ref_conv': True, 'wantodance_enable_refface': False, 'wantodance_enable_global': True, 'wantodance_enable_dynamicfps': True, 'wantodance_enable_unimodel': True} + }, ] flux_series = [ diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 25ddb92..7e5cec6 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -6,6 +6,7 @@ from typing import Tuple, Optional from einops import rearrange from .wan_video_camera_controller import SimpleAdapter from ..core.gradient import gradient_checkpoint_forward +from .wantodance import WanToDanceRotaryEmbedding, WanToDanceMusicEncoderLayer try: import flash_attn_interface @@ -283,6 +284,57 @@ class Head(nn.Module): return x +def wantodance_torch_dfs(model: nn.Module, parent_name='root'): + module_names, modules = [], [] + current_name = parent_name if parent_name else 'root' + module_names.append(current_name) + modules.append(model) + for name, child in model.named_children(): + if parent_name: + child_name = f'{parent_name}.{name}' + else: + child_name = name + child_modules, child_names = wantodance_torch_dfs(child, child_name) + module_names += child_names + modules += child_modules + return modules, module_names + + +class WanToDanceInjector(nn.Module): + def __init__(self, all_modules, all_modules_names, dim=2048, num_heads=32, inject_layer=[0, 27]): + super().__init__() + self.injected_block_id = {} + injector_id = 0 + for mod_name, mod in zip(all_modules_names, all_modules): + if isinstance(mod, DiTBlock): + for inject_id in inject_layer: + if f'root.transformer_blocks.{inject_id}' == mod_name: + self.injected_block_id[inject_id] = injector_id + injector_id += 1 + + self.injector = nn.ModuleList( + [ + CrossAttention( + dim=dim, + num_heads=num_heads, + ) + for _ in range(injector_id) + ] + ) + self.injector_pre_norm_feat = nn.ModuleList( + [ + nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6,) + for _ in range(injector_id) + ] + ) + self.injector_pre_norm_vec = nn.ModuleList( + [ + nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6,) + for _ in range(injector_id) + ] + ) + + class WanModel(torch.nn.Module): def __init__( self, @@ -305,6 +357,13 @@ class WanModel(torch.nn.Module): require_vae_embedding: bool = True, require_clip_embedding: bool = True, fuse_vae_embedding_in_latents: bool = False, + wantodance_enable_music_inject: bool = False, + wantodance_music_inject_layers = [0, 4, 8, 12, 16, 20, 24, 27], + wantodance_enable_refimage: bool = False, + wantodance_enable_refface: bool = False, + wantodance_enable_global: bool = False, + wantodance_enable_dynamicfps: bool = False, + wantodance_enable_unimodel: bool = False, ): super().__init__() self.dim = dim @@ -337,7 +396,12 @@ class WanModel(torch.nn.Module): ]) self.head = Head(dim, out_dim, patch_size, eps) head_dim = dim // num_heads - self.freqs = precompute_freqs_cis_3d(head_dim) + + if wantodance_enable_dynamicfps or wantodance_enable_unimodel: + end = int(22350 / 8 + 0.5) # 149f * 30fps * 5s = 22350 + self.freqs = precompute_freqs_cis_3d(head_dim, end=end) + else: + self.freqs = precompute_freqs_cis_3d(head_dim) if has_image_input: self.img_emb = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280 @@ -350,8 +414,83 @@ class WanModel(torch.nn.Module): else: self.control_adapter = None - def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None): - x = self.patch_embedding(x) + self.prepare_wantodance(in_dim, dim, num_heads, has_image_pos_emb, out_dim, patch_size, eps, + wantodance_enable_music_inject, wantodance_music_inject_layers, wantodance_enable_refimage, wantodance_enable_refface, + wantodance_enable_global, wantodance_enable_dynamicfps, wantodance_enable_unimodel) + + def prepare_wantodance( + self, + in_dim, dim, num_heads, has_image_pos_emb, out_dim, patch_size, eps, + wantodance_enable_music_inject: bool = False, + wantodance_music_inject_layers = [0, 4, 8, 12, 16, 20, 24, 27], + wantodance_enable_refimage: bool = False, + wantodance_enable_refface: bool = False, + wantodance_enable_global: bool = False, + wantodance_enable_dynamicfps: bool = False, + wantodance_enable_unimodel: bool = False, + ): + if wantodance_enable_music_inject: + all_modules, all_modules_names = wantodance_torch_dfs(self.blocks, parent_name="root.transformer_blocks") + self.music_injector = WanToDanceInjector(all_modules, all_modules_names, dim=dim, num_heads=num_heads, inject_layer=wantodance_music_inject_layers) + if wantodance_enable_refimage: + self.img_emb_refimage = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280 + if wantodance_enable_refface: + self.img_emb_refface = MLP(1280, dim, has_pos_emb=has_image_pos_emb) # clip_feature_dim = 1280 + if wantodance_enable_global or wantodance_enable_dynamicfps or wantodance_enable_unimodel: + music_feature_dim = 35 + ff_size = 1024 + dropout = 0.1 + latent_dim = 256 + nhead = 4 + activation = F.gelu + rotary = WanToDanceRotaryEmbedding(dim=latent_dim) + self.music_projection = nn.Linear(music_feature_dim, latent_dim) + self.music_encoder = nn.Sequential() + for _ in range(2): + self.music_encoder.append( + WanToDanceMusicEncoderLayer( + d_model=latent_dim, + nhead=nhead, + dim_feedforward=ff_size, + dropout=dropout, + activation=activation, + batch_first=True, + rotary=rotary, + device='cuda', + ) + ) + if wantodance_enable_unimodel: + self.patch_embedding_global = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) + if wantodance_enable_unimodel: + self.head_global = Head(dim, out_dim, patch_size, eps) + self.wantodance_enable_music_inject = wantodance_enable_music_inject + self.wantodance_enable_refimage = wantodance_enable_refimage + self.wantodance_enable_refface = wantodance_enable_refface + self.wantodance_enable_global = wantodance_enable_global + self.wantodance_enable_dynamicfps = wantodance_enable_dynamicfps + self.wantodance_enable_unimodel = wantodance_enable_unimodel + + def wantodance_after_transformer_block(self, block_idx, hidden_states): + if self.wantodance_enable_music_inject: + if block_idx in self.music_injector.injected_block_id.keys(): + audio_attn_id = self.music_injector.injected_block_id[block_idx] + audio_emb = self.merged_audio_emb # b f n c + num_frames = audio_emb.shape[1] + input_hidden_states = hidden_states.clone() # b (f h w) c + input_hidden_states = rearrange(input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames) + attn_hidden_states = self.music_injector.injector_pre_norm_feat[audio_attn_id](input_hidden_states) + audio_emb = rearrange(audio_emb, "b t c -> (b t) 1 c", t=num_frames) + attn_audio_emb = audio_emb + residual_out = self.music_injector.injector[audio_attn_id](attn_hidden_states, attn_audio_emb) + residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames) + hidden_states = hidden_states + residual_out + return hidden_states + + def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None, enable_wantodance_global=False): + if enable_wantodance_global: + x = self.patch_embedding_global(x) + else: + x = self.patch_embedding(x) if self.control_adapter is not None and control_camera_latents_input is not None: y_camera = self.control_adapter(control_camera_latents_input) x = [u + v for u, v in zip(x, y_camera)] diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index 3d5db68..f4e4a8a 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -1247,6 +1247,22 @@ class WanVideoVAE(nn.Module): return videos + def encode_framewise(self, videos, device): + hidden_states = [] + for i in range(videos.shape[2]): + hidden_states.append(self.single_encode(videos[:, :, i:i+1], device)) + hidden_states = torch.concat(hidden_states, dim=2) + return hidden_states + + + def decode_framewise(self, hidden_states, device): + video = [] + for i in range(hidden_states.shape[2]): + video.append(self.single_decode(hidden_states[:, :, i:i+1], device)) + video = torch.concat(video, dim=2) + return video + + @staticmethod def state_dict_converter(): return WanVideoVAEStateDictConverter() diff --git a/diffsynth/models/wantodance.py b/diffsynth/models/wantodance.py new file mode 100644 index 0000000..bc9ddc9 --- /dev/null +++ b/diffsynth/models/wantodance.py @@ -0,0 +1,209 @@ +from inspect import isfunction +from math import log, pi + +import torch +from einops import rearrange, repeat +from torch import einsum, nn + +from typing import Any, Callable, List, Optional, Union +from torch import Tensor +import torch.nn.functional as F + +# helper functions + + +def exists(val): + return val is not None + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] + ), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +# rotary embedding helper functions + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +def apply_rotary_emb(freqs, t, start_index=0): + freqs = freqs.to(t) + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim + assert ( + rot_dim <= t.shape[-1] + ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" + t_left, t, t_right = ( + t[..., :start_index], + t[..., start_index:end_index], + t[..., end_index:], + ) + t = (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) + return torch.cat((t_left, t, t_right), dim=-1) + + +# learned rotation helpers + + +def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None): + if exists(freq_ranges): + rotations = einsum("..., f -> ... f", rotations, freq_ranges) + rotations = rearrange(rotations, "... r f -> ... (r f)") + + rotations = repeat(rotations, "... n -> ... (n r)", r=2) + return apply_rotary_emb(rotations, t, start_index=start_index) + + +# classes + + +class WanToDanceRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + custom_freqs=None, + freqs_for="lang", + theta=10000, + max_freq=10, + num_freqs=1, + learned_freq=False, + ): + super().__init__() + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + self.cache = dict() + + if learned_freq: + self.freqs = nn.Parameter(freqs) + else: + self.register_buffer("freqs", freqs, persistent=False) + + def rotate_queries_or_keys(self, t, seq_dim=-2): + device = t.device + seq_len = t.shape[seq_dim] + freqs = self.forward( + lambda: torch.arange(seq_len, device=device), cache_key=seq_len + ) + return apply_rotary_emb(freqs, t) + + def forward(self, t, cache_key=None): + if exists(cache_key) and cache_key in self.cache: + return self.cache[cache_key] + + if isfunction(t): + t = t() + + # freqs = self.freqs + freqs = self.freqs.to(t.device) + + freqs = torch.einsum("..., f -> ... f", t.type(freqs.dtype), freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + + if exists(cache_key): + self.cache[cache_key] = freqs + + return freqs + + +class WanToDanceMusicEncoderLayer(nn.Module): + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + layer_norm_eps: float = 1e-5, + batch_first: bool = False, + norm_first: bool = True, + device=None, + dtype=None, + rotary=None, + ) -> None: + super().__init__() + self.self_attn = nn.MultiheadAttention( + d_model, nhead, dropout=dropout, batch_first=batch_first, device=device, dtype=dtype + ) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm_first = norm_first + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.activation = activation + + self.rotary = rotary + self.use_rotary = rotary is not None + + # self-attention block + def _sa_block( + self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor] + ) -> Tensor: + qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x + x = self.self_attn( + qk, + qk, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + x = src + if self.norm_first: + self.norm1.to(device=x.device) + self.norm2.to(device=x.device) + x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) + x = x + self._ff_block(self.norm2(x)) + else: + x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) + x = self.norm2(x + self._ff_block(x)) + return x \ No newline at end of file diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index bbc479e..bb9f2b0 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -75,6 +75,9 @@ class WanVideoPipeline(BasePipeline): WanVideoUnit_TeaCache(), WanVideoUnit_CfgMerger(), WanVideoUnit_LongCatVideo(), + WanVideoUnit_WanToDance_ProcessInputs(), + WanVideoUnit_WanToDance_RefImageEmbedder(), + WanVideoUnit_WanToDance_ImageKeyframesEmbedder(), ] self.post_units = [ WanVideoPostUnit_S2V(), @@ -244,6 +247,13 @@ class WanVideoPipeline(BasePipeline): # Teacache tea_cache_l1_thresh: Optional[float] = None, tea_cache_model_id: Optional[str] = "", + # WanToDance + wantodance_music_path: Optional[str] = None, + wantodance_reference_image: Optional[Image.Image] = None, + wantodance_fps: Optional[float] = 30, + wantodance_keyframes: Optional[list[Image.Image]] = None, + wantodance_keyframes_mask: Optional[list[int]] = None, + framewise_decoding: bool = False, # progress_bar progress_bar_cmd=tqdm, output_type: Optional[Literal["quantized", "floatpoint"]] = "quantized", @@ -280,6 +290,9 @@ class WanVideoPipeline(BasePipeline): "input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video, "animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video, "vap_video": vap_video, + "wantodance_music_path": wantodance_music_path, "wantodance_reference_image": wantodance_reference_image, "wantodance_fps": wantodance_fps, + "wantodance_keyframes": wantodance_keyframes, "wantodance_keyframes_mask": wantodance_keyframes_mask, + "framewise_decoding": framewise_decoding, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) @@ -325,7 +338,10 @@ class WanVideoPipeline(BasePipeline): inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) # Decode self.load_models_to_device(['vae']) - video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if framewise_decoding: + video = self.vae.decode_framewise(inputs_shared["latents"], device=self.device) + else: + video = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) if output_type == "quantized": video = self.vae_output_to_video(video) elif output_type == "floatpoint": @@ -371,17 +387,20 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit): class WanVideoUnit_InputVideoEmbedder(PipelineUnit): def __init__(self): super().__init__( - input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"), + input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image", "framewise_decoding"), output_params=("latents", "input_latents"), onload_model_names=("vae",) ) - def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image): + def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image, framewise_decoding): if input_video is None: return {"latents": noise} pipe.load_models_to_device(self.onload_model_names) input_video = pipe.preprocess_video(input_video) - input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + if framewise_decoding: + input_latents = pipe.vae.encode_framewise(input_video, device=pipe.device) + else: + input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) if vace_reference_image is not None: if not isinstance(vace_reference_image, list): vace_reference_image = [vace_reference_image] @@ -1018,6 +1037,111 @@ class WanVideoUnit_LongCatVideo(PipelineUnit): return {"longcat_latents": longcat_latents} +class WanVideoUnit_WanToDance_ProcessInputs(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + ) + + def get_music_base_feature(self, music_path, fps=30): + import librosa + hop_length = 512 + sr = fps * hop_length + data, sr = librosa.load(music_path, sr=sr) + sr = 22050 + envelope = librosa.onset.onset_strength(y=data, sr=sr) + mfcc = librosa.feature.mfcc(y=data, sr=sr, n_mfcc=20).T + chroma = librosa.feature.chroma_cens( + y=data, sr=sr, hop_length=hop_length, n_chroma=12 + ).T + peak_idxs = librosa.onset.onset_detect( + onset_envelope=envelope.flatten(), sr=sr, hop_length=hop_length + ) + peak_onehot = np.zeros_like(envelope, dtype=np.float32) + peak_onehot[peak_idxs] = 1.0 + start_bpm = librosa.beat.tempo(y=librosa.load(music_path)[0])[0] + _, beat_idxs = librosa.beat.beat_track( + onset_envelope=envelope, + sr=sr, + hop_length=hop_length, + start_bpm=start_bpm, + tightness=100, + ) + beat_onehot = np.zeros_like(envelope, dtype=np.float32) + beat_onehot[beat_idxs] = 1.0 + audio_feature = np.concatenate( + [envelope[:, None], mfcc, chroma, peak_onehot[:, None], beat_onehot[:, None]], + axis=-1, + ) + return torch.from_numpy(audio_feature) + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if pipe.dit.wantodance_enable_global: + inputs_nega["skip_9th_layer"] = True + if inputs_shared.get("wantodance_music_path", None) is not None: + inputs_shared["music_feature"] = self.get_music_base_feature(inputs_shared["wantodance_music_path"]).to(dtype=pipe.torch_dtype, device=pipe.device) + return inputs_shared, inputs_posi, inputs_nega + + +class WanVideoUnit_WanToDance_RefImageEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("wantodance_reference_image", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("wantodance_refimage_feature",), + onload_model_names=("image_encoder", "vae") + ) + + def process(self, pipe: WanVideoPipeline, wantodance_reference_image, num_frames, height, width, tiled, tile_size, tile_stride): + if wantodance_reference_image is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + if isinstance(wantodance_reference_image, list): + wantodance_reference_image = wantodance_reference_image[0] + image = pipe.preprocess_image(wantodance_reference_image.resize((width, height))).to(pipe.device) # B,C,H,W;B=1 + refimage_feature = pipe.image_encoder.encode_image([image]) + refimage_feature = refimage_feature.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"wantodance_refimage_feature": refimage_feature} + + +class WanVideoUnit_WanToDance_ImageKeyframesEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("wantodance_keyframes", "wantodance_keyframes_mask", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride"), + output_params=("clip_feature", "y"), + onload_model_names=("image_encoder", "vae") + ) + + def process(self, pipe: WanVideoPipeline, wantodance_keyframes, wantodance_keyframes_mask, num_frames, height, width, tiled, tile_size, tile_stride): + if wantodance_keyframes is None: + return {} + wantodance_keyframes_mask = torch.tensor(wantodance_keyframes_mask) + pipe.load_models_to_device(self.onload_model_names) + images = [] + for input_image in wantodance_keyframes: + input_image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device) + images.append(input_image) + + clip_context = pipe.image_encoder.encode_image(images[:1]) # 取第一帧作为clip输入 + msk = torch.zeros(1, num_frames, height//8, width//8, device=pipe.device) + msk[:, wantodance_keyframes_mask==1, :, :] = torch.ones(1, height//8, width//8, device=pipe.device) # set keyframes mask to 1 + + images = [image.transpose(0, 1) for image in images] # 3, num_frames, h, w + images = torch.concat(images, dim=1) + vae_input = images + + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) # expand first frame mask, N to N + 3 + msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) + msk = msk.transpose(1, 2)[0] + + y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + y = torch.concat([msk, y]) + y = y.unsqueeze(0) + clip_context = clip_context.to(dtype=pipe.torch_dtype, device=pipe.device) + y = y.to(dtype=pipe.torch_dtype, device=pipe.device) + return {"clip_feature": clip_context, "y": y} + + class TeaCache: def __init__(self, num_inference_steps, rel_l1_thresh, model_id): self.num_inference_steps = num_inference_steps @@ -1123,6 +1247,22 @@ class TemporalTiler_BCTHW: return value +def wantodance_get_single_freqs(freqs, frame_num, fps): + total_frame = int(30.0 / (fps + 1e-6) * frame_num + 0.5) + interval_frame = 30.0 / (fps + 1e-6) + freqs_0 = freqs[:total_frame] + freqs_new = torch.zeros((frame_num, freqs_0.shape[1]), device=freqs_0.device, dtype=freqs_0.dtype) + freqs_new[0] = freqs_0[0] + freqs_new[-1] = freqs_0[total_frame - 1] + for i in range(1, frame_num-1): + pos = i * interval_frame + low_idx = int(pos) + high_idx = min(low_idx + 1, total_frame - 1) + weight_high = pos - low_idx + weight_low = 1.0 - weight_high + freqs_new[i] = freqs_0[low_idx] * weight_low + freqs_0[high_idx] * weight_high + return freqs_new + def model_fn_wan_video( dit: WanModel, @@ -1158,6 +1298,10 @@ def model_fn_wan_video( use_gradient_checkpointing_offload: bool = False, control_camera_latents_input = None, fuse_vae_embedding_in_latents: bool = False, + wantodance_refimage_feature = None, + wantodance_fps: float = 30.0, + music_feature = None, + skip_9th_layer: bool = False, **kwargs, ): if sliding_window_size is not None and sliding_window_stride is not None: @@ -1255,7 +1399,10 @@ def model_fn_wan_video( context = torch.cat([clip_embdding, context], dim=1) # Camera control - x = dit.patchify(x, control_camera_latents_input) + if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global and int(wantodance_fps + 0.5) != 30: + x = dit.patchify(x, control_camera_latents_input, enable_wantodance_global=True) + else: + x = dit.patchify(x, control_camera_latents_input) # Animate if pose_latents is not None and face_pixel_values is not None: @@ -1310,7 +1457,61 @@ def model_fn_wan_video( use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload ) - + + # WanToDance + if hasattr(dit, "wantodance_enable_global") and dit.wantodance_enable_global: + if wantodance_refimage_feature is not None: + refimage_feature_embedding = dit.img_emb_refimage(wantodance_refimage_feature) + context = torch.cat([refimage_feature_embedding, context], dim=1) + if (dit.wantodance_enable_dynamicfps or dit.wantodance_enable_unimodel) and int(wantodance_fps + 0.5) != 30: + freqs_0 = wantodance_get_single_freqs(dit.freqs[0], f, wantodance_fps) + freqs = torch.cat([ + freqs_0.view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + if dit.wantodance_enable_global or dit.wantodance_enable_dynamicfps or dit.wantodance_enable_unimodel: + if use_unified_sequence_parallel: + length = int(float(music_feature.shape[0]) / get_sequence_parallel_world_size()) * get_sequence_parallel_world_size() + music_feature = music_feature[:length] + music_feature = torch.chunk(music_feature, get_sequence_parallel_world_size(), dim=0)[get_sequence_parallel_rank()] + if not dit.training: + dit.music_encoder.to(x.device, dtype=x.dtype) # only evaluation + music_feature = music_feature.to(x.device, dtype=x.dtype) + music_feature = dit.music_projection(music_feature) + music_feature = dit.music_encoder(music_feature) + if music_feature.dim() == 2: + music_feature = music_feature.unsqueeze(0) + if use_unified_sequence_parallel: + if dist.is_initialized() and dist.get_world_size() > 1: + music_feature = get_sp_group().all_gather(music_feature, dim=1) + music_feature = music_feature.unsqueeze(1) # [1, 1, 149, 4800] + N = 149 + M = 4800 + music_feature = torch.nn.functional.interpolate(music_feature, size=(N, M), mode='bilinear') + music_feature = music_feature.squeeze(1) # shape: [1, 149, 4800] + if music_feature is not None: + if music_feature.dim() == 2: + music_feature = music_feature.unsqueeze(0) + music_feature = music_feature.to(x.device, dtype=x.dtype) + interp_mode = 'bilinear' + if interp_mode == 'bilinear': + frame_num = latents.shape[2] if len(latents.shape) == 5 else latents.shape[1] # 21 + context_shape_end = context.shape[2] ## 14B 5120 + music_feature = music_feature.unsqueeze(1) # shape: [1, 1, 149, 4800] + if use_unified_sequence_parallel: + N = int(float(frame_num * 8) / get_sequence_parallel_world_size()) * get_sequence_parallel_world_size() + else: + N = frame_num * 8 + music_feature = torch.nn.functional.interpolate(music_feature, size=(N, context_shape_end), mode='bilinear') + music_feature = music_feature.squeeze(1) # shape: [1, N, context_shape_end] + if use_unified_sequence_parallel: + dit.merged_audio_emb = torch.chunk(music_feature, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] + else: + dit.merged_audio_emb = music_feature + else: + dit.merged_audio_emb = music_feature + # blocks if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: @@ -1326,8 +1527,12 @@ def model_fn_wan_video( return vap(block, *inputs) return custom_forward + # Block for block_id, block in enumerate(dit.blocks): - # Block + if skip_9th_layer: + # This is only used in WanToDance + if block_id == 9: + continue if vap is not None and block_id in vap.mot_layers_mapping: if use_gradient_checkpointing_offload: with torch.autograd.graph.save_on_cpu(): @@ -1364,10 +1569,18 @@ def model_fn_wan_video( # Animate if pose_latents is not None and face_pixel_values is not None: x = animate_adapter.after_transformer_block(block_id, x, motion_vec) + + # WanToDance + if hasattr(dit, "wantodance_enable_music_inject") and dit.wantodance_enable_music_inject: + x = dit.wantodance_after_transformer_block(block_id, x) if tea_cache is not None: tea_cache.store(x) - x = dit.head(x, t) + if hasattr(dit, "wantodance_enable_unimodel") and dit.wantodance_enable_unimodel and int(wantodance_fps + 0.5) != 30: + x = dit.head_global(x, t) + else: + x = dit.head(x, t) + if use_unified_sequence_parallel: if dist.is_initialized() and dist.get_world_size() > 1: x = get_sp_group().all_gather(x, dim=1) diff --git a/docs/en/Model_Details/Wan.md b/docs/en/Model_Details/Wan.md index 73e4d52..25c3133 100644 --- a/docs/en/Model_Details/Wan.md +++ b/docs/en/Model_Details/Wan.md @@ -139,6 +139,8 @@ graph LR; | [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) | | [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py) | | [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py) | +| [Wan-AI/WanToDance-14B (global model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B) | `wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-global.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-global.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py) | +| [Wan-AI/WanToDance-14B (local model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B) | `wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-local.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-local.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-local.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-local.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-local.py) | * FP8 Precision Training: [doc](../Training/FP8_Precision.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/) * Two-stage Split Training: [doc](../Training/Split_Training.md), [code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/) @@ -203,6 +205,50 @@ Input parameters for `WanVideoPipeline` inference include: If VRAM is insufficient, please enable [VRAM Management](../Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above. +### Multi-GPU Parallel Acceleration + +To enable multi-GPU parallel acceleration, please install `flash_attn` and `xfuser`: + +```shell +pip install flash-attn --no-build-isolation +pip install xfuser +``` + +Please modify your code as follows ([example code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/acceleration/unified_sequence_parallel.py)): + +```diff +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig ++ import torch.distributed as dist + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", ++ use_usp=True, + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) +video = pipe( + prompt="An astronaut in a spacesuit rides a mechanical horse across the Martian surface, facing the camera. The red, desolate terrain stretches into the distance, dotted with massive craters and unusual rock formations. The mechanical horse moves with steady strides, kicking up faint dust, embodying a perfect fusion of futuristic technology and primal exploration. The astronaut holds a control device, with a determined gaze, as if pioneering new frontiers for humanity. Against a backdrop of the deep cosmos and the blue Earth, the scene is both sci-fi and hopeful, evoking imagination about future interstellar life.", + negative_prompt="oversaturated colors, overexposed, static, blurry details, subtitles, style, artwork, painting, still image, overall gray tone, worst quality, low quality, JPEG compression artifacts, ugly, malformed, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, malformed limbs, fused fingers, frozen frame, cluttered background, three legs, crowd in background, walking backwards", + seed=0, tiled=True, +) ++ if dist.get_rank() == 0: ++ save_video(video, "video1.mp4", fps=15, quality=5) +``` + +When running multi-GPU parallel inference, please use `torchrun`, where `--nproc_per_node` specifies the number of GPUs: + +```shell +torchrun --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py +``` + ## Model Training Wan series models are uniformly trained through [`examples/wanvideo/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/train.py), and the script parameters include: diff --git a/docs/zh/Model_Details/Wan.md b/docs/zh/Model_Details/Wan.md index 7924e40..f79ff2b 100644 --- a/docs/zh/Model_Details/Wan.md +++ b/docs/zh/Model_Details/Wan.md @@ -140,6 +140,8 @@ graph LR; |[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)| | [openmoss/MOVA-360p](https://modelscope.cn/models/openmoss/MOVA-360p) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-360p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-360P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-360p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-360P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-360p-I2AV.py) | | [openmoss/MOVA-720p](https://modelscope.cn/models/openmoss/MOVA-720p) | `input_image` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_inference/MOVA-720p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/full/MOVA-720P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_full/MOVA-720p-I2AV.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/lora/MOVA-720P-I2AV.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/mova/model_training/validate_lora/MOVA-720p-I2AV.py) | +| [Wan-AI/WanToDance-14B (global model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B) | `wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-global.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-global.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py) | +| [Wan-AI/WanToDance-14B (local model)](https://modelscope.cn/models/Wan-AI/WanToDance-14B) | `wantodance_music_path`, `wantodance_reference_image`, `wantodance_fps`, `wantodance_keyframes`, `wantodance_keyframes_mask` | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_inference/WanToDance-14B-local.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/full/WanToDance-14B-local.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_full/WanToDance-14B-local.py) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/lora/WanToDance-14B-local.sh) | [code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/validate_lora/WanToDance-14B-local.py) | * FP8 精度训练:[doc](../Training/FP8_Precision.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/fp8_training/) * 两阶段拆分训练:[doc](../Training/Split_Training.md)、[code](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/model_training/special/split_training/) @@ -204,6 +206,50 @@ DeepSpeed ZeRO 3 训练:Wan 系列模型支持 DeepSpeed ZeRO 3 训练,将 如果显存不足,请开启[显存管理](../Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。 +### 多卡并行加速 + +如需开启多卡并行加速,请先安装 `flash_attn` 与 `xfuser`: + +```shell +pip install flash-attn --no-build-isolation +pip install xfuser +``` + +对代码进行如下修改([样例代码](https://github.com/modelscope/DiffSynth-Studio/tree/main/examples/wanvideo/acceleration/unified_sequence_parallel.py)): + +```diff +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig ++ import torch.distributed as dist + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", ++ use_usp=True, + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/Wan2.1-T2V-14B", origin_file_pattern="Wan2.1_VAE.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) +video = pipe( + prompt="一名宇航员身穿太空服,面朝镜头骑着一匹机械马在火星表面驰骋。红色的荒凉地表延伸至远方,点缀着巨大的陨石坑和奇特的岩石结构。机械马的步伐稳健,扬起微弱的尘埃,展现出未来科技与原始探索的完美结合。宇航员手持操控装置,目光坚定,仿佛正在开辟人类的新疆域。背景是深邃的宇宙和蔚蓝的地球,画面既科幻又充满希望,让人不禁畅想未来的星际生活。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, +) ++ if dist.get_rank() == 0: ++ save_video(video, "video1.mp4", fps=15, quality=5) +``` + +运行多卡并行推理时,请使用 `torchrun` 运行,其中 `--nproc_per_node` 为 GPU 数量: + +```shell +torchrun --nproc_per_node=8 examples/wanvideo/acceleration/unified_sequence_parallel.py +``` + ## 模型训练 Wan 系列模型统一通过 [`examples/wanvideo/model_training/train.py`](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/wanvideo/model_training/train.py) 进行训练,脚本的参数包括: diff --git a/examples/wanvideo/model_inference/WanToDance-14B-global.py b/examples/wanvideo/model_inference/WanToDance-14B-global.py new file mode 100644 index 0000000..642e4b1 --- /dev/null +++ b/examples/wanvideo/model_inference/WanToDance-14B-global.py @@ -0,0 +1,48 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="global_model.safetensors"), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) +dataset_snapshot_download( + "DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="wanvideo/WanToDance-14B-global/*" +) +# This is a specialized model with the following constraints on its input parameters: +# * The model outputs a sequence of keyframes rather than a video; therefore, `framewise_decoding=True` must be set. +# * When the number of keyframes is $n$, `num_frames` = 4 * (n - 1) + 1. +# * Reducing `height`, `width`, `num_frames`, or `num_inference_steps` may lead to severe artifacts or generation failure. +# * The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 7.5) seconds. +# * The width and height of `wantodance_reference_image` must be multiples of 16. +# * `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 7.5 FPS, setting it to other values is not recommended. +# * The first frame of `wantodance_keyframes` is the `wantodance_reference_image`, while all subsequent frames are solid black. +# * `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`. +wantodance_keyframes = VideoData("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/keyframes.mp4") +wantodance_keyframes = [wantodance_keyframes[i] for i in range(149)] +video = pipe( + prompt="一个人正在跳舞,舞蹈种类是韩舞。帧率是7.5000", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=False, + height=1280, width=720, num_frames=149, + num_inference_steps=48, + wantodance_music_path="data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/music.WAV", + wantodance_reference_image=Image.open("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/refimage.jpg"), + wantodance_fps=7.5, + wantodance_keyframes=wantodance_keyframes, + wantodance_keyframes_mask=[1] + [0] * 148, + framewise_decoding=True, +) +save_video(video, "video_WanToDance-14B-global.mp4", fps=7.5, quality=5) diff --git a/examples/wanvideo/model_inference/WanToDance-14B-local.py b/examples/wanvideo/model_inference/WanToDance-14B-local.py new file mode 100644 index 0000000..b94891c --- /dev/null +++ b/examples/wanvideo/model_inference/WanToDance-14B-local.py @@ -0,0 +1,52 @@ +import torch, os +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="local_model.safetensors"), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) +dataset_snapshot_download( + "DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="wanvideo/WanToDance-14B-local/*" +) +# This is a specialized model with the following constraints on its input parameters: +# * The model renders and outputs video based on a sequence of keyframes; therefore, `wantodance_keyframes` must be provided correctly. +# * If you need to generate a long video, please generate it in segments, and ensure that `wantodance_music_path`, `wantodance_keyframes`, and `wantodance_keyframes_mask` are properly split accordingly. +# * The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 30) seconds. +# * The width and height of `wantodance_reference_image` must be multiples of 16. +# * `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 30 FPS, setting it to other values is not recommended. +# * In `wantodance_keyframes`, frames that are not keyframes should be solid black. +# * `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`. +wantodance_keyframes = VideoData("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/keyframes.mp4") +wantodance_keyframes = [wantodance_keyframes[i] for i in range(149)] +video = pipe( + prompt="一个人正在跳舞,舞蹈种类是古典舞,图像清晰程度高,人物动作平均幅度中等,人物动作最大幅度中等。, 帧率是30fps。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + height=1280, width=720, num_frames=149, + num_inference_steps=24, + wantodance_music_path="data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/music.wav", + wantodance_reference_image=Image.open("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/refimage.jpg"), + wantodance_fps=30, + wantodance_keyframes=wantodance_keyframes, + wantodance_keyframes_mask=[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1], +) +save_video(video, "video_WanToDance-14B-local.mp4", fps=30, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/WanToDance-14B-global.py b/examples/wanvideo/model_inference_low_vram/WanToDance-14B-global.py new file mode 100644 index 0000000..e3a1255 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/WanToDance-14B-global.py @@ -0,0 +1,59 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="global_model.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) +dataset_snapshot_download( + "DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="wanvideo/WanToDance-14B-global/*" +) +# This is a specialized model with the following constraints on its input parameters: +# * The model outputs a sequence of keyframes rather than a video; therefore, `framewise_decoding=True` must be set. +# * When the number of keyframes is $n$, `num_frames` = 4 * (n - 1) + 1. +# * Reducing `height`, `width`, `num_frames`, or `num_inference_steps` may lead to severe artifacts or generation failure. +# * The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 7.5) seconds. +# * The width and height of `wantodance_reference_image` must be multiples of 16. +# * `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 7.5 FPS, setting it to other values is not recommended. +# * The first frame of `wantodance_keyframes` is the `wantodance_reference_image`, while all subsequent frames are solid black. +# * `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`. +wantodance_keyframes = VideoData("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/keyframes.mp4") +wantodance_keyframes = [wantodance_keyframes[i] for i in range(149)] +video = pipe( + prompt="一个人正在跳舞,舞蹈种类是韩舞。帧率是7.5000", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=False, + height=1280, width=720, num_frames=149, + num_inference_steps=48, + wantodance_music_path="data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/music.WAV", + wantodance_reference_image=Image.open("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/refimage.jpg"), + wantodance_fps=7.5, + wantodance_keyframes=wantodance_keyframes, + wantodance_keyframes_mask=[1] + [0] * 148, + framewise_decoding=True, +) +save_video(video, "video_WanToDance-14B-global.mp4", fps=7.5, quality=5) diff --git a/examples/wanvideo/model_inference_low_vram/WanToDance-14B-local.py b/examples/wanvideo/model_inference_low_vram/WanToDance-14B-local.py new file mode 100644 index 0000000..4ac2b49 --- /dev/null +++ b/examples/wanvideo/model_inference_low_vram/WanToDance-14B-local.py @@ -0,0 +1,63 @@ +import torch, os +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="local_model.safetensors", **vram_config), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2, +) +dataset_snapshot_download( + "DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="wanvideo/WanToDance-14B-local/*" +) +# This is a specialized model with the following constraints on its input parameters: +# * The model renders and outputs video based on a sequence of keyframes; therefore, `wantodance_keyframes` must be provided correctly. +# * If you need to generate a long video, please generate it in segments, and ensure that `wantodance_music_path`, `wantodance_keyframes`, and `wantodance_keyframes_mask` are properly split accordingly. +# * The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 30) seconds. +# * The width and height of `wantodance_reference_image` must be multiples of 16. +# * `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 30 FPS, setting it to other values is not recommended. +# * In `wantodance_keyframes`, frames that are not keyframes should be solid black. +# * `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`. +wantodance_keyframes = VideoData("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/keyframes.mp4") +wantodance_keyframes = [wantodance_keyframes[i] for i in range(149)] +video = pipe( + prompt="一个人正在跳舞,舞蹈种类是古典舞,图像清晰程度高,人物动作平均幅度中等,人物动作最大幅度中等。, 帧率是30fps。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + height=1280, width=720, num_frames=149, + num_inference_steps=24, + wantodance_music_path="data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/music.wav", + wantodance_reference_image=Image.open("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/refimage.jpg"), + wantodance_fps=30, + wantodance_keyframes=wantodance_keyframes, + wantodance_keyframes_mask=[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1], +) +save_video(video, "video_WanToDance-14B-local.mp4", fps=30, quality=5) diff --git a/examples/wanvideo/model_training/full/WanToDance-14B-global.sh b/examples/wanvideo/model_training/full/WanToDance-14B-global.sh new file mode 100644 index 0000000..b5b88b6 --- /dev/null +++ b/examples/wanvideo/model_training/full/WanToDance-14B-global.sh @@ -0,0 +1,20 @@ +# 8*H200 required +modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "wanvideo/WanToDance-14B-global/*" --local_dir ./data/diffsynth_example_dataset + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global \ + --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/metadata.json \ + --data_file_keys "video,wantodance_reference_image,wantodance_keyframes,wantodance_music_path" \ + --height 1280 \ + --width 720 \ + --num_frames 149 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/WanToDance-14B:global_model.safetensors,Wan-AI/WanToDance-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/WanToDance-14B:Wan2.1_VAE.pth,Wan-AI/WanToDance-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/WanToDance-14B-global_full" \ + --trainable_models "dit" \ + --extra_inputs "wantodance_music_path,wantodance_reference_image,wantodance_fps,wantodance_keyframes,wantodance_keyframes_mask,framewise_decoding" \ + --use_gradient_checkpointing_offload \ + --framewise_decoding diff --git a/examples/wanvideo/model_training/full/WanToDance-14B-local.sh b/examples/wanvideo/model_training/full/WanToDance-14B-local.sh new file mode 100644 index 0000000..8789179 --- /dev/null +++ b/examples/wanvideo/model_training/full/WanToDance-14B-local.sh @@ -0,0 +1,19 @@ +# 8*H200 required +modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "wanvideo/WanToDance-14B-local/*" --local_dir ./data/diffsynth_example_dataset + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local \ + --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/metadata.json \ + --data_file_keys "video,wantodance_reference_image,wantodance_keyframes,wantodance_music_path" \ + --height 1280 \ + --width 720 \ + --num_frames 149 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/WanToDance-14B:local_model.safetensors,Wan-AI/WanToDance-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/WanToDance-14B:Wan2.1_VAE.pth,Wan-AI/WanToDance-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/WanToDance-14B-local_full" \ + --trainable_models "dit" \ + --extra_inputs "wantodance_music_path,wantodance_reference_image,wantodance_fps,wantodance_keyframes,wantodance_keyframes_mask" \ + --use_gradient_checkpointing_offload diff --git a/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh b/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh new file mode 100644 index 0000000..6c22b2e --- /dev/null +++ b/examples/wanvideo/model_training/lora/WanToDance-14B-global.sh @@ -0,0 +1,22 @@ +# 8*H200 required +modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "wanvideo/WanToDance-14B-global/*" --local_dir ./data/diffsynth_example_dataset + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global \ + --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/metadata.json \ + --data_file_keys "video,wantodance_reference_image,wantodance_keyframes,wantodance_music_path" \ + --height 1280 \ + --width 720 \ + --num_frames 149 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/WanToDance-14B:global_model.safetensors,Wan-AI/WanToDance-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/WanToDance-14B:Wan2.1_VAE.pth,Wan-AI/WanToDance-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/WanToDance-14B-global_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "wantodance_music_path,wantodance_reference_image,wantodance_fps,wantodance_keyframes,wantodance_keyframes_mask,framewise_decoding" \ + --use_gradient_checkpointing_offload \ + --framewise_decoding diff --git a/examples/wanvideo/model_training/lora/WanToDance-14B-local.sh b/examples/wanvideo/model_training/lora/WanToDance-14B-local.sh new file mode 100644 index 0000000..2379630 --- /dev/null +++ b/examples/wanvideo/model_training/lora/WanToDance-14B-local.sh @@ -0,0 +1,21 @@ +# 8*H200 required +modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "wanvideo/WanToDance-14B-local/*" --local_dir ./data/diffsynth_example_dataset + +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local \ + --dataset_metadata_path data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/metadata.json \ + --data_file_keys "video,wantodance_reference_image,wantodance_keyframes,wantodance_music_path" \ + --height 1280 \ + --width 720 \ + --num_frames 149 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/WanToDance-14B:local_model.safetensors,Wan-AI/WanToDance-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/WanToDance-14B:Wan2.1_VAE.pth,Wan-AI/WanToDance-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/WanToDance-14B-local_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "wantodance_music_path,wantodance_reference_image,wantodance_fps,wantodance_keyframes,wantodance_keyframes_mask" \ + --use_gradient_checkpointing_offload diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index d4074a6..a3971ef 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -72,6 +72,9 @@ class WanTrainingModule(DiffusionTrainingModule): inputs_shared[extra_input] = data[extra_input][0] else: inputs_shared[extra_input] = data[extra_input] + if inputs_shared.get("framewise_decoding", False): + # WanToDance global model + inputs_shared["num_frames"] = 4 * (len(data["video"]) - 1) + 1 return inputs_shared def get_pipeline_inputs(self, data): @@ -117,6 +120,7 @@ def wan_parser(): parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.") + parser.add_argument("--framewise_decoding", default=False, action="store_true", help="Enable it if this model is a WanToDance global model.") return parser @@ -140,12 +144,13 @@ if __name__ == "__main__": height_division_factor=16, width_division_factor=16, num_frames=args.num_frames, - time_division_factor=4, - time_division_remainder=1, + time_division_factor=4 if not args.framewise_decoding else 1, + time_division_remainder=1 if not args.framewise_decoding else 0, ), special_operator_map={ "animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)), "input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudio(sr=16000), + "wantodance_music_path": ToAbsolutePath(args.dataset_base_path), } ) model = WanTrainingModule( diff --git a/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py b/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py new file mode 100644 index 0000000..ad22696 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/WanToDance-14B-global.py @@ -0,0 +1,51 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download +from diffsynth.core import load_state_dict + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="global_model.safetensors"), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) +state_dict = load_state_dict("models/train/WanToDance-14B-global_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +dataset_snapshot_download( + "DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="wanvideo/WanToDance-14B-global/*" +) +# This is a specialized model with the following constraints on its input parameters: +# * The model outputs a sequence of keyframes rather than a video; therefore, `framewise_decoding=True` must be set. +# * When the number of keyframes is $n$, `num_frames` = 4 * (n - 1) + 1. +# * Reducing `height`, `width`, `num_frames`, or `num_inference_steps` may lead to severe artifacts or generation failure. +# * The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 7.5) seconds. +# * The width and height of `wantodance_reference_image` must be multiples of 16. +# * `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 7.5 FPS, setting it to other values is not recommended. +# * The first frame of `wantodance_keyframes` is the `wantodance_reference_image`, while all subsequent frames are solid black. +# * `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`. +wantodance_keyframes = VideoData("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/keyframes.mp4") +wantodance_keyframes = [wantodance_keyframes[i] for i in range(149)] +video = pipe( + prompt="一个人正在跳舞,舞蹈种类是韩舞。帧率是7.5000", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=False, + height=1280, width=720, num_frames=149, + num_inference_steps=48, + wantodance_music_path="data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/music.WAV", + wantodance_reference_image=Image.open("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/refimage.jpg"), + wantodance_fps=7.5, + wantodance_keyframes=wantodance_keyframes, + wantodance_keyframes_mask=[1] + [0] * 148, + framewise_decoding=True, +) +save_video(video, "video_WanToDance-14B-global.mp4", fps=7.5, quality=5) diff --git a/examples/wanvideo/model_training/validate_full/WanToDance-14B-local.py b/examples/wanvideo/model_training/validate_full/WanToDance-14B-local.py new file mode 100644 index 0000000..d409ff7 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/WanToDance-14B-local.py @@ -0,0 +1,55 @@ +import torch, os +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download +from diffsynth.core import load_state_dict + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="local_model.safetensors"), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) +state_dict = load_state_dict("models/train/WanToDance-14B-local_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +dataset_snapshot_download( + "DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="wanvideo/WanToDance-14B-local/*" +) +# This is a specialized model with the following constraints on its input parameters: +# * The model renders and outputs video based on a sequence of keyframes; therefore, `wantodance_keyframes` must be provided correctly. +# * If you need to generate a long video, please generate it in segments, and ensure that `wantodance_music_path`, `wantodance_keyframes`, and `wantodance_keyframes_mask` are properly split accordingly. +# * The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 30) seconds. +# * The width and height of `wantodance_reference_image` must be multiples of 16. +# * `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 30 FPS, setting it to other values is not recommended. +# * In `wantodance_keyframes`, frames that are not keyframes should be solid black. +# * `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`. +wantodance_keyframes = VideoData("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/keyframes.mp4") +wantodance_keyframes = [wantodance_keyframes[i] for i in range(149)] +video = pipe( + prompt="一个人正在跳舞,舞蹈种类是古典舞,图像清晰程度高,人物动作平均幅度中等,人物动作最大幅度中等。, 帧率是30fps。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + height=1280, width=720, num_frames=149, + num_inference_steps=24, + wantodance_music_path="data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/music.wav", + wantodance_reference_image=Image.open("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/refimage.jpg"), + wantodance_fps=30, + wantodance_keyframes=wantodance_keyframes, + wantodance_keyframes_mask=[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1], +) +save_video(video, "video_WanToDance-14B-local.mp4", fps=30, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py b/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py new file mode 100644 index 0000000..681c5e1 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/WanToDance-14B-global.py @@ -0,0 +1,49 @@ +import torch +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="global_model.safetensors"), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) +pipe.load_lora(pipe.dit, "models/train/WanToDance-14B-global_lora/epoch-4.safetensors", alpha=1) +dataset_snapshot_download( + "DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="wanvideo/WanToDance-14B-global/*" +) +# This is a specialized model with the following constraints on its input parameters: +# * The model outputs a sequence of keyframes rather than a video; therefore, `framewise_decoding=True` must be set. +# * When the number of keyframes is $n$, `num_frames` = 4 * (n - 1) + 1. +# * Reducing `height`, `width`, `num_frames`, or `num_inference_steps` may lead to severe artifacts or generation failure. +# * The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 7.5) seconds. +# * The width and height of `wantodance_reference_image` must be multiples of 16. +# * `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 7.5 FPS, setting it to other values is not recommended. +# * The first frame of `wantodance_keyframes` is the `wantodance_reference_image`, while all subsequent frames are solid black. +# * `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`. +wantodance_keyframes = VideoData("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/keyframes.mp4") +wantodance_keyframes = [wantodance_keyframes[i] for i in range(149)] +video = pipe( + prompt="一个人正在跳舞,舞蹈种类是韩舞。帧率是7.5000", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=False, + height=1280, width=720, num_frames=149, + num_inference_steps=48, + wantodance_music_path="data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/music.WAV", + wantodance_reference_image=Image.open("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-global/refimage.jpg"), + wantodance_fps=7.5, + wantodance_keyframes=wantodance_keyframes, + wantodance_keyframes_mask=[1] + [0] * 148, + framewise_decoding=True, +) +save_video(video, "video_WanToDance-14B-global.mp4", fps=7.5, quality=5) diff --git a/examples/wanvideo/model_training/validate_lora/WanToDance-14B-local.py b/examples/wanvideo/model_training/validate_lora/WanToDance-14B-local.py new file mode 100644 index 0000000..78c6aea --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/WanToDance-14B-local.py @@ -0,0 +1,53 @@ +import torch, os +from PIL import Image +from diffsynth.utils.data import save_video, VideoData +from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="local_model.safetensors"), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="Wan2.1_VAE.pth"), + ModelConfig(model_id="Wan-AI/WanToDance-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), + ], + tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"), +) +pipe.load_lora(pipe.dit, "models/train/WanToDance-14B-global_lora/epoch-4.safetensors", alpha=1) +dataset_snapshot_download( + "DiffSynth-Studio/diffsynth_example_dataset", + local_dir="data/diffsynth_example_dataset", + allow_file_pattern="wanvideo/WanToDance-14B-local/*" +) +# This is a specialized model with the following constraints on its input parameters: +# * The model renders and outputs video based on a sequence of keyframes; therefore, `wantodance_keyframes` must be provided correctly. +# * If you need to generate a long video, please generate it in segments, and ensure that `wantodance_music_path`, `wantodance_keyframes`, and `wantodance_keyframes_mask` are properly split accordingly. +# * The audio file specified by `wantodance_music_path` must match the video duration, calculated as (`num_frames` / 30) seconds. +# * The width and height of `wantodance_reference_image` must be multiples of 16. +# * `wantodance_fps` is configurable, but since the model appears to have been trained exclusively at 30 FPS, setting it to other values is not recommended. +# * In `wantodance_keyframes`, frames that are not keyframes should be solid black. +# * `wantodance_keyframes_mask` indicates the positions of valid frames within `wantodance_keyframes`. +wantodance_keyframes = VideoData("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/keyframes.mp4") +wantodance_keyframes = [wantodance_keyframes[i] for i in range(149)] +video = pipe( + prompt="一个人正在跳舞,舞蹈种类是古典舞,图像清晰程度高,人物动作平均幅度中等,人物动作最大幅度中等。, 帧率是30fps。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + seed=0, tiled=True, + height=1280, width=720, num_frames=149, + num_inference_steps=24, + wantodance_music_path="data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/music.wav", + wantodance_reference_image=Image.open("data/diffsynth_example_dataset/wanvideo/WanToDance-14B-local/refimage.jpg"), + wantodance_fps=30, + wantodance_keyframes=wantodance_keyframes, + wantodance_keyframes_mask=[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1], +) +save_video(video, "video_WanToDance-14B-local.mp4", fps=30, quality=5)