Merge pull request #303 from modelscope/eligen

Eligen
This commit is contained in:
Zhongjie Duan
2025-01-03 10:47:26 +08:00
committed by GitHub
10 changed files with 862 additions and 109 deletions

View File

@@ -34,7 +34,11 @@ Until now, DiffSynth Studio has supported the following models:
* [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5)
## News
- **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/README.md).
* Paper: Comming soon
* Github: https://github.com/modelscope/DiffSynth-Studio
* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)
* Training dataset: Coming soon
- **December 19, 2024** We implement advanced VRAM management for HunyuanVideo, making it possible to generate videos at a resolution of 129x720x1280 using 24GB of VRAM, or at 129x512x384 resolution with just 6GB of VRAM. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.

315
apps/gradio/eligen_ui.py Normal file
View File

@@ -0,0 +1,315 @@
import gradio as gr
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
import os, torch
from PIL import Image
import numpy as np
from PIL import ImageDraw, ImageFont
import random
import json
def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'):
save_dir = os.path.join('workdirs/tmp_mask', random_dir)
os.makedirs(save_dir, exist_ok=True)
for i, mask in enumerate(masks):
save_path = os.path.join(save_dir, f'{i}.png')
mask.save(save_path)
sample = {
"global_prompt": global_prompt,
"mask_prompts": mask_prompts,
"seed": seed,
}
with open(os.path.join(save_dir, f"prompts.json"), 'w') as f:
json.dump(sample, f, indent=4)
def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False):
# Create a blank image for overlays
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
colors = [
(165, 238, 173, 80),
(76, 102, 221, 80),
(221, 160, 77, 80),
(204, 93, 71, 80),
(145, 187, 149, 80),
(134, 141, 172, 80),
(157, 137, 109, 80),
(153, 104, 95, 80),
(165, 238, 173, 80),
(76, 102, 221, 80),
(221, 160, 77, 80),
(204, 93, 71, 80),
(145, 187, 149, 80),
(134, 141, 172, 80),
(157, 137, 109, 80),
(153, 104, 95, 80),
]
# Generate random colors for each mask
if use_random_colors:
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
# Font settings
try:
font = ImageFont.truetype("arial", font_size) # Adjust as needed
except IOError:
font = ImageFont.load_default(font_size)
# Overlay each mask onto the overlay image
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
# Convert mask to RGBA mode
mask_rgba = mask.convert('RGBA')
mask_data = mask_rgba.getdata()
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
mask_rgba.putdata(new_data)
# Draw the mask prompt text on the mask
draw = ImageDraw.Draw(mask_rgba)
mask_bbox = mask.getbbox() # Get the bounding box of the mask
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
# Alpha composite the overlay with this mask
overlay = Image.alpha_composite(overlay, mask_rgba)
# Composite the overlay onto the original image
result = Image.alpha_composite(image.convert('RGBA'), overlay)
return result
config = {
"model_config": {
"FLUX": {
"model_folder": "models/FLUX",
"pipeline_class": FluxImagePipeline,
"default_parameters": {
"cfg_scale": 3.0,
"embedded_guidance": 3.5,
"num_inference_steps": 50,
}
},
},
"max_num_painter_layers": 8,
"max_num_model_cache": 1,
}
def load_model_list(model_type):
if model_type is None:
return []
folder = config["model_config"][model_type]["model_folder"]
file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")]
if model_type in ["HunyuanDiT", "Kolors", "FLUX"]:
file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))]
file_list = sorted(file_list)
return file_list
model_dict = {}
def load_model(model_type, model_path):
global model_dict
model_key = f"{model_type}:{model_path}"
if model_key in model_dict:
return model_dict[model_key]
model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path)
model_manager = ModelManager()
if model_type == "FLUX":
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
model_manager.load_lora(
download_customized_models(
model_id="DiffSynth-Studio/Eligen",
origin_file_path="model_bf16.safetensors",
local_dir="models/lora/entity_control",
),
lora_alpha=1,
)
else:
model_manager.load_model(model_path)
pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager)
while len(model_dict) + 1 > config["max_num_model_cache"]:
key = next(iter(model_dict.keys()))
model_manager_to_release, _ = model_dict[key]
model_manager_to_release.to("cpu")
del model_dict[key]
torch.cuda.empty_cache()
model_dict[model_key] = model_manager, pipe
return model_manager, pipe
with gr.Blocks() as app:
gr.Markdown("""
# 实体级控制文生图模型EliGen
**UI说明**
1. **点击Load model读取模型**然后左侧界面为文生图输入参数右侧Painter为局部控制区域绘制区域每个局部控制条件由其Local prompt和绘制的mask组成支持精准控制文生图和Inpainting两种模式。
2. **精准控制生图模式:** 输入Globalprompt激活并绘制一个或多个局部控制条件点击Generate生成图像; Global Prompt推荐包含每个Local Prompt。
3. **Inpainting模式** 你可以上传图像或者将上一步生成的图像设置为Inpaint Input Image采用类似的方式输入局部控制条件进行局部重绘。
4. 尽情创造!
""")
gr.Markdown("""
# Entity-Level Controlled Text-to-Image Model: EliGen
**UI Instructions**
1. **Click "Load model" to load the model.** The left interface is for text-to-image input parameters; the right "Painter" is the area for drawing local control regions. Each local control condition consists of its Local Prompt and the drawn mask, supporting both precise control text-to-image and Inpainting modes.
2. **Precise Control Image Generation Mode:** Enter the Global Prompt; activate and draw one or more local control conditions, then click "Generate" to create the image. It is recommended that the Global Prompt includes all Local Prompts.
3. **Inpainting Mode:** You can upload an image or set the image generated in the previous step as the "Inpaint Input Image." Use a similar method to input local control conditions for local redrawing.
4. Enjoy!
""")
with gr.Row():
random_mask_dir = gr.State('')
with gr.Column(scale=382, min_width=100):
model_type = gr.State('FLUX')
model_path = gr.State('FLUX.1-dev')
with gr.Accordion(label="Model"):
load_model_button = gr.Button(value="Load model")
with gr.Accordion(label="Global prompt"):
prompt = gr.Textbox(label="Prompt", lines=3)
negative_prompt = gr.Textbox(label="Negative prompt", value="worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,", lines=1)
cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale")
embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale")
with gr.Accordion(label="Inference Options"):
num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps")
height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height")
width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width")
return_with_mask = gr.Checkbox(value=True, interactive=True, label="show result with mask painting")
with gr.Column():
use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed")
seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False)
with gr.Accordion(label="Inpaint Input Image (Testing)"):
input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil")
background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight")
with gr.Column():
reset_input_button = gr.Button(value="Reset Inpaint Input")
send_input_to_painter = gr.Button(value="Set as painter's background")
@gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click)
def reset_input_image(input_image):
return None
@gr.on(
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask],
outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, load_model_button, random_mask_dir],
triggers=load_model_button.click
)
def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask):
load_model(model_type, model_path)
cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale)
embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance)
num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps)
height = config["model_config"][model_type]["default_parameters"].get("height", height)
width = config["model_config"][model_type]["default_parameters"].get("width", width)
return_with_mask = config["model_config"][model_type]["default_parameters"].get("return_with_mask", return_with_mask)
return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, gr.update(value="Loaded FLUX"), gr.State(f'{random.randint(0, 1000000):08d}')
with gr.Column(scale=618, min_width=100):
with gr.Accordion(label="Painter"):
enable_local_prompt_list = []
local_prompt_list = []
mask_scale_list = []
canvas_list = []
for painter_layer_id in range(config["max_num_painter_layers"]):
with gr.Tab(label=f"Layer {painter_layer_id}"):
enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}")
local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}")
mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}")
canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA",
brush=gr.Brush(default_size=30, default_color="#000000", colors=["#000000"]),
label="Painter", key=f"canvas_{painter_layer_id}")
@gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden")
def resize_canvas(height, width, canvas):
h, w = canvas["background"].shape[:2]
if h != height or width != w:
return np.ones((height, width, 3), dtype=np.uint8) * 255
else:
return canvas
enable_local_prompt_list.append(enable_local_prompt)
local_prompt_list.append(local_prompt)
mask_scale_list.append(mask_scale)
canvas_list.append(canvas)
with gr.Accordion(label="Results"):
run_button = gr.Button(value="Generate", variant="primary")
output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil")
with gr.Row():
with gr.Column():
output_to_painter_button = gr.Button(value="Set as painter's background")
with gr.Column():
output_to_input_button = gr.Button(value="Set as input image")
real_output = gr.State(None)
mask_out = gr.State(None)
@gr.on(
inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list,
outputs=[output_image, real_output, mask_out],
triggers=run_button.click
)
def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()):
_, pipe = load_model(model_type, model_path)
input_params = {
"prompt": prompt,
"negative_prompt": negative_prompt,
"cfg_scale": cfg_scale,
"num_inference_steps": num_inference_steps,
"height": height,
"width": width,
"progress_bar_cmd": progress.tqdm,
}
if isinstance(pipe, FluxImagePipeline):
input_params["embedded_guidance"] = embedded_guidance
if input_image is not None:
input_params["input_image"] = input_image.resize((width, height)).convert("RGB")
input_params["enable_eligen_inpaint"] = True
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = (
args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]],
args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]],
args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]],
args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]]
)
local_prompts, masks, mask_scales = [], [], []
for enable_local_prompt, local_prompt, mask_scale, canvas in zip(
enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list
):
if enable_local_prompt:
local_prompts.append(local_prompt)
masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB"))
mask_scales.append(mask_scale)
entity_masks = None if len(masks) == 0 else masks
entity_prompts = None if len(local_prompts) == 0 else local_prompts
input_params.update({
"eligen_entity_prompts": entity_prompts,
"eligen_entity_masks": entity_masks,
})
torch.manual_seed(seed)
image = pipe(**input_params)
# visualize masks
masks = [mask.resize(image.size) for mask in masks]
image_with_mask = visualize_masks(image, masks, local_prompts)
# save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir.value)
real_output = gr.State(image)
mask_out = gr.State(image_with_mask)
if return_with_mask:
return image_with_mask, real_output, mask_out
return image, real_output, mask_out
@gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click)
def send_input_to_painter_background(input_image, *canvas_list):
if input_image is None:
return tuple(canvas_list)
for canvas in canvas_list:
h, w = canvas["background"].shape[:2]
canvas["background"] = input_image.resize((w, h))
return tuple(canvas_list)
@gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click)
def send_output_to_painter_background(real_output, *canvas_list):
if real_output is None:
return tuple(canvas_list)
for canvas in canvas_list:
h, w = canvas["background"].shape[:2]
canvas["background"] = real_output.value.resize((w, h))
return tuple(canvas_list)
@gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden")
def show_output(return_with_mask, real_output, mask_out):
if return_with_mask:
return mask_out.value
else:
return real_output.value
@gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click)
def send_output_to_pipe_input(real_output):
return real_output.value
app.launch()

