mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 18:28:10 +00:00
Fix num_frames in i2v (#339)
* Fix num_frames in i2v * Remove print in flash_attention
This commit is contained in:
@@ -112,7 +112,6 @@ def flash_attention(
|
|||||||
causal=causal,
|
causal=causal,
|
||||||
deterministic=deterministic)[0].unflatten(0, (b, lq))
|
deterministic=deterministic)[0].unflatten(0, (b, lq))
|
||||||
elif FLASH_ATTN_2_AVAILABLE:
|
elif FLASH_ATTN_2_AVAILABLE:
|
||||||
print(q_lens, lq, k_lens, lk, causal, window_size)
|
|
||||||
x = flash_attn.flash_attn_varlen_func(
|
x = flash_attn.flash_attn_varlen_func(
|
||||||
q=q,
|
q=q,
|
||||||
k=k,
|
k=k,
|
||||||
@@ -128,7 +127,6 @@ def flash_attention(
|
|||||||
causal=causal,
|
causal=causal,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
deterministic=deterministic).unflatten(0, (b, lq))
|
deterministic=deterministic).unflatten(0, (b, lq))
|
||||||
print(x.shape)
|
|
||||||
else:
|
else:
|
||||||
q = q.unsqueeze(0).transpose(1, 2).to(dtype)
|
q = q.unsqueeze(0).transpose(1, 2).to(dtype)
|
||||||
k = k.unsqueeze(0).transpose(1, 2).to(dtype)
|
k = k.unsqueeze(0).transpose(1, 2).to(dtype)
|
||||||
|
|||||||
@@ -150,16 +150,16 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
return {"context": prompt_emb}
|
return {"context": prompt_emb}
|
||||||
|
|
||||||
|
|
||||||
def encode_image(self, image, height, width):
|
def encode_image(self, image, num_frames, height, width):
|
||||||
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
with torch.amp.autocast(dtype=torch.bfloat16, device_type=torch.device(self.device).type):
|
||||||
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
||||||
clip_context = self.image_encoder.encode_image([image])
|
clip_context = self.image_encoder.encode_image([image])
|
||||||
msk = torch.ones(1, 81, height//8, width//8, device=self.device)
|
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
|
||||||
msk[:, 1:] = 0
|
msk[:, 1:] = 0
|
||||||
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
|
||||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||||
msk = msk.transpose(1, 2)[0]
|
msk = msk.transpose(1, 2)[0]
|
||||||
y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, 80, height, width).to(image.device)], dim=1)], device=self.device)[0]
|
y = self.vae.encode([torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)], device=self.device)[0]
|
||||||
y = torch.concat([msk, y])
|
y = torch.concat([msk, y])
|
||||||
return {"clip_fea": clip_context, "y": [y]}
|
return {"clip_fea": clip_context, "y": [y]}
|
||||||
|
|
||||||
@@ -234,7 +234,7 @@ class WanVideoPipeline(BasePipeline):
|
|||||||
# Encode image
|
# Encode image
|
||||||
if input_image is not None and self.image_encoder is not None:
|
if input_image is not None and self.image_encoder is not None:
|
||||||
self.load_models_to_device(["image_encoder", "vae"])
|
self.load_models_to_device(["image_encoder", "vae"])
|
||||||
image_emb = self.encode_image(input_image, height, width)
|
image_emb = self.encode_image(input_image, num_frames, height, width)
|
||||||
else:
|
else:
|
||||||
image_emb = {}
|
image_emb = {}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user