mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 17:38:10 +00:00
Merge pull request #661 from longredzhong/main
fix wan vace load mask video
This commit is contained in:
@@ -858,36 +858,36 @@ class WanVideoUnit_SpeedControl(PipelineUnit):
|
|||||||
class WanVideoUnit_VACE(PipelineUnit):
|
class WanVideoUnit_VACE(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_params=("vace_video", "vace_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"),
|
input_params=("vace_video", "vace_video_mask", "vace_reference_image", "vace_scale", "height", "width", "num_frames", "tiled", "tile_size", "tile_stride"),
|
||||||
onload_model_names=("vae",)
|
onload_model_names=("vae",)
|
||||||
)
|
)
|
||||||
|
|
||||||
def process(
|
def process(
|
||||||
self,
|
self,
|
||||||
pipe: WanVideoPipeline,
|
pipe: WanVideoPipeline,
|
||||||
vace_video, vace_mask, vace_reference_image, vace_scale,
|
vace_video, vace_video_mask, vace_reference_image, vace_scale,
|
||||||
height, width, num_frames,
|
height, width, num_frames,
|
||||||
tiled, tile_size, tile_stride
|
tiled, tile_size, tile_stride
|
||||||
):
|
):
|
||||||
if vace_video is not None or vace_mask is not None or vace_reference_image is not None:
|
if vace_video is not None or vace_video_mask is not None or vace_reference_image is not None:
|
||||||
pipe.load_models_to_device(["vae"])
|
pipe.load_models_to_device(["vae"])
|
||||||
if vace_video is None:
|
if vace_video is None:
|
||||||
vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device)
|
vace_video = torch.zeros((1, 3, num_frames, height, width), dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
else:
|
else:
|
||||||
vace_video = pipe.preprocess_video(vace_video)
|
vace_video = pipe.preprocess_video(vace_video)
|
||||||
|
|
||||||
if vace_mask is None:
|
if vace_video_mask is None:
|
||||||
vace_mask = torch.ones_like(vace_video)
|
vace_video_mask = torch.ones_like(vace_video)
|
||||||
else:
|
else:
|
||||||
vace_mask = pipe.preprocess_video(vace_mask)
|
vace_video_mask = pipe.preprocess_video(vace_video_mask, min_value=0, max_value=1)
|
||||||
|
|
||||||
inactive = vace_video * (1 - vace_mask) + 0 * vace_mask
|
inactive = vace_video * (1 - vace_video_mask) + 0 * vace_video_mask
|
||||||
reactive = vace_video * vace_mask + 0 * (1 - vace_mask)
|
reactive = vace_video * vace_video_mask + 0 * (1 - vace_video_mask)
|
||||||
inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
inactive = pipe.vae.encode(inactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
reactive = pipe.vae.encode(reactive, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
vace_video_latents = torch.concat((inactive, reactive), dim=1)
|
vace_video_latents = torch.concat((inactive, reactive), dim=1)
|
||||||
|
|
||||||
vace_mask_latents = rearrange(vace_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
|
vace_mask_latents = rearrange(vace_video_mask[0,0], "T (H P) (W Q) -> 1 (P Q) T H W", P=8, Q=8)
|
||||||
vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')
|
vace_mask_latents = torch.nn.functional.interpolate(vace_mask_latents, size=((vace_mask_latents.shape[2] + 3) // 4, vace_mask_latents.shape[3], vace_mask_latents.shape[4]), mode='nearest-exact')
|
||||||
|
|
||||||
if vace_reference_image is None:
|
if vace_reference_image is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user