mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
support eligenv2 and context_control
This commit is contained in:
@@ -69,6 +69,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
QwenImageUnit_InputImageEmbedder(),
|
||||
QwenImageUnit_Inpaint(),
|
||||
QwenImageUnit_EditImageEmbedder(),
|
||||
QwenImageUnit_ContextImageEmbedder(),
|
||||
QwenImageUnit_PromptEmbedder(),
|
||||
QwenImageUnit_EntityControl(),
|
||||
QwenImageUnit_BlockwiseControlNet(),
|
||||
@@ -284,6 +285,8 @@ class QwenImagePipeline(BasePipeline):
|
||||
edit_image: Image.Image = None,
|
||||
edit_image_auto_resize: bool = True,
|
||||
edit_rope_interpolation: bool = False,
|
||||
# Qwen-Image-Context-Control
|
||||
context_image: Image.Image = None,
|
||||
# FP8
|
||||
enable_fp8_attention: bool = False,
|
||||
# Tile
|
||||
@@ -315,6 +318,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative,
|
||||
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation,
|
||||
"context_image": context_image,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -615,6 +619,21 @@ class QwenImageUnit_EditImageEmbedder(PipelineUnit):
|
||||
edit_latents = pipe.vae.encode(edit_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return {"edit_latents": edit_latents, "edit_image": resized_edit_image}
|
||||
|
||||
class QwenImageUnit_ContextImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("context_image", "height", "width", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, context_image, height, width, tiled, tile_size, tile_stride):
|
||||
if context_image is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
context_image = pipe.preprocess_image(context_image.resize((width, height))).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
context_latents = pipe.vae.encode(context_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return {"context_latents": context_latents}
|
||||
|
||||
|
||||
def model_fn_qwen_image(
|
||||
dit: QwenImageDiT = None,
|
||||
@@ -633,6 +652,7 @@ def model_fn_qwen_image(
|
||||
entity_prompt_emb_mask=None,
|
||||
entity_masks=None,
|
||||
edit_latents=None,
|
||||
context_latents=None,
|
||||
enable_fp8_attention=False,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
@@ -646,6 +666,10 @@ def model_fn_qwen_image(
|
||||
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
image_seq_len = image.shape[1]
|
||||
|
||||
if context_latents is not None:
|
||||
img_shapes += [(context_latents.shape[0], context_latents.shape[2]//2, context_latents.shape[3]//2)]
|
||||
context_image = rearrange(context_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=context_latents.shape[2]//2, W=context_latents.shape[3]//2, P=2, Q=2)
|
||||
image = torch.cat([image, context_image], dim=1)
|
||||
if edit_latents is not None:
|
||||
img_shapes += [(edit_latents.shape[0], edit_latents.shape[2]//2, edit_latents.shape[3]//2)]
|
||||
edit_image = rearrange(edit_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=edit_latents.shape[2]//2, W=edit_latents.shape[3]//2, P=2, Q=2)
|
||||
@@ -692,7 +716,7 @@ def model_fn_qwen_image(
|
||||
|
||||
image = dit.norm_out(image, conditioning)
|
||||
image = dit.proj_out(image)
|
||||
if edit_latents is not None:
|
||||
if edit_latents is not None or context_latents is not None:
|
||||
image = image[:, :image_seq_len]
|
||||
|
||||
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
|
||||
Reference in New Issue
Block a user