update value controller

This commit is contained in:
Artiprocher
2025-07-21 16:30:06 +08:00
parent 1384de0353
commit 22705a44b4
4 changed files with 18 additions and 16 deletions

View File

@@ -107,7 +107,7 @@ model_loader_configs = [
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
(None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"),
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
(None, "3ede90c44b2c161240b659f3b8393c9d", ["flux_value_controller"], [SingleValueEncoder], "civitai"),
(None, "0629116fce1472503a66992f96f3eb1a", ["flux_value_controller"], [SingleValueEncoder], "civitai"),
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),

View File

@@ -18,7 +18,7 @@ class MultiValueEncoder(torch.nn.Module):
class SingleValueEncoder(torch.nn.Module):
def __init__(self, dim_in=256, dim_out=3072, prefer_len=32, computation_device=None):
def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None):
super().__init__()
self.prefer_len = prefer_len
self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)

View File

@@ -855,18 +855,28 @@ class FluxImageUnit_InfiniteYou(PipelineUnit):
class FluxImageUnit_ValueControl(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
input_params_nega={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
input_params=("value_controller_inputs",),
onload_model_names=("value_controller",)
)
def add_to_text_embedding(self, prompt_emb, text_ids, value_emb):
prompt_emb = torch.concat([prompt_emb, value_emb], dim=1)
extra_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype)
text_ids = torch.concat([text_ids, extra_text_ids], dim=1)
return prompt_emb, text_ids
def process(self, pipe: FluxImagePipeline, value_controller_inputs):
def process(self, pipe: FluxImagePipeline, prompt_emb, text_ids, value_controller_inputs):
if value_controller_inputs is None:
return {}
value_controller_inputs = torch.tensor(value_controller_inputs).to(dtype=pipe.torch_dtype, device=pipe.device)
pipe.load_models_to_device(["value_controller"])
value_emb = pipe.value_controller(value_controller_inputs, pipe.torch_dtype)
value_emb = value_emb.unsqueeze(0)
return {"value_emb": value_emb}
prompt_emb, text_ids = self.add_to_text_embedding(prompt_emb, text_ids, value_emb)
return {"prompt_emb": prompt_emb, "text_ids": text_ids}
@@ -1049,7 +1059,6 @@ def model_fn_flux_image(
flex_condition=None,
flex_uncondition=None,
flex_control_stop_timestep=None,
value_emb=None,
step1x_llm_embedding=None,
step1x_mask=None,
step1x_reference_latents=None,
@@ -1155,12 +1164,6 @@ def model_fn_flux_image(
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
else:
prompt_emb = dit.context_embedder(prompt_emb)
# Value Control
if value_emb is not None:
prompt_emb = torch.concat([prompt_emb, value_emb], dim=1)
value_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype)
text_ids = torch.concat([text_ids, value_text_ids], dim=1)
# Original FLUX inference
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
attention_mask = None