View File

@@ -40,7 +40,7 @@ class RoPEEmbedding(torch.nn.Module):
n_axes = ids.shape[-1]
emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3)
return emb.unsqueeze(1)
class FluxJointAttention(torch.nn.Module):
@@ -70,7 +70,7 @@ class FluxJointAttention(torch.nn.Module):
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, ipadapter_kwargs_list=None):
def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
batch_size = hidden_states_a.shape[0]
# Part A
@@ -91,7 +91,7 @@ class FluxJointAttention(torch.nn.Module):
q, k = self.apply_rope(q, k, image_rotary_emb)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:]
@@ -103,7 +103,7 @@ class FluxJointAttention(torch.nn.Module):
else:
hidden_states_b = self.b_to_out(hidden_states_b)
return hidden_states_a, hidden_states_b
class FluxJointTransformerBlock(torch.nn.Module):
@@ -129,12 +129,12 @@ class FluxJointTransformerBlock(torch.nn.Module):
)
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None):
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
# Attention
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, ipadapter_kwargs_list)
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
# Part A
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
@@ -147,7 +147,7 @@ class FluxJointTransformerBlock(torch.nn.Module):
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
return hidden_states_a, hidden_states_b
class FluxSingleAttention(torch.nn.Module):
@@ -184,7 +184,7 @@ class FluxSingleAttention(torch.nn.Module):
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
return hidden_states
class AdaLayerNormSingle(torch.nn.Module):
@@ -200,7 +200,7 @@ class AdaLayerNormSingle(torch.nn.Module):
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa
class FluxSingleTransformerBlock(torch.nn.Module):
@@ -225,8 +225,8 @@ class FluxSingleTransformerBlock(torch.nn.Module):
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def process_attention(self, hidden_states, image_rotary_emb, ipadapter_kwargs_list=None):
def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
batch_size = hidden_states.shape[0]
qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
@@ -235,7 +235,7 @@ class FluxSingleTransformerBlock(torch.nn.Module):
q, k = self.apply_rope(q, k, image_rotary_emb)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
hidden_states = hidden_states.to(q.dtype)
if ipadapter_kwargs_list is not None:
@@ -243,21 +243,21 @@ class FluxSingleTransformerBlock(torch.nn.Module):
return hidden_states
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None):
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None):
residual = hidden_states_a
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
attn_output = self.process_attention(attn_output, image_rotary_emb, ipadapter_kwargs_list)
attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list)
mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh")
hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2)
hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a)
hidden_states_a = residual + hidden_states_a
return hidden_states_a, hidden_states_b
class AdaLayerNormContinuous(torch.nn.Module):
@@ -300,7 +300,7 @@ class FluxDiT(torch.nn.Module):
def unpatchify(self, hidden_states, height, width):
hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
return hidden_states
def prepare_image_ids(self, latents):
batch_size, _, height, width = latents.shape
@@ -317,7 +317,7 @@ class FluxDiT(torch.nn.Module):
latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype)
return latent_image_ids
def tiled_forward(
self,
@@ -338,11 +338,75 @@ class FluxDiT(torch.nn.Module):
return hidden_states
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
N = len(entity_masks)
batch_size = entity_masks[0].shape[0]
total_seq_len = N * prompt_seq_len + image_seq_len
patched_masks = [self.patchify(entity_masks[i]) for i in range(N)]
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device)
image_start = N * prompt_seq_len
image_end = N * prompt_seq_len + image_seq_len
# prompt-image mask
for i in range(N):
prompt_start = i * prompt_seq_len
prompt_end = (i + 1) * prompt_seq_len
image_mask = torch.sum(patched_masks[i], dim=-1) > 0
image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1)
# prompt update with image
attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask
# image update with prompt
attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2)
# prompt-prompt mask
for i in range(N):
for j in range(N):
if i != j:
prompt_start_i = i * prompt_seq_len
prompt_end_i = (i + 1) * prompt_seq_len
prompt_start_j = j * prompt_seq_len
prompt_end_j = (j + 1) * prompt_seq_len
attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False
attention_mask = attention_mask.float()
attention_mask[attention_mask == 0] = float('-inf')
attention_mask[attention_mask == 1] = 0
return attention_mask
def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids):
repeat_dim = hidden_states.shape[1]
max_masks = 0
attention_mask = None
prompt_embs = [prompt_emb]
if entity_masks is not None:
# entity_masks
batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1]
entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1)
entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)]
# global mask
global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype)
entity_masks = entity_masks + [global_mask] # append global to last
# attention mask
attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1])
attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype)
attention_mask = attention_mask.unsqueeze(1)
# embds: n_masks * b * seq * d
local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)]
prompt_embs = local_embs + prompt_embs # append global to last
prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs]
prompt_emb = torch.cat(prompt_embs, dim=1)
# positional embedding
text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
return prompt_emb, image_rotary_emb, attention_mask
def forward(
self,
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
tiled=False, tile_size=128, tile_stride=64,
tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None,
use_gradient_checkpointing=False,
**kwargs
):
@@ -353,46 +417,51 @@ class FluxDiT(torch.nn.Module):
tile_size=tile_size, tile_stride=tile_stride,
**kwargs
)
if image_ids is None:
image_ids = self.prepare_image_ids(hidden_states)
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
if self.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
prompt_emb = self.context_embedder(prompt_emb)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
height, width = hidden_states.shape[-2:]
hidden_states = self.patchify(hidden_states)
hidden_states = self.x_embedder(hidden_states)
if entity_prompt_emb is not None and entity_masks is not None:
prompt_emb, image_rotary_emb, attention_mask = self.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
else:
prompt_emb = self.context_embedder(prompt_emb)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
attention_mask = None
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.blocks:
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb,
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block in self.single_blocks:
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb,
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
hidden_states = self.final_norm_out(hidden_states, conditioning)
@@ -400,7 +469,7 @@ class FluxDiT(torch.nn.Module):
hidden_states = self.unpatchify(hidden_states, height, width)
return hidden_states
def quantize(self):
def cast_to(weight, dtype=None, device=None, copy=False):
@@ -440,16 +509,16 @@ class FluxDiT(torch.nn.Module):
class Linear(torch.nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self,input,**kwargs):
weight,bias= cast_bias_weight(self,input)
return torch.nn.functional.linear(input,weight,bias)
class RMSNorm(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self,hidden_states,**kwargs):
weight= cast_weight(self.module,hidden_states)
input_dtype = hidden_states.dtype
@@ -457,7 +526,7 @@ class FluxDiT(torch.nn.Module):
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
hidden_states = hidden_states.to(input_dtype) * weight
return hidden_states
def replace_layer(model):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
@@ -483,7 +552,6 @@ class FluxDiT(torch.nn.Module):
@staticmethod
def state_dict_converter():
return FluxDiTStateDictConverter()
class FluxDiTStateDictConverter:
@@ -587,7 +655,7 @@ class FluxDiTStateDictConverter:
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
return state_dict_
def from_civitai(self, state_dict):
rename_dict = {
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",

View File

@@ -366,17 +366,21 @@ class ModelManager:
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
print(f"Loading LoRA models from file: {file_path}")
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
for lora in get_lora_loaders():
match_results = lora.match(model, state_dict)
if match_results is not None:
print(f" Adding LoRA to {model_name} ({model_path}).")
lora_prefix, model_resource = match_results
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
break
if isinstance(file_path, list):
for file_path_ in file_path:
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
else:
print(f"Loading LoRA models from file: {file_path}")
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
for lora in get_lora_loaders():
match_results = lora.match(model, state_dict)
if match_results is not None:
print(f" Adding LoRA to {model_name} ({model_path}).")
lora_prefix, model_resource = match_results
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
break
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):

