Fix num_frames in i2v (#339)

* Fix num_frames in i2v

* Remove print in flash_attention
This commit is contained in:
Kohaku-Blueleaf
2025-02-26 10:05:51 +08:00
committed by GitHub
parent af7d305f00
commit 020560d2b5
2 changed files with 4 additions and 6 deletions

View File

@@ -112,7 +112,6 @@ def flash_attention(
causal=causal,
deterministic=deterministic)[0].unflatten(0, (b, lq))
elif FLASH_ATTN_2_AVAILABLE:
print(q_lens, lq, k_lens, lk, causal, window_size)
x = flash_attn.flash_attn_varlen_func(
q=q,
k=k,
@@ -128,7 +127,6 @@ def flash_attention(
causal=causal,
window_size=window_size,
deterministic=deterministic).unflatten(0, (b, lq))
print(x.shape)
else:
q = q.unsqueeze(0).transpose(1, 2).to(dtype)
k = k.unsqueeze(0).transpose(1, 2).to(dtype)

View File

@@ -150,16 +150,16 @@ class WanVideoPipeline(BasePipeline):
return {"context": prompt_emb}
def encode_image(self, image, height, width):
def encode_image(self, image, num_frames, height, width):
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
image = self.preprocess_image(image.resize((width, height))).to(self.device)
clip_context = self.image_encoder.encode_image([image])
msk = torch.ones(1, 81, height//8, width//8, device=self.device)
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, 80, height, width).to(image.device)], dim=1)], device=self.device)[0]
y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0]
y = torch.concat([msk, y])
return {"clip_fea": clip_context, "y": [y]}
@@ -234,7 +234,7 @@ class WanVideoPipeline(BasePipeline):
# Encode image
if input_image is not None and self.image_encoder is not None:
self.load_models_to_device(["image_encoder", "vae"])
image_emb = self.encode_image(input_image, height, width)
image_emb = self.encode_image(input_image, num_frames, height, width)
else:
image_emb = {}