mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge pull request #684 from modelscope/value_controller
support flux value controller
This commit is contained in:
@@ -64,6 +64,8 @@ from ..models.wan_video_vace import VaceWanModel
|
||||
|
||||
from ..models.step1x_connector import Qwen2Connector
|
||||
|
||||
from ..models.flux_value_control import SingleValueEncoder
|
||||
|
||||
from ..lora.flux_lora import FluxLoraPatcher
|
||||
|
||||
|
||||
@@ -104,6 +106,7 @@ model_loader_configs = [
|
||||
(None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"),
|
||||
(None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"),
|
||||
(None, "3ede90c44b2c161240b659f3b8393c9d", ["flux_value_controller"], [SingleValueEncoder], "civitai"),
|
||||
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
|
||||
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
|
||||
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),
|
||||
|
||||
59
diffsynth/models/flux_value_control.py
Normal file
59
diffsynth/models/flux_value_control.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
from diffsynth.models.svd_unet import TemporalTimesteps
|
||||
|
||||
|
||||
class MultiValueEncoder(torch.nn.Module):
|
||||
def __init__(self, encoders=()):
|
||||
super().__init__()
|
||||
self.encoders = torch.nn.ModuleList(encoders)
|
||||
|
||||
def __call__(self, values, dtype):
|
||||
emb = []
|
||||
for encoder, value in zip(self.encoders, values):
|
||||
if value is not None:
|
||||
value = value.unsqueeze(0)
|
||||
emb.append(encoder(value, dtype))
|
||||
emb = torch.concat(emb, dim=0)
|
||||
return emb
|
||||
|
||||
|
||||
class SingleValueEncoder(torch.nn.Module):
|
||||
def __init__(self, dim_in=256, dim_out=3072, prefer_len=32, computation_device=None):
|
||||
super().__init__()
|
||||
self.prefer_len = prefer_len
|
||||
self.prefer_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
|
||||
self.prefer_value_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||
)
|
||||
self.positional_embedding = torch.nn.Parameter(
|
||||
torch.randn(self.prefer_len, dim_out)
|
||||
)
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
last_linear = self.prefer_value_embedder[-1]
|
||||
torch.nn.init.zeros_(last_linear.weight)
|
||||
torch.nn.init.zeros_(last_linear.bias)
|
||||
|
||||
def forward(self, value, dtype):
|
||||
value = value * 1000
|
||||
emb = self.prefer_proj(value).to(dtype)
|
||||
emb = self.prefer_value_embedder(emb).squeeze(0)
|
||||
base_embeddings = emb.expand(self.prefer_len, -1)
|
||||
learned_embeddings = base_embeddings + self.positional_embedding
|
||||
return learned_embeddings
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SingleValueEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SingleValueEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
@@ -18,6 +18,7 @@ from ..models import ModelManager, load_state_dict, SD3TextEncoder1, FluxTextEnc
|
||||
from ..models.step1x_connector import Qwen2Connector
|
||||
from ..models.flux_controlnet import FluxControlNet
|
||||
from ..models.flux_ipadapter import FluxIpAdapter
|
||||
from ..models.flux_value_control import MultiValueEncoder
|
||||
from ..models.flux_infiniteyou import InfiniteYouImageProjector
|
||||
from ..models.tiler import FastTileWorker
|
||||
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
|
||||
@@ -93,6 +94,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.ipadapter_image_encoder = None
|
||||
self.qwenvl = None
|
||||
self.step1x_connector: Qwen2Connector = None
|
||||
self.value_controller: MultiValueEncoder = None
|
||||
self.infinityou_processor: InfinitYou = None
|
||||
self.image_proj_model: InfiniteYouImageProjector = None
|
||||
self.lora_patcher: FluxLoraPatcher = None
|
||||
@@ -113,6 +115,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
FluxImageUnit_TeaCache(),
|
||||
FluxImageUnit_Flex(),
|
||||
FluxImageUnit_Step1x(),
|
||||
FluxImageUnit_ValueControl(),
|
||||
]
|
||||
self.model_fn = model_fn_flux_image
|
||||
|
||||
@@ -341,7 +344,16 @@ class FluxImagePipeline(BasePipeline):
|
||||
for model_name, model in zip(model_manager.model_name, model_manager.model):
|
||||
if model_name == "flux_controlnet":
|
||||
controlnets.append(model)
|
||||
pipe.controlnet = MultiControlNet(controlnets)
|
||||
if len(controlnets) > 0:
|
||||
pipe.controlnet = MultiControlNet(controlnets)
|
||||
|
||||
# Value Controller
|
||||
value_controllers = []
|
||||
for model_name, model in zip(model_manager.model_name, model_manager.model):
|
||||
if model_name == "flux_value_controller":
|
||||
value_controllers.append(model)
|
||||
if len(value_controllers) > 0:
|
||||
pipe.value_controller = MultiValueEncoder(value_controllers)
|
||||
|
||||
return pipe
|
||||
|
||||
@@ -393,6 +405,8 @@ class FluxImagePipeline(BasePipeline):
|
||||
flex_control_image: Image.Image = None,
|
||||
flex_control_strength: float = 0.5,
|
||||
flex_control_stop: float = 0.5,
|
||||
# Value Controller
|
||||
value_controller_inputs: list[float] = None,
|
||||
# Step1x
|
||||
step1x_reference_image: Image.Image = None,
|
||||
# TeaCache
|
||||
@@ -426,6 +440,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_enable_inpaint": eligen_enable_inpaint,
|
||||
"infinityou_id_image": infinityou_id_image, "infinityou_guidance": infinityou_guidance,
|
||||
"flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop,
|
||||
"value_controller_inputs": value_controller_inputs,
|
||||
"step1x_reference_image": step1x_reference_image,
|
||||
"tea_cache_l1_thresh": tea_cache_l1_thresh,
|
||||
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
|
||||
@@ -724,7 +739,7 @@ class FluxImageUnit_Flex(PipelineUnit):
|
||||
super().__init__(
|
||||
input_params=("latents", "flex_inpaint_image", "flex_inpaint_mask", "flex_control_image", "flex_control_strength", "flex_control_stop", "tiled", "tile_size", "tile_stride"),
|
||||
onload_model_names=("vae_encoder",)
|
||||
)
|
||||
)
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, flex_control_strength, flex_control_stop, tiled, tile_size, tile_stride):
|
||||
if pipe.dit.input_dim == 196:
|
||||
@@ -769,6 +784,24 @@ class FluxImageUnit_InfiniteYou(PipelineUnit):
|
||||
|
||||
|
||||
|
||||
class FluxImageUnit_ValueControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("value_controller_inputs",),
|
||||
onload_model_names=("value_controller",)
|
||||
)
|
||||
|
||||
def process(self, pipe: FluxImagePipeline, value_controller_inputs):
|
||||
if value_controller_inputs is None:
|
||||
return {}
|
||||
value_controller_inputs = torch.tensor(value_controller_inputs).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
pipe.load_models_to_device(["value_controller"])
|
||||
value_emb = pipe.value_controller(value_controller_inputs, pipe.torch_dtype)
|
||||
value_emb = value_emb.unsqueeze(0)
|
||||
return {"value_emb": value_emb}
|
||||
|
||||
|
||||
|
||||
class InfinitYou(torch.nn.Module):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__()
|
||||
@@ -888,6 +921,7 @@ def model_fn_flux_image(
|
||||
flex_condition=None,
|
||||
flex_uncondition=None,
|
||||
flex_control_stop_timestep=None,
|
||||
value_emb=None,
|
||||
step1x_llm_embedding=None,
|
||||
step1x_mask=None,
|
||||
step1x_reference_latents=None,
|
||||
@@ -988,10 +1022,17 @@ def model_fn_flux_image(
|
||||
|
||||
hidden_states = dit.x_embedder(hidden_states)
|
||||
|
||||
# EliGen
|
||||
if entity_prompt_emb is not None and entity_masks is not None:
|
||||
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||
else:
|
||||
prompt_emb = dit.context_embedder(prompt_emb)
|
||||
# Value Control
|
||||
if value_emb is not None:
|
||||
prompt_emb = torch.concat([prompt_emb, value_emb], dim=1)
|
||||
value_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype)
|
||||
text_ids = torch.concat([text_ids, value_text_ids], dim=1)
|
||||
# Original FLUX inference
|
||||
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
attention_mask = None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user