update eligen ui

This commit is contained in:
mi804
2025-01-03 10:37:34 +08:00
parent 6f743fc4b6
commit 8cf3422688

View File

@@ -1,5 +1,5 @@
import gradio as gr import gradio as gr
from diffsynth import ModelManager, FluxImagePipeline from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
import os, torch import os, torch
from PIL import Image from PIL import Image
import numpy as np import numpy as np
@@ -7,11 +7,8 @@ from PIL import ImageDraw, ImageFont
import random import random
import json import json
lora_checkpoint_path = 'models/lora/entity_control/model_bf16.safetensors'
save_masks_dir = 'workdirs/tmp_mask'
def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'): def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
save_dir = os.path.join(save_masks_dir, random_dir) save_dir = os.path.join('workdirs/tmp_mask', random_dir)
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
for i, mask in enumerate(masks): for i, mask in enumerate(masks):
save_path = os.path.join(save_dir, f'{i}.png') save_path = os.path.join(save_dir, f'{i}.png')
@@ -79,7 +76,7 @@ config = {
"default_parameters": { "default_parameters": {
"cfg_scale": 3.0, "cfg_scale": 3.0,
"embedded_guidance": 3.5, "embedded_guidance": 3.5,
"num_inference_steps": 30, "num_inference_steps": 50,
} }
}, },
}, },
@@ -109,17 +106,15 @@ def load_model(model_type, model_path):
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path) model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
model_manager = ModelManager() model_manager = ModelManager()
if model_type == "FLUX": if model_type == "FLUX":
model_manager.torch_dtype = torch.bfloat16 model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
file_list = [ model_manager.load_lora(
os.path.join(model_path, "text_encoder/model.safetensors"), download_customized_models(
os.path.join(model_path, "text_encoder_2"), model_id="DiffSynth-Studio/Eligen",
] origin_file_path="model_bf16.safetensors",
for file_name in os.listdir(model_path): local_dir="models/lora/entity_control",
if file_name.endswith(".safetensors"): ),
file_list.append(os.path.join(model_path, file_name)) lora_alpha=1,
model_manager.load_models(file_list) )
model_manager.load_lora(lora_checkpoint_path, lora_alpha=1.)
else: else:
model_manager.load_model(model_path) model_manager.load_model(model_path)
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager) pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
@@ -137,7 +132,7 @@ with gr.Blocks() as app:
gr.Markdown(""" gr.Markdown("""
# 实体级控制文生图模型EliGen # 实体级控制文生图模型EliGen
**UI说明** **UI说明**
1. 点击Load model读取模型然后左侧界面为文生图输入参数右侧Painter为局部控制区域绘制区域每个局部控制条件由其Local prompt和绘制的mask组成支持精准控制文生图和Inpainting两种模式 1. **点击Load model读取模型**然后左侧界面为文生图输入参数右侧Painter为局部控制区域绘制区域每个局部控制条件由其Local prompt和绘制的mask组成支持精准控制文生图和Inpainting两种模式
2. **精准控制生图模式** 输入Globalprompt激活并绘制一个或多个局部控制条件点击Generate生成图像; Global Prompt推荐包含每个Local Prompt 2. **精准控制生图模式** 输入Globalprompt激活并绘制一个或多个局部控制条件点击Generate生成图像; Global Prompt推荐包含每个Local Prompt
3. **Inpainting模式** 你可以上传图像或者将上一步生成的图像设置为Inpaint Input Image采用类似的方式输入局部控制条件进行局部重绘 3. **Inpainting模式** 你可以上传图像或者将上一步生成的图像设置为Inpaint Input Image采用类似的方式输入局部控制条件进行局部重绘
4. 尽情创造 4. 尽情创造
@@ -145,7 +140,7 @@ with gr.Blocks() as app:
gr.Markdown(""" gr.Markdown("""
# Entity-Level Controlled Text-to-Image Model: EliGen # Entity-Level Controlled Text-to-Image Model: EliGen
**UI Instructions** **UI Instructions**
1. Click "Load model" to load the model. The left interface is for text-to-image input parameters; the right "Painter" is the area for drawing local control regions. Each local control condition consists of its Local Prompt and the drawn mask, supporting both precise control text-to-image and Inpainting modes. 1. **Click "Load model" to load the model.** The left interface is for text-to-image input parameters; the right "Painter" is the area for drawing local control regions. Each local control condition consists of its Local Prompt and the drawn mask, supporting both precise control text-to-image and Inpainting modes.
2. **Precise Control Image Generation Mode:** Enter the Global Prompt; activate and draw one or more local control conditions, then click "Generate" to create the image. It is recommended that the Global Prompt includes all Local Prompts. 2. **Precise Control Image Generation Mode:** Enter the Global Prompt; activate and draw one or more local control conditions, then click "Generate" to create the image. It is recommended that the Global Prompt includes all Local Prompts.
3. **Inpainting Mode:** You can upload an image or set the image generated in the previous step as the "Inpaint Input Image." Use a similar method to input local control conditions for local redrawing. 3. **Inpainting Mode:** You can upload an image or set the image generated in the previous step as the "Inpaint Input Image." Use a similar method to input local control conditions for local redrawing.
4. Enjoy! 4. Enjoy!
@@ -241,7 +236,6 @@ with gr.Blocks() as app:
triggers=run_button.click triggers=run_button.click
) )
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()): def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
print("save to:", random_mask_dir.value)
_, pipe = load_model(model_type, model_path) _, pipe = load_model(model_type, model_path)
input_params = { input_params = {
"prompt": prompt, "prompt": prompt,
@@ -255,7 +249,8 @@ with gr.Blocks() as app:
if isinstance(pipe, FluxImagePipeline): if isinstance(pipe, FluxImagePipeline):
input_params["embedded_guidance"] = embedded_guidance input_params["embedded_guidance"] = embedded_guidance
if input_image is not None: if input_image is not None:
input_params["inpaint_input"] = input_image.resize((width, height)).convert("RGB") input_params["input_image"] = input_image.resize((width, height)).convert("RGB")
input_params["enable_eligen_inpaint"] = True
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = ( enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]], args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
@@ -274,8 +269,8 @@ with gr.Blocks() as app:
entity_masks = None if len(masks) == 0 else masks entity_masks = None if len(masks) == 0 else masks
entity_prompts = None if len(local_prompts) == 0 else local_prompts entity_prompts = None if len(local_prompts) == 0 else local_prompts
input_params.update({ input_params.update({
"entity_prompts": entity_prompts, "eligen_entity_prompts": entity_prompts,
"entity_masks": entity_masks, "eligen_entity_masks": entity_masks,
}) })
torch.manual_seed(seed) torch.manual_seed(seed)
image = pipe(**input_params) image = pipe(**input_params)