mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
update value controller
This commit is contained in:
@@ -855,18 +855,28 @@ class FluxImageUnit_InfiniteYou(PipelineUnit):
|
||||
class FluxImageUnit_ValueControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
|
||||
input_params_nega={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
|
||||
input_params=("value_controller_inputs",),
|
||||
onload_model_names=("value_controller",)
|
||||
)
|
||||
|
||||
def add_to_text_embedding(self, prompt_emb, text_ids, value_emb):
|
||||
prompt_emb = torch.concat([prompt_emb, value_emb], dim=1)
|
||||
extra_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype)
|
||||
text_ids = torch.concat([text_ids, extra_text_ids], dim=1)
|
||||
return prompt_emb, text_ids
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, value_controller_inputs):
|
||||
def process(self, pipe: FluxImagePipeline, prompt_emb, text_ids, value_controller_inputs):
|
||||
if value_controller_inputs is None:
|
||||
return {}
|
||||
value_controller_inputs = torch.tensor(value_controller_inputs).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
pipe.load_models_to_device(["value_controller"])
|
||||
value_emb = pipe.value_controller(value_controller_inputs, pipe.torch_dtype)
|
||||
value_emb = value_emb.unsqueeze(0)
|
||||
return {"value_emb": value_emb}
|
||||
prompt_emb, text_ids = self.add_to_text_embedding(prompt_emb, text_ids, value_emb)
|
||||
return {"prompt_emb": prompt_emb, "text_ids": text_ids}
|
||||
|
||||
|
||||
|
||||
@@ -1049,7 +1059,6 @@ def model_fn_flux_image(
|
||||
flex_condition=None,
|
||||
flex_uncondition=None,
|
||||
flex_control_stop_timestep=None,
|
||||
value_emb=None,
|
||||
step1x_llm_embedding=None,
|
||||
step1x_mask=None,
|
||||
step1x_reference_latents=None,
|
||||
@@ -1155,12 +1164,6 @@ def model_fn_flux_image(
|
||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||
else:
|
||||
prompt_emb = dit.context_embedder(prompt_emb)
|
||||
# Value Control
|
||||
if value_emb is not None:
|
||||
prompt_emb = torch.concat([prompt_emb, value_emb], dim=1)
|
||||
value_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype)
|
||||
text_ids = torch.concat([text_ids, value_text_ids], dim=1)
|
||||
# Original FLUX inference
|
||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
attention_mask = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user