diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index d9e535a..77835a4 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -203,7 +203,7 @@ class WanVideoPipeline(BasePipeline): return {"context": prompt_emb} - def encode_image(self, image, end_image, num_frames, height, width): + def encode_image(self, image, end_image, num_frames, height, width, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): image = self.preprocess_image(image.resize((width, height))).to(self.device) clip_context = self.image_encoder.encode_image([image]) msk = torch.ones(1, num_frames, height//8, width//8, device=self.device) @@ -221,7 +221,8 @@ class WanVideoPipeline(BasePipeline): msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8) msk = msk.transpose(1, 2)[0] - y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0] + y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + y = y.to(dtype=self.torch_dtype, device=self.device) y = torch.concat([msk, y]) y = y.unsqueeze(0) clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device) @@ -393,7 +394,7 @@ class WanVideoPipeline(BasePipeline): # Encode image if input_image is not None and self.image_encoder is not None: self.load_models_to_device(["image_encoder", "vae"]) - image_emb = self.encode_image(input_image, end_image, num_frames, height, width) + image_emb = self.encode_image(input_image, end_image, num_frames, height, width, **tiler_kwargs) else: image_emb = {}