diff --git a/README.md b/README.md index 3053a92..9ed8664 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,11 @@ Until now, DiffSynth Studio has supported the following models: * [Stable Diffusion](https://huggingface.co/runwayml/stable-diffusion-v1-5) ## News - +- **December 31, 2024** We propose EliGen, a novel framework for precise entity-level controlled text-to-image generation, complemented by an inpainting fusion pipeline to extend its capabilities to image inpainting tasks. EliGen seamlessly integrates with existing community models, such as IP-Adapter, enhancing its versatility. For more details, see [./examples/EntityControl](./examples/EntityControl/README.md). + * Paper: Comming soon + * Github: https://github.com/modelscope/DiffSynth-Studio + * Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) + * Training dataset: Coming soon - **December 19, 2024** We implement advanced VRAM management for HunyuanVideo, making it possible to generate videos at a resolution of 129x720x1280 using 24GB of VRAM, or at 129x512x384 resolution with just 6GB of VRAM. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details. diff --git a/apps/gradio/eligen_ui.py b/apps/gradio/eligen_ui.py new file mode 100644 index 0000000..4287ad0 --- /dev/null +++ b/apps/gradio/eligen_ui.py @@ -0,0 +1,315 @@ +import gradio as gr +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models +import os, torch +from PIL import Image +import numpy as np +from PIL import ImageDraw, ImageFont +import random +import json + +def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'): + save_dir = os.path.join('workdirs/tmp_mask', random_dir) + os.makedirs(save_dir, exist_ok=True) + for i, mask in enumerate(masks): + save_path = os.path.join(save_dir, f'{i}.png') + mask.save(save_path) + sample = { + "global_prompt": global_prompt, + "mask_prompts": mask_prompts, + "seed": seed, + } + with open(os.path.join(save_dir, f"prompts.json"), 'w') as f: + json.dump(sample, f, indent=4) + +def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + # Font settings + try: + font = ImageFont.truetype("arial", font_size) # Adjust as needed + except IOError: + font = ImageFont.load_default(font_size) + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + return result + +config = { + "model_config": { + "FLUX": { + "model_folder": "models/FLUX", + "pipeline_class": FluxImagePipeline, + "default_parameters": { + "cfg_scale": 3.0, + "embedded_guidance": 3.5, + "num_inference_steps": 50, + } + }, + }, + "max_num_painter_layers": 8, + "max_num_model_cache": 1, +} + + +def load_model_list(model_type): + if model_type is None: + return [] + folder = config["model_config"][model_type]["model_folder"] + file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")] + if model_type in ["HunyuanDiT", "Kolors", "FLUX"]: + file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))] + file_list = sorted(file_list) + return file_list + + +model_dict = {} + +def load_model(model_type, model_path): + global model_dict + model_key = f"{model_type}:{model_path}" + if model_key in model_dict: + return model_dict[model_key] + model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path) + model_manager = ModelManager() + if model_type == "FLUX": + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) + model_manager.load_lora( + download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control", + ), + lora_alpha=1, + ) + else: + model_manager.load_model(model_path) + pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager) + while len(model_dict) + 1 > config["max_num_model_cache"]: + key = next(iter(model_dict.keys())) + model_manager_to_release, _ = model_dict[key] + model_manager_to_release.to("cpu") + del model_dict[key] + torch.cuda.empty_cache() + model_dict[model_key] = model_manager, pipe + return model_manager, pipe + + +with gr.Blocks() as app: + gr.Markdown(""" + # 实体级控制文生图模型EliGen + **UI说明** + 1. **点击Load model读取模型**,然后左侧界面为文生图输入参数;右侧Painter为局部控制区域绘制区域,每个局部控制条件由其Local prompt和绘制的mask组成,支持精准控制文生图和Inpainting两种模式。 + 2. **精准控制生图模式:** 输入Globalprompt;激活并绘制一个或多个局部控制条件,点击Generate生成图像; Global Prompt推荐包含每个Local Prompt。 + 3. **Inpainting模式:** 你可以上传图像,或者将上一步生成的图像设置为Inpaint Input Image,采用类似的方式输入局部控制条件,进行局部重绘。 + 4. 尽情创造! + """) + gr.Markdown(""" + # Entity-Level Controlled Text-to-Image Model: EliGen + **UI Instructions** + 1. **Click "Load model" to load the model.** The left interface is for text-to-image input parameters; the right "Painter" is the area for drawing local control regions. Each local control condition consists of its Local Prompt and the drawn mask, supporting both precise control text-to-image and Inpainting modes. + 2. **Precise Control Image Generation Mode:** Enter the Global Prompt; activate and draw one or more local control conditions, then click "Generate" to create the image. It is recommended that the Global Prompt includes all Local Prompts. + 3. **Inpainting Mode:** You can upload an image or set the image generated in the previous step as the "Inpaint Input Image." Use a similar method to input local control conditions for local redrawing. + 4. Enjoy! + """) + with gr.Row(): + random_mask_dir = gr.State('') + with gr.Column(scale=382, min_width=100): + model_type = gr.State('FLUX') + model_path = gr.State('FLUX.1-dev') + with gr.Accordion(label="Model"): + load_model_button = gr.Button(value="Load model") + + with gr.Accordion(label="Global prompt"): + prompt = gr.Textbox(label="Prompt", lines=3) + negative_prompt = gr.Textbox(label="Negative prompt", value="worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,", lines=1) + cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale") + embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale") + + with gr.Accordion(label="Inference Options"): + num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps") + height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height") + width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width") + return_with_mask = gr.Checkbox(value=True, interactive=True, label="show result with mask painting") + with gr.Column(): + use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed") + seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False) + with gr.Accordion(label="Inpaint Input Image (Testing)"): + input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil") + background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight") + + with gr.Column(): + reset_input_button = gr.Button(value="Reset Inpaint Input") + send_input_to_painter = gr.Button(value="Set as painter's background") + @gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click) + def reset_input_image(input_image): + return None + @gr.on( + inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask], + outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, load_model_button, random_mask_dir], + triggers=load_model_button.click + ) + def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask): + load_model(model_type, model_path) + cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale) + embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance) + num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps) + height = config["model_config"][model_type]["default_parameters"].get("height", height) + width = config["model_config"][model_type]["default_parameters"].get("width", width) + return_with_mask = config["model_config"][model_type]["default_parameters"].get("return_with_mask", return_with_mask) + return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, gr.update(value="Loaded FLUX"), gr.State(f'{random.randint(0, 1000000):08d}') + + + with gr.Column(scale=618, min_width=100): + with gr.Accordion(label="Painter"): + enable_local_prompt_list = [] + local_prompt_list = [] + mask_scale_list = [] + canvas_list = [] + for painter_layer_id in range(config["max_num_painter_layers"]): + with gr.Tab(label=f"Layer {painter_layer_id}"): + enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}") + local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}") + mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}") + canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA", + brush=gr.Brush(default_size=30, default_color="#000000", colors=["#000000"]), + label="Painter", key=f"canvas_{painter_layer_id}") + @gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden") + def resize_canvas(height, width, canvas): + h, w = canvas["background"].shape[:2] + if h != height or width != w: + return np.ones((height, width, 3), dtype=np.uint8) * 255 + else: + return canvas + + enable_local_prompt_list.append(enable_local_prompt) + local_prompt_list.append(local_prompt) + mask_scale_list.append(mask_scale) + canvas_list.append(canvas) + with gr.Accordion(label="Results"): + run_button = gr.Button(value="Generate", variant="primary") + output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil") + with gr.Row(): + with gr.Column(): + output_to_painter_button = gr.Button(value="Set as painter's background") + with gr.Column(): + output_to_input_button = gr.Button(value="Set as input image") + real_output = gr.State(None) + mask_out = gr.State(None) + + @gr.on( + inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list, + outputs=[output_image, real_output, mask_out], + triggers=run_button.click + ) + def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()): + _, pipe = load_model(model_type, model_path) + input_params = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "cfg_scale": cfg_scale, + "num_inference_steps": num_inference_steps, + "height": height, + "width": width, + "progress_bar_cmd": progress.tqdm, + } + if isinstance(pipe, FluxImagePipeline): + input_params["embedded_guidance"] = embedded_guidance + if input_image is not None: + input_params["input_image"] = input_image.resize((width, height)).convert("RGB") + input_params["enable_eligen_inpaint"] = True + + enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = ( + args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]], + args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]], + args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]], + args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]] + ) + local_prompts, masks, mask_scales = [], [], [] + for enable_local_prompt, local_prompt, mask_scale, canvas in zip( + enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list + ): + if enable_local_prompt: + local_prompts.append(local_prompt) + masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB")) + mask_scales.append(mask_scale) + entity_masks = None if len(masks) == 0 else masks + entity_prompts = None if len(local_prompts) == 0 else local_prompts + input_params.update({ + "eligen_entity_prompts": entity_prompts, + "eligen_entity_masks": entity_masks, + }) + torch.manual_seed(seed) + image = pipe(**input_params) + # visualize masks + masks = [mask.resize(image.size) for mask in masks] + image_with_mask = visualize_masks(image, masks, local_prompts) + # save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir.value) + + real_output = gr.State(image) + mask_out = gr.State(image_with_mask) + + if return_with_mask: + return image_with_mask, real_output, mask_out + return image, real_output, mask_out + + @gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click) + def send_input_to_painter_background(input_image, *canvas_list): + if input_image is None: + return tuple(canvas_list) + for canvas in canvas_list: + h, w = canvas["background"].shape[:2] + canvas["background"] = input_image.resize((w, h)) + return tuple(canvas_list) + @gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click) + def send_output_to_painter_background(real_output, *canvas_list): + if real_output is None: + return tuple(canvas_list) + for canvas in canvas_list: + h, w = canvas["background"].shape[:2] + canvas["background"] = real_output.value.resize((w, h)) + return tuple(canvas_list) + @gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden") + def show_output(return_with_mask, real_output, mask_out): + if return_with_mask: + return mask_out.value + else: + return real_output.value + @gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click) + def send_output_to_pipe_input(real_output): + return real_output.value + +app.launch() diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 6011495..7a01478 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -40,7 +40,7 @@ class RoPEEmbedding(torch.nn.Module): n_axes = ids.shape[-1] emb = torch.cat([self.rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], dim=-3) return emb.unsqueeze(1) - + class FluxJointAttention(torch.nn.Module): @@ -70,7 +70,7 @@ class FluxJointAttention(torch.nn.Module): xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, ipadapter_kwargs_list=None): + def forward(self, hidden_states_a, hidden_states_b, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): batch_size = hidden_states_a.shape[0] # Part A @@ -91,7 +91,7 @@ class FluxJointAttention(torch.nn.Module): q, k = self.apply_rope(q, k, image_rotary_emb) - hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v) + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.to(q.dtype) hidden_states_b, hidden_states_a = hidden_states[:, :hidden_states_b.shape[1]], hidden_states[:, hidden_states_b.shape[1]:] @@ -103,7 +103,7 @@ class FluxJointAttention(torch.nn.Module): else: hidden_states_b = self.b_to_out(hidden_states_b) return hidden_states_a, hidden_states_b - + class FluxJointTransformerBlock(torch.nn.Module): @@ -129,12 +129,12 @@ class FluxJointTransformerBlock(torch.nn.Module): ) - def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None): + def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb) norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb) # Attention - attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, ipadapter_kwargs_list) + attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b, image_rotary_emb, attn_mask, ipadapter_kwargs_list) # Part A hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a @@ -147,7 +147,7 @@ class FluxJointTransformerBlock(torch.nn.Module): hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b) return hidden_states_a, hidden_states_b - + class FluxSingleAttention(torch.nn.Module): @@ -184,7 +184,7 @@ class FluxSingleAttention(torch.nn.Module): hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.to(q.dtype) return hidden_states - + class AdaLayerNormSingle(torch.nn.Module): @@ -200,7 +200,7 @@ class AdaLayerNormSingle(torch.nn.Module): shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1) x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] return x, gate_msa - + class FluxSingleTransformerBlock(torch.nn.Module): @@ -225,8 +225,8 @@ class FluxSingleTransformerBlock(torch.nn.Module): xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) - - def process_attention(self, hidden_states, image_rotary_emb, ipadapter_kwargs_list=None): + + def process_attention(self, hidden_states, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): batch_size = hidden_states.shape[0] qkv = hidden_states.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2) @@ -235,7 +235,7 @@ class FluxSingleTransformerBlock(torch.nn.Module): q, k = self.apply_rope(q, k, image_rotary_emb) - hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v) + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) hidden_states = hidden_states.to(q.dtype) if ipadapter_kwargs_list is not None: @@ -243,21 +243,21 @@ class FluxSingleTransformerBlock(torch.nn.Module): return hidden_states - def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, ipadapter_kwargs_list=None): + def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb, attn_mask=None, ipadapter_kwargs_list=None): residual = hidden_states_a norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb) hidden_states_a = self.to_qkv_mlp(norm_hidden_states) attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:] - attn_output = self.process_attention(attn_output, image_rotary_emb, ipadapter_kwargs_list) + attn_output = self.process_attention(attn_output, image_rotary_emb, attn_mask, ipadapter_kwargs_list) mlp_hidden_states = torch.nn.functional.gelu(mlp_hidden_states, approximate="tanh") hidden_states_a = torch.cat([attn_output, mlp_hidden_states], dim=2) hidden_states_a = gate.unsqueeze(1) * self.proj_out(hidden_states_a) hidden_states_a = residual + hidden_states_a - + return hidden_states_a, hidden_states_b - + class AdaLayerNormContinuous(torch.nn.Module): @@ -300,7 +300,7 @@ class FluxDiT(torch.nn.Module): def unpatchify(self, hidden_states, height, width): hidden_states = rearrange(hidden_states, "B (H W) (C P Q) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2) return hidden_states - + def prepare_image_ids(self, latents): batch_size, _, height, width = latents.shape @@ -317,7 +317,7 @@ class FluxDiT(torch.nn.Module): latent_image_ids = latent_image_ids.to(device=latents.device, dtype=latents.dtype) return latent_image_ids - + def tiled_forward( self, @@ -338,11 +338,75 @@ class FluxDiT(torch.nn.Module): return hidden_states + def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len): + N = len(entity_masks) + batch_size = entity_masks[0].shape[0] + total_seq_len = N * prompt_seq_len + image_seq_len + patched_masks = [self.patchify(entity_masks[i]) for i in range(N)] + attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), dtype=torch.bool).to(device=entity_masks[0].device) + + image_start = N * prompt_seq_len + image_end = N * prompt_seq_len + image_seq_len + # prompt-image mask + for i in range(N): + prompt_start = i * prompt_seq_len + prompt_end = (i + 1) * prompt_seq_len + image_mask = torch.sum(patched_masks[i], dim=-1) > 0 + image_mask = image_mask.unsqueeze(1).repeat(1, prompt_seq_len, 1) + # prompt update with image + attention_mask[:, prompt_start:prompt_end, image_start:image_end] = image_mask + # image update with prompt + attention_mask[:, image_start:image_end, prompt_start:prompt_end] = image_mask.transpose(1, 2) + # prompt-prompt mask + for i in range(N): + for j in range(N): + if i != j: + prompt_start_i = i * prompt_seq_len + prompt_end_i = (i + 1) * prompt_seq_len + prompt_start_j = j * prompt_seq_len + prompt_end_j = (j + 1) * prompt_seq_len + attention_mask[:, prompt_start_i:prompt_end_i, prompt_start_j:prompt_end_j] = False + + attention_mask = attention_mask.float() + attention_mask[attention_mask == 0] = float('-inf') + attention_mask[attention_mask == 1] = 0 + return attention_mask + + + def process_entity_masks(self, hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids): + repeat_dim = hidden_states.shape[1] + max_masks = 0 + attention_mask = None + prompt_embs = [prompt_emb] + if entity_masks is not None: + # entity_masks + batch_size, max_masks = entity_masks.shape[0], entity_masks.shape[1] + entity_masks = entity_masks.repeat(1, 1, repeat_dim, 1, 1) + entity_masks = [entity_masks[:, i, None].squeeze(1) for i in range(max_masks)] + # global mask + global_mask = torch.ones_like(entity_masks[0]).to(device=hidden_states.device, dtype=hidden_states.dtype) + entity_masks = entity_masks + [global_mask] # append global to last + # attention mask + attention_mask = self.construct_mask(entity_masks, prompt_emb.shape[1], hidden_states.shape[1]) + attention_mask = attention_mask.to(device=hidden_states.device, dtype=hidden_states.dtype) + attention_mask = attention_mask.unsqueeze(1) + # embds: n_masks * b * seq * d + local_embs = [entity_prompt_emb[:, i, None].squeeze(1) for i in range(max_masks)] + prompt_embs = local_embs + prompt_embs # append global to last + prompt_embs = [self.context_embedder(prompt_emb) for prompt_emb in prompt_embs] + prompt_emb = torch.cat(prompt_embs, dim=1) + + # positional embedding + text_ids = torch.cat([text_ids] * (max_masks + 1), dim=1) + image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + return prompt_emb, image_rotary_emb, attention_mask + + def forward( self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None, - tiled=False, tile_size=128, tile_stride=64, + tiled=False, tile_size=128, tile_stride=64, entity_prompt_emb=None, entity_masks=None, use_gradient_checkpointing=False, **kwargs ): @@ -353,46 +417,51 @@ class FluxDiT(torch.nn.Module): tile_size=tile_size, tile_stride=tile_stride, **kwargs ) - + if image_ids is None: image_ids = self.prepare_image_ids(hidden_states) - + conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb) if self.guidance_embedder is not None: guidance = guidance * 1000 conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype) - prompt_emb = self.context_embedder(prompt_emb) - image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) height, width = hidden_states.shape[-2:] hidden_states = self.patchify(hidden_states) hidden_states = self.x_embedder(hidden_states) - + + if entity_prompt_emb is not None and entity_masks is not None: + prompt_emb, image_rotary_emb, attention_mask = self.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids) + else: + prompt_emb = self.context_embedder(prompt_emb) + image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + attention_mask = None + def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward - + for block in self.blocks: if self.training and use_gradient_checkpointing: hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - hidden_states, prompt_emb, conditioning, image_rotary_emb, + hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, use_reentrant=False, ) else: - hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask) hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) for block in self.single_blocks: if self.training and use_gradient_checkpointing: hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint( create_custom_forward(block), - hidden_states, prompt_emb, conditioning, image_rotary_emb, + hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask, use_reentrant=False, ) else: - hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) + hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask) hidden_states = hidden_states[:, prompt_emb.shape[1]:] hidden_states = self.final_norm_out(hidden_states, conditioning) @@ -400,7 +469,7 @@ class FluxDiT(torch.nn.Module): hidden_states = self.unpatchify(hidden_states, height, width) return hidden_states - + def quantize(self): def cast_to(weight, dtype=None, device=None, copy=False): @@ -440,16 +509,16 @@ class FluxDiT(torch.nn.Module): class Linear(torch.nn.Linear): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - + def forward(self,input,**kwargs): weight,bias= cast_bias_weight(self,input) return torch.nn.functional.linear(input,weight,bias) - + class RMSNorm(torch.nn.Module): def __init__(self, module): super().__init__() self.module = module - + def forward(self,hidden_states,**kwargs): weight= cast_weight(self.module,hidden_states) input_dtype = hidden_states.dtype @@ -457,7 +526,7 @@ class FluxDiT(torch.nn.Module): hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps) hidden_states = hidden_states.to(input_dtype) * weight return hidden_states - + def replace_layer(model): for name, module in model.named_children(): if isinstance(module, torch.nn.Linear): @@ -483,7 +552,6 @@ class FluxDiT(torch.nn.Module): @staticmethod def state_dict_converter(): return FluxDiTStateDictConverter() - class FluxDiTStateDictConverter: @@ -587,7 +655,7 @@ class FluxDiTStateDictConverter: state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) return state_dict_ - + def from_civitai(self, state_dict): rename_dict = { "time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias", diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index dcee6d3..96ee86a 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -366,17 +366,21 @@ class ModelManager: def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0): - print(f"Loading LoRA models from file: {file_path}") - if len(state_dict) == 0: - state_dict = load_state_dict(file_path) - for model_name, model, model_path in zip(self.model_name, self.model, self.model_path): - for lora in get_lora_loaders(): - match_results = lora.match(model, state_dict) - if match_results is not None: - print(f" Adding LoRA to {model_name} ({model_path}).") - lora_prefix, model_resource = match_results - lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource) - break + if isinstance(file_path, list): + for file_path_ in file_path: + self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha) + else: + print(f"Loading LoRA models from file: {file_path}") + if len(state_dict) == 0: + state_dict = load_state_dict(file_path) + for model_name, model, model_path in zip(self.model_name, self.model, self.model_path): + for lora in get_lora_loaders(): + match_results = lora.match(model, state_dict) + if match_results is not None: + print(f" Adding LoRA to {model_name} ({model_path}).") + lora_prefix, model_resource = match_results + lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource) + break def load_model(self, file_path, model_names=None, device=None, torch_dtype=None): diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 42d142c..b6fac68 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -10,6 +10,7 @@ import numpy as np from PIL import Image from ..models.tiler import FastTileWorker from transformers import SiglipVisionModel +from copy import deepcopy class FluxImagePipeline(BasePipeline): @@ -59,6 +60,7 @@ class FluxImagePipeline(BasePipeline): self.ipadapter = model_manager.fetch_model("flux_ipadapter") self.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model") + @staticmethod def from_model_manager(model_manager: ModelManager, controlnet_config_units: List[ControlNetConfigUnit]=[], prompt_refiner_classes=[], prompt_extender_classes=[], device=None): pipe = FluxImagePipeline( @@ -133,37 +135,156 @@ class FluxImagePipeline(BasePipeline): # store it controlnet_frames.append(image) return controlnet_frames - + + def prepare_ipadapter_inputs(self, images, height=384, width=384): images = [image.convert("RGB").resize((width, height), resample=3) for image in images] images = [self.preprocess_image(image).to(device=self.device, dtype=self.torch_dtype) for image in images] return torch.cat(images, dim=0) + + def inpaint_fusion(self, latents, inpaint_latents, pred_noise, fg_mask, bg_mask, progress_id, background_weight=0.): + # inpaint noise + inpaint_noise = (latents - inpaint_latents) / self.scheduler.sigmas[progress_id] + # merge noise + weight = torch.ones_like(inpaint_noise) + inpaint_noise[fg_mask] = pred_noise[fg_mask] + inpaint_noise[bg_mask] += pred_noise[bg_mask] * background_weight + weight[bg_mask] += background_weight + inpaint_noise /= weight + return inpaint_noise + + + def preprocess_masks(self, masks, height, width, dim): + out_masks = [] + for mask in masks: + mask = self.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0 + mask = mask.repeat(1, dim, 1, 1).to(device=self.device, dtype=self.torch_dtype) + out_masks.append(mask) + return out_masks + + + def prepare_entity_inputs(self, entity_prompts, entity_masks, width, height, t5_sequence_length=512, enable_eligen_inpaint=False): + fg_mask, bg_mask = None, None + if enable_eligen_inpaint: + masks_ = deepcopy(entity_masks) + fg_masks = torch.cat([self.preprocess_image(mask.resize((width//8, height//8))).mean(dim=1, keepdim=True) for mask in masks_]) + fg_masks = (fg_masks > 0).float() + fg_mask = fg_masks.sum(dim=0, keepdim=True).repeat(1, 16, 1, 1) > 0 + bg_mask = ~fg_mask + entity_masks = self.preprocess_masks(entity_masks, height//8, width//8, 1) + entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w + entity_prompts = self.encode_prompt(entity_prompts, t5_sequence_length=t5_sequence_length)['prompt_emb'].unsqueeze(0) + return entity_prompts, entity_masks, fg_mask, bg_mask + + + def prepare_latents(self, input_image, height, width, seed, tiled, tile_size, tile_stride): + if input_image is not None: + self.load_models_to_device(['vae_encoder']) + image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) + input_latents = self.encode_image(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + latents = self.scheduler.add_noise(input_latents, noise, timestep=self.scheduler.timesteps[0]) + else: + latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + input_latents = None + return latents, input_latents + + + def prepare_ipadapter(self, ipadapter_images, ipadapter_scale): + if ipadapter_images is not None: + self.load_models_to_device(['ipadapter_image_encoder']) + ipadapter_images = self.prepare_ipadapter_inputs(ipadapter_images) + ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images).pooler_output + self.load_models_to_device(['ipadapter']) + ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)} + ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))} + else: + ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}} + return ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega + + + def prepare_controlnet(self, controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative): + if controlnet_image is not None: + self.load_models_to_device(['vae_encoder']) + controlnet_kwargs_posi = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)} + if len(masks) > 0 and controlnet_inpaint_mask is not None: + print("The controlnet_inpaint_mask will be overridden by masks.") + local_controlnet_kwargs = [{"controlnet_frames": self.prepare_controlnet_input(controlnet_image, mask, tiler_kwargs)} for mask in masks] + else: + local_controlnet_kwargs = None + else: + controlnet_kwargs_posi, local_controlnet_kwargs = {"controlnet_frames": None}, [{}] * len(masks) + controlnet_kwargs_nega = controlnet_kwargs_posi if enable_controlnet_on_negative else {} + return controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs + + + def prepare_eligen(self, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale): + if eligen_entity_masks is not None: + entity_prompt_emb_posi, entity_masks_posi, fg_mask, bg_mask = self.prepare_entity_inputs(eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint) + if enable_eligen_on_negative and cfg_scale != 1.0: + entity_prompt_emb_nega = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, eligen_entity_masks.shape[1], 1, 1) + entity_masks_nega = eligen_entity_masks + else: + entity_prompt_emb_nega, entity_masks_nega = None, None + else: + entity_prompt_emb_posi, entity_masks_posi, entity_prompt_emb_nega, entity_masks_nega = None, None, None, None + fg_mask, bg_mask = None, None + eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi} + eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega} + return eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask + + + def prepare_prompts(self, prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale): + # Extend prompt + self.load_models_to_device(['text_encoder_1', 'text_encoder_2']) + prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales) + + # Encode prompts + prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length) + prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) if cfg_scale != 1.0 else None + prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts] + return prompt_emb_posi, prompt_emb_nega, prompt_emb_locals + + @torch.no_grad() def __call__( self, + # Prompt prompt, - local_prompts=None, - masks=None, - mask_scales=None, negative_prompt="", cfg_scale=1.0, embedded_guidance=3.5, + t5_sequence_length=512, + # Image input_image=None, - ipadapter_images=None, - ipadapter_scale=1.0, - controlnet_image=None, - controlnet_inpaint_mask=None, - enable_controlnet_on_negative=False, denoising_strength=1.0, height=1024, width=1024, + seed=None, + # Steps num_inference_steps=30, - t5_sequence_length=512, + # local prompts + local_prompts=(), + masks=(), + mask_scales=(), + # ControlNet + controlnet_image=None, + controlnet_inpaint_mask=None, + enable_controlnet_on_negative=False, + # IP-Adapter + ipadapter_images=None, + ipadapter_scale=1.0, + # EliGen + eligen_entity_prompts=None, + eligen_entity_masks=None, + enable_eligen_on_negative=False, + enable_eligen_inpaint=False, + # Tile tiled=False, tile_size=128, tile_stride=64, - seed=None, + # Progress bar progress_bar_cmd=tqdm, progress_bar_st=None, ): @@ -176,72 +297,50 @@ class FluxImagePipeline(BasePipeline): self.scheduler.set_timesteps(num_inference_steps, denoising_strength) # Prepare latent tensors - if input_image is not None: - self.load_models_to_device(['vae_encoder']) - image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) - latents = self.encode_image(image, **tiler_kwargs) - noise = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) - latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) - else: - latents = self.generate_noise((1, 16, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + latents, input_latents = self.prepare_latents(input_image, height, width, seed, tiled, tile_size, tile_stride) - # Extend prompt - self.load_models_to_device(['text_encoder_1', 'text_encoder_2']) - prompt, local_prompts, masks, mask_scales = self.extend_prompt(prompt, local_prompts, masks, mask_scales) - - # Encode prompts - prompt_emb_posi = self.encode_prompt(prompt, t5_sequence_length=t5_sequence_length) - if cfg_scale != 1.0: - prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False, t5_sequence_length=t5_sequence_length) - prompt_emb_locals = [self.encode_prompt(prompt_local, t5_sequence_length=t5_sequence_length) for prompt_local in local_prompts] + # Prompt + prompt_emb_posi, prompt_emb_nega, prompt_emb_locals = self.prepare_prompts(prompt, local_prompts, masks, mask_scales, t5_sequence_length, negative_prompt, cfg_scale) # Extra input extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) - # IP-Adapter - if ipadapter_images is not None: - self.load_models_to_device(['ipadapter_image_encoder']) - ipadapter_images = self.prepare_ipadapter_inputs(ipadapter_images) - ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images).pooler_output - self.load_models_to_device(['ipadapter']) - ipadapter_kwargs_list_posi = {"ipadapter_kwargs_list": self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)} - ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": self.ipadapter(torch.zeros_like(ipadapter_image_encoding))} - else: - ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {"ipadapter_kwargs_list": {}}, {"ipadapter_kwargs_list": {}} + # Entity control + eligen_kwargs_posi, eligen_kwargs_nega, fg_mask, bg_mask = self.prepare_eligen(prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_inpaint, enable_eligen_on_negative, cfg_scale) - # Prepare ControlNets - if controlnet_image is not None: - self.load_models_to_device(['vae_encoder']) - controlnet_kwargs = {"controlnet_frames": self.prepare_controlnet_input(controlnet_image, controlnet_inpaint_mask, tiler_kwargs)} - if len(masks) > 0 and controlnet_inpaint_mask is not None: - print("The controlnet_inpaint_mask will be overridden by masks.") - local_controlnet_kwargs = [{"controlnet_frames": self.prepare_controlnet_input(controlnet_image, mask, tiler_kwargs)} for mask in masks] - else: - local_controlnet_kwargs = None - else: - controlnet_kwargs, local_controlnet_kwargs = {"controlnet_frames": None}, [{}] * len(masks) + # IP-Adapter + ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = self.prepare_ipadapter(ipadapter_images, ipadapter_scale) + + # ControlNets + controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative) # Denoise self.load_models_to_device(['dit', 'controlnet']) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): timestep = timestep.unsqueeze(0).to(self.device) - # Classifier-free guidance + # Positive side inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux( dit=self.dit, controlnet=self.controlnet, hidden_states=latents, timestep=timestep, - **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, + **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, ) noise_pred_posi = self.control_noise_via_local_prompts( prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback, - special_kwargs=controlnet_kwargs, special_local_kwargs_list=local_controlnet_kwargs + special_kwargs=controlnet_kwargs_posi, special_local_kwargs_list=local_controlnet_kwargs ) + + # Inpaint + if enable_eligen_inpaint: + noise_pred_posi = self.inpaint_fusion(latents, input_latents, noise_pred_posi, fg_mask, bg_mask, progress_id) + + # Classifier-free guidance if cfg_scale != 1.0: - negative_controlnet_kwargs = controlnet_kwargs if enable_controlnet_on_negative else {} + # Negative side noise_pred_nega = lets_dance_flux( dit=self.dit, controlnet=self.controlnet, hidden_states=latents, timestep=timestep, - **prompt_emb_nega, **tiler_kwargs, **extra_input, **negative_controlnet_kwargs, **ipadapter_kwargs_list_nega, + **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: @@ -256,7 +355,7 @@ class FluxImagePipeline(BasePipeline): # Decode image self.load_models_to_device(['vae_decoder']) - image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + image = self.decode_image(latents, **tiler_kwargs) # Offload all models self.load_models_to_device([]) @@ -278,6 +377,8 @@ def lets_dance_flux( tiled=False, tile_size=128, tile_stride=64, + entity_prompt_emb=None, + entity_masks=None, ipadapter_kwargs_list={}, **kwargs ): @@ -333,13 +434,18 @@ def lets_dance_flux( if dit.guidance_embedder is not None: guidance = guidance * 1000 conditioning = conditioning + dit.guidance_embedder(guidance, hidden_states.dtype) - prompt_emb = dit.context_embedder(prompt_emb) - image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) height, width = hidden_states.shape[-2:] hidden_states = dit.patchify(hidden_states) hidden_states = dit.x_embedder(hidden_states) - + + if entity_prompt_emb is not None and entity_masks is not None: + prompt_emb, image_rotary_emb, attention_mask = dit.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids) + else: + prompt_emb = dit.context_embedder(prompt_emb) + image_rotary_emb = dit.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) + attention_mask = None + # Joint Blocks for block_id, block in enumerate(dit.blocks): hidden_states, prompt_emb = block( @@ -347,7 +453,9 @@ def lets_dance_flux( prompt_emb, conditioning, image_rotary_emb, - ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None)) + attention_mask, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id, None) + ) # ControlNet if controlnet is not None and controlnet_frames is not None: hidden_states = hidden_states + controlnet_res_stack[block_id] @@ -361,8 +469,9 @@ def lets_dance_flux( prompt_emb, conditioning, image_rotary_emb, - ipadapter_kwargs_list=ipadapter_kwargs_list.get( - block_id + num_joint_blocks, None)) + attention_mask, + ipadapter_kwargs_list=ipadapter_kwargs_list.get(block_id + num_joint_blocks, None) + ) # ControlNet if controlnet is not None and controlnet_frames is not None: hidden_states[:, prompt_emb.shape[1]:] = hidden_states[:, prompt_emb.shape[1]:] + controlnet_single_res_stack[block_id] diff --git a/examples/EntityControl/README.md b/examples/EntityControl/README.md new file mode 100644 index 0000000..94c4cb5 --- /dev/null +++ b/examples/EntityControl/README.md @@ -0,0 +1,60 @@ +# EliGen: Entity-Level Controlled Image Generation + +## Introduction + +We propose EliGen, a novel approach that leverages fine-grained entity-level information to enable precise and controllable text-to-image generation. EliGen excels in tasks such as entity-level controlled image generation and image inpainting, while its applicability is not limited to these areas. Additionally, it can be seamlessly integrated with existing community models, such as the IP-Adpater. + +* Paper: Comming soon +* Github: https://github.com/modelscope/DiffSynth-Studio +* Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) +* Training dataset: Coming soon + +## Methodology + +![regional-attention](https://github.com/user-attachments/assets/9a147201-15ab-421f-a6c5-701075754478) + +We introduce a regional attention mechanism within the DiT framework to effectively process the conditions of each entity. This mechanism enables the local prompt associated with each entity to semantically influence specific regions through regional attention. To further enhance the layout control capabilities of EliGen, we meticulously curate an entity-annotated dataset and fine-tune the model using the LoRA framework. + +1. **Regional Attention**: Regional attention is shown in above figure, which can be easily applied to other text-to-image models. Its core principle involves transforming the positional information of each entity into an attention mask, ensuring that the mechanism only affects the designated regions. + +2. **Dataset with Entity Annotation**: To curate a dedicated entity control dataset, we start by randomly selecting captions from DiffusionDB and generating the corresponding source image using Flux. Next, we employ Qwen2-VL 72B, recognized for its advanced grounding capabilities among MLLMs, to randomly identify entities within the image. These entities are annotated with local prompts and bounding boxes for precise localization, forming the foundation of our dataset for further training. + +3. **Training**: We apply LoRA and deepspeed to finetune regional attention with curated dataset, enabling our EliGen performing effective entity-level control. + +## Usage +1. **Entity-Level Controlled Image Generation** +See [./entity_control.py](./entity_control.py) for usage. +2. **Image Inpainting** + To apply EliGen to image inpainting task, we propose a inpainting fusion pipeline to preserve the non-painting areas while enabling precise, entity-level modifications over inpaining regions. + See [./entity_inpaint.py](./entity_inpaint.py) for usage. +3. **Styled Entity Control** + EliGen can be seamlessly integrated with existing community models. We have provided an example of how to integrate it with the IP-Adpater. See [./entity_control_ipadapter.py](./entity_control_ipadapter.py) for usage. +4. **Play with EliGen using UI** + Download the checkpoint of EliGen from [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) to `models/lora/entity_control` and run the following command to try interactive UI: + ```bash + python apps/gradio/entity_level_control.py + ``` +## Examples +### Entity-Level Controlled Image Generation + +1. The effect of generating images with continuously changing entity positions. + +https://github.com/user-attachments/assets/4fc76df1-b26a-46e8-a950-865cdf02a38d + +2. The image generation effect of complex Entity combinations, demonstrating the strong generalization of EliGen. + +|![image_1_base](https://github.com/user-attachments/assets/b8564b28-19b5-424f-bf3c-6476f2923ff9)|![image_1_base](https://github.com/user-attachments/assets/20793715-42d3-46f7-8d62-0cb4cacef38d)| +|-|-| +|![image_1_base](https://github.com/user-attachments/assets/70ef12fe-d300-4b52-9d11-eabc9b5464a8)|![image_1_enhance](https://github.com/user-attachments/assets/7645ce0e-4aa7-4b1e-b7a7-bccfd9796461)| +|![image_2_base](https://github.com/user-attachments/assets/2f1e44e1-8f1f-4c6e-ab7a-1b6861a33a69)|![image_2_enhance](https://github.com/user-attachments/assets/faf78498-57ba-41bd-b516-570c86984515)| +|![image_3_base](https://github.com/user-attachments/assets/206d1cef-2e96-4469-aed5-cdeb06ab9e99)|![image_3_enhance](https://github.com/user-attachments/assets/75d784d6-d5a1-474f-a5d5-ef8074135f35)| +### Image Inpainting +|Inpainting Input|Inpainting Output| +|-|-| +|![image_1_base](https://github.com/user-attachments/assets/5f74c710-bf30-4db1-ae40-a1e1995ccef6)|![image_1_enhance](https://github.com/user-attachments/assets/1cd71177-e956-46d3-86ce-06f774c96efd)| +|![image_2_base](https://github.com/user-attachments/assets/5ef499f3-3d8a-49cc-8ceb-86af7f5cb9f8)|![image_2_enhance](https://github.com/user-attachments/assets/fb967035-7b28-466c-a753-c00135559121)| +### Styled Entity Control +|Style Reference|Entity Control Variance 1|Entity Control Variance 2|Entity Control Variance 3| +|-|-|-|-| +|![image_1_base](https://github.com/user-attachments/assets/5e2dd3ab-37d3-4f58-8e02-ee2f9b238604)|![image_1_enhance](https://github.com/user-attachments/assets/0f6711a2-572a-41b3-938a-95deff6d732d)|![image_1_enhance](https://github.com/user-attachments/assets/ce2e66e5-1fdf-44e8-bca7-555d805a50b1)|![image_1_enhance](https://github.com/user-attachments/assets/ad2da233-2f7c-4065-ab57-b2d84dc2c0e2)| +|![image_2_base](https://github.com/user-attachments/assets/77cf7ceb-48e3-442d-8ffc-5fa4a10fe81a)|![image_2_enhance](https://github.com/user-attachments/assets/59a4f3c2-e59d-40c7-886c-0768f14fcc89)|![image_2_enhance](https://github.com/user-attachments/assets/a9187fb0-489a-49c9-a52f-56b1bd96faf7)|![image_2_enhance](https://github.com/user-attachments/assets/a62caee4-3863-4b56-96ff-e0785c6d93bb)| \ No newline at end of file diff --git a/examples/EntityControl/entity_control.py b/examples/EntityControl/entity_control.py new file mode 100644 index 0000000..f505b94 --- /dev/null +++ b/examples/EntityControl/entity_control.py @@ -0,0 +1,43 @@ +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models +from modelscope import dataset_snapshot_download +from examples.EntityControl.utils import visualize_masks +from PIL import Image +import torch + + +# download and load model +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) +model_manager.load_lora( + download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control" + ), + lora_alpha=1 +) +pipe = FluxImagePipeline.from_model_manager(model_manager) + +# download and load mask images +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern="data/examples/eligen/mask*") +masks = [Image.open(f"./data/examples/eligen/mask{i}.png") for i in range(1, 8)] + +entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"] +global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;" +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw" + +# generate image +image = pipe( + prompt=global_prompt, + cfg_scale=3.0, + negative_prompt=negative_prompt, + num_inference_steps=50, + embedded_guidance=3.5, + seed=4, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + enable_eligen_on_negative=False, +) +image.save(f"entity_control.png") +visualize_masks(image, masks, entity_prompts, f"entity_control_with_mask.png") diff --git a/examples/EntityControl/entity_control_ipadapter.py b/examples/EntityControl/entity_control_ipadapter.py new file mode 100644 index 0000000..c604bad --- /dev/null +++ b/examples/EntityControl/entity_control_ipadapter.py @@ -0,0 +1,46 @@ +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models +from modelscope import dataset_snapshot_download +from examples.EntityControl.utils import visualize_masks +from PIL import Image +import torch + + +# download and load model +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev", "InstantX/FLUX.1-dev-IP-Adapter"]) +model_manager.load_lora( + download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control" + ), + lora_alpha=1 +) +pipe = FluxImagePipeline.from_model_manager(model_manager) + +# download and load mask images +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern="data/examples/eligen/ipadapter*") +masks = [Image.open(f"./data/examples/eligen/ipadapter_mask_{i}.png") for i in range(1, 4)] + +entity_prompts = ['A girl', 'hat', 'sunset'] +global_prompt = "A girl wearing a hat, looking at the sunset" +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw" +reference_img = Image.open("./data/examples/eligen/ipadapter_image.png") + +# generate image +image = pipe( + prompt=global_prompt, + cfg_scale=3.0, + negative_prompt=negative_prompt, + num_inference_steps=50, + embedded_guidance=3.5, + seed=4, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + enable_eligen_on_negative=False, + ipadapter_images=[reference_img], + ipadapter_scale=0.7 +) +image.save(f"styled_entity_control.png") +visualize_masks(image, masks, entity_prompts, f"styled_entity_control_with_mask.png") diff --git a/examples/EntityControl/entity_inpaint.py b/examples/EntityControl/entity_inpaint.py new file mode 100644 index 0000000..d62da4f --- /dev/null +++ b/examples/EntityControl/entity_inpaint.py @@ -0,0 +1,45 @@ +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models +from modelscope import dataset_snapshot_download +from examples.EntityControl.utils import visualize_masks +from PIL import Image +import torch + +# download and load model +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) +model_manager.load_lora( + download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control" + ), + lora_alpha=1 +) +pipe = FluxImagePipeline.from_model_manager(model_manager) + +# download and load mask images +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern="data/examples/eligen/inpaint*") +masks = [Image.open(f"./data/examples/eligen/inpaint_mask_{i}.png") for i in range(1, 3)] +input_image = Image.open("./data/examples/eligen/inpaint_image.jpg") + +entity_prompts = ["A person wear red shirt", "Airplane"] +global_prompt = "A person walking on the path in front of a house; An airplane in the sky" +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur" + +# generate image +image = pipe( + prompt=global_prompt, + input_image=input_image, + cfg_scale=3.0, + negative_prompt=negative_prompt, + num_inference_steps=50, + embedded_guidance=3.5, + seed=0, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + enable_eligen_on_negative=False, + enable_eligen_inpaint=True, +) +image.save(f"entity_inpaint.png") +visualize_masks(image, masks, entity_prompts, f"entity_inpaint_with_mask.png") diff --git a/examples/EntityControl/utils.py b/examples/EntityControl/utils.py new file mode 100644 index 0000000..573c8a6 --- /dev/null +++ b/examples/EntityControl/utils.py @@ -0,0 +1,59 @@ +from PIL import Image, ImageDraw, ImageFont +import random + +def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + + # Font settings + try: + font = ImageFont.truetype("arial", font_size) # Adjust as needed + except IOError: + font = ImageFont.load_default(font_size) + + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + + # Save or display the resulting image + result.save(output_path) + + return result