From a98700feb24329b3b8d9defac41de844c10b9a7a Mon Sep 17 00:00:00 2001 From: "lzw478614@alibaba-inc.com" Date: Sun, 6 Apr 2025 22:55:42 +0800 Subject: [PATCH] support wan-fun-inp generating --- diffsynth/pipelines/wan_video.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index b6f2c74..6b95a69 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -163,16 +163,22 @@ class WanVideoPipeline(BasePipeline): return {"context": prompt_emb} - def encode_image(self, image, num_frames, height, width): + def encode_image(self, image, end_image, num_frames, height, width): image = self.preprocess_image(image.resize((width, height))).to(self.device) clip_context = self.image_encoder.encode_image([image]) msk = torch.ones(1, num_frames, height//8, width//8, device=self.device) msk[:, 1:] = 0 + if end_image is not None: + end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device) + vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1) + msk[:, -1:] = 1 + else: + vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) + msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) msk = msk.transpose(1, 2)[0] - vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1) y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0] y = torch.concat([msk, y]) y = y.unsqueeze(0) @@ -212,6 +218,7 @@ class WanVideoPipeline(BasePipeline): prompt, negative_prompt="", input_image=None, + end_image=None, input_video=None, denoising_strength=1.0, seed=None, @@ -263,7 +270,7 @@ class WanVideoPipeline(BasePipeline): # Encode image if input_image is not None and self.image_encoder is not None: self.load_models_to_device(["image_encoder", "vae"]) - image_emb = self.encode_image(input_image, num_frames, height, width) + image_emb = self.encode_image(input_image, end_image, num_frames, height, width) else: image_emb = {}