mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
391 lines
20 KiB
Python
391 lines
20 KiB
Python
import os
|
|
import torch
|
|
import numpy as np
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
import random
|
|
import json
|
|
import gradio as gr
|
|
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
|
|
from modelscope import dataset_snapshot_download
|
|
|
|
|
|
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/*")
|
|
example_json = 'data/examples/eligen/entity_control/ui_examples.json'
|
|
with open(example_json, 'r') as f:
|
|
examples = json.load(f)['examples']
|
|
|
|
for idx in range(len(examples)):
|
|
example_id = examples[idx]['example_id']
|
|
entity_prompts = examples[idx]['local_prompt_list']
|
|
examples[idx]['mask_lists'] = [Image.open(f"data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))]
|
|
|
|
def create_canvas_data(background, masks):
|
|
if background.shape[-1] == 3:
|
|
background = np.dstack([background, np.full(background.shape[:2], 255, dtype=np.uint8)])
|
|
layers = []
|
|
for mask in masks:
|
|
if mask is not None:
|
|
mask_single_channel = mask if mask.ndim == 2 else mask[..., 0]
|
|
layer = np.zeros((mask_single_channel.shape[0], mask_single_channel.shape[1], 4), dtype=np.uint8)
|
|
layer[..., -1] = mask_single_channel
|
|
layers.append(layer)
|
|
else:
|
|
layers.append(np.zeros_like(background))
|
|
|
|
composite = background.copy()
|
|
for layer in layers:
|
|
if layer.size > 0:
|
|
composite = np.where(layer[..., -1:] > 0, layer, composite)
|
|
return {
|
|
"background": background,
|
|
"layers": layers,
|
|
"composite": composite,
|
|
}
|
|
|
|
def load_example(load_example_button):
|
|
example_idx = int(load_example_button.split()[-1]) - 1
|
|
example = examples[example_idx]
|
|
result = [
|
|
50,
|
|
example["global_prompt"],
|
|
example["negative_prompt"],
|
|
example["seed"],
|
|
*example["local_prompt_list"],
|
|
]
|
|
num_entities = len(example["local_prompt_list"])
|
|
result += [""] * (config["max_num_painter_layers"] - num_entities)
|
|
masks = []
|
|
for mask in example["mask_lists"]:
|
|
mask_single_channel = np.array(mask.convert("L"))
|
|
masks.append(mask_single_channel)
|
|
for _ in range(config["max_num_painter_layers"] - len(masks)):
|
|
blank_mask = np.zeros_like(masks[0]) if masks else np.zeros((512, 512), dtype=np.uint8)
|
|
masks.append(blank_mask)
|
|
background = np.ones((masks[0].shape[0], masks[0].shape[1], 4), dtype=np.uint8) * 255
|
|
canvas_data_list = []
|
|
for mask in masks:
|
|
canvas_data = create_canvas_data(background, [mask])
|
|
canvas_data_list.append(canvas_data)
|
|
result.extend(canvas_data_list)
|
|
return result
|
|
|
|
def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
|
|
save_dir = os.path.join('workdirs/tmp_mask', random_dir)
|
|
print(f'save to {save_dir}')
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
for i, mask in enumerate(masks):
|
|
save_path = os.path.join(save_dir, f'{i}.png')
|
|
mask.save(save_path)
|
|
sample = {
|
|
"global_prompt": global_prompt,
|
|
"mask_prompts": mask_prompts,
|
|
"seed": seed,
|
|
}
|
|
with open(os.path.join(save_dir, f"prompts.json"), 'w') as f:
|
|
json.dump(sample, f, indent=4)
|
|
|
|
def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False):
|
|
# Create a blank image for overlays
|
|
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
|
|
colors = [
|
|
(165, 238, 173, 80),
|
|
(76, 102, 221, 80),
|
|
(221, 160, 77, 80),
|
|
(204, 93, 71, 80),
|
|
(145, 187, 149, 80),
|
|
(134, 141, 172, 80),
|
|
(157, 137, 109, 80),
|
|
(153, 104, 95, 80),
|
|
(165, 238, 173, 80),
|
|
(76, 102, 221, 80),
|
|
(221, 160, 77, 80),
|
|
(204, 93, 71, 80),
|
|
(145, 187, 149, 80),
|
|
(134, 141, 172, 80),
|
|
(157, 137, 109, 80),
|
|
(153, 104, 95, 80),
|
|
]
|
|
# Generate random colors for each mask
|
|
if use_random_colors:
|
|
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
|
|
# Font settings
|
|
try:
|
|
font = ImageFont.truetype("arial", font_size) # Adjust as needed
|
|
except IOError:
|
|
font = ImageFont.load_default(font_size)
|
|
# Overlay each mask onto the overlay image
|
|
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
|
|
if mask is None:
|
|
continue
|
|
# Convert mask to RGBA mode
|
|
mask_rgba = mask.convert('RGBA')
|
|
mask_data = mask_rgba.getdata()
|
|
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
|
|
mask_rgba.putdata(new_data)
|
|
# Draw the mask prompt text on the mask
|
|
draw = ImageDraw.Draw(mask_rgba)
|
|
mask_bbox = mask.getbbox() # Get the bounding box of the mask
|
|
if mask_bbox is None:
|
|
continue
|
|
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
|
|
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
|
|
# Alpha composite the overlay with this mask
|
|
overlay = Image.alpha_composite(overlay, mask_rgba)
|
|
# Composite the overlay onto the original image
|
|
result = Image.alpha_composite(image.convert('RGBA'), overlay)
|
|
return result
|
|
|
|
config = {
|
|
"model_config": {
|
|
"FLUX": {
|
|
"model_folder": "models/FLUX",
|
|
"pipeline_class": FluxImagePipeline,
|
|
"default_parameters": {
|
|
"cfg_scale": 3.0,
|
|
"embedded_guidance": 3.5,
|
|
"num_inference_steps": 30,
|
|
}
|
|
},
|
|
},
|
|
"max_num_painter_layers": 8,
|
|
"max_num_model_cache": 1,
|
|
}
|
|
|
|
model_dict = {}
|
|
|
|
def load_model(model_type='FLUX', model_path='FLUX.1-dev'):
|
|
global model_dict
|
|
model_key = f"{model_type}:{model_path}"
|
|
if model_key in model_dict:
|
|
return model_dict[model_key]
|
|
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
|
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
|
model_manager.load_lora(
|
|
download_customized_models(
|
|
model_id="DiffSynth-Studio/Eligen",
|
|
origin_file_path="model_bf16.safetensors",
|
|
local_dir="models/lora/entity_control",
|
|
),
|
|
lora_alpha=1,
|
|
)
|
|
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
|
|
model_dict[model_key] = model_manager, pipe
|
|
return model_manager, pipe
|
|
|
|
|
|
with gr.Blocks() as app:
|
|
gr.Markdown(
|
|
"""## EliGen: Entity-Level Controllable Text-to-Image Model
|
|
1. On the left, input the **global prompt** for the overall image, such as "a person stands by the river."
|
|
2. On the right, input the **local prompt** for each entity, such as "person," and draw the corresponding mask in the **Entity Mask Painter**. Generally, solid rectangular masks yield better results.
|
|
3. Click the **Generate** button to create the image. By selecting different **random seeds**, you can generate diverse images.
|
|
4. **You can directly click the "Load Example" button on any sample at the bottom to load example inputs.**
|
|
"""
|
|
)
|
|
|
|
loading_status = gr.Textbox(label="Loading Model...", value="Loading model... Please wait...", visible=True)
|
|
main_interface = gr.Column(visible=False)
|
|
|
|
def initialize_model():
|
|
try:
|
|
load_model()
|
|
return {
|
|
loading_status: gr.update(value="Model loaded successfully!", visible=False),
|
|
main_interface: gr.update(visible=True),
|
|
}
|
|
except Exception as e:
|
|
print(f'Failed to load model with error: {e}')
|
|
return {
|
|
loading_status: gr.update(value=f"Failed to load model: {str(e)}", visible=True),
|
|
main_interface: gr.update(visible=True),
|
|
}
|
|
|
|
app.load(initialize_model, inputs=None, outputs=[loading_status, main_interface])
|
|
|
|
with main_interface:
|
|
with gr.Row():
|
|
local_prompt_list = []
|
|
canvas_list = []
|
|
random_mask_dir = gr.State(f'{random.randint(0, 1000000):08d}')
|
|
with gr.Column(scale=382, min_width=100):
|
|
model_type = gr.State('FLUX')
|
|
model_path = gr.State('FLUX.1-dev')
|
|
with gr.Accordion(label="Global prompt"):
|
|
prompt = gr.Textbox(label="Global Prompt", lines=3)
|
|
negative_prompt = gr.Textbox(label="Negative prompt", value="worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur,", lines=3)
|
|
with gr.Accordion(label="Inference Options", open=True):
|
|
seed = gr.Number(minimum=0, maximum=10**9, value=42, interactive=True, label="Random seed", show_label=True)
|
|
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, interactive=True, label="Inference steps")
|
|
cfg_scale = gr.Slider(minimum=2.0, maximum=10.0, value=3.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
|
|
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=3.5, step=0.1, interactive=True, label="Embedded guidance scale")
|
|
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
|
|
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
|
|
with gr.Accordion(label="Inpaint Input Image", open=False):
|
|
input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil")
|
|
background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight", visible=False)
|
|
|
|
with gr.Column():
|
|
reset_input_button = gr.Button(value="Reset Inpaint Input")
|
|
send_input_to_painter = gr.Button(value="Set as painter's background")
|
|
@gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click)
|
|
def reset_input_image(input_image):
|
|
return None
|
|
|
|
with gr.Column(scale=618, min_width=100):
|
|
with gr.Accordion(label="Entity Painter"):
|
|
for painter_layer_id in range(config["max_num_painter_layers"]):
|
|
with gr.Tab(label=f"Entity {painter_layer_id}"):
|
|
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
|
|
canvas = gr.ImageEditor(
|
|
canvas_size=(512, 512),
|
|
sources=None,
|
|
layers=False,
|
|
interactive=True,
|
|
image_mode="RGBA",
|
|
brush=gr.Brush(
|
|
default_size=50,
|
|
default_color="#000000",
|
|
colors=["#000000"],
|
|
),
|
|
label="Entity Mask Painter",
|
|
key=f"canvas_{painter_layer_id}",
|
|
width=width,
|
|
height=height,
|
|
)
|
|
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear], show_progress="hidden")
|
|
def resize_canvas(height, width, canvas):
|
|
h, w = canvas["background"].shape[:2]
|
|
if h != height or width != w:
|
|
return np.ones((height, width, 3), dtype=np.uint8) * 255
|
|
else:
|
|
return canvas
|
|
local_prompt_list.append(local_prompt)
|
|
canvas_list.append(canvas)
|
|
with gr.Accordion(label="Results"):
|
|
run_button = gr.Button(value="Generate", variant="primary")
|
|
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
|
|
with gr.Row():
|
|
with gr.Column():
|
|
output_to_painter_button = gr.Button(value="Set as painter's background")
|
|
with gr.Column():
|
|
return_with_mask = gr.Checkbox(value=False, interactive=True, label="show result with mask painting")
|
|
output_to_input_button = gr.Button(value="Set as input image", visible=False, interactive=False)
|
|
real_output = gr.State(None)
|
|
mask_out = gr.State(None)
|
|
|
|
@gr.on(
|
|
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + local_prompt_list + canvas_list,
|
|
outputs=[output_image, real_output, mask_out],
|
|
triggers=run_button.click
|
|
)
|
|
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
|
|
_, pipe = load_model(model_type, model_path)
|
|
input_params = {
|
|
"prompt": prompt,
|
|
"negative_prompt": negative_prompt,
|
|
"cfg_scale": cfg_scale,
|
|
"num_inference_steps": num_inference_steps,
|
|
"height": height,
|
|
"width": width,
|
|
"progress_bar_cmd": progress.tqdm,
|
|
}
|
|
if isinstance(pipe, FluxImagePipeline):
|
|
input_params["embedded_guidance"] = embedded_guidance
|
|
if input_image is not None:
|
|
input_params["input_image"] = input_image.resize((width, height)).convert("RGB")
|
|
input_params["enable_eligen_inpaint"] = True
|
|
|
|
local_prompt_list, canvas_list = (
|
|
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
|
|
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
|
|
)
|
|
local_prompts, masks = [], []
|
|
for local_prompt, canvas in zip(local_prompt_list, canvas_list):
|
|
if isinstance(local_prompt, str) and len(local_prompt) > 0:
|
|
local_prompts.append(local_prompt)
|
|
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
|
|
entity_masks = None if len(masks) == 0 else masks
|
|
entity_prompts = None if len(local_prompts) == 0 else local_prompts
|
|
input_params.update({
|
|
"eligen_entity_prompts": entity_prompts,
|
|
"eligen_entity_masks": entity_masks,
|
|
})
|
|
torch.manual_seed(seed)
|
|
# save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir)
|
|
image = pipe(**input_params)
|
|
masks = [mask.resize(image.size) for mask in masks]
|
|
image_with_mask = visualize_masks(image, masks, local_prompts)
|
|
|
|
real_output = gr.State(image)
|
|
mask_out = gr.State(image_with_mask)
|
|
|
|
if return_with_mask:
|
|
return image_with_mask, real_output, mask_out
|
|
return image, real_output, mask_out
|
|
|
|
@gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click)
|
|
def send_input_to_painter_background(input_image, *canvas_list):
|
|
if input_image is None:
|
|
return tuple(canvas_list)
|
|
for canvas in canvas_list:
|
|
h, w = canvas["background"].shape[:2]
|
|
canvas["background"] = input_image.resize((w, h))
|
|
return tuple(canvas_list)
|
|
@gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
|
|
def send_output_to_painter_background(real_output, *canvas_list):
|
|
if real_output is None:
|
|
return tuple(canvas_list)
|
|
for canvas in canvas_list:
|
|
h, w = canvas["background"].shape[:2]
|
|
canvas["background"] = real_output.value.resize((w, h))
|
|
return tuple(canvas_list)
|
|
@gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden")
|
|
def show_output(return_with_mask, real_output, mask_out):
|
|
if return_with_mask:
|
|
return mask_out.value
|
|
else:
|
|
return real_output.value
|
|
@gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click)
|
|
def send_output_to_pipe_input(real_output):
|
|
return real_output.value
|
|
|
|
with gr.Column():
|
|
gr.Markdown("## Examples")
|
|
for i in range(0, len(examples), 2):
|
|
with gr.Row():
|
|
if i < len(examples):
|
|
example = examples[i]
|
|
with gr.Column():
|
|
example_image = gr.Image(
|
|
value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
|
|
label=example["description"],
|
|
interactive=False,
|
|
width=1024,
|
|
height=512
|
|
)
|
|
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
|
|
load_example_button.click(
|
|
load_example,
|
|
inputs=[load_example_button],
|
|
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
|
|
)
|
|
|
|
if i + 1 < len(examples):
|
|
example = examples[i + 1]
|
|
with gr.Column():
|
|
example_image = gr.Image(
|
|
value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png",
|
|
label=example["description"],
|
|
interactive=False,
|
|
width=1024,
|
|
height=512
|
|
)
|
|
load_example_button = gr.Button(value=f"Load Example {example['example_id']}")
|
|
load_example_button.click(
|
|
load_example,
|
|
inputs=[load_example_button],
|
|
outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list
|
|
)
|
|
app.config["show_progress"] = "hidden"
|
|
app.launch()
|