fix wan vace load mask video

This commit is contained in:
longredzhong
2025-07-04 10:22:34 +08:00
committed by GitHub
parent 22d28665fe
commit 63b2c51e11

View File

@@ -858,36 +858,36 @@ class WanVideoUnit_SpeedControl(PipelineUnit):
class WanVideoUnit_VACE(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("vace_video", "vace_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"),
input_params=("vace_video", "vace_video_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae",)
)
def process(
self,
pipe: WanVideoPipeline,
vace_video, vace_mask, vace_reference_image, vace_scale,
vace_video, vace_video_mask, vace_reference_image, vace_scale,
height, width, num_frames,
tiled, tile_size, tile_stride
):
if vace_video is not None or vace_mask is not None or vace_reference_image is not None:
if vace_video is not None or vace_video_mask is not None or vace_reference_image is not None:
pipe.load_models_to_device(["vae"])
if vace_video is None:
vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device)
else:
vace_video = pipe.preprocess_video(vace_video)
if vace_mask is None:
vace_mask = torch.ones_like(vace_video)
if vace_video_mask is None:
vace_video_mask = torch.ones_like(vace_video)
else:
vace_mask = pipe.preprocess_video(vace_mask)
vace_video_mask = pipe.preprocess_video(vace_video_mask, min_value=0, max_value=1)
inactive = vace_video * (1 - vace_mask) + 0 * vace_mask
reactive = vace_video * vace_mask + 0 * (1 - vace_mask)
inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask
reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask)
inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_video_latents = torch.concat((inactive, reactive), dim=1)
vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')
if vace_reference_image is None: