mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-25 18:58:11 +00:00
@@ -6,10 +6,12 @@ import numpy as np
|
|||||||
|
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
|
"model_config": {
|
||||||
"Stable Diffusion": {
|
"Stable Diffusion": {
|
||||||
"model_folder": "models/stable_diffusion",
|
"model_folder": "models/stable_diffusion",
|
||||||
"pipeline_class": SDImagePipeline,
|
"pipeline_class": SDImagePipeline,
|
||||||
"default_parameters": {
|
"default_parameters": {
|
||||||
|
"cfg_scale": 7.0,
|
||||||
"height": 512,
|
"height": 512,
|
||||||
"width": 512,
|
"width": 512,
|
||||||
}
|
}
|
||||||
@@ -17,12 +19,16 @@ config = {
|
|||||||
"Stable Diffusion XL": {
|
"Stable Diffusion XL": {
|
||||||
"model_folder": "models/stable_diffusion_xl",
|
"model_folder": "models/stable_diffusion_xl",
|
||||||
"pipeline_class": SDXLImagePipeline,
|
"pipeline_class": SDXLImagePipeline,
|
||||||
"default_parameters": {}
|
"default_parameters": {
|
||||||
|
"cfg_scale": 7.0,
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"Stable Diffusion 3": {
|
"Stable Diffusion 3": {
|
||||||
"model_folder": "models/stable_diffusion_3",
|
"model_folder": "models/stable_diffusion_3",
|
||||||
"pipeline_class": SD3ImagePipeline,
|
"pipeline_class": SD3ImagePipeline,
|
||||||
"default_parameters": {}
|
"default_parameters": {
|
||||||
|
"cfg_scale": 7.0,
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"Stable Diffusion XL Turbo": {
|
"Stable Diffusion XL Turbo": {
|
||||||
"model_folder": "models/stable_diffusion_xl_turbo",
|
"model_folder": "models/stable_diffusion_xl_turbo",
|
||||||
@@ -38,12 +44,16 @@ config = {
|
|||||||
"Kolors": {
|
"Kolors": {
|
||||||
"model_folder": "models/kolors",
|
"model_folder": "models/kolors",
|
||||||
"pipeline_class": SDXLImagePipeline,
|
"pipeline_class": SDXLImagePipeline,
|
||||||
"default_parameters": {}
|
"default_parameters": {
|
||||||
|
"cfg_scale": 7.0,
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"HunyuanDiT": {
|
"HunyuanDiT": {
|
||||||
"model_folder": "models/HunyuanDiT",
|
"model_folder": "models/HunyuanDiT",
|
||||||
"pipeline_class": HunyuanDiTImagePipeline,
|
"pipeline_class": HunyuanDiTImagePipeline,
|
||||||
"default_parameters": {}
|
"default_parameters": {
|
||||||
|
"cfg_scale": 7.0,
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"FLUX": {
|
"FLUX": {
|
||||||
"model_folder": "models/FLUX",
|
"model_folder": "models/FLUX",
|
||||||
@@ -52,14 +62,16 @@ config = {
|
|||||||
"cfg_scale": 1.0,
|
"cfg_scale": 1.0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
"max_num_painter_layers": 8,
|
||||||
|
"max_num_model_cache": 1,
|
||||||
}
|
}
|
||||||
MAX_NUM_PAINTER_LAYERS = 8
|
|
||||||
|
|
||||||
|
|
||||||
def load_model_list(model_type):
|
def load_model_list(model_type):
|
||||||
if model_type is None:
|
if model_type is None:
|
||||||
return []
|
return []
|
||||||
folder = config[model_type]["model_folder"]
|
folder = config["model_config"][model_type]["model_folder"]
|
||||||
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
|
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
|
||||||
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
|
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
|
||||||
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
|
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
|
||||||
@@ -68,7 +80,11 @@ def load_model_list(model_type):
|
|||||||
|
|
||||||
|
|
||||||
def load_model(model_type, model_path):
|
def load_model(model_type, model_path):
|
||||||
model_path = os.path.join(config[model_type]["model_folder"], model_path)
|
global model_dict
|
||||||
|
model_key = f"{model_type}:{model_path}"
|
||||||
|
if model_key in model_dict:
|
||||||
|
return model_dict[model_key]
|
||||||
|
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
if model_type == "HunyuanDiT":
|
if model_type == "HunyuanDiT":
|
||||||
model_manager.load_models([
|
model_manager.load_models([
|
||||||
@@ -95,13 +111,18 @@ def load_model(model_type, model_path):
|
|||||||
model_manager.load_models(file_list)
|
model_manager.load_models(file_list)
|
||||||
else:
|
else:
|
||||||
model_manager.load_model(model_path)
|
model_manager.load_model(model_path)
|
||||||
pipe = config[model_type]["pipeline_class"].from_model_manager(model_manager)
|
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
|
||||||
|
while len(model_dict) + 1 > config["max_num_model_cache"]:
|
||||||
|
key = next(iter(model_dict.keys()))
|
||||||
|
model_manager_to_release, _ = model_dict[key]
|
||||||
|
model_manager_to_release.to("cpu")
|
||||||
|
del model_dict[key]
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
model_dict[model_key] = model_manager, pipe
|
||||||
return model_manager, pipe
|
return model_manager, pipe
|
||||||
|
|
||||||
|
|
||||||
|
model_dict = {}
|
||||||
model_manager: ModelManager = None
|
|
||||||
pipe = None
|
|
||||||
|
|
||||||
with gr.Blocks() as app:
|
with gr.Blocks() as app:
|
||||||
gr.Markdown("# DiffSynth-Studio Painter")
|
gr.Markdown("# DiffSynth-Studio Painter")
|
||||||
@@ -109,7 +130,7 @@ with gr.Blocks() as app:
|
|||||||
with gr.Column(scale=382, min_width=100):
|
with gr.Column(scale=382, min_width=100):
|
||||||
|
|
||||||
with gr.Accordion(label="Model"):
|
with gr.Accordion(label="Model"):
|
||||||
model_type = gr.Dropdown(choices=[i for i in config], label="Model type")
|
model_type = gr.Dropdown(choices=[i for i in config["model_config"]], label="Model type")
|
||||||
model_path = gr.Dropdown(choices=[], interactive=True, label="Model path")
|
model_path = gr.Dropdown(choices=[], interactive=True, label="Model path")
|
||||||
|
|
||||||
@gr.on(inputs=model_type, outputs=model_path, triggers=model_type.change)
|
@gr.on(inputs=model_type, outputs=model_path, triggers=model_type.change)
|
||||||
@@ -136,16 +157,12 @@ with gr.Blocks() as app:
|
|||||||
triggers=model_path.change
|
triggers=model_path.change
|
||||||
)
|
)
|
||||||
def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width):
|
def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width):
|
||||||
global model_manager, pipe
|
load_model(model_type, model_path)
|
||||||
if isinstance(model_manager, ModelManager):
|
cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale)
|
||||||
model_manager.to("cpu")
|
embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance)
|
||||||
torch.cuda.empty_cache()
|
num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps)
|
||||||
model_manager, pipe = load_model(model_type, model_path)
|
height = config["model_config"][model_type]["default_parameters"].get("height", height)
|
||||||
cfg_scale = config[model_type]["default_parameters"].get("cfg_scale", cfg_scale)
|
width = config["model_config"][model_type]["default_parameters"].get("width", width)
|
||||||
embedded_guidance = config[model_type]["default_parameters"].get("embedded_guidance", embedded_guidance)
|
|
||||||
num_inference_steps = config[model_type]["default_parameters"].get("num_inference_steps", num_inference_steps)
|
|
||||||
height = config[model_type]["default_parameters"].get("height", height)
|
|
||||||
width = config[model_type]["default_parameters"].get("width", width)
|
|
||||||
return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width
|
return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width
|
||||||
|
|
||||||
|
|
||||||
@@ -155,7 +172,7 @@ with gr.Blocks() as app:
|
|||||||
local_prompt_list = []
|
local_prompt_list = []
|
||||||
mask_scale_list = []
|
mask_scale_list = []
|
||||||
canvas_list = []
|
canvas_list = []
|
||||||
for painter_layer_id in range(MAX_NUM_PAINTER_LAYERS):
|
for painter_layer_id in range(config["max_num_painter_layers"]):
|
||||||
with gr.Tab(label=f"Layer {painter_layer_id}"):
|
with gr.Tab(label=f"Layer {painter_layer_id}"):
|
||||||
enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}")
|
enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}")
|
||||||
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
|
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
|
||||||
@@ -186,12 +203,12 @@ with gr.Blocks() as app:
|
|||||||
painter_background = gr.State(None)
|
painter_background = gr.State(None)
|
||||||
input_background = gr.State(None)
|
input_background = gr.State(None)
|
||||||
@gr.on(
|
@gr.on(
|
||||||
inputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list,
|
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list,
|
||||||
outputs=[output_image],
|
outputs=[output_image],
|
||||||
triggers=run_button.click
|
triggers=run_button.click
|
||||||
)
|
)
|
||||||
def generate_image(prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()):
|
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, seed, *args, progress=gr.Progress()):
|
||||||
global pipe
|
_, pipe = load_model(model_type, model_path)
|
||||||
input_params = {
|
input_params = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"negative_prompt": negative_prompt,
|
"negative_prompt": negative_prompt,
|
||||||
@@ -204,10 +221,10 @@ with gr.Blocks() as app:
|
|||||||
if isinstance(pipe, FluxImagePipeline):
|
if isinstance(pipe, FluxImagePipeline):
|
||||||
input_params["embedded_guidance"] = embedded_guidance
|
input_params["embedded_guidance"] = embedded_guidance
|
||||||
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
|
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
|
||||||
args[0 * MAX_NUM_PAINTER_LAYERS: 1 * MAX_NUM_PAINTER_LAYERS],
|
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
|
||||||
args[1 * MAX_NUM_PAINTER_LAYERS: 2 * MAX_NUM_PAINTER_LAYERS],
|
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
|
||||||
args[2 * MAX_NUM_PAINTER_LAYERS: 3 * MAX_NUM_PAINTER_LAYERS],
|
args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]],
|
||||||
args[3 * MAX_NUM_PAINTER_LAYERS: 4 * MAX_NUM_PAINTER_LAYERS]
|
args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]]
|
||||||
)
|
)
|
||||||
local_prompts, masks, mask_scales = [], [], []
|
local_prompts, masks, mask_scales = [], [], []
|
||||||
for enable_local_prompt, local_prompt, mask_scale, canvas in zip(
|
for enable_local_prompt, local_prompt, mask_scale, canvas in zip(
|
||||||
|
|||||||
Reference in New Issue
Block a user