fix qwen_rope

This commit is contained in:
mi804
2025-08-18 17:31:18 +08:00
parent ad1da43476
commit d93de98a21
8 changed files with 36 additions and 54 deletions

View File

@@ -565,7 +565,6 @@ class QwenImageUnit_EditImageEmbedder(PipelineUnit):
def process(self, pipe: QwenImagePipeline, edit_image, height, width, tiled, tile_size, tile_stride):
if edit_image is None:
return {}
edit_image = edit_image.resize((width, height))
pipe.load_models_to_device(['vae'])
edit_image = pipe.preprocess_image(edit_image).to(device=pipe.device, dtype=pipe.torch_dtype)
edit_latents = pipe.vae.encode(edit_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
@@ -601,8 +600,8 @@ def model_fn_qwen_image(
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
if edit_latents is not None:
img_shapes[0] = (img_shapes[0][0] + edit_latents.shape[0], img_shapes[0][1], img_shapes[0][2])
edit_image = rearrange(edit_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
img_shapes += [(edit_latents.shape[0], edit_latents.shape[2]//2, edit_latents.shape[3]//2)]
edit_image = rearrange(edit_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=edit_latents.shape[2]//2, W=edit_latents.shape[3]//2, P=2, Q=2)
image_seq_len = image.shape[1]
image = torch.cat([image, edit_image], dim=1)