value control

This commit is contained in:
Artiprocher
2025-07-04 14:23:07 +08:00
parent 1363a0559f
commit 6c30a7f080
4 changed files with 118 additions and 1 deletions

View File

@@ -18,6 +18,7 @@ from ..models import ModelManager, load_state_dict, SD3TextEncoder1, FluxTextEnc
from ..models.step1x_connector import Qwen2Connector
from ..models.flux_controlnet import FluxControlNet
from ..models.flux_ipadapter import FluxIpAdapter
from ..models.flux_value_control import MultiValueEncoder
from ..models.flux_infiniteyou import InfiniteYouImageProjector
from ..models.tiler import FastTileWorker
from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit
@@ -94,6 +95,7 @@ class FluxImagePipeline(BasePipeline):
self.unit_runner = PipelineUnitRunner()
self.qwenvl = None
self.step1x_connector: Qwen2Connector = None
self.value_controller: MultiValueEncoder = None
self.infinityou_processor: InfinitYou = None
self.image_proj_model: InfiniteYouImageProjector = None
self.in_iteration_models = ("dit", "step1x_connector", "controlnet")
@@ -112,6 +114,7 @@ class FluxImagePipeline(BasePipeline):
FluxImageUnit_TeaCache(),
FluxImageUnit_Flex(),
FluxImageUnit_Step1x(),
FluxImageUnit_ValueControl(),
]
self.model_fn = model_fn_flux_image
@@ -295,7 +298,16 @@ class FluxImagePipeline(BasePipeline):
for model_name, model in zip(model_manager.model_name, model_manager.model):
if model_name == "flux_controlnet":
controlnets.append(model)
pipe.controlnet = MultiControlNet(controlnets)
if len(controlnets) > 0:
pipe.controlnet = MultiControlNet(controlnets)
# Value Controller
value_controllers = []
for model_name, model in zip(model_manager.model_name, model_manager.model):
if model_name == "flux_value_controller":
value_controllers.append(model)
if len(value_controllers) > 0:
pipe.value_controller = MultiValueEncoder(value_controllers)
return pipe
@@ -347,6 +359,8 @@ class FluxImagePipeline(BasePipeline):
flex_control_image: Image.Image = None,
flex_control_strength: float = 0.5,
flex_control_stop: float = 0.5,
# Value Controller
value_controller_inputs: list[float] = None,
# Step1x
step1x_reference_image: Image.Image = None,
# TeaCache
@@ -380,6 +394,7 @@ class FluxImagePipeline(BasePipeline):
"eligen_entity_prompts": eligen_entity_prompts, "eligen_entity_masks": eligen_entity_masks, "eligen_enable_on_negative": eligen_enable_on_negative, "eligen_enable_inpaint": eligen_enable_inpaint,
"infinityou_id_image": infinityou_id_image, "infinityou_guidance": infinityou_guidance,
"flex_inpaint_image": flex_inpaint_image, "flex_inpaint_mask": flex_inpaint_mask, "flex_control_image": flex_control_image, "flex_control_strength": flex_control_strength, "flex_control_stop": flex_control_stop,
"value_controller_inputs": value_controller_inputs,
"step1x_reference_image": step1x_reference_image,
"tea_cache_l1_thresh": tea_cache_l1_thresh,
"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride,
@@ -720,6 +735,27 @@ class FluxImageUnit_InfiniteYou(PipelineUnit):
class FluxImageUnit_ValueControl(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
onload_model_names=("value_controller",)
)
def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega):
if inputs_shared.get("value_controller_inputs", None) is None:
return inputs_shared, inputs_posi, inputs_nega
value_controller_inputs = torch.tensor(inputs_shared["value_controller_inputs"]).to(dtype=pipe.torch_dtype, device=pipe.device)
pipe.load_models_to_device(["value_controller_inputs"])
value_emb = pipe.value_controller(value_controller_inputs, pipe.torch_dtype)
value_emb = value_emb.unsqueeze(0)
value_text_ids = torch.zeros((value_emb.shape[0], value_emb.shape[1], 3), device=value_emb.device, dtype=value_emb.dtype)
inputs_posi["prompt_emb"] = torch.concat([inputs_posi["prompt_emb"], value_emb], dim=1)
inputs_posi["text_ids"] = torch.concat([inputs_posi["text_ids"], value_text_ids], dim=1)
return inputs_shared, inputs_posi, inputs_nega
class InfinitYou:
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
from facexlib.recognition import init_recognition_model