Merge pull request #533 from modelscope/wan-vace

vace
This commit is contained in:
Zhongjie Duan
2025-04-15 18:47:36 +08:00
committed by GitHub
6 changed files with 243 additions and 23 deletions

View File

@@ -60,6 +60,7 @@ from ..models.wan_video_text_encoder import WanTextEncoder
from ..models.wan_video_image_encoder import WanImageEncoder from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vae import WanVideoVAE from ..models.wan_video_vae import WanVideoVAE
from ..models.wan_video_motion_controller import WanMotionControllerModel from ..models.wan_video_motion_controller import WanMotionControllerModel
from ..models.wan_video_vace import VaceWanModel
model_loader_configs = [ model_loader_configs = [
@@ -125,6 +126,7 @@ model_loader_configs = [
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"), (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"), (None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"), (None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"), (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"), (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),

View File

@@ -451,6 +451,7 @@ class WanModelStateDictConverter:
return state_dict_, config return state_dict_, config
def from_civitai(self, state_dict): def from_civitai(self, state_dict):
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814": if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
config = { config = {
"has_image_input": False, "has_image_input": False,

View File

@@ -0,0 +1,77 @@
import torch
from .wan_video_dit import DiTBlock
class VaceWanAttentionBlock(DiTBlock):
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
self.block_id = block_id
if block_id == 0:
self.before_proj = torch.nn.Linear(self.dim, self.dim)
self.after_proj = torch.nn.Linear(self.dim, self.dim)
def forward(self, c, x, context, t_mod, freqs):
if self.block_id == 0:
c = self.before_proj(c) + x
all_c = []
else:
all_c = list(torch.unbind(c))
c = all_c.pop(-1)
c = super().forward(c, context, t_mod, freqs)
c_skip = self.after_proj(c)
all_c += [c_skip, c]
c = torch.stack(all_c)
return c
class VaceWanModel(torch.nn.Module):
def __init__(
self,
vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
vace_in_dim=96,
patch_size=(1, 2, 2),
has_image_input=False,
dim=1536,
num_heads=12,
ffn_dim=8960,
eps=1e-6,
):
super().__init__()
self.vace_layers = vace_layers
self.vace_in_dim = vace_in_dim
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
# vace blocks
self.vace_blocks = torch.nn.ModuleList([
VaceWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i)
for i in self.vace_layers
])
# vace patch embeddings
self.vace_patch_embedding = torch.nn.Conv3d(vace_in_dim, dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x, vace_context, context, t_mod, freqs):
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
c = [u.flatten(2).transpose(1, 2) for u in c]
c = torch.cat([
torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))],
dim=1) for u in c
])
for block in self.vace_blocks:
c = block(c, x, context, t_mod, freqs)
hints = torch.unbind(c)[:-1]
return hints
@staticmethod
def state_dict_converter():
return VaceWanModelDictConverter()
class VaceWanModelDictConverter:
def __init__(self):
pass
def from_civitai(self, state_dict):
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("vace")}
return state_dict_

View File

