update template framework

This commit is contained in:
Artiprocher
2026-04-15 14:07:51 +08:00
parent 9f8c352a15
commit 59b4bbb62c
7 changed files with 85 additions and 24 deletions

View File

@@ -100,6 +100,9 @@ class Flux2ImagePipeline(BasePipeline):
# LoRA
lora = None,
negative_lora = None,
# Text Embedding
extra_text_embedding = None,
negative_extra_text_embedding = None,
# Inpaint
inpaint_mask: Image.Image = None,
inpaint_blur_size: int = None,
@@ -113,10 +116,12 @@ class Flux2ImagePipeline(BasePipeline):
inputs_posi = {
"prompt": prompt,
"kv_cache": kv_cache,
"extra_text_embedding": extra_text_embedding,
}
inputs_nega = {
"negative_prompt": negative_prompt,
"kv_cache": negative_kv_cache,
"extra_text_embedding": negative_extra_text_embedding,
}
inputs_shared = {
"cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance,
@@ -607,6 +612,7 @@ def model_fn_flux2(
edit_latents=None,
edit_image_ids=None,
kv_cache=None,
extra_text_embedding=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
@@ -617,6 +623,11 @@ def model_fn_flux2(
latents = torch.concat([latents, edit_latents], dim=1)
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
embedded_guidance = torch.tensor([embedded_guidance], device=latents.device)
if extra_text_embedding is not None:
extra_text_ids = torch.zeros((1, extra_text_embedding.shape[1], 4), dtype=text_ids.dtype, device=text_ids.device)
extra_text_ids[:, :, -1] = torch.arange(prompt_embeds.shape[1], prompt_embeds.shape[1] + extra_text_embedding.shape[1])
prompt_embeds = torch.concat([prompt_embeds, extra_text_embedding], dim=1)
text_ids = torch.concat([text_ids, extra_text_ids], dim=1)
model_output = dit(
hidden_states=latents,
timestep=timestep / 1000,