support wan-flf2v

This commit is contained in:
Artiprocher
2025-04-17 14:47:55 +08:00
parent e9e24b8cf1
commit 553b341f5f
5 changed files with 86 additions and 2 deletions

View File

@@ -211,6 +211,8 @@ class WanVideoPipeline(BasePipeline):
if end_image is not None:
end_image = self.preprocess_image(end_image.resize((width, height))).to(self.device)
vae_input = torch.concat([image.transpose(0,1), torch.zeros(3, num_frames-2, height, width).to(image.device), end_image.transpose(0,1)],dim=1)
if self.dit.has_image_pos_emb:
clip_context = torch.concat([clip_context, self.image_encoder.encode_image([end_image])], dim=1)
msk[:, -1:] = 1
else:
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)