From 0a24d0819fba499cd8f745c62cb969e073000cae Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 14 Jul 2025 13:37:55 +0800 Subject: [PATCH] support flux value controller --- diffsynth/configs/model_config.py | 3 + diffsynth/models/flux_value_control.py | 59 +++++++++++++++++++ diffsynth/pipelines/flux_image_new.py | 45 +++++++++++++- .../FLUX.1-dev-ValueControl.py | 20 +++++++ 4 files changed, 125 insertions(+), 2 deletions(-) create mode 100644 diffsynth/models/flux_value_control.py create mode 100644 examples/flux/model_inference/FLUX.1-dev-ValueControl.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 0713b7c..7a0b72b 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -64,6 +64,8 @@ from ..models.wan_video_vace import VaceWanModel from ..models.step1x_connector import Qwen2Connector +from ..models.flux_value_control import SingleValueEncoder + from ..lora.flux_lora import FluxLoraPatcher @@ -104,6 +106,7 @@ model_loader_configs = [ (None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"), (None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"), (None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"), + (None, "3ede90c44b2c161240b659f3b8393c9d", ["flux_value_controller"], [SingleValueEncoder], "civitai"), (None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"), (None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"), (None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"), diff --git a/diffsynth/models/flux_value_control.py b/diffsynth/models/flux_value_control.py new file mode 100644 index 0000000..54eaa07 --- /dev/null +++ b/diffsynth/models/flux_value_control.py @@ -0,0 +1,59 @@ +import torch +from diffsynth.models.svd_unet import TemporalTimesteps + + +class MultiValueEncoder(torch.nn.Module): + def __init__(self, encoders=()): + super().__init__() + self.encoders = torch.nn.ModuleList(encoders) + + def __call__(self, values, dtype): + emb = [] + for encoder, value in zip(self.encoders, values): + if value is not None: + value = value.unsqueeze(0) + emb.append(encoder(value, dtype)) + emb = torch.concat(emb, dim=0) + return emb + + +class SingleValueEncoder(torch.nn.Module): + def __init__(self, dim_in=256, dim_out=3072, prefer_len=32, computation_device=None): + super().__init__() + self.prefer_len = prefer_len + self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device) + self.prefer_value_embedder = torch.nn.Sequential( + torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out) + ) + self.positional_embedding = torch.nn.Parameter( + torch.randn(self.prefer_len, dim_out) + ) + self._initialize_weights() + + def _initialize_weights(self): + last_linear = self.prefer_value_embedder[-1] + torch.nn.init.zeros_(last_linear.weight) + torch.nn.init.zeros_(last_linear.bias) + + def forward(self, value, dtype): + value = value * 1000 + emb = self.prefer_proj(value).to(dtype) + emb = self.prefer_value_embedder(emb).squeeze(0) + base_embeddings = emb.expand(self.prefer_len, -1) + learned_embeddings = base_embeddings + self.positional_embedding + return learned_embeddings + + @staticmethod + def state_dict_converter(): + return SingleValueEncoderStateDictConverter() + + +class SingleValueEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + return state_dict diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py index 4b7c68d..b7f13a1 100644 --- a/diffsynth/pipelines/flux_image_new.py +++ b/diffsynth/pipelines/flux_image_new.py @@ -18,6 +18,7 @@ from ..models import ModelManager, load_state_dict, SD3TextEncoder1, FluxTextEnc from ..models.step1x_connector import Qwen2Connector from ..models.flux_controlnet import FluxControlNet from ..models.flux_ipadapter import FluxIpAdapter +from ..models.flux_value_control import MultiValueEncoder from ..models.flux_infiniteyou import InfiniteYouImageProjector from ..models.tiler import FastTileWorker from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit @@ -93,6 +94,7 @@ class FluxImagePipeline(BasePipeline): self.ipadapter_image_encoder = None self.qwenvl = None self.step1x_connector: Qwen2Connector = None + self.value_controller: MultiValueEncoder = None self.infinityou_processor: InfinitYou = None self.image_proj_model: InfiniteYouImageProjector = None self.lora_patcher: FluxLoraPatcher = None @@ -113,6 +115,7 @@ class FluxImagePipeline(BasePipeline): FluxImageUnit_TeaCache(), FluxImageUnit_Flex(), FluxImageUnit_Step1x(), + FluxImageUnit_ValueControl(), ] self.model_fn = model_fn_flux_image @@ -341,7 +344,16 @@ class FluxImagePipeline(BasePipeline): for model_name, model in zip(model_manager.model_name, model_manager.model): if model_name == "flux_controlnet": controlnets.append(model) - pipe.controlnet = MultiControlNet(controlnets) + if len(controlnets) > 0: + pipe.controlnet = MultiControlNet(controlnets) + + # Value Controller + value_controllers = [] + for model_name, model in zip(model_manager.model_name, model_manager.model): + if model_name == "flux_value_controller": + value_controllers.append(model) + if len(value_controllers) > 0: + pipe.value_controller = MultiValueEncoder(value_controllers) return pipe @@ -393,6 +405,8 @@ class FluxImagePipeline(BasePipeline): flex_control_image: Image.Image = None, flex_control_strength: float = 0.5, flex_control_stop: float = 0.5, + # Value Controller + value_controller_inputs: list[float] = None, # Step1x step1x_reference_image: Image.Image = None, # TeaCache @@ -426,6 +440,7 @@ class FluxImagePipeline(BasePipeline): "eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_enable_inpaint": eligen_enable_inpaint, "infinityou_id_image": infinityou_id_image, "infinityou_guidance": infinityou_guidance, "flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop, + "value_controller_inputs": value_controller_inputs, "step1x_reference_image": step1x_reference_image, "tea_cache_l1_thresh": tea_cache_l1_thresh, "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, @@ -724,7 +739,7 @@ class FluxImageUnit_Flex(PipelineUnit): super().__init__( input_params=("latents", "flex_inpaint_image", "flex_inpaint_mask", "flex_control_image", "flex_control_strength", "flex_control_stop", "tiled", "tile_size", "tile_stride"), onload_model_names=("vae_encoder",) - ) + ) def process(self, pipe: FluxImagePipeline, latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength, flex_control_stop, tiled, tile_size, tile_stride): if pipe.dit.input_dim == 196: @@ -769,6 +784,24 @@ class FluxImageUnit_InfiniteYou(PipelineUnit): +class FluxImageUnit_ValueControl(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("value_controller_inputs",), + onload_model_names=("value_controller",) + ) + + def process(self, pipe: FluxImagePipeline, value_controller_inputs): + if value_controller_inputs is None: + return {} + value_controller_inputs = torch.tensor(value_controller_inputs).to(dtype=pipe.torch_dtype, device=pipe.device) + pipe.load_models_to_device(["value_controller"]) + value_emb = pipe.value_controller(value_controller_inputs, pipe.torch_dtype) + value_emb = value_emb.unsqueeze(0) + return {"value_emb": value_emb} + + + class InfinitYou(torch.nn.Module): def __init__(self, device="cuda", torch_dtype=torch.bfloat16): super().__init__() @@ -888,6 +921,7 @@ def model_fn_flux_image( flex_condition=None, flex_uncondition=None, flex_control_stop_timestep=None, + value_emb=None, step1x_llm_embedding=None, step1x_mask=None, step1x_reference_latents=None, @@ -988,10 +1022,17 @@ def model_fn_flux_image( hidden_states = dit.x_embedder(hidden_states) + # EliGen if entity_prompt_emb is not None and entity_masks is not None: prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids) else: prompt_emb = dit.context_embedder(prompt_emb) + # Value Control + if value_emb is not None: + prompt_emb = torch.concat([prompt_emb, value_emb], dim=1) + value_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype) + text_ids = torch.concat([text_ids, value_text_ids], dim=1) + # Original FLUX inference image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) attention_mask = None diff --git a/examples/flux/model_inference/FLUX.1-dev-ValueControl.py b/examples/flux/model_inference/FLUX.1-dev-ValueControl.py new file mode 100644 index 0000000..0bb3ed0 --- /dev/null +++ b/examples/flux/model_inference/FLUX.1-dev-ValueControl.py @@ -0,0 +1,20 @@ +import torch +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="DiffSynth-Studio/FLUX.1-dev-ValueController", origin_file_pattern="single/prefer_embed/value.ckpt") + ], +) +pipe.load_lora(pipe.dit, ModelConfig(model_id="DiffSynth-Studio/FLUX.1-dev-ValueController", origin_file_pattern="single/dit_lora/dit_value.ckpt")) + +for i in range(10): + image = pipe(prompt="a cat", seed=0, value_controller_inputs=[i/10]) + image.save(f"value_control_{i}.jpg") \ No newline at end of file