This commit is contained in:
Zhongjie Duan
2025-04-17 15:23:46 +08:00
committed by GitHub
parent a4105d2c0e
commit b161bd6dfd

View File

@@ -203,7 +203,7 @@ class WanVideoPipeline(BasePipeline):
return {"context": prompt_emb}
def encode_image(self, image, end_image, num_frames, height, width):
def encode_image(self, image, end_image, num_frames, height, width, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
image = self.preprocess_image(image.resize((width, height))).to(self.device)
clip_context = self.image_encoder.encode_image([image])
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
@@ -221,7 +221,8 @@ class WanVideoPipeline(BasePipeline):
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
msk = msk.transpose(1, 2)[0]
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
y = y.to(dtype=self.torch_dtype, device=self.device)
y = torch.concat([msk, y])
y = y.unsqueeze(0)
clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
@@ -393,7 +394,7 @@ class WanVideoPipeline(BasePipeline):
# Encode image
if input_image is not None and self.image_encoder is not None:
self.load_models_to_device(["image_encoder", "vae"])
image_emb = self.encode_image(input_image, end_image, num_frames, height, width)
image_emb = self.encode_image(input_image, end_image, num_frames, height, width, **tiler_kwargs)
else:
image_emb = {}