mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
@@ -60,6 +60,7 @@ from ..models.wan_video_text_encoder import WanTextEncoder
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vae import WanVideoVAE
|
||||
from ..models.wan_video_motion_controller import WanMotionControllerModel
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
|
||||
|
||||
model_loader_configs = [
|
||||
@@ -125,6 +126,7 @@ model_loader_configs = [
|
||||
(None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"),
|
||||
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
|
||||
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),
|
||||
(None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"),
|
||||
(None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"),
|
||||
|
||||
@@ -451,6 +451,7 @@ class WanModelStateDictConverter:
|
||||
return state_dict_, config
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
|
||||
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
|
||||
77
diffsynth/models/wan_video_vace.py
Normal file
77
diffsynth/models/wan_video_vace.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
from .wan_video_dit import DiTBlock
|
||||
|
||||
|
||||
class VaceWanAttentionBlock(DiTBlock):
|
||||
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
|
||||
super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps)
|
||||
self.block_id = block_id
|
||||
if block_id == 0:
|
||||
self.before_proj = torch.nn.Linear(self.dim, self.dim)
|
||||
self.after_proj = torch.nn.Linear(self.dim, self.dim)
|
||||
|
||||
def forward(self, c, x, context, t_mod, freqs):
|
||||
if self.block_id == 0:
|
||||
c = self.before_proj(c) + x
|
||||
all_c = []
|
||||
else:
|
||||
all_c = list(torch.unbind(c))
|
||||
c = all_c.pop(-1)
|
||||
c = super().forward(c, context, t_mod, freqs)
|
||||
c_skip = self.after_proj(c)
|
||||
all_c += [c_skip, c]
|
||||
c = torch.stack(all_c)
|
||||
return c
|
||||
|
||||
|
||||
class VaceWanModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
|
||||
vace_in_dim=96,
|
||||
patch_size=(1, 2, 2),
|
||||
has_image_input=False,
|
||||
dim=1536,
|
||||
num_heads=12,
|
||||
ffn_dim=8960,
|
||||
eps=1e-6,
|
||||
):
|
||||
super().__init__()
|
||||
self.vace_layers = vace_layers
|
||||
self.vace_in_dim = vace_in_dim
|
||||
self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
|
||||
|
||||
# vace blocks
|
||||
self.vace_blocks = torch.nn.ModuleList([
|
||||
VaceWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i)
|
||||
for i in self.vace_layers
|
||||
])
|
||||
|
||||
# vace patch embeddings
|
||||
self.vace_patch_embedding = torch.nn.Conv3d(vace_in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, x, vace_context, context, t_mod, freqs):
|
||||
c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
|
||||
c = [u.flatten(2).transpose(1, 2) for u in c]
|
||||
c = torch.cat([
|
||||
torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))],
|
||||
dim=1) for u in c
|
||||
])
|
||||
|
||||
for block in self.vace_blocks:
|
||||
c = block(c, x, context, t_mod, freqs)
|
||||
hints = torch.unbind(c)[:-1]
|
||||
return hints
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return VaceWanModelDictConverter()
|
||||
|
||||
|
||||
class VaceWanModelDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("vace")}
|
||||
return state_dict_
|
||||
@@ -4,6 +4,7 @@ from ..models.wan_video_dit import WanModel
|
||||
from ..models.wan_video_text_encoder import WanTextEncoder
|
||||
from ..models.wan_video_vae import WanVideoVAE
|
||||
from ..models.wan_video_image_encoder import WanImageEncoder
|
||||
from ..models.wan_video_vace import VaceWanModel
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from ..prompters import WanPrompter
|
||||
@@ -33,7 +34,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.dit: WanModel = None
|
||||
self.vae: WanVideoVAE = None
|
||||
self.motion_controller: WanMotionControllerModel = None
|
||||
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller']
|
||||
self.vace: VaceWanModel = None
|
||||
self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller', 'vace']
|
||||
self.height_division_factor = 16
|
||||
self.width_division_factor = 16
|
||||
self.use_unified_sequence_parallel = False
|
||||
@@ -153,6 +155,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
self.vae = model_manager.fetch_model("wan_video_vae")
|
||||
self.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
|
||||
self.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
|
||||
self.vace = model_manager.fetch_model("wan_video_vace")
|
||||
|
||||
|
||||
@staticmethod
|
||||
@@ -253,6 +256,57 @@ class WanVideoPipeline(BasePipeline):
|
||||
def prepare_motion_bucket_id(self, motion_bucket_id):
|
||||
motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device)
|
||||
return {"motion_bucket_id": motion_bucket_id}
|
||||
|
||||
|
||||
def prepare_vace_kwargs(
|
||||
self,
|
||||
latents,
|
||||
vace_video=None, vace_mask=None, vace_reference_image=None, vace_scale=1.0,
|
||||
height=480, width=832, num_frames=81,
|
||||
seed=None, rand_device="cpu",
|
||||
tiled=True, tile_size=(34, 34), tile_stride=(18, 16)
|
||||
):
|
||||
if vace_video is not None or vace_mask is not None or vace_reference_image is not None:
|
||||
self.load_models_to_device(["vae"])
|
||||
if vace_video is None:
|
||||
vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=self.torch_dtype, device=self.device)
|
||||
else:
|
||||
vace_video = self.preprocess_images(vace_video)
|
||||
vace_video = torch.stack(vace_video, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
if vace_mask is None:
|
||||
vace_mask = torch.ones_like(vace_video)
|
||||
else:
|
||||
vace_mask = self.preprocess_images(vace_mask)
|
||||
vace_mask = torch.stack(vace_mask, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
inactive = vace_video * (1 - vace_mask) + 0 * vace_mask
|
||||
reactive = vace_video * vace_mask + 0 * (1 - vace_mask)
|
||||
inactive = self.encode_video(inactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
|
||||
reactive = self.encode_video(reactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
|
||||
vace_video_latents = torch.concat((inactive, reactive), dim=1)
|
||||
|
||||
vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
|
||||
vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')
|
||||
|
||||
if vace_reference_image is None:
|
||||
pass
|
||||
else:
|
||||
vace_reference_image = self.preprocess_images([vace_reference_image])
|
||||
vace_reference_image = torch.stack(vace_reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device)
|
||||
vace_reference_latents = self.encode_video(vace_reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device)
|
||||
vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
|
||||
vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
|
||||
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2)
|
||||
|
||||
noise = self.generate_noise((1, 16, 1, latents.shape[3], latents.shape[4]), seed=seed, device=rand_device, dtype=torch.float32)
|
||||
noise = noise.to(dtype=self.torch_dtype, device=self.device)
|
||||
latents = torch.concat((noise, latents), dim=2)
|
||||
|
||||
vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)
|
||||
return latents, {"vace_context": vace_context, "vace_scale": vace_scale}
|
||||
else:
|
||||
return latents, {"vace_context": None, "vace_scale": vace_scale}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -264,6 +318,10 @@ class WanVideoPipeline(BasePipeline):
|
||||
end_image=None,
|
||||
input_video=None,
|
||||
control_video=None,
|
||||
vace_video=None,
|
||||
vace_video_mask=None,
|
||||
vace_reference_image=None,
|
||||
vace_scale=1.0,
|
||||
denoising_strength=1.0,
|
||||
seed=None,
|
||||
rand_device="cpu",
|
||||
@@ -333,6 +391,12 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents)
|
||||
|
||||
# VACE
|
||||
latents, vace_kwargs = self.prepare_vace_kwargs(
|
||||
latents, vace_video, vace_video_mask, vace_reference_image, vace_scale,
|
||||
height=height, width=width, num_frames=num_frames, seed=seed, rand_device=rand_device, **tiler_kwargs
|
||||
)
|
||||
|
||||
# TeaCache
|
||||
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||
@@ -341,23 +405,23 @@ class WanVideoPipeline(BasePipeline):
|
||||
usp_kwargs = self.prepare_unified_sequence_parallel()
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(["dit", "motion_controller"])
|
||||
self.load_models_to_device(["dit", "motion_controller", "vace"])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = model_fn_wan_video(
|
||||
self.dit, motion_controller=self.motion_controller,
|
||||
self.dit, motion_controller=self.motion_controller, vace=self.vace,
|
||||
x=latents, timestep=timestep,
|
||||
**prompt_emb_posi, **image_emb, **extra_input,
|
||||
**tea_cache_posi, **usp_kwargs, **motion_kwargs
|
||||
**tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = model_fn_wan_video(
|
||||
self.dit, motion_controller=self.motion_controller,
|
||||
self.dit, motion_controller=self.motion_controller, vace=self.vace,
|
||||
x=latents, timestep=timestep,
|
||||
**prompt_emb_nega, **image_emb, **extra_input,
|
||||
**tea_cache_nega, **usp_kwargs, **motion_kwargs
|
||||
**tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
@@ -365,6 +429,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
# Scheduler
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
if vace_reference_image is not None:
|
||||
latents = latents[:, :, 1:]
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
@@ -432,11 +499,14 @@ class TeaCache:
|
||||
def model_fn_wan_video(
|
||||
dit: WanModel,
|
||||
motion_controller: WanMotionControllerModel = None,
|
||||
vace: VaceWanModel = None,
|
||||
x: torch.Tensor = None,
|
||||
timestep: torch.Tensor = None,
|
||||
context: torch.Tensor = None,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
vace_context = None,
|
||||
vace_scale = 1.0,
|
||||
tea_cache: TeaCache = None,
|
||||
use_unified_sequence_parallel: bool = False,
|
||||
motion_bucket_id: Optional[torch.Tensor] = None,
|
||||
@@ -472,6 +542,9 @@ def model_fn_wan_video(
|
||||
tea_cache_update = tea_cache.check(dit, x, t_mod)
|
||||
else:
|
||||
tea_cache_update = False
|
||||
|
||||
if vace_context is not None:
|
||||
vace_hints = vace(x, vace_context, context, t_mod, freqs)
|
||||
|
||||
# blocks
|
||||
if use_unified_sequence_parallel:
|
||||
@@ -480,8 +553,10 @@ def model_fn_wan_video(
|
||||
if tea_cache_update:
|
||||
x = tea_cache.update(x)
|
||||
else:
|
||||
for block in dit.blocks:
|
||||
for block_id, block in enumerate(dit.blocks):
|
||||
x = block(x, context, t_mod, freqs)
|
||||
if vace_context is not None and block_id in vace.vace_layers_mapping:
|
||||
x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(x)
|
||||
|
||||
|
||||
@@ -26,28 +26,30 @@ pip install -e .
|
||||
|PAI Team|14B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)|
|
||||
|PAI Team|1.3B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
|
||||
|PAI Team|14B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|[wan_fun_control.py](./wan_fun_control.py)|
|
||||
|IIC Team|1.3B VACE|[Link](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|[wan_1.3b_vace.py](./wan_1.3b_vace.py)|
|
||||
|
||||
Base model features
|
||||
|
||||
||Text-to-video|Image-to-video|End frame|Control|
|
||||
|-|-|-|-|-|
|
||||
|1.3B text-to-video|✅||||
|
||||
|14B text-to-video|✅||||
|
||||
|14B image-to-video 480P||✅|||
|
||||
|14B image-to-video 720P||✅|||
|
||||
|1.3B InP||✅|✅||
|
||||
|14B InP||✅|✅||
|
||||
|1.3B Control||||✅|
|
||||
|14B Control||||✅|
|
||||
||Text-to-video|Image-to-video|End frame|Control|Reference image|
|
||||
|-|-|-|-|-|-|
|
||||
|1.3B text-to-video|✅|||||
|
||||
|14B text-to-video|✅|||||
|
||||
|14B image-to-video 480P||✅||||
|
||||
|14B image-to-video 720P||✅||||
|
||||
|1.3B InP||✅|✅|||
|
||||
|14B InP||✅|✅|||
|
||||
|1.3B Control||||✅||
|
||||
|14B Control||||✅||
|
||||
|1.3B VACE||||✅|✅|
|
||||
|
||||
Adapter model compatibility
|
||||
|
||||
||1.3B text-to-video|1.3B InP|
|
||||
|-|-|-|
|
||||
|1.3B aesthetics LoRA|✅||
|
||||
|1.3B Highres-fix LoRA|✅||
|
||||
|1.3B ExVideo LoRA|✅||
|
||||
|1.3B Speed Control adapter|✅|✅|
|
||||
||1.3B text-to-video|1.3B InP|1.3B VACE|
|
||||
|-|-|-|-|
|
||||
|1.3B aesthetics LoRA|✅||✅|
|
||||
|1.3B Highres-fix LoRA|✅||✅|
|
||||
|1.3B ExVideo LoRA|✅||✅|
|
||||
|1.3B Speed Control adapter|✅|✅|✅|
|
||||
|
||||
## VRAM Usage
|
||||
|
||||
|
||||
63
examples/wanvideo/wan_1.3b_vace.py
Normal file
63
examples/wanvideo/wan_1.3b_vace.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||
from modelscope import snapshot_download, dataset_snapshot_download
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# Download models
|
||||
snapshot_download("iic/VACE-Wan2.1-1.3B-Preview", local_dir="models/iic/VACE-Wan2.1-1.3B-Preview")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(device="cuda")
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/iic/VACE-Wan2.1-1.3B-Preview/diffusion_pytorch_model.safetensors",
|
||||
"models/iic/VACE-Wan2.1-1.3B-Preview/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/iic/VACE-Wan2.1-1.3B-Preview/Wan2.1_VAE.pth",
|
||||
],
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
|
||||
# Download example video
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"]
|
||||
)
|
||||
|
||||
# Depth video -> Video
|
||||
control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832)
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
num_inference_steps=50,
|
||||
height=480, width=832, num_frames=81,
|
||||
vace_video=control_video,
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
# Reference image -> Video
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
num_inference_steps=50,
|
||||
height=480, width=832, num_frames=81,
|
||||
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video2.mp4", fps=15, quality=5)
|
||||
|
||||
# Depth video + Reference image -> Video
|
||||
video = pipe(
|
||||
prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
num_inference_steps=50,
|
||||
height=480, width=832, num_frames=81,
|
||||
vace_video=control_video,
|
||||
vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)),
|
||||
seed=1, tiled=True
|
||||
)
|
||||
save_video(video, "video3.mp4", fps=15, quality=5)
|
||||
Reference in New Issue
Block a user