mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support wan2.2-animate-14b
This commit is contained in:
@@ -207,6 +207,7 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
|||||||
|
|
||||||
| Model ID | Extra Parameters | Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training |
|
| Model ID | Extra Parameters | Inference | Full Training | Validate After Full Training | LoRA Training | Validate After LoRA Training |
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
||||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
||||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||||
|
|||||||
@@ -207,6 +207,7 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
|||||||
|
|
||||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
||||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
||||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||||
|
|||||||
@@ -63,6 +63,7 @@ from ..models.wan_video_vae import WanVideoVAE, WanVideoVAE38
|
|||||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||||
from ..models.wan_video_vace import VaceWanModel
|
from ..models.wan_video_vace import VaceWanModel
|
||||||
from ..models.wav2vec import WanS2VAudioEncoder
|
from ..models.wav2vec import WanS2VAudioEncoder
|
||||||
|
from ..models.wan_video_animate_adapter import WanAnimateAdapter
|
||||||
|
|
||||||
from ..models.step1x_connector import Qwen2Connector
|
from ..models.step1x_connector import Qwen2Connector
|
||||||
|
|
||||||
@@ -142,7 +143,6 @@ model_loader_configs = [
|
|||||||
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "aafcfd9672c3a2456dc46e1cb6e52c70", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "6d6ccde6845b95ad9114ab993d917893", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
|
||||||
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
|
(None, "3ef3b1f8e1dab83d5b71fd7b617f859f", ["wan_video_dit"], [WanModel], "civitai"),
|
||||||
@@ -176,6 +176,7 @@ model_loader_configs = [
|
|||||||
(None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
|
(None, "073bce9cf969e317e5662cd570c3e79c", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
|
||||||
(None, "a9e54e480a628f0b956a688a81c33bab", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
|
(None, "a9e54e480a628f0b956a688a81c33bab", ["qwen_image_blockwise_controlnet"], [QwenImageBlockWiseControlNet], "civitai"),
|
||||||
(None, "06be60f3a4526586d8431cd038a71486", ["wans2v_audio_encoder"], [WanS2VAudioEncoder], "civitai"),
|
(None, "06be60f3a4526586d8431cd038a71486", ["wans2v_audio_encoder"], [WanS2VAudioEncoder], "civitai"),
|
||||||
|
(None, "31fa352acb8a1b1d33cd8764273d80a2", ["wan_video_dit", "wan_video_animate_adapter"], [WanModel, WanAnimateAdapter], "civitai"),
|
||||||
]
|
]
|
||||||
huggingface_model_loader_configs = [
|
huggingface_model_loader_configs = [
|
||||||
# These configs are provided for detecting model type automatically.
|
# These configs are provided for detecting model type automatically.
|
||||||
|
|||||||
670
diffsynth/models/wan_video_animate_adapter.py
Normal file
670
diffsynth/models/wan_video_animate_adapter.py
Normal file
@@ -0,0 +1,670 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import functional as F
|
||||||
|
import math
|
||||||
|
from typing import Tuple, Optional, List
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
MEMORY_LAYOUT = {
|
||||||
|
"flash": (
|
||||||
|
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
||||||
|
lambda x: x,
|
||||||
|
),
|
||||||
|
"torch": (
|
||||||
|
lambda x: x.transpose(1, 2),
|
||||||
|
lambda x: x.transpose(1, 2),
|
||||||
|
),
|
||||||
|
"vanilla": (
|
||||||
|
lambda x: x.transpose(1, 2),
|
||||||
|
lambda x: x.transpose(1, 2),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
mode="torch",
|
||||||
|
drop_rate=0,
|
||||||
|
attn_mask=None,
|
||||||
|
causal=False,
|
||||||
|
max_seqlen_q=None,
|
||||||
|
batch_size=1,
|
||||||
|
):
|
||||||
|
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
||||||
|
|
||||||
|
if mode == "torch":
|
||||||
|
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
||||||
|
attn_mask = attn_mask.to(q.dtype)
|
||||||
|
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
|
||||||
|
|
||||||
|
x = post_attn_layout(x)
|
||||||
|
b, s, a, d = x.shape
|
||||||
|
out = x.reshape(b, s, -1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv1d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.pad_mode = pad_mode
|
||||||
|
padding = (kernel_size - 1, 0) # T
|
||||||
|
self.time_causal_padding = padding
|
||||||
|
|
||||||
|
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FaceEncoder(nn.Module):
|
||||||
|
def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
|
||||||
|
factory_kwargs = {"dtype": dtype, "device": device}
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
|
||||||
|
self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
|
||||||
|
self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
|
||||||
|
|
||||||
|
self.out_proj = nn.Linear(1024, hidden_dim)
|
||||||
|
self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
|
||||||
|
self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
|
||||||
|
self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
|
||||||
|
self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
x = rearrange(x, "b t c -> b c t")
|
||||||
|
b, c, t = x.shape
|
||||||
|
|
||||||
|
x = self.conv1_local(x)
|
||||||
|
x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
|
||||||
|
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = rearrange(x, "b t c -> b c t")
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = rearrange(x, "b c t -> b t c")
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = rearrange(x, "b t c -> b c t")
|
||||||
|
x = self.conv3(x)
|
||||||
|
x = rearrange(x, "b c t -> b t c")
|
||||||
|
x = self.norm3(x)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.out_proj(x)
|
||||||
|
x = rearrange(x, "(b n) t c -> b t n c", b=b)
|
||||||
|
padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
|
||||||
|
x = torch.cat([x, padding], dim=-2)
|
||||||
|
x_local = x.clone()
|
||||||
|
|
||||||
|
return x_local
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class RMSNorm(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
elementwise_affine=True,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the RMSNorm normalization layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): The dimension of the input tensor.
|
||||||
|
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
eps (float): A small value added to the denominator for numerical stability.
|
||||||
|
weight (nn.Parameter): Learnable scaling parameter.
|
||||||
|
|
||||||
|
"""
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
if elementwise_affine:
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
||||||
|
|
||||||
|
def _norm(self, x):
|
||||||
|
"""
|
||||||
|
Apply the RMSNorm normalization to the input tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The normalized tensor.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass through the RMSNorm layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The output tensor after applying RMSNorm.
|
||||||
|
|
||||||
|
"""
|
||||||
|
output = self._norm(x.float()).type_as(x)
|
||||||
|
if hasattr(self, "weight"):
|
||||||
|
output = output * self.weight
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def get_norm_layer(norm_layer):
|
||||||
|
"""
|
||||||
|
Get the normalization layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
norm_layer (str): The type of normalization layer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
norm_layer (nn.Module): The normalization layer.
|
||||||
|
"""
|
||||||
|
if norm_layer == "layer":
|
||||||
|
return nn.LayerNorm
|
||||||
|
elif norm_layer == "rms":
|
||||||
|
return RMSNorm
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
||||||
|
|
||||||
|
|
||||||
|
class FaceAdapter(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_dim: int,
|
||||||
|
heads_num: int,
|
||||||
|
qk_norm: bool = True,
|
||||||
|
qk_norm_type: str = "rms",
|
||||||
|
num_adapter_layers: int = 1,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
factory_kwargs = {"dtype": dtype, "device": device}
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_dim
|
||||||
|
self.heads_num = heads_num
|
||||||
|
self.fuser_blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
FaceBlock(
|
||||||
|
self.hidden_size,
|
||||||
|
self.heads_num,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
qk_norm_type=qk_norm_type,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
for _ in range(num_adapter_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
motion_embed: torch.Tensor,
|
||||||
|
idx: int,
|
||||||
|
freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
|
||||||
|
freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class FaceBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
heads_num: int,
|
||||||
|
qk_norm: bool = True,
|
||||||
|
qk_norm_type: str = "rms",
|
||||||
|
qk_scale: float = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
):
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.deterministic = False
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.heads_num = heads_num
|
||||||
|
head_dim = hidden_size // heads_num
|
||||||
|
self.scale = qk_scale or head_dim**-0.5
|
||||||
|
|
||||||
|
self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
|
||||||
|
self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
||||||
|
|
||||||
|
self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
|
||||||
|
|
||||||
|
qk_norm_layer = get_norm_layer(qk_norm_type)
|
||||||
|
self.q_norm = (
|
||||||
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
||||||
|
)
|
||||||
|
self.k_norm = (
|
||||||
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
|
||||||
|
self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
motion_vec: torch.Tensor,
|
||||||
|
motion_mask: Optional[torch.Tensor] = None,
|
||||||
|
use_context_parallel=False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
B, T, N, C = motion_vec.shape
|
||||||
|
T_comp = T
|
||||||
|
|
||||||
|
x_motion = self.pre_norm_motion(motion_vec)
|
||||||
|
x_feat = self.pre_norm_feat(x)
|
||||||
|
|
||||||
|
kv = self.linear1_kv(x_motion)
|
||||||
|
q = self.linear1_q(x_feat)
|
||||||
|
|
||||||
|
k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
|
||||||
|
q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
|
||||||
|
|
||||||
|
# Apply QK-Norm if needed.
|
||||||
|
q = self.q_norm(q).to(v)
|
||||||
|
k = self.k_norm(k).to(v)
|
||||||
|
|
||||||
|
k = rearrange(k, "B L N H D -> (B L) H N D")
|
||||||
|
v = rearrange(v, "B L N H D -> (B L) H N D")
|
||||||
|
|
||||||
|
q = rearrange(q, "B (L S) H D -> (B L) H S D", L=T_comp)
|
||||||
|
# Compute attention.
|
||||||
|
attn = F.scaled_dot_product_attention(q, k, v)
|
||||||
|
|
||||||
|
attn = rearrange(attn, "(B L) H S D -> B (L S) (H D)", L=T_comp)
|
||||||
|
|
||||||
|
output = self.linear2(attn)
|
||||||
|
|
||||||
|
if motion_mask is not None:
|
||||||
|
output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def custom_qr(input_tensor):
|
||||||
|
original_dtype = input_tensor.dtype
|
||||||
|
if original_dtype == torch.bfloat16:
|
||||||
|
q, r = torch.linalg.qr(input_tensor.to(torch.float32))
|
||||||
|
return q.to(original_dtype), r.to(original_dtype)
|
||||||
|
return torch.linalg.qr(input_tensor)
|
||||||
|
|
||||||
|
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
||||||
|
return F.leaky_relu(input + bias, negative_slope) * scale
|
||||||
|
|
||||||
|
|
||||||
|
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
|
||||||
|
_, minor, in_h, in_w = input.shape
|
||||||
|
kernel_h, kernel_w = kernel.shape
|
||||||
|
|
||||||
|
out = input.view(-1, minor, in_h, 1, in_w, 1)
|
||||||
|
out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
|
||||||
|
out = out.view(-1, minor, in_h * up_y, in_w * up_x)
|
||||||
|
|
||||||
|
out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
|
||||||
|
out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
|
||||||
|
max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
|
||||||
|
|
||||||
|
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
|
||||||
|
w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
|
||||||
|
out = F.conv2d(out, w)
|
||||||
|
out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
|
||||||
|
in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
|
||||||
|
return out[:, :, ::down_y, ::down_x]
|
||||||
|
|
||||||
|
|
||||||
|
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||||
|
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
|
||||||
|
|
||||||
|
|
||||||
|
def make_kernel(k):
|
||||||
|
k = torch.tensor(k, dtype=torch.float32)
|
||||||
|
if k.ndim == 1:
|
||||||
|
k = k[None, :] * k[:, None]
|
||||||
|
k /= k.sum()
|
||||||
|
return k
|
||||||
|
|
||||||
|
|
||||||
|
class FusedLeakyReLU(nn.Module):
|
||||||
|
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
|
||||||
|
super().__init__()
|
||||||
|
self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
|
||||||
|
self.negative_slope = negative_slope
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Blur(nn.Module):
|
||||||
|
def __init__(self, kernel, pad, upsample_factor=1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
kernel = make_kernel(kernel)
|
||||||
|
|
||||||
|
if upsample_factor > 1:
|
||||||
|
kernel = kernel * (upsample_factor ** 2)
|
||||||
|
|
||||||
|
self.register_buffer('kernel', kernel)
|
||||||
|
|
||||||
|
self.pad = pad
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return upfirdn2d(input, self.kernel, pad=self.pad)
|
||||||
|
|
||||||
|
|
||||||
|
class ScaledLeakyReLU(nn.Module):
|
||||||
|
def __init__(self, negative_slope=0.2):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.negative_slope = negative_slope
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return F.leaky_relu(input, negative_slope=self.negative_slope)
|
||||||
|
|
||||||
|
|
||||||
|
class EqualConv2d(nn.Module):
|
||||||
|
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
|
||||||
|
self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
|
||||||
|
|
||||||
|
self.stride = stride
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.zeros(out_channel))
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
|
||||||
|
return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
|
||||||
|
f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EqualLinear(nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
self.activation = activation
|
||||||
|
|
||||||
|
self.scale = (1 / math.sqrt(in_dim)) * lr_mul
|
||||||
|
self.lr_mul = lr_mul
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
|
||||||
|
if self.activation:
|
||||||
|
out = F.linear(input, self.weight * self.scale)
|
||||||
|
out = fused_leaky_relu(out, self.bias * self.lr_mul)
|
||||||
|
else:
|
||||||
|
out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
|
||||||
|
|
||||||
|
|
||||||
|
class ConvLayer(nn.Sequential):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channel,
|
||||||
|
out_channel,
|
||||||
|
kernel_size,
|
||||||
|
downsample=False,
|
||||||
|
blur_kernel=[1, 3, 3, 1],
|
||||||
|
bias=True,
|
||||||
|
activate=True,
|
||||||
|
):
|
||||||
|
layers = []
|
||||||
|
|
||||||
|
if downsample:
|
||||||
|
factor = 2
|
||||||
|
p = (len(blur_kernel) - factor) + (kernel_size - 1)
|
||||||
|
pad0 = (p + 1) // 2
|
||||||
|
pad1 = p // 2
|
||||||
|
|
||||||
|
layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
|
||||||
|
|
||||||
|
stride = 2
|
||||||
|
self.padding = 0
|
||||||
|
|
||||||
|
else:
|
||||||
|
stride = 1
|
||||||
|
self.padding = kernel_size // 2
|
||||||
|
|
||||||
|
layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
|
||||||
|
bias=bias and not activate))
|
||||||
|
|
||||||
|
if activate:
|
||||||
|
if bias:
|
||||||
|
layers.append(FusedLeakyReLU(out_channel))
|
||||||
|
else:
|
||||||
|
layers.append(ScaledLeakyReLU(0.2))
|
||||||
|
|
||||||
|
super().__init__(*layers)
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(nn.Module):
|
||||||
|
def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.conv1 = ConvLayer(in_channel, in_channel, 3)
|
||||||
|
self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
|
||||||
|
|
||||||
|
self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
out = self.conv1(input)
|
||||||
|
out = self.conv2(out)
|
||||||
|
|
||||||
|
skip = self.skip(input)
|
||||||
|
out = (out + skip) / math.sqrt(2)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderApp(nn.Module):
|
||||||
|
def __init__(self, size, w_dim=512):
|
||||||
|
super(EncoderApp, self).__init__()
|
||||||
|
|
||||||
|
channels = {
|
||||||
|
4: 512,
|
||||||
|
8: 512,
|
||||||
|
16: 512,
|
||||||
|
32: 512,
|
||||||
|
64: 256,
|
||||||
|
128: 128,
|
||||||
|
256: 64,
|
||||||
|
512: 32,
|
||||||
|
1024: 16
|
||||||
|
}
|
||||||
|
|
||||||
|
self.w_dim = w_dim
|
||||||
|
log_size = int(math.log(size, 2))
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList()
|
||||||
|
self.convs.append(ConvLayer(3, channels[size], 1))
|
||||||
|
|
||||||
|
in_channel = channels[size]
|
||||||
|
for i in range(log_size, 2, -1):
|
||||||
|
out_channel = channels[2 ** (i - 1)]
|
||||||
|
self.convs.append(ResBlock(in_channel, out_channel))
|
||||||
|
in_channel = out_channel
|
||||||
|
|
||||||
|
self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
res = []
|
||||||
|
h = x
|
||||||
|
for conv in self.convs:
|
||||||
|
h = conv(h)
|
||||||
|
res.append(h)
|
||||||
|
|
||||||
|
return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Module):
|
||||||
|
def __init__(self, size, dim=512, dim_motion=20):
|
||||||
|
super(Encoder, self).__init__()
|
||||||
|
|
||||||
|
# appearance netmork
|
||||||
|
self.net_app = EncoderApp(size, dim)
|
||||||
|
|
||||||
|
# motion network
|
||||||
|
fc = [EqualLinear(dim, dim)]
|
||||||
|
for i in range(3):
|
||||||
|
fc.append(EqualLinear(dim, dim))
|
||||||
|
|
||||||
|
fc.append(EqualLinear(dim, dim_motion))
|
||||||
|
self.fc = nn.Sequential(*fc)
|
||||||
|
|
||||||
|
def enc_app(self, x):
|
||||||
|
h_source = self.net_app(x)
|
||||||
|
return h_source
|
||||||
|
|
||||||
|
def enc_motion(self, x):
|
||||||
|
h, _ = self.net_app(x)
|
||||||
|
h_motion = self.fc(h)
|
||||||
|
return h_motion
|
||||||
|
|
||||||
|
|
||||||
|
class Direction(nn.Module):
|
||||||
|
def __init__(self, motion_dim):
|
||||||
|
super(Direction, self).__init__()
|
||||||
|
self.weight = nn.Parameter(torch.randn(512, motion_dim))
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
|
||||||
|
weight = self.weight + 1e-8
|
||||||
|
Q, R = custom_qr(weight)
|
||||||
|
if input is None:
|
||||||
|
return Q
|
||||||
|
else:
|
||||||
|
input_diag = torch.diag_embed(input) # alpha, diagonal matrix
|
||||||
|
out = torch.matmul(input_diag, Q.T)
|
||||||
|
out = torch.sum(out, dim=1)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class Synthesis(nn.Module):
|
||||||
|
def __init__(self, motion_dim):
|
||||||
|
super(Synthesis, self).__init__()
|
||||||
|
self.direction = Direction(motion_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class Generator(nn.Module):
|
||||||
|
def __init__(self, size, style_dim=512, motion_dim=20):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.enc = Encoder(size, style_dim, motion_dim)
|
||||||
|
self.dec = Synthesis(motion_dim)
|
||||||
|
|
||||||
|
def get_motion(self, img):
|
||||||
|
#motion_feat = self.enc.enc_motion(img)
|
||||||
|
motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
|
||||||
|
motion = self.dec.direction(motion_feat)
|
||||||
|
return motion
|
||||||
|
|
||||||
|
|
||||||
|
class WanAnimateAdapter(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.pose_patch_embedding = torch.nn.Conv3d(16, 5120, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
||||||
|
self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)
|
||||||
|
self.face_adapter = FaceAdapter(heads_num=40, hidden_dim=5120, num_adapter_layers=40 // 5)
|
||||||
|
self.face_encoder = FaceEncoder(in_dim=512, hidden_dim=5120, num_heads=4)
|
||||||
|
|
||||||
|
def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
|
||||||
|
pose_latents = self.pose_patch_embedding(pose_latents)
|
||||||
|
x[:, :, 1:] += pose_latents
|
||||||
|
|
||||||
|
b,c,T,h,w = face_pixel_values.shape
|
||||||
|
face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
|
||||||
|
|
||||||
|
encode_bs = 8
|
||||||
|
face_pixel_values_tmp = []
|
||||||
|
for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
|
||||||
|
face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
|
||||||
|
|
||||||
|
motion_vec = torch.cat(face_pixel_values_tmp)
|
||||||
|
|
||||||
|
motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
|
||||||
|
motion_vec = self.face_encoder(motion_vec)
|
||||||
|
|
||||||
|
B, L, H, C = motion_vec.shape
|
||||||
|
pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
|
||||||
|
motion_vec = torch.cat([pad_face, motion_vec], dim=1)
|
||||||
|
return x, motion_vec
|
||||||
|
|
||||||
|
def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None):
|
||||||
|
if block_idx % 5 == 0:
|
||||||
|
adapter_args = [x, motion_vec, motion_masks, False]
|
||||||
|
residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args)
|
||||||
|
x = residual_out + x
|
||||||
|
return x
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return WanAnimateAdapterStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class WanAnimateAdapterStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name.startswith("pose_patch_embedding.") or name.startswith("face_adapter") or name.startswith("face_encoder") or name.startswith("motion_encoder"):
|
||||||
|
state_dict_[name] = param
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
@@ -342,9 +342,7 @@ class WanModel(torch.nn.Module):
|
|||||||
y_camera = self.control_adapter(control_camera_latents_input)
|
y_camera = self.control_adapter(control_camera_latents_input)
|
||||||
x = [u + v for u, v in zip(x, y_camera)]
|
x = [u + v for u, v in zip(x, y_camera)]
|
||||||
x = x[0].unsqueeze(0)
|
x = x[0].unsqueeze(0)
|
||||||
grid_size = x.shape[2:]
|
return x
|
||||||
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
|
||||||
return x, grid_size # x, grid_size: (f, h, w)
|
|
||||||
|
|
||||||
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
||||||
return rearrange(
|
return rearrange(
|
||||||
@@ -496,6 +494,7 @@ class WanModelStateDictConverter:
|
|||||||
|
|
||||||
def from_civitai(self, state_dict):
|
def from_civitai(self, state_dict):
|
||||||
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
|
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
|
||||||
|
state_dict = {name: param for name, param in state_dict.items() if name.split(".")[0] not in ["pose_patch_embedding", "face_adapter", "face_encoder", "motion_encoder"]}
|
||||||
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
||||||
config = {
|
config = {
|
||||||
"has_image_input": False,
|
"has_image_input": False,
|
||||||
@@ -552,20 +551,6 @@ class WanModelStateDictConverter:
|
|||||||
"num_layers": 30,
|
"num_layers": 30,
|
||||||
"eps": 1e-6
|
"eps": 1e-6
|
||||||
}
|
}
|
||||||
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
|
||||||
config = {
|
|
||||||
"has_image_input": True,
|
|
||||||
"patch_size": [1, 2, 2],
|
|
||||||
"in_dim": 36,
|
|
||||||
"dim": 5120,
|
|
||||||
"ffn_dim": 13824,
|
|
||||||
"freq_dim": 256,
|
|
||||||
"text_dim": 4096,
|
|
||||||
"out_dim": 16,
|
|
||||||
"num_heads": 40,
|
|
||||||
"num_layers": 40,
|
|
||||||
"eps": 1e-6
|
|
||||||
}
|
|
||||||
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
||||||
# 1.3B PAI control
|
# 1.3B PAI control
|
||||||
config = {
|
config = {
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from ..models.wan_video_vae import WanVideoVAE, RMS_norm, CausalConv3d, Upsample
|
|||||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||||
from ..models.wan_video_vace import VaceWanModel
|
from ..models.wan_video_vace import VaceWanModel
|
||||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||||
|
from ..models.wan_video_animate_adapter import WanAnimateAdapter
|
||||||
from ..schedulers.flow_match import FlowMatchScheduler
|
from ..schedulers.flow_match import FlowMatchScheduler
|
||||||
from ..prompters import WanPrompter
|
from ..prompters import WanPrompter
|
||||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
|
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear, WanAutoCastLayerNorm
|
||||||
@@ -45,8 +46,9 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
self.motion_controller: WanMotionControllerModel = None
|
self.motion_controller: WanMotionControllerModel = None
|
||||||
self.vace: VaceWanModel = None
|
self.vace: VaceWanModel = None
|
||||||
self.vace2: VaceWanModel = None
|
self.vace2: VaceWanModel = None
|
||||||
self.in_iteration_models = ("dit", "motion_controller", "vace")
|
self.animate_adapter: WanAnimateAdapter = None
|
||||||
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2")
|
self.in_iteration_models = ("dit", "motion_controller", "vace", "animate_adapter")
|
||||||
|
self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2", "animate_adapter")
|
||||||
self.unit_runner = PipelineUnitRunner()
|
self.unit_runner = PipelineUnitRunner()
|
||||||
self.units = [
|
self.units = [
|
||||||
WanVideoUnit_ShapeChecker(),
|
WanVideoUnit_ShapeChecker(),
|
||||||
@@ -62,6 +64,10 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
WanVideoUnit_FunCameraControl(),
|
WanVideoUnit_FunCameraControl(),
|
||||||
WanVideoUnit_SpeedControl(),
|
WanVideoUnit_SpeedControl(),
|
||||||
WanVideoUnit_VACE(),
|
WanVideoUnit_VACE(),
|
||||||
|
WanVideoPostUnit_AnimateVideoSplit(),
|
||||||
|
WanVideoPostUnit_AnimatePoseLatents(),
|
||||||
|
WanVideoPostUnit_AnimateFacePixelValues(),
|
||||||
|
WanVideoPostUnit_AnimateInpaint(),
|
||||||
WanVideoUnit_UnifiedSequenceParallel(),
|
WanVideoUnit_UnifiedSequenceParallel(),
|
||||||
WanVideoUnit_TeaCache(),
|
WanVideoUnit_TeaCache(),
|
||||||
WanVideoUnit_CfgMerger(),
|
WanVideoUnit_CfgMerger(),
|
||||||
@@ -70,13 +76,34 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
WanVideoPostUnit_S2V(),
|
WanVideoPostUnit_S2V(),
|
||||||
]
|
]
|
||||||
self.model_fn = model_fn_wan_video
|
self.model_fn = model_fn_wan_video
|
||||||
|
|
||||||
|
|
||||||
def load_lora(self, module, path, alpha=1):
|
def load_lora(
|
||||||
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
self,
|
||||||
lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device)
|
module: torch.nn.Module,
|
||||||
loader.load(module, lora, alpha=alpha)
|
lora_config: Union[ModelConfig, str] = None,
|
||||||
|
alpha=1,
|
||||||
|
hotload=False,
|
||||||
|
state_dict=None,
|
||||||
|
):
|
||||||
|
if state_dict is None:
|
||||||
|
if isinstance(lora_config, str):
|
||||||
|
lora = load_state_dict(lora_config, torch_dtype=self.torch_dtype, device=self.device)
|
||||||
|
else:
|
||||||
|
lora_config.download_if_necessary()
|
||||||
|
lora = load_state_dict(lora_config.path, torch_dtype=self.torch_dtype, device=self.device)
|
||||||
|
else:
|
||||||
|
lora = state_dict
|
||||||
|
if hotload:
|
||||||
|
for name, module in module.named_modules():
|
||||||
|
if isinstance(module, AutoWrappedLinear):
|
||||||
|
lora_a_name = f'{name}.lora_A.default.weight'
|
||||||
|
lora_b_name = f'{name}.lora_B.default.weight'
|
||||||
|
if lora_a_name in lora and lora_b_name in lora:
|
||||||
|
module.lora_A_weights.append(lora[lora_a_name] * alpha)
|
||||||
|
module.lora_B_weights.append(lora[lora_b_name])
|
||||||
|
else:
|
||||||
|
loader = GeneralLoRALoader(torch_dtype=self.torch_dtype, device=self.device)
|
||||||
|
loader.load(module, lora, alpha=alpha)
|
||||||
|
|
||||||
def training_loss(self, **inputs):
|
def training_loss(self, **inputs):
|
||||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps)
|
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * self.scheduler.num_train_timesteps)
|
||||||
@@ -359,12 +386,13 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
pipe.vae = model_manager.fetch_model("wan_video_vae")
|
||||||
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||||
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||||
pipe.vace = model_manager.fetch_model("wan_video_vace")
|
vace = model_manager.fetch_model("wan_video_vace", index=2)
|
||||||
if isinstance(vace, list):
|
if isinstance(vace, list):
|
||||||
pipe.vace, pipe.vace2 = vace
|
pipe.vace, pipe.vace2 = vace
|
||||||
else:
|
else:
|
||||||
pipe.vace = vace
|
pipe.vace = vace
|
||||||
pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder")
|
pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder")
|
||||||
|
pipe.animate_adapter = model_manager.fetch_model("wan_video_animate_adapter")
|
||||||
|
|
||||||
# Size division factor
|
# Size division factor
|
||||||
if pipe.vae is not None:
|
if pipe.vae is not None:
|
||||||
@@ -417,6 +445,11 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
vace_video_mask: Optional[Image.Image] = None,
|
vace_video_mask: Optional[Image.Image] = None,
|
||||||
vace_reference_image: Optional[Image.Image] = None,
|
vace_reference_image: Optional[Image.Image] = None,
|
||||||
vace_scale: Optional[float] = 1.0,
|
vace_scale: Optional[float] = 1.0,
|
||||||
|
# Animate
|
||||||
|
animate_pose_video: Optional[list[Image.Image]] = None,
|
||||||
|
animate_face_video: Optional[list[Image.Image]] = None,
|
||||||
|
animate_inpaint_video: Optional[list[Image.Image]] = None,
|
||||||
|
animate_mask_video: Optional[list[Image.Image]] = None,
|
||||||
# Randomness
|
# Randomness
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
rand_device: Optional[str] = "cpu",
|
rand_device: Optional[str] = "cpu",
|
||||||
@@ -474,6 +507,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||||
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
|
"sliding_window_size": sliding_window_size, "sliding_window_stride": sliding_window_stride,
|
||||||
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
|
"input_audio": input_audio, "audio_sample_rate": audio_sample_rate, "s2v_pose_video": s2v_pose_video, "audio_embeds": audio_embeds, "s2v_pose_latents": s2v_pose_latents, "motion_video": motion_video,
|
||||||
|
"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video,
|
||||||
}
|
}
|
||||||
for unit in self.units:
|
for unit in self.units:
|
||||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
@@ -508,7 +542,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"]
|
inputs_shared["latents"][:, :, 0:1] = inputs_shared["first_frame_latents"]
|
||||||
|
|
||||||
# VACE (TODO: remove it)
|
# VACE (TODO: remove it)
|
||||||
if vace_reference_image is not None:
|
if vace_reference_image is not None or (animate_pose_video is not None and animate_face_video is not None):
|
||||||
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
|
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
|
||||||
# post-denoising, pre-decoding processing logic
|
# post-denoising, pre-decoding processing logic
|
||||||
for unit in self.post_units:
|
for unit in self.post_units:
|
||||||
@@ -1021,6 +1055,95 @@ class WanVideoPostUnit_S2V(PipelineUnit):
|
|||||||
return {"latents": latents}
|
return {"latents": latents}
|
||||||
|
|
||||||
|
|
||||||
|
class WanVideoPostUnit_AnimateVideoSplit(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(input_params=("input_video", "animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video"))
|
||||||
|
|
||||||
|
def process(self, pipe: WanVideoPipeline, input_video, animate_pose_video, animate_face_video, animate_inpaint_video, animate_mask_video):
|
||||||
|
if input_video is None:
|
||||||
|
return {}
|
||||||
|
if animate_pose_video is not None:
|
||||||
|
animate_pose_video = animate_pose_video[:len(input_video) - 4]
|
||||||
|
if animate_face_video is not None:
|
||||||
|
animate_face_video = animate_face_video[:len(input_video) - 4]
|
||||||
|
if animate_inpaint_video is not None:
|
||||||
|
animate_inpaint_video = animate_inpaint_video[:len(input_video) - 4]
|
||||||
|
if animate_mask_video is not None:
|
||||||
|
animate_mask_video = animate_mask_video[:len(input_video) - 4]
|
||||||
|
return {"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video}
|
||||||
|
|
||||||
|
|
||||||
|
class WanVideoPostUnit_AnimatePoseLatents(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("animate_pose_video", "tiled", "tile_size", "tile_stride"),
|
||||||
|
onload_model_names=("vae",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: WanVideoPipeline, animate_pose_video, tiled, tile_size, tile_stride):
|
||||||
|
if animate_pose_video is None:
|
||||||
|
return {}
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
animate_pose_video = pipe.preprocess_video(animate_pose_video)
|
||||||
|
pose_latents = pipe.vae.encode(animate_pose_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
return {"pose_latents": pose_latents}
|
||||||
|
|
||||||
|
|
||||||
|
class WanVideoPostUnit_AnimateFacePixelValues(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(take_over=True)
|
||||||
|
|
||||||
|
def process(self, pipe: WanVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||||
|
if inputs_shared.get("animate_face_video", None) is None:
|
||||||
|
return {}
|
||||||
|
inputs_posi["face_pixel_values"] = pipe.preprocess_video(inputs_shared["animate_face_video"])
|
||||||
|
inputs_nega["face_pixel_values"] = torch.zeros_like(inputs_posi["face_pixel_values"]) - 1
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
|
|
||||||
|
class WanVideoPostUnit_AnimateInpaint(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("animate_inpaint_video", "animate_mask_video", "input_image", "tiled", "tile_size", "tile_stride"),
|
||||||
|
onload_model_names=("vae",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
|
||||||
|
if mask_pixel_values is None:
|
||||||
|
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
|
||||||
|
else:
|
||||||
|
msk = mask_pixel_values.clone()
|
||||||
|
msk[:, :mask_len] = 1
|
||||||
|
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||||
|
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
|
||||||
|
msk = msk.transpose(1, 2)[0]
|
||||||
|
return msk
|
||||||
|
|
||||||
|
def process(self, pipe: WanVideoPipeline, animate_inpaint_video, animate_mask_video, input_image, tiled, tile_size, tile_stride):
|
||||||
|
if animate_inpaint_video is None or animate_mask_video is None:
|
||||||
|
return {}
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
|
||||||
|
bg_pixel_values = pipe.preprocess_video(animate_inpaint_video)
|
||||||
|
y_reft = pipe.vae.encode(bg_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
_, lat_t, lat_h, lat_w = y_reft.shape
|
||||||
|
|
||||||
|
ref_pixel_values = pipe.preprocess_video([input_image])
|
||||||
|
ref_latents = pipe.vae.encode(ref_pixel_values, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
mask_ref = self.get_i2v_mask(1, lat_h, lat_w, 1, device=pipe.device)
|
||||||
|
y_ref = torch.concat([mask_ref, ref_latents[0]]).to(dtype=torch.bfloat16, device=pipe.device)
|
||||||
|
|
||||||
|
mask_pixel_values = 1 - pipe.preprocess_video(animate_mask_video, max_value=1, min_value=0)
|
||||||
|
mask_pixel_values = rearrange(mask_pixel_values, "b c t h w -> (b t) c h w")
|
||||||
|
mask_pixel_values = torch.nn.functional.interpolate(mask_pixel_values, size=(lat_h, lat_w), mode='nearest')
|
||||||
|
mask_pixel_values = rearrange(mask_pixel_values, "(b t) c h w -> b t c h w", b=1)[:,:,0]
|
||||||
|
msk_reft = self.get_i2v_mask(lat_t, lat_h, lat_w, 0, mask_pixel_values=mask_pixel_values, device=pipe.device)
|
||||||
|
|
||||||
|
y_reft = torch.concat([msk_reft, y_reft]).to(dtype=torch.bfloat16, device=pipe.device)
|
||||||
|
y = torch.concat([y_ref, y_reft], dim=1).unsqueeze(0)
|
||||||
|
return {"y": y}
|
||||||
|
|
||||||
|
|
||||||
class TeaCache:
|
class TeaCache:
|
||||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
@@ -1131,6 +1254,7 @@ def model_fn_wan_video(
|
|||||||
dit: WanModel,
|
dit: WanModel,
|
||||||
motion_controller: WanMotionControllerModel = None,
|
motion_controller: WanMotionControllerModel = None,
|
||||||
vace: VaceWanModel = None,
|
vace: VaceWanModel = None,
|
||||||
|
animate_adapter: WanAnimateAdapter = None,
|
||||||
latents: torch.Tensor = None,
|
latents: torch.Tensor = None,
|
||||||
timestep: torch.Tensor = None,
|
timestep: torch.Tensor = None,
|
||||||
context: torch.Tensor = None,
|
context: torch.Tensor = None,
|
||||||
@@ -1146,6 +1270,8 @@ def model_fn_wan_video(
|
|||||||
tea_cache: TeaCache = None,
|
tea_cache: TeaCache = None,
|
||||||
use_unified_sequence_parallel: bool = False,
|
use_unified_sequence_parallel: bool = False,
|
||||||
motion_bucket_id: Optional[torch.Tensor] = None,
|
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||||
|
pose_latents=None,
|
||||||
|
face_pixel_values=None,
|
||||||
sliding_window_size: Optional[int] = None,
|
sliding_window_size: Optional[int] = None,
|
||||||
sliding_window_stride: Optional[int] = None,
|
sliding_window_stride: Optional[int] = None,
|
||||||
cfg_merge: bool = False,
|
cfg_merge: bool = False,
|
||||||
@@ -1236,9 +1362,16 @@ def model_fn_wan_video(
|
|||||||
if clip_feature is not None and dit.require_clip_embedding:
|
if clip_feature is not None and dit.require_clip_embedding:
|
||||||
clip_embdding = dit.img_emb(clip_feature)
|
clip_embdding = dit.img_emb(clip_feature)
|
||||||
context = torch.cat([clip_embdding, context], dim=1)
|
context = torch.cat([clip_embdding, context], dim=1)
|
||||||
|
|
||||||
# Add camera control
|
# Camera control
|
||||||
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
|
x = dit.patchify(x, control_camera_latents_input)
|
||||||
|
|
||||||
|
# Animate
|
||||||
|
x, motion_vec = animate_adapter.after_patch_embedding(x, pose_latents, face_pixel_values)
|
||||||
|
|
||||||
|
# Patchify
|
||||||
|
f, h, w = x.shape[2:]
|
||||||
|
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
||||||
|
|
||||||
# Reference image
|
# Reference image
|
||||||
if reference_latents is not None:
|
if reference_latents is not None:
|
||||||
@@ -1283,6 +1416,7 @@ def model_fn_wan_video(
|
|||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
for block_id, block in enumerate(dit.blocks):
|
for block_id, block in enumerate(dit.blocks):
|
||||||
|
# Block
|
||||||
if use_gradient_checkpointing_offload:
|
if use_gradient_checkpointing_offload:
|
||||||
with torch.autograd.graph.save_on_cpu():
|
with torch.autograd.graph.save_on_cpu():
|
||||||
x = torch.utils.checkpoint.checkpoint(
|
x = torch.utils.checkpoint.checkpoint(
|
||||||
@@ -1298,12 +1432,18 @@ def model_fn_wan_video(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
x = block(x, context, t_mod, freqs)
|
x = block(x, context, t_mod, freqs)
|
||||||
|
|
||||||
|
# VACE
|
||||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||||
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
current_vace_hint = vace_hints[vace.vace_layers_mapping[block_id]]
|
||||||
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
if use_unified_sequence_parallel and dist.is_initialized() and dist.get_world_size() > 1:
|
||||||
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
current_vace_hint = torch.chunk(current_vace_hint, get_sequence_parallel_world_size(), dim=1)[get_sequence_parallel_rank()]
|
||||||
current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
|
current_vace_hint = torch.nn.functional.pad(current_vace_hint, (0, 0, 0, chunks[0].shape[1] - current_vace_hint.shape[1]), value=0)
|
||||||
x = x + current_vace_hint * vace_scale
|
x = x + current_vace_hint * vace_scale
|
||||||
|
|
||||||
|
# Animate
|
||||||
|
if pose_latents is not None and face_pixel_values is not None:
|
||||||
|
x = animate_adapter.after_transformer_block(block_id, x, motion_vec)
|
||||||
if tea_cache is not None:
|
if tea_cache is not None:
|
||||||
tea_cache.store(x)
|
tea_cache.store(x)
|
||||||
|
|
||||||
|
|||||||
@@ -316,7 +316,7 @@ class UnifiedDataset(torch.utils.data.Dataset):
|
|||||||
for key in self.data_file_keys:
|
for key in self.data_file_keys:
|
||||||
if key in data:
|
if key in data:
|
||||||
if key in self.special_operator_map:
|
if key in self.special_operator_map:
|
||||||
data[key] = self.special_operator_map[key]
|
data[key] = self.special_operator_map[key](data[key])
|
||||||
elif key in self.data_file_keys:
|
elif key in self.data_file_keys:
|
||||||
data[key] = self.main_data_operator(data[key])
|
data[key] = self.main_data_operator(data[key])
|
||||||
return data
|
return data
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
|||||||
|
|
||||||
| Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
| Model ID | Extra Parameters | Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./model_inference/Wan2.2-Animate-14B.py)|[code](./model_training/full/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./model_training/lora/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
||||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
||||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ save_video(video, "video1.mp4", fps=15, quality=5)
|
|||||||
|
|
||||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](./model_inference/Wan2.2-Animate-14B.py)|[code](./model_training/full/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_full/Wan2.2-Animate-14B.py)|[code](./model_training/lora/Wan2.2-Animate-14B.sh)|[code](./model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
||||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](./model_inference/Wan2.2-S2V-14B_multi_clips.py)|-|-|-|-|
|
||||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](./model_inference/Wan2.2-I2V-A14B.py)|[code](./model_training/full/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](./model_training/lora/Wan2.2-I2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](./model_inference/Wan2.2-T2V-A14B.py)|[code](./model_training/full/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](./model_training/lora/Wan2.2-T2V-A14B.sh)|[code](./model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||||
|
|||||||
62
examples/wanvideo/model_inference/Wan2.2-Animate-14B.py
Normal file
62
examples/wanvideo/model_inference/Wan2.2-Animate-14B.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from diffsynth import save_video, VideoData, load_state_dict
|
||||||
|
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||||
|
from modelscope import dataset_snapshot_download, snapshot_download
|
||||||
|
|
||||||
|
|
||||||
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pipe.enable_vram_management()
|
||||||
|
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||||
|
local_dir="./",
|
||||||
|
allow_file_pattern="data/examples/wan/animate/*",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Animate
|
||||||
|
input_image = Image.open("data/examples/wan/animate/animate_input_image.png")
|
||||||
|
animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4").raw_data()[:81-4]
|
||||||
|
animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4").raw_data()[:81-4]
|
||||||
|
video = pipe(
|
||||||
|
prompt="视频中的人在做动作",
|
||||||
|
seed=0, tiled=True,
|
||||||
|
input_image=input_image,
|
||||||
|
animate_pose_video=animate_pose_video,
|
||||||
|
animate_face_video=animate_face_video,
|
||||||
|
num_frames=81, height=720, width=1280,
|
||||||
|
num_inference_steps=20, cfg_scale=1,
|
||||||
|
)
|
||||||
|
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||||
|
|
||||||
|
# Replace
|
||||||
|
snapshot_download("Wan-AI/Wan2.2-Animate-14B", allow_file_pattern="relighting_lora.ckpt", local_dir="models/Wan-AI/Wan2.2-Animate-14B")
|
||||||
|
lora_state_dict = load_state_dict("models/Wan-AI/Wan2.2-Animate-14B/relighting_lora.ckpt", torch_dtype=torch.float32, device="cuda")["state_dict"]
|
||||||
|
pipe.load_lora(pipe.dit, state_dict=lora_state_dict)
|
||||||
|
input_image = Image.open("data/examples/wan/animate/replace_input_image.png")
|
||||||
|
animate_pose_video = VideoData("data/examples/wan/animate/replace_pose_video.mp4").raw_data()[:81-4]
|
||||||
|
animate_face_video = VideoData("data/examples/wan/animate/replace_face_video.mp4").raw_data()[:81-4]
|
||||||
|
animate_inpaint_video = VideoData("data/examples/wan/animate/replace_inpaint_video.mp4").raw_data()[:81-4]
|
||||||
|
animate_mask_video = VideoData("data/examples/wan/animate/replace_mask_video.mp4").raw_data()[:81-4]
|
||||||
|
video = pipe(
|
||||||
|
prompt="视频中的人在做动作",
|
||||||
|
seed=0, tiled=True,
|
||||||
|
input_image=input_image,
|
||||||
|
animate_pose_video=animate_pose_video,
|
||||||
|
animate_face_video=animate_face_video,
|
||||||
|
animate_inpaint_video=animate_inpaint_video,
|
||||||
|
animate_mask_video=animate_mask_video,
|
||||||
|
num_frames=81, height=720, width=1280,
|
||||||
|
num_inference_steps=20, cfg_scale=1,
|
||||||
|
)
|
||||||
|
save_video(video, "video2.mp4", fps=15, quality=5)
|
||||||
|
|
||||||
16
examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh
Normal file
16
examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_video_dataset \
|
||||||
|
--dataset_metadata_path data/example_video_dataset/metadata_animate.csv \
|
||||||
|
--data_file_keys "video,animate_pose_video,animate_face_video" \
|
||||||
|
--height 480 \
|
||||||
|
--width 832 \
|
||||||
|
--num_frames 81 \
|
||||||
|
--dataset_repeat 100 \
|
||||||
|
--model_id_with_origin_paths "Wan-AI/Wan2.2-Animate-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-Animate-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-Animate-14B:Wan2.1_VAE.pth,Wan-AI/Wan2.2-Animate-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_epochs 2 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.animate_adapter." \
|
||||||
|
--output_path "./models/train/Wan2.2-Animate-14B_full" \
|
||||||
|
--trainable_models "animate_adapter" \
|
||||||
|
--extra_inputs "input_image,animate_pose_video,animate_face_video" \
|
||||||
|
--use_gradient_checkpointing_offload
|
||||||
20
examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh
Normal file
20
examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
# 1*80G GPU cannot train Wan2.2-Animate-14B LoRA
|
||||||
|
# We tested on 8*80G GPUs
|
||||||
|
accelerate launch --config_file examples/wanvideo/model_training/full/accelerate_config_14B.yaml examples/wanvideo/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_video_dataset \
|
||||||
|
--dataset_metadata_path data/example_video_dataset/metadata_animate.csv \
|
||||||
|
--data_file_keys "video,animate_pose_video,animate_face_video" \
|
||||||
|
--height 480 \
|
||||||
|
--width 832 \
|
||||||
|
--num_frames 81 \
|
||||||
|
--dataset_repeat 100 \
|
||||||
|
--model_id_with_origin_paths "Wan-AI/Wan2.2-Animate-14B:diffusion_pytorch_model*.safetensors,Wan-AI/Wan2.2-Animate-14B:models_t5_umt5-xxl-enc-bf16.pth,Wan-AI/Wan2.2-Animate-14B:Wan2.1_VAE.pth,Wan-AI/Wan2.2-Animate-14B:models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Wan2.2-Animate-14B_lora" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "q,k,v,o,ffn.0,ffn.2" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--extra_inputs "input_image,animate_pose_video,animate_face_video" \
|
||||||
|
--use_gradient_checkpointing_offload
|
||||||
@@ -2,7 +2,7 @@ import torch, os, json
|
|||||||
from diffsynth import load_state_dict
|
from diffsynth import load_state_dict
|
||||||
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||||
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser
|
from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser
|
||||||
from diffsynth.trainers.unified_dataset import UnifiedDataset
|
from diffsynth.trainers.unified_dataset import UnifiedDataset, LoadVideo, ImageCropAndResize, ToAbsolutePath
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
@@ -108,6 +108,9 @@ if __name__ == "__main__":
|
|||||||
time_division_factor=4,
|
time_division_factor=4,
|
||||||
time_division_remainder=1,
|
time_division_remainder=1,
|
||||||
),
|
),
|
||||||
|
special_operator_map={
|
||||||
|
"animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16))
|
||||||
|
}
|
||||||
)
|
)
|
||||||
model = WanTrainingModule(
|
model = WanTrainingModule(
|
||||||
model_paths=args.model_paths,
|
model_paths=args.model_paths,
|
||||||
|
|||||||
@@ -0,0 +1,33 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from diffsynth import save_video, VideoData, load_state_dict
|
||||||
|
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
state_dict = load_state_dict("models/train/Wan2.2-Animate-14B_full/epoch-1.safetensors")
|
||||||
|
pipe.animate_adapter.load_state_dict(state_dict, strict=False)
|
||||||
|
pipe.enable_vram_management()
|
||||||
|
|
||||||
|
input_image = VideoData("data/example_video_dataset/animate/animate_output.mp4", height=480, width=832)[0]
|
||||||
|
animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4", height=480, width=832).raw_data()[:81-4]
|
||||||
|
animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4", height=512, width=512).raw_data()[:81-4]
|
||||||
|
video = pipe(
|
||||||
|
prompt="视频中的人在做动作",
|
||||||
|
seed=0, tiled=True,
|
||||||
|
input_image=input_image,
|
||||||
|
animate_pose_video=animate_pose_video,
|
||||||
|
animate_face_video=animate_face_video,
|
||||||
|
num_frames=81, height=480, width=832,
|
||||||
|
num_inference_steps=20, cfg_scale=1,
|
||||||
|
)
|
||||||
|
save_video(video, "video_Wan2.2-Animate-14B.mp4", fps=15, quality=5)
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from diffsynth import save_video, VideoData, load_state_dict
|
||||||
|
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
|
||||||
|
|
||||||
|
|
||||||
|
pipe = WanVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
|
||||||
|
ModelConfig(model_id="Wan-AI/Wan2.2-Animate-14B", origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", offload_device="cpu"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pipe.load_lora(pipe.dit, "models/train/Wan2.2-Animate-14B_lora/epoch-4.safetensors", alpha=1)
|
||||||
|
pipe.enable_vram_management()
|
||||||
|
|
||||||
|
input_image = VideoData("data/example_video_dataset/animate/animate_output.mp4", height=480, width=832)[0]
|
||||||
|
animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4", height=480, width=832).raw_data()[:81-4]
|
||||||
|
animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4", height=512, width=512).raw_data()[:81-4]
|
||||||
|
video = pipe(
|
||||||
|
prompt="视频中的人在做动作",
|
||||||
|
seed=0, tiled=True,
|
||||||
|
input_image=input_image,
|
||||||
|
animate_pose_video=animate_pose_video,
|
||||||
|
animate_face_video=animate_face_video,
|
||||||
|
num_frames=81, height=480, width=832,
|
||||||
|
num_inference_steps=20, cfg_scale=1,
|
||||||
|
)
|
||||||
|
save_video(video, "video_Wan2.2-Animate-14B.mp4", fps=15, quality=5)
|
||||||
Reference in New Issue
Block a user