mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
fix wan vace load mask video
This commit is contained in:
@@ -858,36 +858,36 @@ class WanVideoUnit_SpeedControl(PipelineUnit):
|
||||
class WanVideoUnit_VACE(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("vace_video", "vace_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"),
|
||||
input_params=("vace_video", "vace_video_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(
|
||||
self,
|
||||
pipe: WanVideoPipeline,
|
||||
vace_video, vace_mask, vace_reference_image, vace_scale,
|
||||
vace_video, vace_video_mask, vace_reference_image, vace_scale,
|
||||
height, width, num_frames,
|
||||
tiled, tile_size, tile_stride
|
||||
):
|
||||
if vace_video is not None or vace_mask is not None or vace_reference_image is not None:
|
||||
if vace_video is not None or vace_video_mask is not None or vace_reference_image is not None:
|
||||
pipe.load_models_to_device(["vae"])
|
||||
if vace_video is None:
|
||||
vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
else:
|
||||
vace_video = pipe.preprocess_video(vace_video)
|
||||
|
||||
if vace_mask is None:
|
||||
vace_mask = torch.ones_like(vace_video)
|
||||
if vace_video_mask is None:
|
||||
vace_video_mask = torch.ones_like(vace_video)
|
||||
else:
|
||||
vace_mask = pipe.preprocess_video(vace_mask)
|
||||
vace_video_mask = pipe.preprocess_video(vace_video_mask, min_value=0, max_value=1)
|
||||
|
||||
inactive = vace_video * (1 - vace_mask) + 0 * vace_mask
|
||||
reactive = vace_video * vace_mask + 0 * (1 - vace_mask)
|
||||
inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask
|
||||
reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask)
|
||||
inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
vace_video_latents = torch.concat((inactive, reactive), dim=1)
|
||||
|
||||
vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
|
||||
vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
|
||||
vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')
|
||||
|
||||
if vace_reference_image is None:
|
||||
|
||||
Reference in New Issue
Block a user