refine code

This commit is contained in:
Artiprocher
2025-07-29 18:47:16 +08:00
parent 7df48fc2b5
commit 9c51623fc2
14 changed files with 124 additions and 18 deletions

View File

@@ -375,6 +375,7 @@ class FluxImagePipeline(BasePipeline):
torch_dtype: torch.dtype = torch.bfloat16,
device: Union[str, torch.device] = "cuda",
model_configs: list[ModelConfig] = [],
nexus_gen_processor_config: ModelConfig = None,
):
# Download and load models
model_manager = ModelManager()
@@ -406,6 +407,9 @@ class FluxImagePipeline(BasePipeline):
pipe.nexus_gen = model_manager.fetch_model("nexus_gen_llm")
pipe.nexus_gen_generation_adapter = model_manager.fetch_model("nexus_gen_generation_adapter")
pipe.nexus_gen_editing_adapter = model_manager.fetch_model("nexus_gen_editing_adapter")
if nexus_gen_processor_config is not None and pipe.nexus_gen is not None:
nexus_gen_processor_config.download_if_necessary()
pipe.nexus_gen.load_processor(nexus_gen_processor_config.path)
# ControlNet
controlnets = []