fix wan vace bug (#960)

* fix wan vace bug
This commit is contained in:
lzws
2025-09-26 13:49:27 +08:00
committed by GitHub
parent 15079a6cb8
commit ed256ef8be
6 changed files with 15 additions and 7 deletions

View File

@@ -359,7 +359,7 @@ class WanVideoPipeline(BasePipeline):
pipe.vae = model_manager.fetch_model("wan_video_vae")
pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder")
pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller")
vace = model_manager.fetch_model("wan_video_vace")
vace = model_manager.fetch_model("wan_video_vace", index=2)
if isinstance(vace, list):
pipe.vace, pipe.vace2 = vace
else:
@@ -509,7 +509,8 @@ class WanVideoPipeline(BasePipeline):
# VACE (TODO: remove it)
if vace_reference_image is not None:
inputs_shared["latents"] = inputs_shared["latents"][:, :, 1:]
f = len(vace_reference_image) if isinstance(vace_reference_image, list) else 1
inputs_shared["latents"] = inputs_shared["latents"][:, :, f:]
# post-denoising, pre-decoding processing logic
for unit in self.post_units:
inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -564,7 +565,9 @@ class WanVideoUnit_InputVideoEmbedder(PipelineUnit):
input_video = pipe.preprocess_video(input_video)
input_latents = pipe.vae.encode(input_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
if vace_reference_image is not None:
vace_reference_image = pipe.preprocess_video([vace_reference_image])
if not isinstance(vace_reference_image, list):
vace_reference_image = [vace_reference_image]
vace_reference_image = pipe.preprocess_video(vace_reference_image)
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device)
input_latents = torch.concat([vace_reference_latents, input_latents], dim=2)
if pipe.scheduler.training:
@@ -866,11 +869,12 @@ class WanVideoUnit_VACE(PipelineUnit):
for j in range(f):
new_vace_ref_images.append(vace_reference_image[0, :, j:j+1])
vace_reference_image = new_vace_ref_images
vace_reference_image = pipe.preprocess_video([vace_reference_image])
vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
vace_reference_latents = torch.concat((*vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2)
vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1)
vace_reference_latents = [u.unsqueeze(0) for u in vace_reference_latents]
vace_video_latents = torch.concat((*vace_reference_latents, vace_video_latents), dim=2)
vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :f]), vace_mask_latents), dim=2)
vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1)