From 76335e0fe5f6a1a9a3cbea1f11d565626577241d Mon Sep 17 00:00:00 2001 From: lzws <2538048363@qq.com> Date: Mon, 22 Sep 2025 02:14:20 +0800 Subject: [PATCH] uodate wan2.2-VACE-Fun --- diffsynth/pipelines/wan_video_new.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/diffsynth/pipelines/wan_video_new.py b/diffsynth/pipelines/wan_video_new.py index 660a38e..2649bae 100644 --- a/diffsynth/pipelines/wan_video_new.py +++ b/diffsynth/pipelines/wan_video_new.py @@ -44,8 +44,9 @@ class WanVideoPipeline(BasePipeline): self.vae: WanVideoVAE = None self.motion_controller: WanMotionControllerModel = None self.vace: VaceWanModel = None + self.vace2: VaceWanModel = None self.in_iteration_models = ("dit", "motion_controller", "vace") - self.in_iteration_models_2 = ("dit2", "motion_controller", "vace") + self.in_iteration_models_2 = ("dit2", "motion_controller", "vace2") self.unit_runner = PipelineUnitRunner() self.units = [ WanVideoUnit_ShapeChecker(), @@ -359,6 +360,10 @@ class WanVideoPipeline(BasePipeline): pipe.image_encoder = model_manager.fetch_model("wan_video_image_encoder") pipe.motion_controller = model_manager.fetch_model("wan_video_motion_controller") pipe.vace = model_manager.fetch_model("wan_video_vace") + if isinstance(vace, list): + pipe.vace, pipe.vace2 = vace + else: + pipe.vace = vace pipe.audio_encoder = model_manager.fetch_model("wans2v_audio_encoder") # Size division factor @@ -481,6 +486,7 @@ class WanVideoPipeline(BasePipeline): if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2: self.load_models_to_device(self.in_iteration_models_2) models["dit"] = self.dit2 + models["vace"] = self.vace2 # Timestep timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) @@ -534,11 +540,12 @@ class WanVideoUnit_NoiseInitializer(PipelineUnit): def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image): length = (num_frames - 1) // 4 + 1 if vace_reference_image is not None: - length += 1 + f = len(vace_reference_image) if isinstance(vace_reference_image, list) else 1 + length += f shape = (1, pipe.vae.model.z_dim, length, height // pipe.vae.upsampling_factor, width // pipe.vae.upsampling_factor) noise = pipe.generate_noise(shape, seed=seed, rand_device=rand_device) if vace_reference_image is not None: - noise = torch.concat((noise[:, :, -1:], noise[:, :, :-1]), dim=2) + noise = torch.concat((noise[:, :, -f:], noise[:, :, :-f]), dim=2) return {"noise": noise} @@ -849,11 +856,22 @@ class WanVideoUnit_VACE(PipelineUnit): if vace_reference_image is None: pass else: + if not isinstance(vace_reference_image,list): + vace_reference_image = [vace_reference_image] + + vace_reference_image = pipe.preprocess_video(vace_reference_image) + + bs, c, f, h, w = vace_reference_image.shape + new_vace_ref_images = [] + for j in range(f): + new_vace_ref_images.append(vace_reference_image[0, :, j:j+1]) + vace_reference_image = new_vace_ref_images + vace_reference_image = pipe.preprocess_video([vace_reference_image]) vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device) - vace_reference_latents = torch.concat((vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) + vace_reference_latents = torch.concat((*vace_reference_latents, torch.zeros_like(vace_reference_latents)), dim=1) vace_video_latents = torch.concat((vace_reference_latents, vace_video_latents), dim=2) - vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :1]), vace_mask_latents), dim=2) + vace_mask_latents = torch.concat((torch.zeros_like(vace_mask_latents[:, :, :f]), vace_mask_latents), dim=2) vace_context = torch.concat((vace_video_latents, vace_mask_latents), dim=1) return {"vace_context": vace_context, "vace_scale": vace_scale}