From b7316281126a9130674ffe64af570aa167550520 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 15 Apr 2025 17:52:25 +0800 Subject: [PATCH] vace --- diffsynth/configs/model_config.py | 2 + diffsynth/models/wan_video_dit.py | 1 + diffsynth/models/wan_video_vace.py | 77 ++++++++++++++++++++++++++ diffsynth/pipelines/wan_video.py | 89 +++++++++++++++++++++++++++--- examples/wanvideo/README.md | 34 ++++++------ examples/wanvideo/wan_1.3b_vace.py | 63 +++++++++++++++++++++ 6 files changed, 243 insertions(+), 23 deletions(-) create mode 100644 diffsynth/models/wan_video_vace.py create mode 100644 examples/wanvideo/wan_1.3b_vace.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 8fdb50a..de09d16 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -60,6 +60,7 @@ from ..models.wan_video_text_encoder import WanTextEncoder from ..models.wan_video_image_encoder import WanImageEncoder from ..models.wan_video_vae import WanVideoVAE from ..models.wan_video_motion_controller import WanMotionControllerModel +from ..models.wan_video_vace import VaceWanModel model_loader_configs = [ @@ -125,6 +126,7 @@ model_loader_configs = [ (None, "6bfcfb3b342cb286ce886889d519a77e", ["wan_video_dit"], [WanModel], "civitai"), (None, "349723183fc063b2bfc10bb2835cf677", ["wan_video_dit"], [WanModel], "civitai"), (None, "efa44cddf936c70abd0ea28b6cbe946c", ["wan_video_dit"], [WanModel], "civitai"), + (None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"), (None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"), (None, "9c8818c2cbea55eca56c7b447df170da", ["wan_video_text_encoder"], [WanTextEncoder], "civitai"), (None, "5941c53e207d62f20f9025686193c40b", ["wan_video_image_encoder"], [WanImageEncoder], "civitai"), diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index c999596..93d108a 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -451,6 +451,7 @@ class WanModelStateDictConverter: return state_dict_, config def from_civitai(self, state_dict): + state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")} if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814": config = { "has_image_input": False, diff --git a/diffsynth/models/wan_video_vace.py b/diffsynth/models/wan_video_vace.py new file mode 100644 index 0000000..0c9c2d7 --- /dev/null +++ b/diffsynth/models/wan_video_vace.py @@ -0,0 +1,77 @@ +import torch +from .wan_video_dit import DiTBlock + + +class VaceWanAttentionBlock(DiTBlock): + def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0): + super().__init__(has_image_input, dim, num_heads, ffn_dim, eps=eps) + self.block_id = block_id + if block_id == 0: + self.before_proj = torch.nn.Linear(self.dim, self.dim) + self.after_proj = torch.nn.Linear(self.dim, self.dim) + + def forward(self, c, x, context, t_mod, freqs): + if self.block_id == 0: + c = self.before_proj(c) + x + all_c = [] + else: + all_c = list(torch.unbind(c)) + c = all_c.pop(-1) + c = super().forward(c, context, t_mod, freqs) + c_skip = self.after_proj(c) + all_c += [c_skip, c] + c = torch.stack(all_c) + return c + + +class VaceWanModel(torch.nn.Module): + def __init__( + self, + vace_layers=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28), + vace_in_dim=96, + patch_size=(1, 2, 2), + has_image_input=False, + dim=1536, + num_heads=12, + ffn_dim=8960, + eps=1e-6, + ): + super().__init__() + self.vace_layers = vace_layers + self.vace_in_dim = vace_in_dim + self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)} + + # vace blocks + self.vace_blocks = torch.nn.ModuleList([ + VaceWanAttentionBlock(has_image_input, dim, num_heads, ffn_dim, eps, block_id=i) + for i in self.vace_layers + ]) + + # vace patch embeddings + self.vace_patch_embedding = torch.nn.Conv3d(vace_in_dim, dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x, vace_context, context, t_mod, freqs): + c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context] + c = [u.flatten(2).transpose(1, 2) for u in c] + c = torch.cat([ + torch.cat([u, u.new_zeros(1, x.shape[1] - u.size(1), u.size(2))], + dim=1) for u in c + ]) + + for block in self.vace_blocks: + c = block(c, x, context, t_mod, freqs) + hints = torch.unbind(c)[:-1] + return hints + + @staticmethod + def state_dict_converter(): + return VaceWanModelDictConverter() + + +class VaceWanModelDictConverter: + def __init__(self): + pass + + def from_civitai(self, state_dict): + state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("vace")} + return state_dict_ diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 9a80f78..780d74b 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -4,6 +4,7 @@ from ..models.wan_video_dit import WanModel from ..models.wan_video_text_encoder import WanTextEncoder from ..models.wan_video_vae import WanVideoVAE from ..models.wan_video_image_encoder import WanImageEncoder +from ..models.wan_video_vace import VaceWanModel from ..schedulers.flow_match import FlowMatchScheduler from .base import BasePipeline from ..prompters import WanPrompter @@ -33,7 +34,8 @@ class WanVideoPipeline(BasePipeline): self.dit: WanModel = None self.vae: WanVideoVAE = None self.motion_controller: WanMotionControllerModel = None - self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller'] + self.vace: VaceWanModel = None + self.model_names = ['text_encoder', 'dit', 'vae', 'image_encoder', 'motion_controller', 'vace'] self.height_division_factor = 16 self.width_division_factor = 16 self.use_unified_sequence_parallel = False @@ -153,6 +155,7 @@ class WanVideoPipeline(BasePipeline): self.vae = model_manager.fetch_model("wan_video_vae") self.image_encoder = model_manager.fetch_model("wan_video_image_encoder") self.motion_controller = model_manager.fetch_model("wan_video_motion_controller") + self.vace = model_manager.fetch_model("wan_video_vace") @staticmethod @@ -253,6 +256,57 @@ class WanVideoPipeline(BasePipeline): def prepare_motion_bucket_id(self, motion_bucket_id): motion_bucket_id = torch.Tensor((motion_bucket_id,)).to(dtype=self.torch_dtype, device=self.device) return {"motion_bucket_id": motion_bucket_id} + + + def prepare_vace_kwargs( + self, + latents, + vace_video=None, vace_mask=None, vace_reference_image=None, vace_scale=1.0, + height=480, width=832, num_frames=81, + seed=None, rand_device="cpu", + tiled=True, tile_size=(34, 34), tile_stride=(18, 16) + ): + if vace_video is not None or vace_mask is not None or vace_reference_image is not None: + self.load_models_to_device(["vae"]) + if vace_video is None: + vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=self.torch_dtype, device=self.device) + else: + vace_video = self.preprocess_images(vace_video) + vace_video = torch.stack(vace_video, dim=2).to(dtype=self.torch_dtype, device=self.device) + + if vace_mask is None: + vace_mask = torch.ones_like(vace_video) + else: + vace_mask = self.preprocess_images(vace_mask) + vace_mask = torch.stack(vace_mask, dim=2).to(dtype=self.torch_dtype, device=self.device) + + inactive = vace_video * (1 - vace_mask) + 0 * vace_mask + reactive = vace_video * vace_mask + 0 * (1 - vace_mask) + inactive = self.encode_video(inactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device) + reactive = self.encode_video(reactive, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device) + vace_video_latents = torch.concat((inactive, reactive), dim=1) + + vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) + vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact') + + if vace_reference_image is None: + pass + else: + vace_reference_image = self.preprocess_images([vace_reference_image]) + vace_reference_image = torch.stack(vace_reference_image, dim=2).to(dtype=self.torch_dtype, device=self.device) + vace_reference_latents = self.encode_video(vace_reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=self.torch_dtype, device=self.device) + vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) + vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2) + vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2) + + noise = self.generate_noise((1, 16, 1, latents.shape[3], latents.shape[4]), seed=seed, device=rand_device, dtype=torch.float32) + noise = noise.to(dtype=self.torch_dtype, device=self.device) + latents = torch.concat((noise, latents), dim=2) + + vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1) + return latents, {"vace_context": vace_context, "vace_scale": vace_scale} + else: + return latents, {"vace_context": None, "vace_scale": vace_scale} @torch.no_grad() @@ -264,6 +318,10 @@ class WanVideoPipeline(BasePipeline): end_image=None, input_video=None, control_video=None, + vace_video=None, + vace_video_mask=None, + vace_reference_image=None, + vace_scale=1.0, denoising_strength=1.0, seed=None, rand_device="cpu", @@ -333,6 +391,12 @@ class WanVideoPipeline(BasePipeline): # Extra input extra_input = self.prepare_extra_input(latents) + # VACE + latents, vace_kwargs = self.prepare_vace_kwargs( + latents, vace_video, vace_video_mask, vace_reference_image, vace_scale, + height=height, width=width, num_frames=num_frames, seed=seed, rand_device=rand_device, **tiler_kwargs + ) + # TeaCache tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} @@ -341,23 +405,23 @@ class WanVideoPipeline(BasePipeline): usp_kwargs = self.prepare_unified_sequence_parallel() # Denoise - self.load_models_to_device(["dit", "motion_controller"]) + self.load_models_to_device(["dit", "motion_controller", "vace"]) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) # Inference noise_pred_posi = model_fn_wan_video( - self.dit, motion_controller=self.motion_controller, + self.dit, motion_controller=self.motion_controller, vace=self.vace, x=latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, - **tea_cache_posi, **usp_kwargs, **motion_kwargs + **tea_cache_posi, **usp_kwargs, **motion_kwargs, **vace_kwargs, ) if cfg_scale != 1.0: noise_pred_nega = model_fn_wan_video( - self.dit, motion_controller=self.motion_controller, + self.dit, motion_controller=self.motion_controller, vace=self.vace, x=latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, - **tea_cache_nega, **usp_kwargs, **motion_kwargs + **tea_cache_nega, **usp_kwargs, **motion_kwargs, **vace_kwargs, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: @@ -365,6 +429,9 @@ class WanVideoPipeline(BasePipeline): # Scheduler latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) + + if vace_reference_image is not None: + latents = latents[:, :, 1:] # Decode self.load_models_to_device(['vae']) @@ -432,11 +499,14 @@ class TeaCache: def model_fn_wan_video( dit: WanModel, motion_controller: WanMotionControllerModel = None, + vace: VaceWanModel = None, x: torch.Tensor = None, timestep: torch.Tensor = None, context: torch.Tensor = None, clip_feature: Optional[torch.Tensor] = None, y: Optional[torch.Tensor] = None, + vace_context = None, + vace_scale = 1.0, tea_cache: TeaCache = None, use_unified_sequence_parallel: bool = False, motion_bucket_id: Optional[torch.Tensor] = None, @@ -472,6 +542,9 @@ def model_fn_wan_video( tea_cache_update = tea_cache.check(dit, x, t_mod) else: tea_cache_update = False + + if vace_context is not None: + vace_hints = vace(x, vace_context, context, t_mod, freqs) # blocks if use_unified_sequence_parallel: @@ -480,8 +553,10 @@ def model_fn_wan_video( if tea_cache_update: x = tea_cache.update(x) else: - for block in dit.blocks: + for block_id, block in enumerate(dit.blocks): x = block(x, context, t_mod, freqs) + if vace_context is not None and block_id in vace.vace_layers_mapping: + x = x + vace_hints[vace.vace_layers_mapping[block_id]] * vace_scale if tea_cache is not None: tea_cache.store(x) diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index d3a22bf..3b03f86 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -26,28 +26,30 @@ pip install -e . |PAI Team|14B InP|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|[wan_fun_InP.py](./wan_fun_InP.py)| |PAI Team|1.3B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|[wan_fun_control.py](./wan_fun_control.py)| |PAI Team|14B Control|[Link](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|[wan_fun_control.py](./wan_fun_control.py)| +|IIC Team|1.3B VACE|[Link](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|[wan_1.3b_vace.py](./wan_1.3b_vace.py)| Base model features -||Text-to-video|Image-to-video|End frame|Control| -|-|-|-|-|-| -|1.3B text-to-video|✅|||| -|14B text-to-video|✅|||| -|14B image-to-video 480P||✅||| -|14B image-to-video 720P||✅||| -|1.3B InP||✅|✅|| -|14B InP||✅|✅|| -|1.3B Control||||✅| -|14B Control||||✅| +||Text-to-video|Image-to-video|End frame|Control|Reference image| +|-|-|-|-|-|-| +|1.3B text-to-video|✅||||| +|14B text-to-video|✅||||| +|14B image-to-video 480P||✅|||| +|14B image-to-video 720P||✅|||| +|1.3B InP||✅|✅||| +|14B InP||✅|✅||| +|1.3B Control||||✅|| +|14B Control||||✅|| +|1.3B VACE||||✅|✅| Adapter model compatibility -||1.3B text-to-video|1.3B InP| -|-|-|-| -|1.3B aesthetics LoRA|✅|| -|1.3B Highres-fix LoRA|✅|| -|1.3B ExVideo LoRA|✅|| -|1.3B Speed Control adapter|✅|✅| +||1.3B text-to-video|1.3B InP|1.3B VACE| +|-|-|-|-| +|1.3B aesthetics LoRA|✅||✅| +|1.3B Highres-fix LoRA|✅||✅| +|1.3B ExVideo LoRA|✅||✅| +|1.3B Speed Control adapter|✅|✅|✅| ## VRAM Usage diff --git a/examples/wanvideo/wan_1.3b_vace.py b/examples/wanvideo/wan_1.3b_vace.py new file mode 100644 index 0000000..01d1370 --- /dev/null +++ b/examples/wanvideo/wan_1.3b_vace.py @@ -0,0 +1,63 @@ +import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData +from modelscope import snapshot_download, dataset_snapshot_download +from PIL import Image + + +# Download models +snapshot_download("iic/VACE-Wan2.1-1.3B-Preview", local_dir="models/iic/VACE-Wan2.1-1.3B-Preview") + +# Load models +model_manager = ModelManager(device="cuda") +model_manager.load_models( + [ + "models/iic/VACE-Wan2.1-1.3B-Preview/diffusion_pytorch_model.safetensors", + "models/iic/VACE-Wan2.1-1.3B-Preview/models_t5_umt5-xxl-enc-bf16.pth", + "models/iic/VACE-Wan2.1-1.3B-Preview/Wan2.1_VAE.pth", + ], + torch_dtype=torch.bfloat16, +) +pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=None) + +# Download example video +dataset_snapshot_download( + dataset_id="DiffSynth-Studio/examples_in_diffsynth", + local_dir="./", + allow_file_pattern=["data/examples/wan/depth_video.mp4", "data/examples/wan/cat_fightning.jpg"] +) + +# Depth video -> Video +control_video = VideoData("data/examples/wan/depth_video.mp4", height=480, width=832) +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=50, + height=480, width=832, num_frames=81, + vace_video=control_video, + seed=1, tiled=True +) +save_video(video, "video1.mp4", fps=15, quality=5) + +# Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=50, + height=480, width=832, num_frames=81, + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video2.mp4", fps=15, quality=5) + +# Depth video + Reference image -> Video +video = pipe( + prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=50, + height=480, width=832, num_frames=81, + vace_video=control_video, + vace_reference_image=Image.open("data/examples/wan/cat_fightning.jpg").resize((832, 480)), + seed=1, tiled=True +) +save_video(video, "video3.mp4", fps=15, quality=5)