mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
update value controller
This commit is contained in:
@@ -107,7 +107,7 @@ model_loader_configs = [
|
||||
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
|
||||
(None, "3ede90c44b2c161240b659f3b8393c9d", ["flux_value_controller"], [SingleValueEncoder], "civitai"),
|
||||
(None, "0629116fce1472503a66992f96f3eb1a", ["flux_value_controller"], [SingleValueEncoder], "civitai"),
|
||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
||||
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
||||
|
||||
@@ -18,7 +18,7 @@ class MultiValueEncoder(torch.nn.Module):
|
||||
|
||||
|
||||
class SingleValueEncoder(torch.nn.Module):
|
||||
def __init__(self, dim_in=256, dim_out=3072, prefer_len=32, computation_device=None):
|
||||
def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None):
|
||||
super().__init__()
|
||||
self.prefer_len = prefer_len
|
||||
self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
|
||||
|
||||
@@ -855,18 +855,28 @@ class FluxImageUnit_InfiniteYou(PipelineUnit):
|
||||
class FluxImageUnit_ValueControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
|
||||
input_params_nega={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
|
||||
input_params=("value_controller_inputs",),
|
||||
onload_model_names=("value_controller",)
|
||||
)
|
||||
|
||||
def add_to_text_embedding(self, prompt_emb, text_ids, value_emb):
|
||||
prompt_emb = torch.concat([prompt_emb, value_emb], dim=1)
|
||||
extra_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype)
|
||||
text_ids = torch.concat([text_ids, extra_text_ids], dim=1)
|
||||
return prompt_emb, text_ids
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, value_controller_inputs):
|
||||
def process(self, pipe: FluxImagePipeline, prompt_emb, text_ids, value_controller_inputs):
|
||||
if value_controller_inputs is None:
|
||||
return {}
|
||||
value_controller_inputs = torch.tensor(value_controller_inputs).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
pipe.load_models_to_device(["value_controller"])
|
||||
value_emb = pipe.value_controller(value_controller_inputs, pipe.torch_dtype)
|
||||
value_emb = value_emb.unsqueeze(0)
|
||||
return {"value_emb": value_emb}
|
||||
prompt_emb, text_ids = self.add_to_text_embedding(prompt_emb, text_ids, value_emb)
|
||||
return {"prompt_emb": prompt_emb, "text_ids": text_ids}
|
||||
|
||||
|
||||
|
||||
@@ -1049,7 +1059,6 @@ def model_fn_flux_image(
|
||||
flex_condition=None,
|
||||
flex_uncondition=None,
|
||||
flex_control_stop_timestep=None,
|
||||
value_emb=None,
|
||||
step1x_llm_embedding=None,
|
||||
step1x_mask=None,
|
||||
step1x_reference_latents=None,
|
||||
@@ -1155,12 +1164,6 @@ def model_fn_flux_image(
|
||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||
else:
|
||||
prompt_emb = dit.context_embedder(prompt_emb)
|
||||
# Value Control
|
||||
if value_emb is not None:
|
||||
prompt_emb = torch.concat([prompt_emb, value_emb], dim=1)
|
||||
value_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype)
|
||||
text_ids = torch.concat([text_ids, value_text_ids], dim=1)
|
||||
# Original FLUX inference
|
||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
attention_mask = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user