support qwen-image-edit lowres fix

This commit is contained in:
mi804
2025-08-19 20:15:36 +08:00
parent 3a9621f6da
commit 838b8109b1
8 changed files with 133 additions and 6 deletions

View File

@@ -68,10 +68,10 @@ class QwenImagePipeline(BasePipeline):
QwenImageUnit_NoiseInitializer(),
QwenImageUnit_InputImageEmbedder(),
QwenImageUnit_Inpaint(),
QwenImageUnit_EditImageEmbedder(),
QwenImageUnit_PromptEmbedder(),
QwenImageUnit_EntityControl(),
QwenImageUnit_BlockwiseControlNet(),
QwenImageUnit_EditImageEmbedder(),
]
self.model_fn = model_fn_qwen_image
@@ -280,6 +280,7 @@ class QwenImagePipeline(BasePipeline):
eligen_enable_on_negative: bool = False,
# Edit Image
edit_image: Image.Image = None,
edit_rope_interpolation: bool = False,
# FP8
enable_fp8_attention: bool = False,
# Tile
@@ -310,7 +311,7 @@ class QwenImagePipeline(BasePipeline):
"blockwise_controlnet_inputs": blockwise_controlnet_inputs,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
"edit_image": edit_image,
"edit_image": edit_image, "edit_rope_interpolation": edit_rope_interpolation,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -583,11 +584,11 @@ class QwenImageUnit_BlockwiseControlNet(PipelineUnit):
class QwenImageUnit_EditImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("edit_image", "height", "width", "tiled", "tile_size", "tile_stride"),
input_params=("edit_image", "tiled", "tile_size", "tile_stride"),
onload_model_names=("vae",)
)
def process(self, pipe: QwenImagePipeline, edit_image, height, width, tiled, tile_size, tile_stride):
def process(self, pipe: QwenImagePipeline, edit_image, tiled, tile_size, tile_stride):
if edit_image is None:
return {}
pipe.load_models_to_device(['vae'])
@@ -616,6 +617,7 @@ def model_fn_qwen_image(
enable_fp8_attention=False,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
edit_rope_interpolation=False,
**kwargs
):
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
@@ -640,7 +642,10 @@ def model_fn_qwen_image(
)
else:
text = dit.txt_in(dit.txt_norm(prompt_emb))
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
if edit_rope_interpolation:
image_rotary_emb = dit.pos_embed.forward_sampling(img_shapes, txt_seq_lens, device=latents.device)
else:
image_rotary_emb = dit.pos_embed(img_shapes, txt_seq_lens, device=latents.device)
attention_mask = None
if blockwise_controlnet_conditioning is not None: