From b6620f3dde192f994b094c71d61119f8f6e9addd Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Tue, 31 Dec 2024 14:04:28 +0800 Subject: [PATCH] update_example entity control --- apps/gradio/entity_level_control.py | 320 ++++++++++++++++++ examples/EntityControl/README.md | 1 + examples/EntityControl/entity_control.py | 57 ++++ examples/EntityControl/entity_control_flux.py | 54 --- .../EntityControl/entity_control_ipadapter.py | 51 +++ examples/EntityControl/entity_inpaint.py | 58 ++++ examples/EntityControl/entity_inpaint_flux.py | 59 ---- 7 files changed, 487 insertions(+), 113 deletions(-) create mode 100644 apps/gradio/entity_level_control.py create mode 100644 examples/EntityControl/README.md create mode 100644 examples/EntityControl/entity_control.py delete mode 100644 examples/EntityControl/entity_control_flux.py create mode 100644 examples/EntityControl/entity_control_ipadapter.py create mode 100644 examples/EntityControl/entity_inpaint.py delete mode 100644 examples/EntityControl/entity_inpaint_flux.py diff --git a/apps/gradio/entity_level_control.py b/apps/gradio/entity_level_control.py new file mode 100644 index 0000000..d914cd0 --- /dev/null +++ b/apps/gradio/entity_level_control.py @@ -0,0 +1,320 @@ +import gradio as gr +from diffsynth import ModelManager, FluxImagePipeline +import os, torch +from PIL import Image +import numpy as np +from PIL import ImageDraw, ImageFont +import random +import json + +lora_checkpoint_path = 'models/lora/entity_control/model_bf16.safetensors' +save_masks_dir = 'workdirs/tmp_mask' + +def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'): + save_dir = os.path.join(save_masks_dir, random_dir) + os.makedirs(save_dir, exist_ok=True) + for i, mask in enumerate(masks): + save_path = os.path.join(save_dir, f'{i}.png') + mask.save(save_path) + sample = { + "global_prompt": global_prompt, + "mask_prompts": mask_prompts, + "seed": seed, + } + with open(os.path.join(save_dir, f"prompts.json"), 'w') as f: + json.dump(sample, f, indent=4) + +def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + # Font settings + try: + font = ImageFont.truetype("arial", font_size) # Adjust as needed + except IOError: + font = ImageFont.load_default(font_size) + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + return result + +config = { + "model_config": { + "FLUX": { + "model_folder": "models/FLUX", + "pipeline_class": FluxImagePipeline, + "default_parameters": { + "cfg_scale": 3.0, + "embedded_guidance": 3.5, + "num_inference_steps": 30, + } + }, + }, + "max_num_painter_layers": 8, + "max_num_model_cache": 1, +} + + +def load_model_list(model_type): + if model_type is None: + return [] + folder = config["model_config"][model_type]["model_folder"] + file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")] + if model_type in ["HunyuanDiT", "Kolors", "FLUX"]: + file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))] + file_list = sorted(file_list) + return file_list + + +model_dict = {} + +def load_model(model_type, model_path): + global model_dict + model_key = f"{model_type}:{model_path}" + if model_key in model_dict: + return model_dict[model_key] + model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path) + model_manager = ModelManager() + if model_type == "FLUX": + model_manager.torch_dtype = torch.bfloat16 + file_list = [ + os.path.join(model_path, "text_encoder/model.safetensors"), + os.path.join(model_path, "text_encoder_2"), + ] + for file_name in os.listdir(model_path): + if file_name.endswith(".safetensors"): + file_list.append(os.path.join(model_path, file_name)) + model_manager.load_models(file_list) + model_manager.load_lora(lora_checkpoint_path, lora_alpha=1.) + + else: + model_manager.load_model(model_path) + pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager) + while len(model_dict) + 1 > config["max_num_model_cache"]: + key = next(iter(model_dict.keys())) + model_manager_to_release, _ = model_dict[key] + model_manager_to_release.to("cpu") + del model_dict[key] + torch.cuda.empty_cache() + model_dict[model_key] = model_manager, pipe + return model_manager, pipe + + +with gr.Blocks() as app: + gr.Markdown(""" + # 实体级控制文生图模型EliGen + **UI说明** + 1. 点击Load model读取模型,然后左侧界面为文生图输入参数;右侧Painter为局部控制区域绘制区域,每个局部控制条件由其Local prompt和绘制的mask组成,支持精准控制文生图和Inpainting两种模式。 + 2. **精准控制生图模式:** 输入Globalprompt;激活并绘制一个或多个局部控制条件,点击Generate生成图像; Global Prompt推荐包含每个Local Prompt。 + 3. **Inpainting模式:** 你可以上传图像,或者将上一步生成的图像设置为Inpaint Input Image,采用类似的方式输入局部控制条件,进行局部重绘。 + 4. 尽情创造! + """) + gr.Markdown(""" + # Entity-Level Controlled Text-to-Image Model: EliGen + **UI Instructions** + 1. Click "Load model" to load the model. The left interface is for text-to-image input parameters; the right "Painter" is the area for drawing local control regions. Each local control condition consists of its Local Prompt and the drawn mask, supporting both precise control text-to-image and Inpainting modes. + 2. **Precise Control Image Generation Mode:** Enter the Global Prompt; activate and draw one or more local control conditions, then click "Generate" to create the image. It is recommended that the Global Prompt includes all Local Prompts. + 3. **Inpainting Mode:** You can upload an image or set the image generated in the previous step as the "Inpaint Input Image." Use a similar method to input local control conditions for local redrawing. + 4. Enjoy! + """) + with gr.Row(): + random_mask_dir = gr.State('') + with gr.Column(scale=382, min_width=100): + model_type = gr.State('FLUX') + model_path = gr.State('FLUX.1-dev') + with gr.Accordion(label="Model"): + load_model_button = gr.Button(value="Load model") + + with gr.Accordion(label="Global prompt"): + prompt = gr.Textbox(label="Prompt", lines=3) + negative_prompt = gr.Textbox(label="Negative prompt", value="worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,", lines=1) + cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale") + embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale") + + with gr.Accordion(label="Inference Options"): + num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps") + height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height") + width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width") + return_with_mask = gr.Checkbox(value=True, interactive=True, label="show result with mask painting") + with gr.Column(): + use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed") + seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False) + with gr.Accordion(label="Inpaint Input Image (Testing)"): + input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil") + background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight") + + with gr.Column(): + reset_input_button = gr.Button(value="Reset Inpaint Input") + send_input_to_painter = gr.Button(value="Set as painter's background") + @gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click) + def reset_input_image(input_image): + return None + @gr.on( + inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask], + outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, load_model_button, random_mask_dir], + triggers=load_model_button.click + ) + def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask): + load_model(model_type, model_path) + cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale) + embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance) + num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps) + height = config["model_config"][model_type]["default_parameters"].get("height", height) + width = config["model_config"][model_type]["default_parameters"].get("width", width) + return_with_mask = config["model_config"][model_type]["default_parameters"].get("return_with_mask", return_with_mask) + return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, gr.update(value="Loaded FLUX"), gr.State(f'{random.randint(0, 1000000):08d}') + + + with gr.Column(scale=618, min_width=100): + with gr.Accordion(label="Painter"): + enable_local_prompt_list = [] + local_prompt_list = [] + mask_scale_list = [] + canvas_list = [] + for painter_layer_id in range(config["max_num_painter_layers"]): + with gr.Tab(label=f"Layer {painter_layer_id}"): + enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}") + local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}") + mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}") + canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA", + brush=gr.Brush(default_size=30, default_color="#000000", colors=["#000000"]), + label="Painter", key=f"canvas_{painter_layer_id}") + @gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden") + def resize_canvas(height, width, canvas): + h, w = canvas["background"].shape[:2] + if h != height or width != w: + return np.ones((height, width, 3), dtype=np.uint8) * 255 + else: + return canvas + + enable_local_prompt_list.append(enable_local_prompt) + local_prompt_list.append(local_prompt) + mask_scale_list.append(mask_scale) + canvas_list.append(canvas) + with gr.Accordion(label="Results"): + run_button = gr.Button(value="Generate", variant="primary") + output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil") + with gr.Row(): + with gr.Column(): + output_to_painter_button = gr.Button(value="Set as painter's background") + with gr.Column(): + output_to_input_button = gr.Button(value="Set as input image") + real_output = gr.State(None) + mask_out = gr.State(None) + + @gr.on( + inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list, + outputs=[output_image, real_output, mask_out], + triggers=run_button.click + ) + def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()): + print("save to:", random_mask_dir.value) + _, pipe = load_model(model_type, model_path) + input_params = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "cfg_scale": cfg_scale, + "num_inference_steps": num_inference_steps, + "height": height, + "width": width, + "progress_bar_cmd": progress.tqdm, + } + if isinstance(pipe, FluxImagePipeline): + input_params["embedded_guidance"] = embedded_guidance + if input_image is not None: + input_params["inpaint_input"] = input_image.resize((width, height)).convert("RGB") + + enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = ( + args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]], + args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]], + args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]], + args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]] + ) + local_prompts, masks, mask_scales = [], [], [] + for enable_local_prompt, local_prompt, mask_scale, canvas in zip( + enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list + ): + if enable_local_prompt: + local_prompts.append(local_prompt) + masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB")) + mask_scales.append(mask_scale) + entity_masks = None if len(masks) == 0 else masks + entity_prompts = None if len(local_prompts) == 0 else local_prompts + input_params.update({ + "entity_prompts": entity_prompts, + "entity_masks": entity_masks, + }) + torch.manual_seed(seed) + image = pipe(**input_params) + # visualize masks + masks = [mask.resize(image.size) for mask in masks] + image_with_mask = visualize_masks(image, masks, local_prompts) + # save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir.value) + + real_output = gr.State(image) + mask_out = gr.State(image_with_mask) + + if return_with_mask: + return image_with_mask, real_output, mask_out + return image, real_output, mask_out + + @gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click) + def send_input_to_painter_background(input_image, *canvas_list): + if input_image is None: + return tuple(canvas_list) + for canvas in canvas_list: + h, w = canvas["background"].shape[:2] + canvas["background"] = input_image.resize((w, h)) + return tuple(canvas_list) + @gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click) + def send_output_to_painter_background(real_output, *canvas_list): + if real_output is None: + return tuple(canvas_list) + for canvas in canvas_list: + h, w = canvas["background"].shape[:2] + canvas["background"] = real_output.value.resize((w, h)) + return tuple(canvas_list) + @gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden") + def show_output(return_with_mask, real_output, mask_out): + if return_with_mask: + return mask_out.value + else: + return real_output.value + @gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click) + def send_output_to_pipe_input(real_output): + return real_output.value + +app.launch() diff --git a/examples/EntityControl/README.md b/examples/EntityControl/README.md new file mode 100644 index 0000000..784c590 --- /dev/null +++ b/examples/EntityControl/README.md @@ -0,0 +1 @@ +# EliGen: Entity-Level Controlled Image Generation diff --git a/examples/EntityControl/entity_control.py b/examples/EntityControl/entity_control.py new file mode 100644 index 0000000..2735682 --- /dev/null +++ b/examples/EntityControl/entity_control.py @@ -0,0 +1,57 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models +from examples.EntityControl.utils import visualize_masks +from PIL import Image +import requests +from io import BytesIO + +# download and load model +lora_path = download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control" +)[0] +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) +model_manager.load_lora(lora_path, lora_alpha=1.) +pipe = FluxImagePipeline.from_model_manager(model_manager) + +# prepare inputs +image_shape = 1024 +seed = 4 +# set True to apply regional attention in negative prompt prediction for better results with more time +use_seperated_negtive_prompt = False +mask_urls = [ + 'https://github.com/user-attachments/assets/02905f6e-40c2-4482-9abe-b1ce50ccabbf', + 'https://github.com/user-attachments/assets/a4cf4361-abf7-4556-ba94-74683eda4cb7', + 'https://github.com/user-attachments/assets/b6595ff4-7269-4d8f-acf0-5df40bd6c59f', + 'https://github.com/user-attachments/assets/941d39a7-3aa1-437f-8b2a-4adb15d2fb3e', + 'https://github.com/user-attachments/assets/400c4086-5398-4291-b1b5-22d8483c08d9', + 'https://github.com/user-attachments/assets/ce324c77-fa1d-4aad-a5cb-698f0d5eca70', + 'https://github.com/user-attachments/assets/4e62325f-a60c-44f7-b53b-6da0869bb9db' +] +# prepare entity masks, entity prompts, global prompt and negative prompt +masks = [] +for url in mask_urls: + response = requests.get(url) + mask = Image.open(BytesIO(response.content)).resize((image_shape, image_shape), resample=Image.NEAREST) + masks.append(mask) +entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"] +global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;" +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw" + +# generate image +torch.manual_seed(seed) +image = pipe( + prompt=global_prompt, + cfg_scale=3.0, + negative_prompt=negative_prompt, + num_inference_steps=50, + embedded_guidance=3.5, + height=image_shape, + width=image_shape, + entity_prompts=entity_prompts, + entity_masks=masks, + use_seperated_negtive_prompt=use_seperated_negtive_prompt, +) +image.save(f"entity_control.png") +visualize_masks(image, masks, entity_prompts, f"entity_control_with_mask.png") diff --git a/examples/EntityControl/entity_control_flux.py b/examples/EntityControl/entity_control_flux.py deleted file mode 100644 index 5e74e0b..0000000 --- a/examples/EntityControl/entity_control_flux.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch -from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline -from examples.EntityControl.utils import visualize_masks -import os -import json -from PIL import Image - -# lora_path = download_customized_models( -# model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", -# origin_file_path="merged_lora.safetensors", -# local_dir="models/lora" -# )[0] - -lora_path = '/root/model_bf16.safetensors' -model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") -model_manager.load_models([ - "t2i_models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", - "t2i_models/FLUX/FLUX.1-dev/text_encoder_2", - "t2i_models/FLUX/FLUX.1-dev/ae.safetensors", - "t2i_models/FLUX/FLUX.1-dev/flux1-dev.safetensors" -]) -model_manager.load_lora(lora_path, lora_alpha=1.) - -pipe = FluxImagePipeline.from_model_manager(model_manager) - -mask_dir = '/mnt/nas1/zhanghong/DiffSynth-Studio/workdirs/tmp_mask' -image_shape = 1024 -guidance = 3.5 -cfg = 3.0 -negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," -names = ['row_2_1'] -seeds = [0] -# use this to apply regional attention in negative prompt prediction for better results with more time -use_seperated_negtive_prompt = False -for name, seed in zip(names, seeds): - out_dir = f'workdirs/entity_control/{name}' - os.makedirs(out_dir, exist_ok=True) - cur_dir = os.path.join(mask_dir, name) - metas = json.load(open(os.path.join(mask_dir, name, 'prompts.json'))) - for seed in range(3, 10): - prompt = metas['global_prompt'] - mask_prompts = metas['mask_prompts'] - masks = [Image.open(os.path.join(mask_dir, name, f"{mask_idx}.png")).resize((image_shape, image_shape), resample=Image.NEAREST) for mask_idx in range(len(mask_prompts))] - torch.manual_seed(seed) - image = pipe( - prompt=prompt, - cfg_scale=cfg, - negative_prompt=negative_prompt, - num_inference_steps=50, embedded_guidance=guidance, height=image_shape, width=image_shape, - entity_prompts=mask_prompts, entity_masks=masks, - use_seperated_negtive_prompt=use_seperated_negtive_prompt - ) - use_sep = f'_sepneg' if use_seperated_negtive_prompt else '' - visualize_masks(image, masks, mask_prompts, os.path.join(out_dir, f"{name}_{seed}{use_sep}.png")) diff --git a/examples/EntityControl/entity_control_ipadapter.py b/examples/EntityControl/entity_control_ipadapter.py new file mode 100644 index 0000000..76ae869 --- /dev/null +++ b/examples/EntityControl/entity_control_ipadapter.py @@ -0,0 +1,51 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models +from examples.EntityControl.utils import visualize_masks +from PIL import Image +import requests +from io import BytesIO + +lora_path = download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control" +)[0] +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev", "InstantX/FLUX.1-dev-IP-Adapter"]) +model_manager.load_lora(lora_path, lora_alpha=1.) +pipe = FluxImagePipeline.from_model_manager(model_manager) + +# prepare inputs +image_shape = 1024 +seed = 4 +# set True to apply regional attention in negative prompt prediction for better results with more time +use_seperated_negtive_prompt = False +mask_urls = [ + 'https://github.com/user-attachments/assets/e6745b3f-ab2b-4612-9bb5-b7235474a9a4', + 'https://github.com/user-attachments/assets/5ddf9a89-32fa-4540-89ad-e956130942b3', + 'https://github.com/user-attachments/assets/9d8a0bb0-6817-497e-af85-44f2512afe79' +] +# prepare entity masks, entity prompts, global prompt and negative prompt +masks = [] +for url in mask_urls: + response = requests.get(url) + mask = Image.open(BytesIO(response.content)).resize((image_shape, image_shape), resample=Image.NEAREST) + masks.append(mask) +entity_prompts = ['A girl', 'hat', 'sunset'] +global_prompt = "A girl wearing a hat, looking at the sunset" +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw" + +response = requests.get('https://github.com/user-attachments/assets/019bbfaa-04b3-4de6-badb-32b67c29a1bc') +reference_img = Image.open(BytesIO(response.content)).convert('RGB').resize((image_shape, image_shape)) + +torch.manual_seed(seed) +image = pipe( + prompt=global_prompt, + cfg_scale=3.0, + negative_prompt=negative_prompt, + num_inference_steps=50, embedded_guidance=3.5, height=image_shape, width=image_shape, + entity_prompts=entity_prompts, entity_masks=masks, + use_seperated_negtive_prompt=use_seperated_negtive_prompt, + ipadapter_images=[reference_img], ipadapter_scale=0.7 +) +image.save(f"styled_entity_control.png") +visualize_masks(image, masks, entity_prompts, f"styled_entity_control_with_mask.png") diff --git a/examples/EntityControl/entity_inpaint.py b/examples/EntityControl/entity_inpaint.py new file mode 100644 index 0000000..eae4641 --- /dev/null +++ b/examples/EntityControl/entity_inpaint.py @@ -0,0 +1,58 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline +from examples.EntityControl.utils import visualize_masks +import os +import json +from PIL import Image +import requests +from io import BytesIO + +# download and load model +lora_path = download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control" +)[0] +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) +model_manager.load_lora(lora_path, lora_alpha=1.) +pipe = FluxImagePipeline.from_model_manager(model_manager) + +# prepare inputs +image_shape = 1024 +seed = 0 +# set True to apply regional attention in negative prompt prediction for better results with more time +use_seperated_negtive_prompt = False +mask_urls = [ + 'https://github.com/user-attachments/assets/0cf78663-5314-4280-a065-31ded7a24a46', + 'https://github.com/user-attachments/assets/bd3938b8-72a8-4d56-814f-f6445971b91d' +] +# prepare entity masks, entity prompts, global prompt and negative prompt +masks = [] +for url in mask_urls: + response = requests.get(url) + mask = Image.open(BytesIO(response.content)).resize((image_shape, image_shape), resample=Image.NEAREST) + masks.append(mask) +entity_prompts = ["A person wear red shirt", "Airplane"] +global_prompt = "A person walking on the path in front of a house; An airplane in the sky" +negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur" + +response = requests.get('https://github.com/user-attachments/assets/fa4d6ba5-08fd-4fc7-adbb-19898d839364') +inpaint_input = Image.open(BytesIO(response.content)).convert('RGB').resize((image_shape, image_shape)) + +# generate image +torch.manual_seed(seed) +image = pipe( + prompt=global_prompt, + cfg_scale=3.0, + negative_prompt=negative_prompt, + num_inference_steps=50, + embedded_guidance=3.5, + height=image_shape, + width=image_shape, + entity_prompts=entity_prompts, + entity_masks=masks, + inpaint_input=inpaint_input, + use_seperated_negtive_prompt=use_seperated_negtive_prompt, +) +image.save(f"entity_inpaint.png") +visualize_masks(image, masks, entity_prompts, f"entity_inpaint_with_mask.png") \ No newline at end of file diff --git a/examples/EntityControl/entity_inpaint_flux.py b/examples/EntityControl/entity_inpaint_flux.py deleted file mode 100644 index 5027563..0000000 --- a/examples/EntityControl/entity_inpaint_flux.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -from diffsynth import ModelManager, FluxImagePipeline, download_customized_models, FluxImageLoraPipeline -from examples.EntityControl.utils import visualize_masks -import os -import json -from PIL import Image - -# lora_path = download_customized_models( -# model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", -# origin_file_path="merged_lora.safetensors", -# local_dir="models/lora" -# )[0] - -lora_path = '/root/model_bf16.safetensors' -model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") -model_manager.load_models([ - "t2i_models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", - "t2i_models/FLUX/FLUX.1-dev/text_encoder_2", - "t2i_models/FLUX/FLUX.1-dev/ae.safetensors", - "t2i_models/FLUX/FLUX.1-dev/flux1-dev.safetensors" -]) -model_manager.load_lora(lora_path, lora_alpha=1.) - -pipe = FluxImagePipeline.from_model_manager(model_manager) - -mask_dir = '/mnt/nas1/zhanghong/DiffSynth-Studio/workdirs/tmp_mask' -image_shape = 1024 -guidance = 3.5 -cfg = 3.0 -negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," -names = ['inpaint2'] -seeds = [0] -use_seperated_negtive_prompt = False -for name, seed in zip(names, seeds): - out_dir = f'workdirs/paper_app/inpaint/elc/{name}' - os.makedirs(out_dir, exist_ok=True) - cur_dir = os.path.join(mask_dir, name) - metas = json.load(open(os.path.join(mask_dir, name, 'prompts.json'))) - inpaint_input = Image.open(os.path.join(cur_dir, 'input.png')).convert('RGB') - prompt = metas['global_prompt'] - prompt = 'A person with a dog walking on the cloud. A rocket in the sky' - mask_prompts = metas['mask_prompts'] - masks = [Image.open(os.path.join(mask_dir, name, f"{mask_idx}.png")).resize((image_shape, image_shape), resample=Image.NEAREST) for mask_idx in range(len(mask_prompts))] - torch.manual_seed(seed) - image = pipe( - prompt=prompt, - cfg_scale=cfg, - negative_prompt=negative_prompt, - num_inference_steps=50, - embedded_guidance=guidance, - height=image_shape, - width=image_shape, - entity_prompts=mask_prompts, - entity_masks=masks, - inpaint_input=inpaint_input, - use_seperated_negtive_prompt=use_seperated_negtive_prompt, - ) - use_sep = f'_sepneg' if use_seperated_negtive_prompt else '' - visualize_masks(image, masks, mask_prompts, os.path.join(out_dir, f"{name}_{seed}{use_sep}.png"))