update value controller

This commit is contained in:
Artiprocher
2025-07-21 16:30:06 +08:00
parent 1384de0353
commit 22705a44b4
4 changed files with 18 additions and 16 deletions

View File

@@ -855,18 +855,28 @@ class FluxImageUnit_InfiniteYou(PipelineUnit):
class FluxImageUnit_ValueControl(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
input_params_nega={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
input_params=("value_controller_inputs",),
onload_model_names=("value_controller",)
)
def add_to_text_embedding(self, prompt_emb, text_ids, value_emb):
prompt_emb = torch.concat([prompt_emb, value_emb], dim=1)
extra_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype)
text_ids = torch.concat([text_ids, extra_text_ids], dim=1)
return prompt_emb, text_ids
def process(self, pipe: FluxImagePipeline, value_controller_inputs):
def process(self, pipe: FluxImagePipeline, prompt_emb, text_ids, value_controller_inputs):
if value_controller_inputs is None:
return {}
value_controller_inputs = torch.tensor(value_controller_inputs).to(dtype=pipe.torch_dtype, device=pipe.device)
pipe.load_models_to_device(["value_controller"])
value_emb = pipe.value_controller(value_controller_inputs, pipe.torch_dtype)
value_emb = value_emb.unsqueeze(0)
return {"value_emb": value_emb}
prompt_emb, text_ids = self.add_to_text_embedding(prompt_emb, text_ids, value_emb)
return {"prompt_emb": prompt_emb, "text_ids": text_ids}
@@ -1049,7 +1059,6 @@ def model_fn_flux_image(
flex_condition=None,
flex_uncondition=None,
flex_control_stop_timestep=None,
value_emb=None,
step1x_llm_embedding=None,
step1x_mask=None,
step1x_reference_latents=None,
@@ -1155,12 +1164,6 @@ def model_fn_flux_image(
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
else:
prompt_emb = dit.context_embedder(prompt_emb)
# Value Control
if value_emb is not None:
prompt_emb = torch.concat([prompt_emb, value_emb], dim=1)
value_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype)
text_ids = torch.concat([text_ids, value_text_ids], dim=1)
# Original FLUX inference
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
attention_mask = None