update eligen ui

This commit is contained in:
mi804
2025-01-03 10:37:34 +08:00
parent 6f743fc4b6
commit 8cf3422688

View File

@@ -1,5 +1,5 @@
import gradio as gr
from diffsynth import ModelManager, FluxImagePipeline
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
import os, torch
from PIL import Image
import numpy as np
@@ -7,11 +7,8 @@ from PIL import ImageDraw, ImageFont
import random
import json
lora_checkpoint_path = 'models/lora/entity_control/model_bf16.safetensors'
save_masks_dir = 'workdirs/tmp_mask'
def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
save_dir = os.path.join(save_masks_dir, random_dir)
save_dir = os.path.join('workdirs/tmp_mask', random_dir)
os.makedirs(save_dir, exist_ok=True)
for i, mask in enumerate(masks):
save_path = os.path.join(save_dir, f'{i}.png')
@@ -79,7 +76,7 @@ config = {
"default_parameters": {
"cfg_scale": 3.0,
"embedded_guidance": 3.5,
"num_inference_steps": 30,
"num_inference_steps": 50,
}
},
},
@@ -109,17 +106,15 @@ def load_model(model_type, model_path):
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
model_manager = ModelManager()
if model_type == "FLUX":
model_manager.torch_dtype = torch.bfloat16
file_list = [
os.path.join(model_path, "text_encoder/model.safetensors"),
os.path.join(model_path, "text_encoder_2"),
]
for file_name in os.listdir(model_path):
if file_name.endswith(".safetensors"):
file_list.append(os.path.join(model_path, file_name))
model_manager.load_models(file_list)
model_manager.load_lora(lora_checkpoint_path, lora_alpha=1.)
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
model_manager.load_lora(
download_customized_models(
model_id="DiffSynth-Studio/Eligen",
origin_file_path="model_bf16.safetensors",
local_dir="models/lora/entity_control",
),
lora_alpha=1,
)
else:
model_manager.load_model(model_path)
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
@@ -137,7 +132,7 @@ with gr.Blocks() as app:
gr.Markdown("""
# 实体级控制文生图模型EliGen
**UI说明**
1. 点击Load model读取模型然后左侧界面为文生图输入参数右侧Painter为局部控制区域绘制区域每个局部控制条件由其Local prompt和绘制的mask组成支持精准控制文生图和Inpainting两种模式
1. **点击Load model读取模型**然后左侧界面为文生图输入参数右侧Painter为局部控制区域绘制区域每个局部控制条件由其Local prompt和绘制的mask组成支持精准控制文生图和Inpainting两种模式
2. **精准控制生图模式** 输入Globalprompt激活并绘制一个或多个局部控制条件点击Generate生成图像; Global Prompt推荐包含每个Local Prompt
3. **Inpainting模式** 你可以上传图像或者将上一步生成的图像设置为Inpaint Input Image采用类似的方式输入局部控制条件进行局部重绘
4. 尽情创造
@@ -145,7 +140,7 @@ with gr.Blocks() as app:
gr.Markdown("""
# Entity-Level Controlled Text-to-Image Model: EliGen
**UI Instructions**
1. Click "Load model" to load the model. The left interface is for text-to-image input parameters; the right "Painter" is the area for drawing local control regions. Each local control condition consists of its Local Prompt and the drawn mask, supporting both precise control text-to-image and Inpainting modes.
1. **Click "Load model" to load the model.** The left interface is for text-to-image input parameters; the right "Painter" is the area for drawing local control regions. Each local control condition consists of its Local Prompt and the drawn mask, supporting both precise control text-to-image and Inpainting modes.
2. **Precise Control Image Generation Mode:** Enter the Global Prompt; activate and draw one or more local control conditions, then click "Generate" to create the image. It is recommended that the Global Prompt includes all Local Prompts.
3. **Inpainting Mode:** You can upload an image or set the image generated in the previous step as the "Inpaint Input Image." Use a similar method to input local control conditions for local redrawing.
4. Enjoy!
@@ -241,7 +236,6 @@ with gr.Blocks() as app:
triggers=run_button.click
)
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
print("save to:", random_mask_dir.value)
_, pipe = load_model(model_type, model_path)
input_params = {
"prompt": prompt,
@@ -255,7 +249,8 @@ with gr.Blocks() as app:
if isinstance(pipe, FluxImagePipeline):
input_params["embedded_guidance"] = embedded_guidance
if input_image is not None:
input_params["inpaint_input"] = input_image.resize((width, height)).convert("RGB")
input_params["input_image"] = input_image.resize((width, height)).convert("RGB")
input_params["enable_eligen_inpaint"] = True
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
@@ -274,8 +269,8 @@ with gr.Blocks() as app:
entity_masks = None if len(masks) == 0 else masks
entity_prompts = None if len(local_prompts) == 0 else local_prompts
input_params.update({
"entity_prompts": entity_prompts,
"entity_masks": entity_masks,
"eligen_entity_prompts": entity_prompts,
"eligen_entity_masks": entity_masks,
})
torch.manual_seed(seed)
image = pipe(**input_params)