mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support qwen-image-layered
This commit is contained in:
@@ -48,6 +48,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
QwenImageUnit_InputImageEmbedder(),
|
||||
QwenImageUnit_Inpaint(),
|
||||
QwenImageUnit_EditImageEmbedder(),
|
||||
QwenImageUnit_LayerInputImageEmbedder(),
|
||||
QwenImageUnit_ContextImageEmbedder(),
|
||||
QwenImageUnit_PromptEmbedder(),
|
||||
QwenImageUnit_EntityControl(),
|
||||
@@ -128,6 +129,9 @@ class QwenImagePipeline(BasePipeline):
|
||||
edit_rope_interpolation: bool = False,
|
||||
# Qwen-Image-Edit-2511
|
||||
zero_cond_t: bool = False,
|
||||
# Qwen-Image-Layered
|
||||
layer_input_image: Image.Image = None,
|
||||
layer_num: int = None,
|
||||
# In-context control
|
||||
context_image: Image.Image = None,
|
||||
# Tile
|
||||
@@ -160,6 +164,8 @@ class QwenImagePipeline(BasePipeline):
|
||||
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize, "edit_rope_interpolation": edit_rope_interpolation,
|
||||
"context_image": context_image,
|
||||
"zero_cond_t": zero_cond_t,
|
||||
"layer_input_image": layer_input_image,
|
||||
"layer_num": layer_num,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -179,7 +185,10 @@ class QwenImagePipeline(BasePipeline):
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
image = self.vae.decode(inputs_shared["latents"], device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
image = self.vae_output_to_image(image)
|
||||
if layer_num is None:
|
||||
image = self.vae_output_to_image(image)
|
||||
else:
|
||||
image = [self.vae_output_to_image(i, pattern="C H W") for i in image]
|
||||
self.load_models_to_device([])
|
||||
|
||||
return image
|
||||
@@ -230,12 +239,15 @@ class QwenImageUnit_ShapeChecker(PipelineUnit):
|
||||
class QwenImageUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "seed", "rand_device"),
|
||||
input_params=("height", "width", "seed", "rand_device", "layer_num"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device):
|
||||
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device, layer_num):
|
||||
if layer_num is None:
|
||||
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
else:
|
||||
noise = pipe.generate_noise((layer_num + 1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
@@ -252,8 +264,15 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
|
||||
if input_image is None:
|
||||
return {"latents": noise, "input_latents": None}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
if isinstance(input_image, list):
|
||||
input_latents = []
|
||||
for image in input_image:
|
||||
image = pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
input_latents.append(pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride))
|
||||
input_latents = torch.concat(input_latents, dim=0)
|
||||
else:
|
||||
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
if pipe.scheduler.training:
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
@@ -261,6 +280,22 @@ class QwenImageUnit_InputImageEmbedder(PipelineUnit):
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
|
||||
|
||||
class QwenImageUnit_LayerInputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("layer_input_image", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("layer_input_latents",),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, layer_input_image, tiled, tile_size, tile_stride):
|
||||
if layer_input_image is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
image = pipe.preprocess_image(layer_input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return {"layer_input_latents": latents}
|
||||
|
||||
|
||||
class QwenImageUnit_Inpaint(PipelineUnit):
|
||||
def __init__(self):
|
||||
@@ -677,6 +712,8 @@ def model_fn_qwen_image(
|
||||
entity_prompt_emb_mask=None,
|
||||
entity_masks=None,
|
||||
edit_latents=None,
|
||||
layer_input_latents=None,
|
||||
layer_num=None,
|
||||
context_latents=None,
|
||||
enable_fp8_attention=False,
|
||||
use_gradient_checkpointing=False,
|
||||
@@ -685,11 +722,16 @@ def model_fn_qwen_image(
|
||||
zero_cond_t=False,
|
||||
**kwargs
|
||||
):
|
||||
img_shapes = [(latents.shape[0], latents.shape[2]//2, latents.shape[3]//2)]
|
||||
if layer_num is None:
|
||||
layer_num = 1
|
||||
img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)]
|
||||
else:
|
||||
layer_num = layer_num + 1
|
||||
img_shapes = [(1, latents.shape[2]//2, latents.shape[3]//2)] * layer_num
|
||||
txt_seq_lens = prompt_emb_mask.sum(dim=1).tolist()
|
||||
timestep = timestep / 1000
|
||||
|
||||
image = rearrange(latents, "B C (H P) (W Q) -> B (H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
image = rearrange(latents, "(B N) C (H P) (W Q) -> B (N H W) (C P Q)", H=height//16, W=width//16, P=2, Q=2, N=layer_num)
|
||||
image_seq_len = image.shape[1]
|
||||
|
||||
if context_latents is not None:
|
||||
@@ -701,6 +743,11 @@ def model_fn_qwen_image(
|
||||
img_shapes += [(e.shape[0], e.shape[2]//2, e.shape[3]//2) for e in edit_latents_list]
|
||||
edit_image = [rearrange(e, "B C (H P) (W Q) -> B (H W) (C P Q)", H=e.shape[2]//2, W=e.shape[3]//2, P=2, Q=2) for e in edit_latents_list]
|
||||
image = torch.cat([image] + edit_image, dim=1)
|
||||
if layer_input_latents is not None:
|
||||
layer_num = layer_num + 1
|
||||
img_shapes += [(layer_input_latents.shape[0], layer_input_latents.shape[2]//2, layer_input_latents.shape[3]//2)]
|
||||
layer_input_latents = rearrange(layer_input_latents, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
|
||||
image = torch.cat([image, layer_input_latents], dim=1)
|
||||
|
||||
image = dit.img_in(image)
|
||||
if zero_cond_t:
|
||||
@@ -712,7 +759,11 @@ def model_fn_qwen_image(
|
||||
)
|
||||
else:
|
||||
modulate_index = None
|
||||
conditioning = dit.time_text_embed(timestep, image.dtype)
|
||||
conditioning = dit.time_text_embed(
|
||||
timestep,
|
||||
image.dtype,
|
||||
addition_t_cond=None if layer_num is None else torch.tensor([0]).to(device=image.device, dtype=torch.long)
|
||||
)
|
||||
|
||||
if entity_prompt_emb is not None:
|
||||
text, image_rotary_emb, attention_mask = dit.process_entity_masks(
|
||||
@@ -759,5 +810,5 @@ def model_fn_qwen_image(
|
||||
image = dit.proj_out(image)
|
||||
image = image[:, :image_seq_len]
|
||||
|
||||
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
latents = rearrange(image, "B (N H W) (C P Q) -> (B N) C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2, B=1)
|
||||
return latents
|
||||
|
||||
Reference in New Issue
Block a user