This commit is contained in:
Artiprocher
2025-11-21 20:26:49 +08:00
parent 0b7dd55ff3
commit 0336551544
13 changed files with 46 additions and 20 deletions

View File

@@ -202,7 +202,7 @@ class DiffusionTrainingModule(torch.nn.Module):
if name not in controlnet_inputs:
controlnet_inputs[name] = {}
controlnet_inputs[name][extra_input.replace(prefix, "")] = data[extra_input]
break
break
else:
inputs_shared[extra_input] = data[extra_input]
for name, params in controlnet_inputs.items():

View File

@@ -145,7 +145,8 @@ class FluxImagePipeline(BasePipeline):
value_controllers = model_pool.fetch_model("flux_value_controller")
if value_controllers is not None:
pipe.value_controller = MultiValueEncoder(value_controllers)
pipe.value_controller.vram_management_enabled = pipe.value_controller.encoders[0].vram_management_enabled
if hasattr(pipe.value_controller.encoders[0], "vram_management_enabled"):
pipe.value_controller.vram_management_enabled = pipe.value_controller.encoders[0].vram_management_enabled
controlnets = model_pool.fetch_model("flux_controlnet")
if controlnets is not None: pipe.controlnet = MultiControlNet(controlnets)
pipe.ipadapter = model_pool.fetch_model("flux_ipadapter")
@@ -295,7 +296,7 @@ class FluxImagePipeline(BasePipeline):
class FluxImageUnit_ShapeChecker(PipelineUnit):
def __init__(self):
super().__init__(input_params=("height", "width"))
super().__init__(input_params=("height", "width"), output_params=("height", "width"))
def process(self, pipe: FluxImagePipeline, height, width):
height, width = pipe.check_resize_height_width(height, width)
@@ -305,7 +306,7 @@ class FluxImageUnit_ShapeChecker(PipelineUnit):
class FluxImageUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(input_params=("height", "width", "seed", "rand_device"))
super().__init__(input_params=("height", "width", "seed", "rand_device"), output_params=("noise",))
def process(self, pipe: FluxImagePipeline, height, width, seed, rand_device):
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device)
@@ -317,6 +318,7 @@ class FluxImageUnit_InputImageEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
output_params=("latents", "input_latents"),
onload_model_names=("vae_encoder",)
)
@@ -341,6 +343,7 @@ class FluxImageUnit_PromptEmbedder(PipelineUnit):
input_params_posi={"prompt": "prompt", "positive": "positive"},
input_params_nega={"prompt": "negative_prompt", "positive": "positive"},
input_params=("t5_sequence_length",),
output_params=("prompt_emb", "pooled_prompt_emb", "text_ids"),
onload_model_names=("text_encoder_1", "text_encoder_2")
)
@@ -396,7 +399,7 @@ class FluxImageUnit_PromptEmbedder(PipelineUnit):
class FluxImageUnit_ImageIDs(PipelineUnit):
def __init__(self):
super().__init__(input_params=("latents",))
super().__init__(input_params=("latents",), output_params=("image_ids",))
def process(self, pipe: FluxImagePipeline, latents):
latent_image_ids = pipe.dit.prepare_image_ids(latents)
@@ -406,7 +409,7 @@ class FluxImageUnit_ImageIDs(PipelineUnit):
class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit):
def __init__(self):
super().__init__(input_params=("embedded_guidance", "latents"))
super().__init__(input_params=("embedded_guidance", "latents"), output_params=("guidance",))
def process(self, pipe: FluxImagePipeline, embedded_guidance, latents):
guidance = torch.Tensor([embedded_guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
@@ -416,7 +419,11 @@ class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit):
class FluxImageUnit_Kontext(PipelineUnit):
def __init__(self):
super().__init__(input_params=("kontext_images", "tiled", "tile_size", "tile_stride"))
super().__init__(
input_params=("kontext_images", "tiled", "tile_size", "tile_stride"),
output_params=("kontext_latents", "kontext_image_ids"),
onload_model_names=("vae_encoder",)
)
def process(self, pipe: FluxImagePipeline, kontext_images, tiled, tile_size, tile_stride):
if kontext_images is None:
@@ -444,6 +451,7 @@ class FluxImageUnit_ControlNet(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("controlnet_inputs", "tiled", "tile_size", "tile_stride"),
output_params=("controlnet_conditionings",),
onload_model_names=("vae_encoder",)
)
@@ -486,6 +494,8 @@ class FluxImageUnit_IPAdapter(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
input_params=("ipadapter_images", "ipadapter_scale"),
output_params=("ipadapter_kwargs_list",),
onload_model_names=("ipadapter_image_encoder", "ipadapter")
)
@@ -513,6 +523,8 @@ class FluxImageUnit_EntityControl(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
input_params=("eligen_entity_prompts", "eligen_entity_masks", "eligen_enable_on_negative", "width", "height", "t5_sequence_length", "cfg_scale"),
output_params=("entity_prompt_emb", "entity_masks"),
onload_model_names=("text_encoder_1", "text_encoder_2")
)
@@ -603,6 +615,8 @@ class FluxImageUnit_NexusGen(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
input_params=("nexus_gen_reference_image", "prompt", "latents"),
output_params=("prompt_emb", "text_ids"),
onload_model_names=("nexus_gen", "nexus_gen_generation_adapter", "nexus_gen_editing_adapter"),
)
@@ -655,7 +669,12 @@ class FluxImageUnit_NexusGen(PipelineUnit):
class FluxImageUnit_Step1x(PipelineUnit):
def __init__(self):
super().__init__(take_over=True,onload_model_names=("qwenvl","vae_encoder"))
super().__init__(
take_over=True,
input_params=("step1x_reference_image", "prompt", "negative_prompt"),
output_params=("step1x_llm_embedding", "step1x_mask", "step1x_reference_latents"),
onload_model_names=("qwenvl","vae_encoder")
)
def process(self, pipe: FluxImagePipeline, inputs_shared: dict, inputs_posi: dict, inputs_nega: dict):
image = inputs_shared.get("step1x_reference_image",None)
@@ -678,7 +697,7 @@ class FluxImageUnit_Step1x(PipelineUnit):
class FluxImageUnit_TeaCache(PipelineUnit):
def __init__(self):
super().__init__(input_params=("num_inference_steps","tea_cache_l1_thresh"))
super().__init__(input_params=("num_inference_steps","tea_cache_l1_thresh"), output_params=("tea_cache",))
def process(self, pipe: FluxImagePipeline, num_inference_steps, tea_cache_l1_thresh):
if tea_cache_l1_thresh is None:
@@ -690,6 +709,7 @@ class FluxImageUnit_Flex(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("latents", "flex_inpaint_image", "flex_inpaint_mask", "flex_control_image", "flex_control_strength", "flex_control_stop", "tiled", "tile_size", "tile_stride"),
output_params=("flex_condition", "flex_uncondition", "flex_control_stop_timestep"),
onload_model_names=("vae_encoder",)
)
@@ -728,6 +748,7 @@ class FluxImageUnit_InfiniteYou(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("infinityou_id_image", "infinityou_guidance"),
output_params=("id_emb", "infinityou_guidance"),
onload_model_names=("infinityou_processor",)
)
@@ -747,6 +768,7 @@ class FluxImageUnit_ValueControl(PipelineUnit):
input_params_posi={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
input_params_nega={"prompt_emb": "prompt_emb", "text_ids": "text_ids"},
input_params=("value_controller_inputs",),
output_params=("prompt_emb", "text_ids"),
onload_model_names=("value_controller",)
)
@@ -825,6 +847,8 @@ class FluxImageUnit_LoRAEncode(PipelineUnit):
def __init__(self):
super().__init__(
take_over=True,
input_params=("lora_encoder_inputs", "lora_encoder_scale"),
output_params=("prompt_emb", "text_ids"),
onload_model_names=("lora_encoder",)
)