diff --git a/apps/gradio/DiffSynth_Studio.py b/apps/gradio/DiffSynth_Studio.py index 05dfba3..d265492 100644 --- a/apps/gradio/DiffSynth_Studio.py +++ b/apps/gradio/DiffSynth_Studio.py @@ -6,60 +6,72 @@ import numpy as np config = { - "Stable Diffusion": { - "model_folder": "models/stable_diffusion", - "pipeline_class": SDImagePipeline, - "default_parameters": { - "height": 512, - "width": 512, + "model_config": { + "Stable Diffusion": { + "model_folder": "models/stable_diffusion", + "pipeline_class": SDImagePipeline, + "default_parameters": { + "cfg_scale": 7.0, + "height": 512, + "width": 512, + } + }, + "Stable Diffusion XL": { + "model_folder": "models/stable_diffusion_xl", + "pipeline_class": SDXLImagePipeline, + "default_parameters": { + "cfg_scale": 7.0, + } + }, + "Stable Diffusion 3": { + "model_folder": "models/stable_diffusion_3", + "pipeline_class": SD3ImagePipeline, + "default_parameters": { + "cfg_scale": 7.0, + } + }, + "Stable Diffusion XL Turbo": { + "model_folder": "models/stable_diffusion_xl_turbo", + "pipeline_class": SDXLImagePipeline, + "default_parameters": { + "negative_prompt": "", + "cfg_scale": 1.0, + "num_inference_steps": 1, + "height": 512, + "width": 512, + } + }, + "Kolors": { + "model_folder": "models/kolors", + "pipeline_class": SDXLImagePipeline, + "default_parameters": { + "cfg_scale": 7.0, + } + }, + "HunyuanDiT": { + "model_folder": "models/HunyuanDiT", + "pipeline_class": HunyuanDiTImagePipeline, + "default_parameters": { + "cfg_scale": 7.0, + } + }, + "FLUX": { + "model_folder": "models/FLUX", + "pipeline_class": FluxImagePipeline, + "default_parameters": { + "cfg_scale": 1.0, + } } }, - "Stable Diffusion XL": { - "model_folder": "models/stable_diffusion_xl", - "pipeline_class": SDXLImagePipeline, - "default_parameters": {} - }, - "Stable Diffusion 3": { - "model_folder": "models/stable_diffusion_3", - "pipeline_class": SD3ImagePipeline, - "default_parameters": {} - }, - "Stable Diffusion XL Turbo": { - "model_folder": "models/stable_diffusion_xl_turbo", - "pipeline_class": SDXLImagePipeline, - "default_parameters": { - "negative_prompt": "", - "cfg_scale": 1.0, - "num_inference_steps": 1, - "height": 512, - "width": 512, - } - }, - "Kolors": { - "model_folder": "models/kolors", - "pipeline_class": SDXLImagePipeline, - "default_parameters": {} - }, - "HunyuanDiT": { - "model_folder": "models/HunyuanDiT", - "pipeline_class": HunyuanDiTImagePipeline, - "default_parameters": {} - }, - "FLUX": { - "model_folder": "models/FLUX", - "pipeline_class": FluxImagePipeline, - "default_parameters": { - "cfg_scale": 1.0, - } - } + "max_num_painter_layers": 8, + "max_num_model_cache": 1, } -MAX_NUM_PAINTER_LAYERS = 8 def load_model_list(model_type): if model_type is None: return [] - folder = config[model_type]["model_folder"] + folder = config["model_config"][model_type]["model_folder"] file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")] if model_type in ["HunyuanDiT", "Kolors", "FLUX"]: file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))] @@ -68,7 +80,11 @@ def load_model_list(model_type): def load_model(model_type, model_path): - model_path = os.path.join(config[model_type]["model_folder"], model_path) + global model_dict + model_key = f"{model_type}:{model_path}" + if model_key in model_dict: + return model_dict[model_key] + model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path) model_manager = ModelManager() if model_type == "HunyuanDiT": model_manager.load_models([ @@ -95,13 +111,18 @@ def load_model(model_type, model_path): model_manager.load_models(file_list) else: model_manager.load_model(model_path) - pipe = config[model_type]["pipeline_class"].from_model_manager(model_manager) + pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager) + while len(model_dict) + 1 > config["max_num_model_cache"]: + key = next(iter(model_dict.keys())) + model_manager_to_release, _ = model_dict[key] + model_manager_to_release.to("cpu") + del model_dict[key] + torch.cuda.empty_cache() + model_dict[model_key] = model_manager, pipe return model_manager, pipe - -model_manager: ModelManager = None -pipe = None +model_dict = {} with gr.Blocks() as app: gr.Markdown("# DiffSynth-Studio Painter") @@ -109,7 +130,7 @@ with gr.Blocks() as app: with gr.Column(scale=382, min_width=100): with gr.Accordion(label="Model"): - model_type = gr.Dropdown(choices=[i for i in config], label="Model type") + model_type = gr.Dropdown(choices=[i for i in config["model_config"]], label="Model type") model_path = gr.Dropdown(choices=[], interactive=True, label="Model path") @gr.on(inputs=model_type, outputs=model_path, triggers=model_type.change) @@ -136,16 +157,12 @@ with gr.Blocks() as app: triggers=model_path.change ) def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width): - global model_manager, pipe - if isinstance(model_manager, ModelManager): - model_manager.to("cpu") - torch.cuda.empty_cache() - model_manager, pipe = load_model(model_type, model_path) - cfg_scale = config[model_type]["default_parameters"].get("cfg_scale", cfg_scale) - embedded_guidance = config[model_type]["default_parameters"].get("embedded_guidance", embedded_guidance) - num_inference_steps = config[model_type]["default_parameters"].get("num_inference_steps", num_inference_steps) - height = config[model_type]["default_parameters"].get("height", height) - width = config[model_type]["default_parameters"].get("width", width) + load_model(model_type, model_path) + cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale) + embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance) + num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps) + height = config["model_config"][model_type]["default_parameters"].get("height", height) + width = config["model_config"][model_type]["default_parameters"].get("width", width) return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width @@ -155,7 +172,7 @@ with gr.Blocks() as app: local_prompt_list = [] mask_scale_list = [] canvas_list = [] - for painter_layer_id in range(MAX_NUM_PAINTER_LAYERS): + for painter_layer_id in range(config["max_num_painter_layers"]): with gr.Tab(label=f"Layer {painter_layer_id}"): enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}") local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}") @@ -186,12 +203,12 @@ with gr.Blocks() as app: painter_background = gr.State(None) input_background = gr.State(None) @gr.on( - inputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list, + inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list, outputs=[output_image], triggers=run_button.click ) - def generate_image(prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()): - global pipe + def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()): + _, pipe = load_model(model_type, model_path) input_params = { "prompt": prompt, "negative_prompt": negative_prompt, @@ -204,10 +221,10 @@ with gr.Blocks() as app: if isinstance(pipe, FluxImagePipeline): input_params["embedded_guidance"] = embedded_guidance enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = ( - args[0 * MAX_NUM_PAINTER_LAYERS: 1 * MAX_NUM_PAINTER_LAYERS], - args[1 * MAX_NUM_PAINTER_LAYERS: 2 * MAX_NUM_PAINTER_LAYERS], - args[2 * MAX_NUM_PAINTER_LAYERS: 3 * MAX_NUM_PAINTER_LAYERS], - args[3 * MAX_NUM_PAINTER_LAYERS: 4 * MAX_NUM_PAINTER_LAYERS] + args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]], + args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]], + args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]], + args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]] ) local_prompts, masks, mask_scales = [], [], [] for enable_local_prompt, local_prompt, mask_scale, canvas in zip(