Update train_wan_t2v.py

在应用itv的管道处理数据时有bug,提交修复
This commit is contained in:
mohui37
2025-04-11 17:05:40 +08:00
committed by GitHub
parent b925b402e2
commit 0dc56d9dcc

View File

@@ -140,7 +140,7 @@ class LightningModelForDataProcess(pl.LightningModule):
if "first_frame" in batch:
first_frame = Image.fromarray(batch["first_frame"][0].cpu().numpy())
_, _, num_frames, height, width = video.shape
image_emb = self.pipe.encode_image(first_frame, num_frames, height, width)
image_emb = self.pipe.encode_image(first_frame, None, num_frames, height, width)
else:
image_emb = {}
data = {"latents": latents, "prompt_emb": prompt_emb, "image_emb": image_emb}