diff --git a/README.md b/README.md index a4109f6..7093950 100644 --- a/README.md +++ b/README.md @@ -207,6 +207,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) | Model ID | Extra Parameters | Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training | |-|-|-|-|-|-|-| +|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)| |[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-| |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| diff --git a/README_zh.md b/README_zh.md index dfc52fe..a9bd0ca 100644 --- a/README_zh.md +++ b/README_zh.md @@ -207,6 +207,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) |模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-| +|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)| |[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-| |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)| |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)| diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 43fe84b..31a7d66 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -63,6 +63,7 @@ from ..models.wan_video_vae import WanVideoVAE, WanVideoVAE38 from ..models.wan_video_motion_controller import WanMotionControllerModel from ..models.wan_video_vace import VaceWanModel from ..models.wav2vec import WanS2VAudioEncoder +from ..models.wan_video_animate_adapter import WanAnimateAdapter from ..models.step1x_connector import Qwen2Connector @@ -142,7 +143,6 @@ model_loader_configs = [ (None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"), (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"), (None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"), - (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"), (None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"), (None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"), (None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"), @@ -176,6 +176,7 @@ model_loader_configs = [ (None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"), (None, "a9e54e480a628f0b956a688a81c33bab", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"), (None, "06be60f3a4526586d8431cd038a71486", ["wans2v_audio_encoder"], [WanS2VAudioEncoder], "civitai"), + (None, "31fa352acb8a1b1d33cd8764273d80a2", ["wan_video_dit", "wan_video_animate_adapter"], [WanModel, WanAnimateAdapter], "civitai"), ] huggingface_model_loader_configs = [ # These configs are provided for detecting model type automatically. diff --git a/diffsynth/models/wan_video_animate_adapter.py b/diffsynth/models/wan_video_animate_adapter.py new file mode 100644 index 0000000..771280a --- /dev/null +++ b/diffsynth/models/wan_video_animate_adapter.py @@ -0,0 +1,670 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +import math +from typing import Tuple, Optional, List +from einops import rearrange + + + +MEMORY_LAYOUT = { + "flash": ( + lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), + lambda x: x, + ), + "torch": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), + "vanilla": ( + lambda x: x.transpose(1, 2), + lambda x: x.transpose(1, 2), + ), +} + + +def attention( + q, + k, + v, + mode="torch", + drop_rate=0, + attn_mask=None, + causal=False, + max_seqlen_q=None, + batch_size=1, +): + pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] + + if mode == "torch": + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(q.dtype) + x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal) + + x = post_attn_layout(x) + b, s, a, d = x.shape + out = x.reshape(b, s, -1) + return out + + +class CausalConv1d(nn.Module): + + def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs): + super().__init__() + + self.pad_mode = pad_mode + padding = (kernel_size - 1, 0) # T + self.time_causal_padding = padding + + self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + x = F.pad(x, self.time_causal_padding, mode=self.pad_mode) + return self.conv(x) + + + +class FaceEncoder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None): + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + + self.num_heads = num_heads + self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1) + self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs) + self.act = nn.SiLU() + self.conv2 = CausalConv1d(1024, 1024, 3, stride=2) + self.conv3 = CausalConv1d(1024, 1024, 3, stride=2) + + self.out_proj = nn.Linear(1024, hidden_dim) + self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim)) + + def forward(self, x): + + x = rearrange(x, "b t c -> b c t") + b, c, t = x.shape + + x = self.conv1_local(x) + x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads) + + x = self.norm1(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv2(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm2(x) + x = self.act(x) + x = rearrange(x, "b t c -> b c t") + x = self.conv3(x) + x = rearrange(x, "b c t -> b t c") + x = self.norm3(x) + x = self.act(x) + x = self.out_proj(x) + x = rearrange(x, "(b n) t c -> b t n c", b=b) + padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1) + x = torch.cat([x, padding], dim=-2) + x_local = x.clone() + + return x_local + + + +class RMSNorm(nn.Module): + def __init__( + self, + dim: int, + elementwise_affine=True, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +def get_norm_layer(norm_layer): + """ + Get the normalization layer. + + Args: + norm_layer (str): The type of normalization layer. + + Returns: + norm_layer (nn.Module): The normalization layer. + """ + if norm_layer == "layer": + return nn.LayerNorm + elif norm_layer == "rms": + return RMSNorm + else: + raise NotImplementedError(f"Norm layer {norm_layer} is not implemented") + + +class FaceAdapter(nn.Module): + def __init__( + self, + hidden_dim: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + num_adapter_layers: int = 1, + dtype=None, + device=None, + ): + + factory_kwargs = {"dtype": dtype, "device": device} + super().__init__() + self.hidden_size = hidden_dim + self.heads_num = heads_num + self.fuser_blocks = nn.ModuleList( + [ + FaceBlock( + self.hidden_size, + self.heads_num, + qk_norm=qk_norm, + qk_norm_type=qk_norm_type, + **factory_kwargs, + ) + for _ in range(num_adapter_layers) + ] + ) + + def forward( + self, + x: torch.Tensor, + motion_embed: torch.Tensor, + idx: int, + freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None, + freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None, + ) -> torch.Tensor: + + return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k) + + + +class FaceBlock(nn.Module): + def __init__( + self, + hidden_size: int, + heads_num: int, + qk_norm: bool = True, + qk_norm_type: str = "rms", + qk_scale: float = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + + self.deterministic = False + self.hidden_size = hidden_size + self.heads_num = heads_num + head_dim = hidden_size // heads_num + self.scale = qk_scale or head_dim**-0.5 + + self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs) + self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs) + + qk_norm_layer = get_norm_layer(qk_norm_type) + self.q_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + self.k_norm = ( + qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity() + ) + + self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) + + def forward( + self, + x: torch.Tensor, + motion_vec: torch.Tensor, + motion_mask: Optional[torch.Tensor] = None, + use_context_parallel=False, + ) -> torch.Tensor: + + B, T, N, C = motion_vec.shape + T_comp = T + + x_motion = self.pre_norm_motion(motion_vec) + x_feat = self.pre_norm_feat(x) + + kv = self.linear1_kv(x_motion) + q = self.linear1_q(x_feat) + + k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num) + q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num) + + # Apply QK-Norm if needed. + q = self.q_norm(q).to(v) + k = self.k_norm(k).to(v) + + k = rearrange(k, "B L N H D -> (B L) H N D") + v = rearrange(v, "B L N H D -> (B L) H N D") + + q = rearrange(q, "B (L S) H D -> (B L) H S D", L=T_comp) + # Compute attention. + attn = F.scaled_dot_product_attention(q, k, v) + + attn = rearrange(attn, "(B L) H S D -> B (L S) (H D)", L=T_comp) + + output = self.linear2(attn) + + if motion_mask is not None: + output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1) + + return output + + + +def custom_qr(input_tensor): + original_dtype = input_tensor.dtype + if original_dtype == torch.bfloat16: + q, r = torch.linalg.qr(input_tensor.to(torch.float32)) + return q.to(original_dtype), r.to(original_dtype) + return torch.linalg.qr(input_tensor) + +def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): + return F.leaky_relu(input + bias, negative_slope) * scale + + +def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): + _, minor, in_h, in_w = input.shape + kernel_h, kernel_w = kernel.shape + + out = input.view(-1, minor, in_h, 1, in_w, 1) + out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0]) + out = out.view(-1, minor, in_h * up_y, in_w * up_x) + + out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) + out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0), + max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ] + + out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) + w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) + out = F.conv2d(out, w) + out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, + in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) + return out[:, :, ::down_y, ::down_x] + + +def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): + return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + + +def make_kernel(k): + k = torch.tensor(k, dtype=torch.float32) + if k.ndim == 1: + k = k[None, :] * k[:, None] + k /= k.sum() + return k + + +class FusedLeakyReLU(nn.Module): + def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): + super().__init__() + self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1)) + self.negative_slope = negative_slope + self.scale = scale + + def forward(self, input): + out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) + return out + + +class Blur(nn.Module): + def __init__(self, kernel, pad, upsample_factor=1): + super().__init__() + + kernel = make_kernel(kernel) + + if upsample_factor > 1: + kernel = kernel * (upsample_factor ** 2) + + self.register_buffer('kernel', kernel) + + self.pad = pad + + def forward(self, input): + return upfirdn2d(input, self.kernel, pad=self.pad) + + +class ScaledLeakyReLU(nn.Module): + def __init__(self, negative_slope=0.2): + super().__init__() + + self.negative_slope = negative_slope + + def forward(self, input): + return F.leaky_relu(input, negative_slope=self.negative_slope) + + +class EqualConv2d(nn.Module): + def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size)) + self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) + + self.stride = stride + self.padding = padding + + if bias: + self.bias = nn.Parameter(torch.zeros(out_channel)) + else: + self.bias = None + + def forward(self, input): + + return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding) + + def __repr__(self): + return ( + f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' + f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' + ) + + +class EqualLinear(nn.Module): + def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None): + super().__init__() + + self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) + + if bias: + self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) + else: + self.bias = None + + self.activation = activation + + self.scale = (1 / math.sqrt(in_dim)) * lr_mul + self.lr_mul = lr_mul + + def forward(self, input): + + if self.activation: + out = F.linear(input, self.weight * self.scale) + out = fused_leaky_relu(out, self.bias * self.lr_mul) + else: + out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul) + + return out + + def __repr__(self): + return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})') + + +class ConvLayer(nn.Sequential): + def __init__( + self, + in_channel, + out_channel, + kernel_size, + downsample=False, + blur_kernel=[1, 3, 3, 1], + bias=True, + activate=True, + ): + layers = [] + + if downsample: + factor = 2 + p = (len(blur_kernel) - factor) + (kernel_size - 1) + pad0 = (p + 1) // 2 + pad1 = p // 2 + + layers.append(Blur(blur_kernel, pad=(pad0, pad1))) + + stride = 2 + self.padding = 0 + + else: + stride = 1 + self.padding = kernel_size // 2 + + layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride, + bias=bias and not activate)) + + if activate: + if bias: + layers.append(FusedLeakyReLU(out_channel)) + else: + layers.append(ScaledLeakyReLU(0.2)) + + super().__init__(*layers) + + +class ResBlock(nn.Module): + def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): + super().__init__() + + self.conv1 = ConvLayer(in_channel, in_channel, 3) + self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) + + self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False) + + def forward(self, input): + out = self.conv1(input) + out = self.conv2(out) + + skip = self.skip(input) + out = (out + skip) / math.sqrt(2) + + return out + + +class EncoderApp(nn.Module): + def __init__(self, size, w_dim=512): + super(EncoderApp, self).__init__() + + channels = { + 4: 512, + 8: 512, + 16: 512, + 32: 512, + 64: 256, + 128: 128, + 256: 64, + 512: 32, + 1024: 16 + } + + self.w_dim = w_dim + log_size = int(math.log(size, 2)) + + self.convs = nn.ModuleList() + self.convs.append(ConvLayer(3, channels[size], 1)) + + in_channel = channels[size] + for i in range(log_size, 2, -1): + out_channel = channels[2 ** (i - 1)] + self.convs.append(ResBlock(in_channel, out_channel)) + in_channel = out_channel + + self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False)) + + def forward(self, x): + + res = [] + h = x + for conv in self.convs: + h = conv(h) + res.append(h) + + return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:] + + +class Encoder(nn.Module): + def __init__(self, size, dim=512, dim_motion=20): + super(Encoder, self).__init__() + + # appearance netmork + self.net_app = EncoderApp(size, dim) + + # motion network + fc = [EqualLinear(dim, dim)] + for i in range(3): + fc.append(EqualLinear(dim, dim)) + + fc.append(EqualLinear(dim, dim_motion)) + self.fc = nn.Sequential(*fc) + + def enc_app(self, x): + h_source = self.net_app(x) + return h_source + + def enc_motion(self, x): + h, _ = self.net_app(x) + h_motion = self.fc(h) + return h_motion + + +class Direction(nn.Module): + def __init__(self, motion_dim): + super(Direction, self).__init__() + self.weight = nn.Parameter(torch.randn(512, motion_dim)) + + def forward(self, input): + + weight = self.weight + 1e-8 + Q, R = custom_qr(weight) + if input is None: + return Q + else: + input_diag = torch.diag_embed(input) # alpha, diagonal matrix + out = torch.matmul(input_diag, Q.T) + out = torch.sum(out, dim=1) + return out + + +class Synthesis(nn.Module): + def __init__(self, motion_dim): + super(Synthesis, self).__init__() + self.direction = Direction(motion_dim) + + +class Generator(nn.Module): + def __init__(self, size, style_dim=512, motion_dim=20): + super().__init__() + + self.enc = Encoder(size, style_dim, motion_dim) + self.dec = Synthesis(motion_dim) + + def get_motion(self, img): + #motion_feat = self.enc.enc_motion(img) + motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True) + motion = self.dec.direction(motion_feat) + return motion + + +class WanAnimateAdapter(torch.nn.Module): + def __init__(self): + super().__init__() + self.pose_patch_embedding = torch.nn.Conv3d(16, 5120, kernel_size=(1, 2, 2), stride=(1, 2, 2)) + self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20) + self.face_adapter = FaceAdapter(heads_num=40, hidden_dim=5120, num_adapter_layers=40 // 5) + self.face_encoder = FaceEncoder(in_dim=512, hidden_dim=5120, num_heads=4) + + def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values): + pose_latents = self.pose_patch_embedding(pose_latents) + x[:, :, 1:] += pose_latents + + b,c,T,h,w = face_pixel_values.shape + face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w") + + encode_bs = 8 + face_pixel_values_tmp = [] + for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)): + face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs])) + + motion_vec = torch.cat(face_pixel_values_tmp) + + motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T) + motion_vec = self.face_encoder(motion_vec) + + B, L, H, C = motion_vec.shape + pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec) + motion_vec = torch.cat([pad_face, motion_vec], dim=1) + return x, motion_vec + + def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None): + if block_idx % 5 == 0: + adapter_args = [x, motion_vec, motion_masks, False] + residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args) + x = residual_out + x + return x + + @staticmethod + def state_dict_converter(): + return WanAnimateAdapterStateDictConverter() + + +class WanAnimateAdapterStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + state_dict_ = {} + for name, param in state_dict.items(): + if name.startswith("pose_patch_embedding.") or name.startswith("face_adapter") or name.startswith("face_encoder") or name.startswith("motion_encoder"): + state_dict_[name] = param + return state_dict_ + diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 1a54728..1c8543e 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -342,9 +342,7 @@ class WanModel(torch.nn.Module): y_camera = self.control_adapter(control_camera_latents_input) x = [u + v for u, v in zip(x, y_camera)] x = x[0].unsqueeze(0) - grid_size = x.shape[2:] - x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() - return x, grid_size # x, grid_size: (f, h, w) + return x def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor): return rearrange( @@ -496,6 +494,7 @@ class WanModelStateDictConverter: def from_civitai(self, state_dict): state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")} + state_dict = {name: param for name, param in state_dict.items() if name.split(".")[0] not in ["pose_patch_embedding", "face_adapter", "face_encoder", "motion_encoder"]} if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814": config = { "has_image_input": False, @@ -552,20 +551,6 @@ class WanModelStateDictConverter: "num_layers": 30, "eps": 1e-6 } - elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e": - config = { - "has_image_input": True, - "patch_size": [1, 2, 2], - "in_dim": 36, - "dim": 5120, - "ffn_dim": 13824, - "freq_dim": 256, - "text_dim": 4096, - "out_dim": 16, - "num_heads": 40, - "num_layers": 40, - "eps": 1e-6 - } elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677": # 1.3B PAI control config = { diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 8d447e8..4895247 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -21,6 +21,7 @@ from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample from ..models.wan_video_image_encoder import WanImageEncoder from ..models.wan_video_vace import VaceWanModel from ..models.wan_video_motion_controller import WanMotionControllerModel +from ..models.wan_video_animate_adapter import WanAnimateAdapter from ..schedulers.flow_match import FlowMatchScheduler from ..prompters import WanPrompter from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm @@ -45,8 +46,9 @@ class WanVideoPipeline(BasePipeline): self.motion_controller: WanMotionControllerModel = None self.vace: VaceWanModel = None self.vace2: VaceWanModel = None - self.in_iteration_models = ("dit", "motion_controller", "vace") - self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2") + self.animate_adapter: WanAnimateAdapter = None + self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter") + self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter") self.unit_runner = PipelineUnitRunner() self.units = [ WanVideoUnit_ShapeChecker(), @@ -62,6 +64,10 @@ class WanVideoPipeline(BasePipeline): WanVideoUnit_FunCameraControl(), WanVideoUnit_SpeedControl(), WanVideoUnit_VACE(), + WanVideoPostUnit_AnimateVideoSplit(), + WanVideoPostUnit_AnimatePoseLatents(), + WanVideoPostUnit_AnimateFacePixelValues(), + WanVideoPostUnit_AnimateInpaint(), WanVideoUnit_UnifiedSequenceParallel(), WanVideoUnit_TeaCache(), WanVideoUnit_CfgMerger(), @@ -70,13 +76,34 @@ class WanVideoPipeline(BasePipeline): WanVideoPostUnit_S2V(), ] self.model_fn = model_fn_wan_video - - def load_lora(self, module, path, alpha=1): - loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device) - lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device) - loader.load(module, lora, alpha=alpha) - + def load_lora( + self, + module: torch.nn.Module, + lora_config: Union[ModelConfig, str] = None, + alpha=1, + hotload=False, + state_dict=None, + ): + if state_dict is None: + if isinstance(lora_config, str): + lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device) + else: + lora_config.download_if_necessary() + lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device) + else: + lora = state_dict + if hotload: + for name, module in module.named_modules(): + if isinstance(module, AutoWrappedLinear): + lora_a_name = f'{name}.lora_A.default.weight' + lora_b_name = f'{name}.lora_B.default.weight' + if lora_a_name in lora and lora_b_name in lora: + module.lora_A_weights.append(lora[lora_a_name] * alpha) + module.lora_B_weights.append(lora[lora_b_name]) + else: + loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device) + loader.load(module, lora, alpha=alpha) def training_loss(self, **inputs): max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps) @@ -359,12 +386,13 @@ class WanVideoPipeline(BasePipeline): pipe.vae = model_manager.fetch_model("wan_video_vae") pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder") pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller") - pipe.vace = model_manager.fetch_model("wan_video_vace") + vace = model_manager.fetch_model("wan_video_vace", index=2) if isinstance(vace, list): pipe.vace, pipe.vace2 = vace else: pipe.vace = vace pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder") + pipe.animate_adapter = model_manager.fetch_model("wan_video_animate_adapter") # Size division factor if pipe.vae is not None: @@ -417,6 +445,11 @@ class WanVideoPipeline(BasePipeline): vace_video_mask: Optional[Image.Image] = None, vace_reference_image: Optional[Image.Image] = None, vace_scale: Optional[float] = 1.0, + # Animate + animate_pose_video: Optional[list[Image.Image]] = None, + animate_face_video: Optional[list[Image.Image]] = None, + animate_inpaint_video: Optional[list[Image.Image]] = None, + animate_mask_video: Optional[list[Image.Image]] = None, # Randomness seed: Optional[int] = None, rand_device: Optional[str] = "cpu", @@ -474,6 +507,7 @@ class WanVideoPipeline(BasePipeline): "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, "sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride, "input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video, + "animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) @@ -508,7 +542,7 @@ class WanVideoPipeline(BasePipeline): inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"] # VACE (TODO: remove it) - if vace_reference_image is not None: + if vace_reference_image is not None or (animate_pose_video is not None and animate_face_video is not None): inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:] # post-denoising, pre-decoding processing logic for unit in self.post_units: @@ -1021,6 +1055,95 @@ class WanVideoPostUnit_S2V(PipelineUnit): return {"latents": latents} +class WanVideoPostUnit_AnimateVideoSplit(PipelineUnit): + def __init__(self): + super().__init__(input_params=("input_video", "animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video")) + + def process(self, pipe: WanVideoPipeline, input_video, animate_pose_video, animate_face_video, animate_inpaint_video, animate_mask_video): + if input_video is None: + return {} + if animate_pose_video is not None: + animate_pose_video = animate_pose_video[:len(input_video) - 4] + if animate_face_video is not None: + animate_face_video = animate_face_video[:len(input_video) - 4] + if animate_inpaint_video is not None: + animate_inpaint_video = animate_inpaint_video[:len(input_video) - 4] + if animate_mask_video is not None: + animate_mask_video = animate_mask_video[:len(input_video) - 4] + return {"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video} + + +class WanVideoPostUnit_AnimatePoseLatents(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("animate_pose_video", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae",) + ) + + def process(self, pipe: WanVideoPipeline, animate_pose_video, tiled, tile_size, tile_stride): + if animate_pose_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + animate_pose_video = pipe.preprocess_video(animate_pose_video) + pose_latents = pipe.vae.encode(animate_pose_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + return {"pose_latents": pose_latents} + + +class WanVideoPostUnit_AnimateFacePixelValues(PipelineUnit): + def __init__(self): + super().__init__(take_over=True) + + def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega): + if inputs_shared.get("animate_face_video", None) is None: + return {} + inputs_posi["face_pixel_values"] = pipe.preprocess_video(inputs_shared["animate_face_video"]) + inputs_nega["face_pixel_values"] = torch.zeros_like(inputs_posi["face_pixel_values"]) - 1 + return inputs_shared, inputs_posi, inputs_nega + + +class WanVideoPostUnit_AnimateInpaint(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("animate_inpaint_video", "animate_mask_video", "input_image", "tiled", "tile_size", "tile_stride"), + onload_model_names=("vae",) + ) + + def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"): + if mask_pixel_values is None: + msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device) + else: + msk = mask_pixel_values.clone() + msk[:, :mask_len] = 1 + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) + msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w) + msk = msk.transpose(1, 2)[0] + return msk + + def process(self, pipe: WanVideoPipeline, animate_inpaint_video, animate_mask_video, input_image, tiled, tile_size, tile_stride): + if animate_inpaint_video is None or animate_mask_video is None: + return {} + pipe.load_models_to_device(self.onload_model_names) + + bg_pixel_values = pipe.preprocess_video(animate_inpaint_video) + y_reft = pipe.vae.encode(bg_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0].to(dtype=pipe.torch_dtype, device=pipe.device) + _, lat_t, lat_h, lat_w = y_reft.shape + + ref_pixel_values = pipe.preprocess_video([input_image]) + ref_latents = pipe.vae.encode(ref_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) + mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=pipe.device) + y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=pipe.device) + + mask_pixel_values = 1 - pipe.preprocess_video(animate_mask_video, max_value=1, min_value=0) + mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w") + mask_pixel_values = torch.nn.functional.interpolate(mask_pixel_values, size=(lat_h, lat_w), mode='nearest') + mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0] + msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, 0, mask_pixel_values=mask_pixel_values, device=pipe.device) + + y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=pipe.device) + y = torch.concat([y_ref, y_reft], dim=1).unsqueeze(0) + return {"y": y} + + class TeaCache: def __init__(self, num_inference_steps, rel_l1_thresh, model_id): self.num_inference_steps = num_inference_steps @@ -1131,6 +1254,7 @@ def model_fn_wan_video( dit: WanModel, motion_controller: WanMotionControllerModel = None, vace: VaceWanModel = None, + animate_adapter: WanAnimateAdapter = None, latents: torch.Tensor = None, timestep: torch.Tensor = None, context: torch.Tensor = None, @@ -1146,6 +1270,8 @@ def model_fn_wan_video( tea_cache: TeaCache = None, use_unified_sequence_parallel: bool = False, motion_bucket_id: Optional[torch.Tensor] = None, + pose_latents=None, + face_pixel_values=None, sliding_window_size: Optional[int] = None, sliding_window_stride: Optional[int] = None, cfg_merge: bool = False, @@ -1236,9 +1362,16 @@ def model_fn_wan_video( if clip_feature is not None and dit.require_clip_embedding: clip_embdding = dit.img_emb(clip_feature) context = torch.cat([clip_embdding, context], dim=1) - - # Add camera control - x, (f, h, w) = dit.patchify(x, control_camera_latents_input) + + # Camera control + x = dit.patchify(x, control_camera_latents_input) + + # Animate + x, motion_vec = animate_adapter.after_patch_embedding(x, pose_latents, face_pixel_values) + + # Patchify + f, h, w = x.shape[2:] + x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous() # Reference image if reference_latents is not None: @@ -1283,6 +1416,7 @@ def model_fn_wan_video( return custom_forward for block_id, block in enumerate(dit.blocks): + # Block if use_gradient_checkpointing_offload: with torch.autograd.graph.save_on_cpu(): x = torch.utils.checkpoint.checkpoint( @@ -1298,12 +1432,18 @@ def model_fn_wan_video( ) else: x = block(x, context, t_mod, freqs) + + # VACE if vace_context is not None and block_id in vace.vace_layers_mapping: current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]] if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1: current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()] current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0) x = x + current_vace_hint * vace_scale + + # Animate + if pose_latents is not None and face_pixel_values is not None: + x = animate_adapter.after_transformer_block(block_id, x, motion_vec) if tea_cache is not None: tea_cache.store(x) diff --git a/diffsynth/trainers/unified_dataset.py b/diffsynth/trainers/unified_dataset.py index 4e94ab8..0083f44 100644 --- a/diffsynth/trainers/unified_dataset.py +++ b/diffsynth/trainers/unified_dataset.py @@ -316,7 +316,7 @@ class UnifiedDataset(torch.utils.data.Dataset): for key in self.data_file_keys: if key in data: if key in self.special_operator_map: - data[key] = self.special_operator_map[key] + data[key] = self.special_operator_map[key](data[key]) elif key in self.data_file_keys: data[key] = self.main_data_operator(data[key]) return data diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index a893c0a..924becd 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -48,6 +48,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) | Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | |-|-|-|-|-|-|-| +|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./model_inference/Wan2.2-Animate-14B.py)|[code](./model_training/full/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./model_training/lora/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_lora/Wan2.2-Animate-14B.py)| |[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-| |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)| |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)| diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md index 2ec2c48..fa238c8 100644 --- a/examples/wanvideo/README_zh.md +++ b/examples/wanvideo/README_zh.md @@ -48,6 +48,7 @@ save_video(video, "video1.mp4", fps=15, quality=5) |模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-| +|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./model_inference/Wan2.2-Animate-14B.py)|[code](./model_training/full/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./model_training/lora/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_lora/Wan2.2-Animate-14B.py)| |[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-| |[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)| |[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)| diff --git a/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py b/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py new file mode 100644 index 0000000..436eb17 --- /dev/null +++ b/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py @@ -0,0 +1,62 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig +from modelscope import dataset_snapshot_download, snapshot_download + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.enable_vram_management() + +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern="data/examples/wan/animate/*", +) + +# Animate +input_image = Image.open("data/examples/wan/animate/animate_input_image.png") +animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4").raw_data()[:81-4] +animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4").raw_data()[:81-4] +video = pipe( + prompt="视频中的人在做动作", + seed=0, tiled=True, + input_image=input_image, + animate_pose_video=animate_pose_video, + animate_face_video=animate_face_video, + num_frames=81, height=720, width=1280, + num_inference_steps=20, cfg_scale=1, +) +save_video(video, "video1.mp4", fps=15, quality=5) + +# Replace +snapshot_download("Wan-AI/Wan2.2-Animate-14B", allow_file_pattern="relighting_lora.ckpt", local_dir="models/Wan-AI/Wan2.2-Animate-14B") +lora_state_dict = load_state_dict("models/Wan-AI/Wan2.2-Animate-14B/relighting_lora.ckpt", torch_dtype=torch.float32, device="cuda")["state_dict"] +pipe.load_lora(pipe.dit, state_dict=lora_state_dict) +input_image = Image.open("data/examples/wan/animate/replace_input_image.png") +animate_pose_video = VideoData("data/examples/wan/animate/replace_pose_video.mp4").raw_data()[:81-4] +animate_face_video = VideoData("data/examples/wan/animate/replace_face_video.mp4").raw_data()[:81-4] +animate_inpaint_video = VideoData("data/examples/wan/animate/replace_inpaint_video.mp4").raw_data()[:81-4] +animate_mask_video = VideoData("data/examples/wan/animate/replace_mask_video.mp4").raw_data()[:81-4] +video = pipe( + prompt="视频中的人在做动作", + seed=0, tiled=True, + input_image=input_image, + animate_pose_video=animate_pose_video, + animate_face_video=animate_face_video, + animate_inpaint_video=animate_inpaint_video, + animate_mask_video=animate_mask_video, + num_frames=81, height=720, width=1280, + num_inference_steps=20, cfg_scale=1, +) +save_video(video, "video2.mp4", fps=15, quality=5) + diff --git a/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh b/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh new file mode 100644 index 0000000..ab09a78 --- /dev/null +++ b/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh @@ -0,0 +1,16 @@ +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_animate.csv \ + --data_file_keys "video,animate_pose_video,animate_face_video" \ + --height 480 \ + --width 832 \ + --num_frames 81 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-Animate-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-Animate-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-Animate-14B:Wan2.1_VAE.pth,Wan-AI/Wan2.2-Animate-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.animate_adapter." \ + --output_path "./models/train/Wan2.2-Animate-14B_full" \ + --trainable_models "animate_adapter" \ + --extra_inputs "input_image,animate_pose_video,animate_face_video" \ + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh b/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh new file mode 100644 index 0000000..0b6e571 --- /dev/null +++ b/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh @@ -0,0 +1,20 @@ +# 1*80G GPU cannot train Wan2.2-Animate-14B LoRA +# We tested on 8*80G GPUs +accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \ + --dataset_base_path data/example_video_dataset \ + --dataset_metadata_path data/example_video_dataset/metadata_animate.csv \ + --data_file_keys "video,animate_pose_video,animate_face_video" \ + --height 480 \ + --width 832 \ + --num_frames 81 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "Wan-AI/Wan2.2-Animate-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-Animate-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-Animate-14B:Wan2.1_VAE.pth,Wan-AI/Wan2.2-Animate-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Wan2.2-Animate-14B_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "q,k,v,o,ffn.0,ffn.2" \ + --lora_rank 32 \ + --extra_inputs "input_image,animate_pose_video,animate_face_video" \ + --use_gradient_checkpointing_offload \ No newline at end of file diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 37494e7..f31ad69 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -2,7 +2,7 @@ import torch, os, json from diffsynth import load_state_dict from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser -from diffsynth.trainers.unified_dataset import UnifiedDataset +from diffsynth.trainers.unified_dataset import UnifiedDataset, LoadVideo, ImageCropAndResize, ToAbsolutePath os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -108,6 +108,9 @@ if __name__ == "__main__": time_division_factor=4, time_division_remainder=1, ), + special_operator_map={ + "animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)) + } ) model = WanTrainingModule( model_paths=args.model_paths, diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py new file mode 100644 index 0000000..d6fbfc1 --- /dev/null +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py @@ -0,0 +1,33 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +state_dict = load_state_dict("models/train/Wan2.2-Animate-14B_full/epoch-1.safetensors") +pipe.animate_adapter.load_state_dict(state_dict, strict=False) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/animate/animate_output.mp4", height=480, width=832)[0] +animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4", height=480, width=832).raw_data()[:81-4] +animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4", height=512, width=512).raw_data()[:81-4] +video = pipe( + prompt="视频中的人在做动作", + seed=0, tiled=True, + input_image=input_image, + animate_pose_video=animate_pose_video, + animate_face_video=animate_face_video, + num_frames=81, height=480, width=832, + num_inference_steps=20, cfg_scale=1, +) +save_video(video, "video_Wan2.2-Animate-14B.mp4", fps=15, quality=5) \ No newline at end of file diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py b/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py new file mode 100644 index 0000000..9f6d7c4 --- /dev/null +++ b/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py @@ -0,0 +1,32 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, load_state_dict +from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig + + +pipe = WanVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"), + ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"), + ], +) +pipe.load_lora(pipe.dit, "models/train/Wan2.2-Animate-14B_lora/epoch-4.safetensors", alpha=1) +pipe.enable_vram_management() + +input_image = VideoData("data/example_video_dataset/animate/animate_output.mp4", height=480, width=832)[0] +animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4", height=480, width=832).raw_data()[:81-4] +animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4", height=512, width=512).raw_data()[:81-4] +video = pipe( + prompt="视频中的人在做动作", + seed=0, tiled=True, + input_image=input_image, + animate_pose_video=animate_pose_video, + animate_face_video=animate_face_video, + num_frames=81, height=480, width=832, + num_inference_steps=20, cfg_scale=1, +) +save_video(video, "video_Wan2.2-Animate-14B.mp4", fps=15, quality=5) \ No newline at end of file