@@ -4,6 +4,7 @@ from ..models.wan_video_dit import WanModel
from ..models.wan_video_text_encoder import WanTextEncoder from ..models.wan_video_text_encoder import WanTextEncoder
from ..models.wan_video_vae import WanVideoVAE from ..models.wan_video_vae import WanVideoVAE
from ..models.wan_video_image_encoder import WanImageEncoder from ..models.wan_video_image_encoder import WanImageEncoder
from ..models.wan_video_vace import VaceWanModel
from ..schedulers.flow_match import FlowMatchScheduler from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline from .base import BasePipeline
from ..prompters import WanPrompter from ..prompters import WanPrompter
@@ -33,7 +34,8 @@ class WanVideoPipeline(BasePipeline):
self.dit: WanModel = None self.dit: WanModel = None
self.vae: WanVideoVAE = None self.vae: WanVideoVAE = None
self.motion_controller: WanMotionControllerModel = None self.motion_controller: WanMotionControllerModel = None
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller'] self.vace: VaceWanModel = None
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller', 'vace']
self.height_division_factor = 16 self.height_division_factor = 16
self.width_division_factor = 16 self.width_division_factor = 16
self.use_unified_sequence_parallel = False self.use_unified_sequence_parallel = False
@@ -153,6 +155,7 @@ class WanVideoPipeline(BasePipeline):
self.vae = model_manager.fetch_model("wan_video_vae") self.vae = model_manager.fetch_model("wan_video_vae")
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder") self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
self.motion_controller = model_manager.fetch_model("wan_video_motion_controller") self.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
self.vace = model_manager.fetch_model("wan_video_vace")
@staticmethod @staticmethod
@@ -253,6 +256,57 @@ class WanVideoPipeline(BasePipeline):
def prepare_motion_bucket_id(self, motion_bucket_id): def prepare_motion_bucket_id(self, motion_bucket_id):
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device) motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
return {"motion_bucket_id": motion_bucket_id} return {"motion_bucket_id": motion_bucket_id}
def prepare_vace_kwargs(
self,
latents,
vace_video=None, vace_mask=None, vace_reference_image=None, vace_scale=1.0,
height=480, width=832, num_frames=81,
seed=None, rand_device="cpu",
tiled=True, tile_size=(34, 34), tile_stride=(18, 16)
):
if vace_video is not None or vace_mask is not None or vace_reference_image is not None:
self.load_models_to_device(["vae"])
if vace_video is None:
vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=self.torch_dtype, device=self.device)
else:
vace_video = self.preprocess_images(vace_video)
vace_video = torch.stack(vace_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
if vace_mask is None:
vace_mask = torch.ones_like(vace_video)
else:
vace_mask = self.preprocess_images(vace_mask)
vace_mask = torch.stack(vace_mask, dim=2).to(dtype=self.torch_dtype, device=self.device)
inactive = vace_video * (1 - vace_mask) + 0 * vace_mask
reactive = vace_video * vace_mask + 0 * (1 - vace_mask)
inactive = self.encode_video(inactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
reactive = self.encode_video(reactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
vace_video_latents = torch.concat((inactive, reactive), dim=1)
vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')
if vace_reference_image is None:
pass
else:
vace_reference_image = self.preprocess_images([vace_reference_image])
vace_reference_image = torch.stack(vace_reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device)
vace_reference_latents = self.encode_video(vace_reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
noise = self.generate_noise((1, 16, 1, latents.shape[3], latents.shape[4]), seed=seed, device=rand_device, dtype=torch.float32)
noise = noise.to(dtype=self.torch_dtype, device=self.device)
latents = torch.concat((noise, latents), dim=2)
vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
return latents, {"vace_context": vace_context, "vace_scale": vace_scale}
else:
return latents, {"vace_context": None, "vace_scale": vace_scale}
@torch.no_grad() @torch.no_grad()
@@ -264,6 +318,10 @@ class WanVideoPipeline(BasePipeline):
end_image=None, end_image=None,
input_video=None, input_video=None,
control_video=None, control_video=None,
vace_video=None,
vace_video_mask=None,
vace_reference_image=None,
vace_scale=1.0,
denoising_strength=1.0, denoising_strength=1.0,
seed=None, seed=None,
rand_device="cpu", rand_device="cpu",
@@ -333,6 +391,12 @@ class WanVideoPipeline(BasePipeline):
# Extra input # Extra input
extra_input = self.prepare_extra_input(latents) extra_input = self.prepare_extra_input(latents)
# VACE
latents, vace_kwargs = self.prepare_vace_kwargs(
latents, vace_video, vace_video_mask, vace_reference_image, vace_scale,
height=height, width=width, num_frames=num_frames, seed=seed, rand_device=rand_device, **tiler_kwargs
)
# TeaCache # TeaCache
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
@@ -341,23 +405,23 @@ class WanVideoPipeline(BasePipeline):
usp_kwargs = self.prepare_unified_sequence_parallel() usp_kwargs = self.prepare_unified_sequence_parallel()
# Denoise # Denoise
self.load_models_to_device(["dit", "motion_controller"]) self.load_models_to_device(["dit", "motion_controller", "vace"])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
# Inference # Inference
noise_pred_posi = model_fn_wan_video( noise_pred_posi = model_fn_wan_video(
self.dit, motion_controller=self.motion_controller, self.dit, motion_controller=self.motion_controller, vace=self.vace,
x=latents, timestep=timestep, x=latents, timestep=timestep,
**prompt_emb_posi, **image_emb, **extra_input, **prompt_emb_posi, **image_emb, **extra_input,
**tea_cache_posi, **usp_kwargs, **motion_kwargs **tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs,
) )
if cfg_scale != 1.0: if cfg_scale != 1.0:
noise_pred_nega = model_fn_wan_video( noise_pred_nega = model_fn_wan_video(
self.dit, motion_controller=self.motion_controller, self.dit, motion_controller=self.motion_controller, vace=self.vace,
x=latents, timestep=timestep, x=latents, timestep=timestep,
**prompt_emb_nega, **image_emb, **extra_input, **prompt_emb_nega, **image_emb, **extra_input,
**tea_cache_nega, **usp_kwargs, **motion_kwargs **tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs,
) )
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else: else:
@@ -365,6 +429,9 @@ class WanVideoPipeline(BasePipeline):
# Scheduler # Scheduler
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
if vace_reference_image is not None:
latents = latents[:, :, 1:]
# Decode # Decode
self.load_models_to_device(['vae']) self.load_models_to_device(['vae'])
@@ -432,11 +499,14 @@ class TeaCache:
def model_fn_wan_video( def model_fn_wan_video(
dit: WanModel, dit: WanModel,
motion_controller: WanMotionControllerModel = None, motion_controller: WanMotionControllerModel = None,
vace: VaceWanModel = None,
x: torch.Tensor = None, x: torch.Tensor = None,
timestep: torch.Tensor = None, timestep: torch.Tensor = None,
context: torch.Tensor = None, context: torch.Tensor = None,
clip_feature: Optional[torch.Tensor] = None, clip_feature: Optional[torch.Tensor] = None,
y: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None,
vace_context = None,
vace_scale = 1.0,
tea_cache: TeaCache = None, tea_cache: TeaCache = None,
use_unified_sequence_parallel: bool = False, use_unified_sequence_parallel: bool = False,
motion_bucket_id: Optional[torch.Tensor] = None, motion_bucket_id: Optional[torch.Tensor] = None,
@@ -472,6 +542,9 @@ def model_fn_wan_video(
tea_cache_update = tea_cache.check(dit, x, t_mod) tea_cache_update = tea_cache.check(dit, x, t_mod)
else: else:
tea_cache_update = False tea_cache_update = False
if vace_context is not None:
vace_hints = vace(x, vace_context, context, t_mod, freqs)
# blocks # blocks
if use_unified_sequence_parallel: if use_unified_sequence_parallel:
@@ -480,8 +553,10 @@ def model_fn_wan_video(
if tea_cache_update: if tea_cache_update:
x = tea_cache.update(x) x = tea_cache.update(x)
else: else:
for block in dit.blocks: for block_id, block in enumerate(dit.blocks):
x = block(x, context, t_mod, freqs) x = block(x, context, t_mod, freqs)
if vace_context is not None and block_id in vace.vace_layers_mapping:
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
if tea_cache is not None: if tea_cache is not None:
tea_cache.store(x) tea_cache.store(x)

View File

@@ -26,28 +26,30 @@ pip install -e .
|PAI Team|14B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)| |PAI Team|14B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
|PAI Team|1.3B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|[wan_fun_control.py](./wan_fun_control.py)| |PAI Team|1.3B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
|PAI Team|14B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|[wan_fun_control.py](./wan_fun_control.py)| |PAI Team|14B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
|IIC Team|1.3B VACE|[Link](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|[wan_1.3b_vace.py](./wan_1.3b_vace.py)|
Base model features Base model features
||Text-to-video|Image-to-video|End frame|Control| ||Text-to-video|Image-to-video|End frame|Control|Reference image|
|-|-|-|-|-| |-|-|-|-|-|-|
|1.3B text-to-video|✅|||| |1.3B text-to-video|✅|||||
|14B text-to-video|✅|||| |14B text-to-video|✅|||||
|14B image-to-video 480P||✅||| |14B image-to-video 480P||✅||||
|14B image-to-video 720P||✅||| |14B image-to-video 720P||✅||||
|1.3B InP||✅|✅|| |1.3B InP||✅|✅|||
|14B InP||✅|✅|| |14B InP||✅|✅|||
|1.3B Control||||✅| |1.3B Control||||✅||
|14B Control||||✅| |14B Control||||✅||
|1.3B VACE||||✅|✅|
Adapter model compatibility Adapter model compatibility
||1.3B text-to-video|1.3B InP| ||1.3B text-to-video|1.3B InP|1.3B VACE|
|-|-|-| |-|-|-|-|
|1.3B aesthetics LoRA|✅|| |1.3B aesthetics LoRA|✅||✅|
|1.3B Highres-fix LoRA|✅|| |1.3B Highres-fix LoRA|✅||✅|
|1.3B ExVideo LoRA|✅|| |1.3B ExVideo LoRA|✅||✅|
|1.3B Speed Control adapter|✅|✅| |1.3B Speed Control adapter|✅|✅|✅|
## VRAM Usage ## VRAM Usage

View File

@@ -0,0 +1,63 @@
import torch
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
from modelscope import snapshot_download, dataset_snapshot_download
from PIL import Image
# Download models
snapshot_download("iic/VACE-Wan2.1-1.3B-Preview", local_dir="models/iic/VACE-Wan2.1-1.3B-Preview")
# Load models
model_manager = ModelManager(device="cuda")
model_manager.load_models(
[
"models/iic/VACE-Wan2.1-1.3B-Preview/diffusion_pytorch_model.safetensors",
"models/iic/VACE-Wan2.1-1.3B-Preview/models_t5_umt5-xxl-enc-bf16.pth",
"models/iic/VACE-Wan2.1-1.3B-Preview/Wan2.1_VAE.pth",
],
torch_dtype=torch.bfloat16,
)
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
pipe.enable_vram_management(num_persistent_param_in_dit=None)
# Download example video
dataset_snapshot_download(
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
local_dir="./",
allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"]
)
# Depth video -> Video
control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832)
video = pipe(
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
height=480, width=832, num_frames=81,
vace_video=control_video,
seed=1, tiled=True
)
save_video(video, "video1.mp4", fps=15, quality=5)
# Reference image -> Video
video = pipe(
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
height=480, width=832, num_frames=81,
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
seed=1, tiled=True
)
save_video(video, "video2.mp4", fps=15, quality=5)
# Depth video + Reference image -> Video
video = pipe(
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
negative_prompt="色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
num_inference_steps=50,
height=480, width=832, num_frames=81,
vace_video=control_video,
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
seed=1, tiled=True
)
save_video(video, "video3.mp4", fps=15, quality=5)