From 020560d2b519bb9131ec44816604c772a192b153 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 26 Feb 2025 10:05:51 +0800 Subject: [PATCH] Fix num_frames in i2v (#339) * Fix num_frames in i2v * Remove print in flash_attention --- diffsynth/models/wan_video_dit.py | 2 -- diffsynth/pipelines/wan_video.py | 8 ++++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index 7c31b55..395642a 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -112,7 +112,6 @@ def flash_attention( causal=causal, deterministic=deterministic)[0].unflatten(0, (b, lq)) elif FLASH_ATTN_2_AVAILABLE: - print(q_lens, lq, k_lens, lk, causal, window_size) x = flash_attn.flash_attn_varlen_func( q=q, k=k, @@ -128,7 +127,6 @@ def flash_attention( causal=causal, window_size=window_size, deterministic=deterministic).unflatten(0, (b, lq)) - print(x.shape) else: q = q.unsqueeze(0).transpose(1, 2).to(dtype) k = k.unsqueeze(0).transpose(1, 2).to(dtype) diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 8d50bad..f43d559 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -150,16 +150,16 @@ class WanVideoPipeline(BasePipeline): return {"context": prompt_emb} - def encode_image(self, image, height, width): + def encode_image(self, image, num_frames, height, width): with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type): image = self.preprocess_image(image.resize((width, height))).to(self.device) clip_context = self.image_encoder.encode_image([image]) - msk = torch.ones(1, 81, height//8, width//8, device=self.device) + msk = torch.ones(1, num_frames, height//8, width//8, device=self.device) msk[:, 1:] = 0 msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1) msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) msk = msk.transpose(1, 2)[0] - y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, 80, height, width).to(image.device)], dim=1)], device=self.device)[0] + y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0] y = torch.concat([msk, y]) return {"clip_fea": clip_context, "y": [y]} @@ -234,7 +234,7 @@ class WanVideoPipeline(BasePipeline): # Encode image if input_image is not None and self.image_encoder is not None: self.load_models_to_device(["image_encoder", "vae"]) - image_emb = self.encode_image(input_image, height, width) + image_emb = self.encode_image(input_image, num_frames, height, width) else: image_emb = {}