mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
bugfix
This commit is contained in:
@@ -203,7 +203,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
return {"context": prompt_emb}
|
||||
|
||||
|
||||
def encode_image(self, image, end_image, num_frames, height, width):
|
||||
def encode_image(self, image, end_image, num_frames, height, width, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||
image = self.preprocess_image(image.resize((width, height))).to(self.device)
|
||||
clip_context = self.image_encoder.encode_image([image])
|
||||
msk = torch.ones(1, num_frames, height//8, width//8, device=self.device)
|
||||
@@ -221,7 +221,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
|
||||
msk = msk.transpose(1, 2)[0]
|
||||
|
||||
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device)[0]
|
||||
y = self.vae.encode([vae_input.to(dtype=self.torch_dtype, device=self.device)], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
y = y.to(dtype=self.torch_dtype, device=self.device)
|
||||
y = torch.concat([msk, y])
|
||||
y = y.unsqueeze(0)
|
||||
clip_context = clip_context.to(dtype=self.torch_dtype, device=self.device)
|
||||
@@ -393,7 +394,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
# Encode image
|
||||
if input_image is not None and self.image_encoder is not None:
|
||||
self.load_models_to_device(["image_encoder", "vae"])
|
||||
image_emb = self.encode_image(input_image, end_image, num_frames, height, width)
|
||||
image_emb = self.encode_image(input_image, end_image, num_frames, height, width, **tiler_kwargs)
|
||||
else:
|
||||
image_emb = {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user