mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
update eligen ui
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
from diffsynth import ModelManager, FluxImagePipeline
|
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
|
||||||
import os, torch
|
import os, torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -7,11 +7,8 @@ from PIL import ImageDraw, ImageFont
|
|||||||
import random
|
import random
|
||||||
import json
|
import json
|
||||||
|
|
||||||
lora_checkpoint_path = 'models/lora/entity_control/model_bf16.safetensors'
|
|
||||||
save_masks_dir = 'workdirs/tmp_mask'
|
|
||||||
|
|
||||||
def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
|
def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
|
||||||
save_dir = os.path.join(save_masks_dir, random_dir)
|
save_dir = os.path.join('workdirs/tmp_mask', random_dir)
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
for i, mask in enumerate(masks):
|
for i, mask in enumerate(masks):
|
||||||
save_path = os.path.join(save_dir, f'{i}.png')
|
save_path = os.path.join(save_dir, f'{i}.png')
|
||||||
@@ -79,7 +76,7 @@ config = {
|
|||||||
"default_parameters": {
|
"default_parameters": {
|
||||||
"cfg_scale": 3.0,
|
"cfg_scale": 3.0,
|
||||||
"embedded_guidance": 3.5,
|
"embedded_guidance": 3.5,
|
||||||
"num_inference_steps": 30,
|
"num_inference_steps": 50,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -109,17 +106,15 @@ def load_model(model_type, model_path):
|
|||||||
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
|
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
|
||||||
model_manager = ModelManager()
|
model_manager = ModelManager()
|
||||||
if model_type == "FLUX":
|
if model_type == "FLUX":
|
||||||
model_manager.torch_dtype = torch.bfloat16
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||||
file_list = [
|
model_manager.load_lora(
|
||||||
os.path.join(model_path, "text_encoder/model.safetensors"),
|
download_customized_models(
|
||||||
os.path.join(model_path, "text_encoder_2"),
|
model_id="DiffSynth-Studio/Eligen",
|
||||||
]
|
origin_file_path="model_bf16.safetensors",
|
||||||
for file_name in os.listdir(model_path):
|
local_dir="models/lora/entity_control",
|
||||||
if file_name.endswith(".safetensors"):
|
),
|
||||||
file_list.append(os.path.join(model_path, file_name))
|
lora_alpha=1,
|
||||||
model_manager.load_models(file_list)
|
)
|
||||||
model_manager.load_lora(lora_checkpoint_path, lora_alpha=1.)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
model_manager.load_model(model_path)
|
model_manager.load_model(model_path)
|
||||||
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
|
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
|
||||||
@@ -137,7 +132,7 @@ with gr.Blocks() as app:
|
|||||||
gr.Markdown("""
|
gr.Markdown("""
|
||||||
# 实体级控制文生图模型EliGen
|
# 实体级控制文生图模型EliGen
|
||||||
**UI说明**
|
**UI说明**
|
||||||
1. 点击Load model读取模型,然后左侧界面为文生图输入参数;右侧Painter为局部控制区域绘制区域,每个局部控制条件由其Local prompt和绘制的mask组成,支持精准控制文生图和Inpainting两种模式。
|
1. **点击Load model读取模型**,然后左侧界面为文生图输入参数;右侧Painter为局部控制区域绘制区域,每个局部控制条件由其Local prompt和绘制的mask组成,支持精准控制文生图和Inpainting两种模式。
|
||||||
2. **精准控制生图模式:** 输入Globalprompt;激活并绘制一个或多个局部控制条件,点击Generate生成图像; Global Prompt推荐包含每个Local Prompt。
|
2. **精准控制生图模式:** 输入Globalprompt;激活并绘制一个或多个局部控制条件,点击Generate生成图像; Global Prompt推荐包含每个Local Prompt。
|
||||||
3. **Inpainting模式:** 你可以上传图像,或者将上一步生成的图像设置为Inpaint Input Image,采用类似的方式输入局部控制条件,进行局部重绘。
|
3. **Inpainting模式:** 你可以上传图像,或者将上一步生成的图像设置为Inpaint Input Image,采用类似的方式输入局部控制条件,进行局部重绘。
|
||||||
4. 尽情创造!
|
4. 尽情创造!
|
||||||
@@ -145,7 +140,7 @@ with gr.Blocks() as app:
|
|||||||
gr.Markdown("""
|
gr.Markdown("""
|
||||||
# Entity-Level Controlled Text-to-Image Model: EliGen
|
# Entity-Level Controlled Text-to-Image Model: EliGen
|
||||||
**UI Instructions**
|
**UI Instructions**
|
||||||
1. Click "Load model" to load the model. The left interface is for text-to-image input parameters; the right "Painter" is the area for drawing local control regions. Each local control condition consists of its Local Prompt and the drawn mask, supporting both precise control text-to-image and Inpainting modes.
|
1. **Click "Load model" to load the model.** The left interface is for text-to-image input parameters; the right "Painter" is the area for drawing local control regions. Each local control condition consists of its Local Prompt and the drawn mask, supporting both precise control text-to-image and Inpainting modes.
|
||||||
2. **Precise Control Image Generation Mode:** Enter the Global Prompt; activate and draw one or more local control conditions, then click "Generate" to create the image. It is recommended that the Global Prompt includes all Local Prompts.
|
2. **Precise Control Image Generation Mode:** Enter the Global Prompt; activate and draw one or more local control conditions, then click "Generate" to create the image. It is recommended that the Global Prompt includes all Local Prompts.
|
||||||
3. **Inpainting Mode:** You can upload an image or set the image generated in the previous step as the "Inpaint Input Image." Use a similar method to input local control conditions for local redrawing.
|
3. **Inpainting Mode:** You can upload an image or set the image generated in the previous step as the "Inpaint Input Image." Use a similar method to input local control conditions for local redrawing.
|
||||||
4. Enjoy!
|
4. Enjoy!
|
||||||
@@ -241,7 +236,6 @@ with gr.Blocks() as app:
|
|||||||
triggers=run_button.click
|
triggers=run_button.click
|
||||||
)
|
)
|
||||||
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
|
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
|
||||||
print("save to:", random_mask_dir.value)
|
|
||||||
_, pipe = load_model(model_type, model_path)
|
_, pipe = load_model(model_type, model_path)
|
||||||
input_params = {
|
input_params = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
@@ -255,7 +249,8 @@ with gr.Blocks() as app:
|
|||||||
if isinstance(pipe, FluxImagePipeline):
|
if isinstance(pipe, FluxImagePipeline):
|
||||||
input_params["embedded_guidance"] = embedded_guidance
|
input_params["embedded_guidance"] = embedded_guidance
|
||||||
if input_image is not None:
|
if input_image is not None:
|
||||||
input_params["inpaint_input"] = input_image.resize((width, height)).convert("RGB")
|
input_params["input_image"] = input_image.resize((width, height)).convert("RGB")
|
||||||
|
input_params["enable_eligen_inpaint"] = True
|
||||||
|
|
||||||
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
|
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
|
||||||
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
|
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
|
||||||
@@ -274,8 +269,8 @@ with gr.Blocks() as app:
|
|||||||
entity_masks = None if len(masks) == 0 else masks
|
entity_masks = None if len(masks) == 0 else masks
|
||||||
entity_prompts = None if len(local_prompts) == 0 else local_prompts
|
entity_prompts = None if len(local_prompts) == 0 else local_prompts
|
||||||
input_params.update({
|
input_params.update({
|
||||||
"entity_prompts": entity_prompts,
|
"eligen_entity_prompts": entity_prompts,
|
||||||
"entity_masks": entity_masks,
|
"eligen_entity_masks": entity_masks,
|
||||||
})
|
})
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
image = pipe(**input_params)
|
image = pipe(**input_params)
|
||||||
Reference in New Issue
Block a user