From 63b2c51e1109828c065558638199e619811aec70 Mon Sep 17 00:00:00 2001 From: longredzhong Date: Fri, 4 Jul 2025 10:22:34 +0800 Subject: [PATCH] fix wan vace load mask video --- diffsynth/pipelines/wan_video_new.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 14e564c..9f52ddc 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -858,36 +858,36 @@ class WanVideoUnit_SpeedControl(PipelineUnit): class WanVideoUnit_VACE(PipelineUnit): def __init__(self): super().__init__( - input_params=("vace_video", "vace_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"), + input_params=("vace_video", "vace_video_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"), onload_model_names=("vae",) ) def process( self, pipe: WanVideoPipeline, - vace_video, vace_mask, vace_reference_image, vace_scale, + vace_video, vace_video_mask, vace_reference_image, vace_scale, height, width, num_frames, tiled, tile_size, tile_stride ): - if vace_video is not None or vace_mask is not None or vace_reference_image is not None: + if vace_video is not None or vace_video_mask is not None or vace_reference_image is not None: pipe.load_models_to_device(["vae"]) if vace_video is None: vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device) else: vace_video = pipe.preprocess_video(vace_video) - if vace_mask is None: - vace_mask = torch.ones_like(vace_video) + if vace_video_mask is None: + vace_video_mask = torch.ones_like(vace_video) else: - vace_mask = pipe.preprocess_video(vace_mask) + vace_video_mask = pipe.preprocess_video(vace_video_mask, min_value=0, max_value=1) - inactive = vace_video * (1 - vace_mask) + 0 * vace_mask - reactive = vace_video * vace_mask + 0 * (1 - vace_mask) + inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask + reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask) inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) vace_video_latents = torch.concat((inactive, reactive), dim=1) - vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) + vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8) vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact') if vace_reference_image is None: