model cache

This commit is contained in:
Artiprocher
2024-08-26 13:57:03 +08:00
parent 23f9675218
commit 547aca3db2

View File

@@ -6,60 +6,72 @@ import numpy as np
config = {
"Stable Diffusion": {
"model_folder": "models/stable_diffusion",
"pipeline_class": SDImagePipeline,
"default_parameters": {
"height": 512,
"width": 512,
"model_config": {
"Stable Diffusion": {
"model_folder": "models/stable_diffusion",
"pipeline_class": SDImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
"height": 512,
"width": 512,
}
},
"Stable Diffusion XL": {
"model_folder": "models/stable_diffusion_xl",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
}
},
"Stable Diffusion 3": {
"model_folder": "models/stable_diffusion_3",
"pipeline_class": SD3ImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
}
},
"Stable Diffusion XL Turbo": {
"model_folder": "models/stable_diffusion_xl_turbo",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {
"negative_prompt": "",
"cfg_scale": 1.0,
"num_inference_steps": 1,
"height": 512,
"width": 512,
}
},
"Kolors": {
"model_folder": "models/kolors",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
}
},
"HunyuanDiT": {
"model_folder": "models/HunyuanDiT",
"pipeline_class": HunyuanDiTImagePipeline,
"default_parameters": {
"cfg_scale": 7.0,
}
},
"FLUX": {
"model_folder": "models/FLUX",
"pipeline_class": FluxImagePipeline,
"default_parameters": {
"cfg_scale": 1.0,
}
}
},
"Stable Diffusion XL": {
"model_folder": "models/stable_diffusion_xl",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {}
},
"Stable Diffusion 3": {
"model_folder": "models/stable_diffusion_3",
"pipeline_class": SD3ImagePipeline,
"default_parameters": {}
},
"Stable Diffusion XL Turbo": {
"model_folder": "models/stable_diffusion_xl_turbo",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {
"negative_prompt": "",
"cfg_scale": 1.0,
"num_inference_steps": 1,
"height": 512,
"width": 512,
}
},
"Kolors": {
"model_folder": "models/kolors",
"pipeline_class": SDXLImagePipeline,
"default_parameters": {}
},
"HunyuanDiT": {
"model_folder": "models/HunyuanDiT",
"pipeline_class": HunyuanDiTImagePipeline,
"default_parameters": {}
},
"FLUX": {
"model_folder": "models/FLUX",
"pipeline_class": FluxImagePipeline,
"default_parameters": {
"cfg_scale": 1.0,
}
}
"max_num_painter_layers": 8,
"max_num_model_cache": 1,
}
MAX_NUM_PAINTER_LAYERS = 8
def load_model_list(model_type):
if model_type is None:
return []
folder = config[model_type]["model_folder"]
folder = config["model_config"][model_type]["model_folder"]
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
@@ -68,7 +80,11 @@ def load_model_list(model_type):
def load_model(model_type, model_path):
model_path = os.path.join(config[model_type]["model_folder"], model_path)
global model_dict
model_key = f"{model_type}:{model_path}"
if model_key in model_dict:
return model_dict[model_key]
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
model_manager = ModelManager()
if model_type == "HunyuanDiT":
model_manager.load_models([
@@ -95,13 +111,18 @@ def load_model(model_type, model_path):
model_manager.load_models(file_list)
else:
model_manager.load_model(model_path)
pipe = config[model_type]["pipeline_class"].from_model_manager(model_manager)
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
while len(model_dict) + 1 > config["max_num_model_cache"]:
key = next(iter(model_dict.keys()))
model_manager_to_release, _ = model_dict[key]
model_manager_to_release.to("cpu")
del model_dict[key]
torch.cuda.empty_cache()
model_dict[model_key] = model_manager, pipe
return model_manager, pipe
model_manager: ModelManager = None
pipe = None
model_dict = {}
with gr.Blocks() as app:
gr.Markdown("# DiffSynth-Studio Painter")
@@ -109,7 +130,7 @@ with gr.Blocks() as app:
with gr.Column(scale=382, min_width=100):
with gr.Accordion(label="Model"):
model_type = gr.Dropdown(choices=[i for i in config], label="Model type")
model_type = gr.Dropdown(choices=[i for i in config["model_config"]], label="Model type")
model_path = gr.Dropdown(choices=[], interactive=True, label="Model path")
@gr.on(inputs=model_type, outputs=model_path, triggers=model_type.change)
@@ -136,16 +157,12 @@ with gr.Blocks() as app:
triggers=model_path.change
)
def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width):
global model_manager, pipe
if isinstance(model_manager, ModelManager):
model_manager.to("cpu")
torch.cuda.empty_cache()
model_manager, pipe = load_model(model_type, model_path)
cfg_scale = config[model_type]["default_parameters"].get("cfg_scale", cfg_scale)
embedded_guidance = config[model_type]["default_parameters"].get("embedded_guidance", embedded_guidance)
num_inference_steps = config[model_type]["default_parameters"].get("num_inference_steps", num_inference_steps)
height = config[model_type]["default_parameters"].get("height", height)
width = config[model_type]["default_parameters"].get("width", width)
load_model(model_type, model_path)
cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale)
embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance)
num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps)
height = config["model_config"][model_type]["default_parameters"].get("height", height)
width = config["model_config"][model_type]["default_parameters"].get("width", width)
return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width
@@ -155,7 +172,7 @@ with gr.Blocks() as app:
local_prompt_list = []
mask_scale_list = []
canvas_list = []
for painter_layer_id in range(MAX_NUM_PAINTER_LAYERS):
for painter_layer_id in range(config["max_num_painter_layers"]):
with gr.Tab(label=f"Layer {painter_layer_id}"):
enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}")
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
@@ -186,12 +203,12 @@ with gr.Blocks() as app:
painter_background = gr.State(None)
input_background = gr.State(None)
@gr.on(
inputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list,
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list,
outputs=[output_image],
triggers=run_button.click
)
def generate_image(prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()):
global pipe
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()):
_, pipe = load_model(model_type, model_path)
input_params = {
"prompt": prompt,
"negative_prompt": negative_prompt,
@@ -204,10 +221,10 @@ with gr.Blocks() as app:
if isinstance(pipe, FluxImagePipeline):
input_params["embedded_guidance"] = embedded_guidance
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
args[0 * MAX_NUM_PAINTER_LAYERS: 1 * MAX_NUM_PAINTER_LAYERS],
args[1 * MAX_NUM_PAINTER_LAYERS: 2 * MAX_NUM_PAINTER_LAYERS],
args[2 * MAX_NUM_PAINTER_LAYERS: 3 * MAX_NUM_PAINTER_LAYERS],
args[3 * MAX_NUM_PAINTER_LAYERS: 4 * MAX_NUM_PAINTER_LAYERS]
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]],
args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]]
)
local_prompts, masks, mask_scales = [], [], []
for enable_local_prompt, local_prompt, mask_scale, canvas in zip(