qwen_image layercontrol v2

This commit is contained in:
mi804
2026-02-24 15:19:16 +08:00
parent 288bbc7128
commit ee73a29885
10 changed files with 171 additions and 3 deletions

View File

@@ -682,14 +682,16 @@ class QwenImageUnit_Image2LoRADecode(PipelineUnit):
class QwenImageUnit_ContextImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("context_image", "height", "width", "tiled", "tile_size", "tile_stride"),
input_params=("context_image", "height", "width", "tiled", "tile_size", "tile_stride", "layer_input_image"),
output_params=("context_latents",),
onload_model_names=("vae",)
)
def process(self, pipe: QwenImagePipeline, context_image, height, width, tiled, tile_size, tile_stride):
def process(self, pipe: QwenImagePipeline, context_image, height, width, tiled, tile_size, tile_stride, layer_input_image=None):
if context_image is None:
return {}
if layer_input_image is not None:
context_image = context_image.convert("RGBA")
pipe.load_models_to_device(self.onload_model_names)
context_image = pipe.preprocess_image(context_image.resize((width, height))).to(device=pipe.device, dtype=pipe.torch_dtype)
context_latents = pipe.vae.encode(context_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)