View File

@@ -10,6 +10,7 @@ import numpy as np
from PIL import Image
from ..models.tiler import FastTileWorker
from transformers import SiglipVisionModel
from copy import deepcopy
class FluxImagePipeline(BasePipeline):
@@ -59,6 +60,7 @@ class FluxImagePipeline(BasePipeline):
self.ipadapter = model_manager.fetch_model("flux_ipadapter")
self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model")
@staticmethod
def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None):
pipe = FluxImagePipeline(
@@ -133,37 +135,156 @@ class FluxImagePipeline(BasePipeline):
# store it
controlnet_frames.append(image)
return controlnet_frames
def prepare_ipadapter_inputs(self, images, height=384, width=384):
images = [image.convert("RGB").resize((width, height), resample=3) for image in images]
images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images]
return torch.cat(images, dim=0)
def inpaint_fusion(self, latents, inpaint_latents, pred_noise, fg_mask, bg_mask, progress_id, background_weight=0.):
# inpaint noise
inpaint_noise = (latents - inpaint_latents) / self.scheduler.sigmas[progress_id]
# merge noise
weight = torch.ones_like(inpaint_noise)
inpaint_noise[fg_mask] = pred_noise[fg_mask]
inpaint_noise[bg_mask] += pred_noise[bg_mask] * background_weight
weight[bg_mask] += background_weight
inpaint_noise /= weight
return inpaint_noise
def preprocess_masks(self, masks, height, width, dim):
out_masks = []
for mask in masks:
mask = self.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0
mask = mask.repeat(1, dim, 1, 1).to(device=self.device, dtype=self.torch_dtype)
out_masks.append(mask)
return out_masks
def prepare_entity_inputs(self, entity_prompts, entity_masks, width, height, t5_sequence_length=512, enable_eligen_inpaint=False):
fg_mask, bg_mask = None, None
if enable_eligen_inpaint:
masks_ = deepcopy(entity_masks)
fg_masks = torch.cat([self.preprocess_image(mask.resize((width//8, height//8))).mean(dim=1, keepdim=True) for mask in masks_])
fg_masks = (fg_masks > 0).float()
fg_mask = fg_masks.sum(dim=0, keepdim=True).repeat(1, 16, 1, 1) > 0
bg_mask = ~fg_mask
entity_masks = self.preprocess_masks(entity_masks, height//8, width//8, 1)
entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w
entity_prompts = self.encode_prompt(entity_prompts, t5_sequence_length=t5_sequence_length)['prompt_emb'].unsqueeze(0)
return entity_prompts, entity_masks, fg_mask, bg_mask
def prepare_latents(self, input_image, height, width, seed, tiled, tile_size, tile_stride):
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
input_latents = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(input_latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
input_latents = None
return latents, input_latents
def prepare_ipadapter(self, ipadapter_images, ipadapter_scale):
if ipadapter_images is not None:
self.load_models_to_device(['ipadapter_image_encoder'])
ipadapter_images = self.prepare_ipadapter_inputs(ipadapter_images)
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images).pooler_output
self.load_models_to_device(['ipadapter'])
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
return ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega
def prepare_controlnet(self, controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative):
if controlnet_image is not None:
self.load_models_to_device(['vae_encoder'])
controlnet_kwargs_posi = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)}
if len(masks) > 0 and controlnet_inpaint_mask is not None:
print("The controlnet_inpaint_mask will be overridden by masks.")
local_controlnet_kwargs = [{"controlnet_frames": self.prepare_controlnet_input(controlnet_image, mask, tiler_kwargs)} for mask in masks]
else:
local_controlnet_kwargs = None
else:
controlnet_kwargs_posi, local_controlnet_kwargs = {"controlnet_frames": None}, [{}] * len(masks)
controlnet_kwargs_nega = controlnet_kwargs_posi if enable_controlnet_on_negative else {}
return controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs
def prepare_eligen(self, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale):
if eligen_entity_masks is not None:
entity_prompt_emb_posi, entity_masks_posi, fg_mask, bg_mask = self.prepare_entity_inputs(eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint)
if enable_eligen_on_negative and cfg_scale != 1.0:
entity_prompt_emb_nega = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, eligen_entity_masks.shape[1], 1, 1)
entity_masks_nega = eligen_entity_masks
else:
entity_prompt_emb_nega, entity_masks_nega = None, None
else:
entity_prompt_emb_posi, entity_masks_posi, entity_prompt_emb_nega, entity_masks_nega = None, None, None, None
fg_mask, bg_mask = None, None
eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi}
eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega}
return eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask
def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale):
# Extend prompt
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
# Encode prompts
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length)
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals
@torch.no_grad()
def __call__(
self,
# Prompt
prompt,
local_prompts=None,
masks=None,
mask_scales=None,
negative_prompt="",
cfg_scale=1.0,
embedded_guidance=3.5,
t5_sequence_length=512,
# Image
input_image=None,
ipadapter_images=None,
ipadapter_scale=1.0,
controlnet_image=None,
controlnet_inpaint_mask=None,
enable_controlnet_on_negative=False,
denoising_strength=1.0,
height=1024,
width=1024,
seed=None,
# Steps
num_inference_steps=30,
t5_sequence_length=512,
# local prompts
local_prompts=(),
masks=(),
mask_scales=(),
# ControlNet
controlnet_image=None,
controlnet_inpaint_mask=None,
enable_controlnet_on_negative=False,
# IP-Adapter
ipadapter_images=None,
ipadapter_scale=1.0,
# EliGen
eligen_entity_prompts=None,
eligen_entity_masks=None,
enable_eligen_on_negative=False,
enable_eligen_inpaint=False,
# Tile
tiled=False,
tile_size=128,
tile_stride=64,
seed=None,
# Progress bar
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
@@ -176,72 +297,50 @@ class FluxImagePipeline(BasePipeline):
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
self.load_models_to_device(['vae_encoder'])
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.encode_image(image, **tiler_kwargs)
noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
latents, input_latents = self.prepare_latents(input_image, height, width, seed, tiled, tile_size, tile_stride)
# Extend prompt
self.load_models_to_device(['text_encoder_1', 'text_encoder_2'])
prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales)
# Encode prompts
prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length)
prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts]
# Prompt
prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale)
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# IP-Adapter
if ipadapter_images is not None:
self.load_models_to_device(['ipadapter_image_encoder'])
ipadapter_images = self.prepare_ipadapter_inputs(ipadapter_images)
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images).pooler_output
self.load_models_to_device(['ipadapter'])
ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}
ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))}
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}}
# Entity control
eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale)
# Prepare ControlNets
if controlnet_image is not None:
self.load_models_to_device(['vae_encoder'])
controlnet_kwargs = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)}
if len(masks) > 0 and controlnet_inpaint_mask is not None:
print("The controlnet_inpaint_mask will be overridden by masks.")
local_controlnet_kwargs = [{"controlnet_frames": self.prepare_controlnet_input(controlnet_image, mask, tiler_kwargs)} for mask in masks]
else:
local_controlnet_kwargs = None
else:
controlnet_kwargs, local_controlnet_kwargs = {"controlnet_frames": None}, [{}] * len(masks)
# IP-Adapter
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = self.prepare_ipadapter(ipadapter_images, ipadapter_scale)
# ControlNets
controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative)
# Denoise
self.load_models_to_device(['dit', 'controlnet'])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
# Classifier-free guidance
# Positive side
inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi,
**prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi,
)
noise_pred_posi = self.control_noise_via_local_prompts(
prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback,
special_kwargs=controlnet_kwargs, special_local_kwargs_list=local_controlnet_kwargs
special_kwargs=controlnet_kwargs_posi, special_local_kwargs_list=local_controlnet_kwargs
)
# Inpaint
if enable_eligen_inpaint:
noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id)
# Classifier-free guidance
if cfg_scale != 1.0:
negative_controlnet_kwargs = controlnet_kwargs if enable_controlnet_on_negative else {}
# Negative side
noise_pred_nega = lets_dance_flux(
dit=self.dit, controlnet=self.controlnet,
hidden_states=latents, timestep=timestep,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs, **ipadapter_kwargs_list_nega,
**prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
@@ -256,7 +355,7 @@ class FluxImagePipeline(BasePipeline):
# Decode image
self.load_models_to_device(['vae_decoder'])
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.decode_image(latents, **tiler_kwargs)
# Offload all models
self.load_models_to_device([])
@@ -278,6 +377,8 @@ def lets_dance_flux(
tiled=False,
tile_size=128,
tile_stride=64,
entity_prompt_emb=None,
entity_masks=None,
ipadapter_kwargs_list={},
**kwargs
):
@@ -333,13 +434,18 @@ def lets_dance_flux(
if dit.guidance_embedder is not None:
guidance = guidance * 1000
conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype)
prompt_emb = dit.context_embedder(prompt_emb)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
height, width = hidden_states.shape[-2:]
hidden_states = dit.patchify(hidden_states)
hidden_states = dit.x_embedder(hidden_states)
if entity_prompt_emb is not None and entity_masks is not None:
prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
else:
prompt_emb = dit.context_embedder(prompt_emb)
image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
attention_mask = None
# Joint Blocks
for block_id, block in enumerate(dit.blocks):
hidden_states, prompt_emb = block(
@@ -347,7 +453,9 @@ def lets_dance_flux(
prompt_emb,
conditioning,
image_rotary_emb,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None))
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states = hidden_states + controlnet_res_stack[block_id]
@@ -361,8 +469,9 @@ def lets_dance_flux(
prompt_emb,
conditioning,
image_rotary_emb,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(
block_id + num_joint_blocks, None))
attention_mask,
ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None)
)
# ControlNet
if controlnet is not None and controlnet_frames is not None:
hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id]

