support value controller training

This commit is contained in:
Artiprocher
2025-07-21 19:16:30 +08:00
parent 22705a44b4
commit e3c5d2540b
7 changed files with 79 additions and 7 deletions

View File

@@ -466,7 +466,7 @@ class FluxImagePipeline(BasePipeline):
flex_control_strength: float = 0.5,
flex_control_stop: float = 0.5,
# Value Controller
value_controller_inputs: list[float] = None,
value_controller_inputs: Union[list[float], float] = None,
# Step1x
step1x_reference_image: Image.Image = None,
# LoRA Encoder
@@ -871,6 +871,8 @@ class FluxImageUnit_ValueControl(PipelineUnit):
def process(self, pipe: FluxImagePipeline, prompt_emb, text_ids, value_controller_inputs):
if value_controller_inputs is None:
return {}
if not isinstance(value_controller_inputs, list):
value_controller_inputs = [value_controller_inputs]
value_controller_inputs = torch.tensor(value_controller_inputs).to(dtype=pipe.torch_dtype, device=pipe.device)
pipe.load_models_to_device(["value_controller"])
value_emb = pipe.value_controller(value_controller_inputs, pipe.torch_dtype)