diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index c2e050d..b60c200 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -107,7 +107,7 @@ model_loader_configs = [ (None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"), (None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"), (None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"), - (None, "3ede90c44b2c161240b659f3b8393c9d", ["flux_value_controller"], [SingleValueEncoder], "civitai"), + (None, "0629116fce1472503a66992f96f3eb1a", ["flux_value_controller"], [SingleValueEncoder], "civitai"), (None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"), (None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"), (None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"), diff --git a/diffsynth/models/flux_value_control.py b/diffsynth/models/flux_value_control.py index 0ff68d3..6981344 100644 --- a/diffsynth/models/flux_value_control.py +++ b/diffsynth/models/flux_value_control.py @@ -18,7 +18,7 @@ class MultiValueEncoder(torch.nn.Module): class SingleValueEncoder(torch.nn.Module): - def __init__(self, dim_in=256, dim_out=3072, prefer_len=32, computation_device=None): + def __init__(self, dim_in=256, dim_out=4096, prefer_len=32, computation_device=None): super().__init__() self.prefer_len = prefer_len self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device) diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py index 811b119..330667c 100644 --- a/diffsynth/pipelines/flux_image_new.py +++ b/diffsynth/pipelines/flux_image_new.py @@ -855,18 +855,28 @@ class FluxImageUnit_InfiniteYou(PipelineUnit): class FluxImageUnit_ValueControl(PipelineUnit): def __init__(self): super().__init__( + seperate_cfg=True, + input_params_posi={"prompt_emb": "prompt_emb", "text_ids": "text_ids"}, + input_params_nega={"prompt_emb": "prompt_emb", "text_ids": "text_ids"}, input_params=("value_controller_inputs",), onload_model_names=("value_controller",) ) + + def add_to_text_embedding(self, prompt_emb, text_ids, value_emb): + prompt_emb = torch.concat([prompt_emb, value_emb], dim=1) + extra_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype) + text_ids = torch.concat([text_ids, extra_text_ids], dim=1) + return prompt_emb, text_ids - def process(self, pipe: FluxImagePipeline, value_controller_inputs): + def process(self, pipe: FluxImagePipeline, prompt_emb, text_ids, value_controller_inputs): if value_controller_inputs is None: return {} value_controller_inputs = torch.tensor(value_controller_inputs).to(dtype=pipe.torch_dtype, device=pipe.device) pipe.load_models_to_device(["value_controller"]) value_emb = pipe.value_controller(value_controller_inputs, pipe.torch_dtype) value_emb = value_emb.unsqueeze(0) - return {"value_emb": value_emb} + prompt_emb, text_ids = self.add_to_text_embedding(prompt_emb, text_ids, value_emb) + return {"prompt_emb": prompt_emb, "text_ids": text_ids} @@ -1049,7 +1059,6 @@ def model_fn_flux_image( flex_condition=None, flex_uncondition=None, flex_control_stop_timestep=None, - value_emb=None, step1x_llm_embedding=None, step1x_mask=None, step1x_reference_latents=None, @@ -1155,12 +1164,6 @@ def model_fn_flux_image( prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids) else: prompt_emb = dit.context_embedder(prompt_emb) - # Value Control - if value_emb is not None: - prompt_emb = torch.concat([prompt_emb, value_emb], dim=1) - value_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype) - text_ids = torch.concat([text_ids, value_text_ids], dim=1) - # Original FLUX inference image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) attention_mask = None diff --git a/examples/flux/model_inference/FLUX.1-dev-ValueControl.py b/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py similarity index 60% rename from examples/flux/model_inference/FLUX.1-dev-ValueControl.py rename to examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py index 0bb3ed0..7dd4574 100644 --- a/examples/flux/model_inference/FLUX.1-dev-ValueControl.py +++ b/examples/flux/model_inference/FLUX.1-dev-AttriCtrl.py @@ -10,11 +10,10 @@ pipe = FluxImagePipeline.from_pretrained( ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), - ModelConfig(model_id="DiffSynth-Studio/FLUX.1-dev-ValueController", origin_file_pattern="single/prefer_embed/value.ckpt") + ModelConfig(model_id="DiffSynth-Studio/AttriCtrl-FLUX.1-Dev", origin_file_pattern="models/brightness.safetensors") ], ) -pipe.load_lora(pipe.dit, ModelConfig(model_id="DiffSynth-Studio/FLUX.1-dev-ValueController", origin_file_pattern="single/dit_lora/dit_value.ckpt")) -for i in range(10): - image = pipe(prompt="a cat", seed=0, value_controller_inputs=[i/10]) - image.save(f"value_control_{i}.jpg") \ No newline at end of file +for i in [0.1, 0.3, 0.5, 0.7, 0.9]: + image = pipe(prompt="A woman.", seed=602, value_controller_inputs=[i], rand_device="cuda") + image.save(f"value_control_{i}.jpg")