View File

@@ -0,0 +1,60 @@
# EliGen: Entity-Level Controlled Image Generation
## Introduction
We propose EliGen, a novel approach that leverages fine-grained entity-level information to enable precise and controllable text-to-image generation. EliGen excels in tasks such as entity-level controlled image generation and image inpainting, while its applicability is not limited to these areas. Additionally, it can be seamlessly integrated with existing community models, such as the IP-Adpater.
* Paper: Comming soon
* Github: https://github.com/modelscope/DiffSynth-Studio
* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)
* Training dataset: Coming soon
## Methodology
![regional-attention](https://github.com/user-attachments/assets/9a147201-15ab-421f-a6c5-701075754478)
We introduce a regional attention mechanism within the DiT framework to effectively process the conditions of each entity. This mechanism enables the local prompt associated with each entity to semantically influence specific regions through regional attention. To further enhance the layout control capabilities of EliGen, we meticulously curate an entity-annotated dataset and fine-tune the model using the LoRA framework.
1. **Regional Attention**: Regional attention is shown in above figure, which can be easily applied to other text-to-image models. Its core principle involves transforming the positional information of each entity into an attention mask, ensuring that the mechanism only affects the designated regions.
2. **Dataset with Entity Annotation**: To curate a dedicated entity control dataset, we start by randomly selecting captions from DiffusionDB and generating the corresponding source image using Flux. Next, we employ Qwen2-VL 72B, recognized for its advanced grounding capabilities among MLLMs, to randomly identify entities within the image. These entities are annotated with local prompts and bounding boxes for precise localization, forming the foundation of our dataset for further training.
3. **Training**: We apply LoRA and deepspeed to finetune regional attention with curated dataset, enabling our EliGen performing effective entity-level control.
## Usage
1. **Entity-Level Controlled Image Generation**
See [./entity_control.py](./entity_control.py) for usage.
2. **Image Inpainting**
To apply EliGen to image inpainting task, we propose a inpainting fusion pipeline to preserve the non-painting areas while enabling precise, entity-level modifications over inpaining regions.
See [./entity_inpaint.py](./entity_inpaint.py) for usage.
3. **Styled Entity Control**
EliGen can be seamlessly integrated with existing community models. We have provided an example of how to integrate it with the IP-Adpater. See [./entity_control_ipadapter.py](./entity_control_ipadapter.py) for usage.
4. **Play with EliGen using UI**
Download the checkpoint of EliGen from [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) to `models/lora/entity_control` and run the following command to try interactive UI:
```bash
python apps/gradio/entity_level_control.py
```
## Examples
### Entity-Level Controlled Image Generation
1. The effect of generating images with continuously changing entity positions.
https://github.com/user-attachments/assets/4fc76df1-b26a-46e8-a950-865cdf02a38d
2. The image generation effect of complex Entity combinations, demonstrating the strong generalization of EliGen.
|![image_1_base](https://github.com/user-attachments/assets/b8564b28-19b5-424f-bf3c-6476f2923ff9)|![image_1_base](https://github.com/user-attachments/assets/20793715-42d3-46f7-8d62-0cb4cacef38d)|
|-|-|
|![image_1_base](https://github.com/user-attachments/assets/70ef12fe-d300-4b52-9d11-eabc9b5464a8)|![image_1_enhance](https://github.com/user-attachments/assets/7645ce0e-4aa7-4b1e-b7a7-bccfd9796461)|
|![image_2_base](https://github.com/user-attachments/assets/2f1e44e1-8f1f-4c6e-ab7a-1b6861a33a69)|![image_2_enhance](https://github.com/user-attachments/assets/faf78498-57ba-41bd-b516-570c86984515)|
|![image_3_base](https://github.com/user-attachments/assets/206d1cef-2e96-4469-aed5-cdeb06ab9e99)|![image_3_enhance](https://github.com/user-attachments/assets/75d784d6-d5a1-474f-a5d5-ef8074135f35)|
### Image Inpainting
|Inpainting Input|Inpainting Output|
|-|-|
|![image_1_base](https://github.com/user-attachments/assets/5f74c710-bf30-4db1-ae40-a1e1995ccef6)|![image_1_enhance](https://github.com/user-attachments/assets/1cd71177-e956-46d3-86ce-06f774c96efd)|
|![image_2_base](https://github.com/user-attachments/assets/5ef499f3-3d8a-49cc-8ceb-86af7f5cb9f8)|![image_2_enhance](https://github.com/user-attachments/assets/fb967035-7b28-466c-a753-c00135559121)|
### Styled Entity Control
|Style Reference|Entity Control Variance 1|Entity Control Variance 2|Entity Control Variance 3|
|-|-|-|-|
|![image_1_base](https://github.com/user-attachments/assets/5e2dd3ab-37d3-4f58-8e02-ee2f9b238604)|![image_1_enhance](https://github.com/user-attachments/assets/0f6711a2-572a-41b3-938a-95deff6d732d)|![image_1_enhance](https://github.com/user-attachments/assets/ce2e66e5-1fdf-44e8-bca7-555d805a50b1)|![image_1_enhance](https://github.com/user-attachments/assets/ad2da233-2f7c-4065-ab57-b2d84dc2c0e2)|
|![image_2_base](https://github.com/user-attachments/assets/77cf7ceb-48e3-442d-8ffc-5fa4a10fe81a)|![image_2_enhance](https://github.com/user-attachments/assets/59a4f3c2-e59d-40c7-886c-0768f14fcc89)|![image_2_enhance](https://github.com/user-attachments/assets/a9187fb0-489a-49c9-a52f-56b1bd96faf7)|![image_2_enhance](https://github.com/user-attachments/assets/a62caee4-3863-4b56-96ff-e0785c6d93bb)|

View File

@@ -0,0 +1,43 @@
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
from modelscope import dataset_snapshot_download
from examples.EntityControl.utils import visualize_masks
from PIL import Image
import torch
# download and load model
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
model_manager.load_lora(
download_customized_models(
model_id="DiffSynth-Studio/Eligen",
origin_file_path="model_bf16.safetensors",
local_dir="models/lora/entity_control"
),
lora_alpha=1
)
pipe = FluxImagePipeline.from_model_manager(model_manager)
# download and load mask images
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern="data/examples/eligen/mask*")
masks = [Image.open(f"./data/examples/eligen/mask{i}.png") for i in range(1, 8)]
entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"]
global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;"
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw"
# generate image
image = pipe(
prompt=global_prompt,
cfg_scale=3.0,
negative_prompt=negative_prompt,
num_inference_steps=50,
embedded_guidance=3.5,
seed=4,
height=1024,
width=1024,
eligen_entity_prompts=entity_prompts,
eligen_entity_masks=masks,
enable_eligen_on_negative=False,
)
image.save(f"entity_control.png")
visualize_masks(image, masks, entity_prompts, f"entity_control_with_mask.png")

View File

@@ -0,0 +1,46 @@
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
from modelscope import dataset_snapshot_download
from examples.EntityControl.utils import visualize_masks
from PIL import Image
import torch
# download and load model
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev", "InstantX/FLUX.1-dev-IP-Adapter"])
model_manager.load_lora(
download_customized_models(
model_id="DiffSynth-Studio/Eligen",
origin_file_path="model_bf16.safetensors",
local_dir="models/lora/entity_control"
),
lora_alpha=1
)
pipe = FluxImagePipeline.from_model_manager(model_manager)
# download and load mask images
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern="data/examples/eligen/ipadapter*")
masks = [Image.open(f"./data/examples/eligen/ipadapter_mask_{i}.png") for i in range(1, 4)]
entity_prompts = ['A girl', 'hat', 'sunset']
global_prompt = "A girl wearing a hat, looking at the sunset"
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw"
reference_img = Image.open("./data/examples/eligen/ipadapter_image.png")
# generate image
image = pipe(
prompt=global_prompt,
cfg_scale=3.0,
negative_prompt=negative_prompt,
num_inference_steps=50,
embedded_guidance=3.5,
seed=4,
height=1024,
width=1024,
eligen_entity_prompts=entity_prompts,
eligen_entity_masks=masks,
enable_eligen_on_negative=False,
ipadapter_images=[reference_img],
ipadapter_scale=0.7
)
image.save(f"styled_entity_control.png")
visualize_masks(image, masks, entity_prompts, f"styled_entity_control_with_mask.png")

View File

@@ -0,0 +1,45 @@
from diffsynth import ModelManager, FluxImagePipeline, download_customized_models
from modelscope import dataset_snapshot_download
from examples.EntityControl.utils import visualize_masks
from PIL import Image
import torch
# download and load model
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
model_manager.load_lora(
download_customized_models(
model_id="DiffSynth-Studio/Eligen",
origin_file_path="model_bf16.safetensors",
local_dir="models/lora/entity_control"
),
lora_alpha=1
)
pipe = FluxImagePipeline.from_model_manager(model_manager)
# download and load mask images
dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern="data/examples/eligen/inpaint*")
masks = [Image.open(f"./data/examples/eligen/inpaint_mask_{i}.png") for i in range(1, 3)]
input_image = Image.open("./data/examples/eligen/inpaint_image.jpg")
entity_prompts = ["A person wear red shirt", "Airplane"]
global_prompt = "A person walking on the path in front of a house; An airplane in the sky"
negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur"
# generate image
image = pipe(
prompt=global_prompt,
input_image=input_image,
cfg_scale=3.0,
negative_prompt=negative_prompt,
num_inference_steps=50,
embedded_guidance=3.5,
seed=0,
height=1024,
width=1024,
eligen_entity_prompts=entity_prompts,
eligen_entity_masks=masks,
enable_eligen_on_negative=False,
enable_eligen_inpaint=True,
)
image.save(f"entity_inpaint.png")
visualize_masks(image, masks, entity_prompts, f"entity_inpaint_with_mask.png")

View File

@@ -0,0 +1,59 @@
from PIL import Image, ImageDraw, ImageFont
import random
def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False):
# Create a blank image for overlays
overlay = Image.new('RGBA', image.size, (0, 0, 0, 0))
colors = [
(165, 238, 173, 80),
(76, 102, 221, 80),
(221, 160, 77, 80),
(204, 93, 71, 80),
(145, 187, 149, 80),
(134, 141, 172, 80),
(157, 137, 109, 80),
(153, 104, 95, 80),
(165, 238, 173, 80),
(76, 102, 221, 80),
(221, 160, 77, 80),
(204, 93, 71, 80),
(145, 187, 149, 80),
(134, 141, 172, 80),
(157, 137, 109, 80),
(153, 104, 95, 80),
]
# Generate random colors for each mask
if use_random_colors:
colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))]
# Font settings
try:
font = ImageFont.truetype("arial", font_size) # Adjust as needed
except IOError:
font = ImageFont.load_default(font_size)
# Overlay each mask onto the overlay image
for mask, mask_prompt, color in zip(masks, mask_prompts, colors):
# Convert mask to RGBA mode
mask_rgba = mask.convert('RGBA')
mask_data = mask_rgba.getdata()
new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data]
mask_rgba.putdata(new_data)
# Draw the mask prompt text on the mask
draw = ImageDraw.Draw(mask_rgba)
mask_bbox = mask.getbbox() # Get the bounding box of the mask
text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position
draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font)
# Alpha composite the overlay with this mask
overlay = Image.alpha_composite(overlay, mask_rgba)
# Composite the overlay onto the original image
result = Image.alpha_composite(image.convert('RGBA'), overlay)
# Save or display the resulting image
result.save(output_path)
return result