diff --git a/apps/gradio/qwen_image_eligen.py b/apps/gradio/qwen_image_eligen.py new file mode 100644 index 0000000..38dcf71 --- /dev/null +++ b/apps/gradio/qwen_image_eligen.py @@ -0,0 +1,382 @@ +import os +import torch +import numpy as np +from PIL import Image, ImageDraw, ImageFont +import random +import json +import gradio as gr +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models +from modelscope import dataset_snapshot_download +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig + +# pip install pydantic==2.10.6 +# pip install gradio==5.4.0 + + +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/*") +example_json = 'data/examples/eligen/entity_control/ui_examples.json' +with open(example_json, 'r') as f: + examples = json.load(f)['examples'] + +for idx in range(len(examples)): + example_id = examples[idx]['example_id'] + entity_prompts = examples[idx]['local_prompt_list'] + examples[idx]['mask_lists'] = [Image.open(f"data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))] + +def create_canvas_data(background, masks): + if background.shape[-1] == 3: + background = np.dstack([background, np.full(background.shape[:2], 255, dtype=np.uint8)]) + layers = [] + for mask in masks: + if mask is not None: + mask_single_channel = mask if mask.ndim == 2 else mask[..., 0] + layer = np.zeros((mask_single_channel.shape[0], mask_single_channel.shape[1], 4), dtype=np.uint8) + layer[..., -1] = mask_single_channel + layers.append(layer) + else: + layers.append(np.zeros_like(background)) + + composite = background.copy() + for layer in layers: + if layer.size > 0: + composite = np.where(layer[..., -1:] > 0, layer, composite) + return { + "background": background, + "layers": layers, + "composite": composite, + } + +def load_example(load_example_button): + example_idx = int(load_example_button.split()[-1]) - 1 + example = examples[example_idx] + result = [ + 50, + example["global_prompt"], + example["negative_prompt"], + example["seed"], + *example["local_prompt_list"], + ] + num_entities = len(example["local_prompt_list"]) + result += [""] * (config["max_num_painter_layers"] - num_entities) + masks = [] + for mask in example["mask_lists"]: + mask_single_channel = np.array(mask.convert("L")) + masks.append(mask_single_channel) + for _ in range(config["max_num_painter_layers"] - len(masks)): + blank_mask = np.zeros_like(masks[0]) if masks else np.zeros((512, 512), dtype=np.uint8) + masks.append(blank_mask) + background = np.ones((masks[0].shape[0], masks[0].shape[1], 4), dtype=np.uint8) * 255 + canvas_data_list = [] + for mask in masks: + canvas_data = create_canvas_data(background, [mask]) + canvas_data_list.append(canvas_data) + result.extend(canvas_data_list) + return result + +def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'): + save_dir = os.path.join('workdirs/tmp_mask', random_dir) + print(f'save to {save_dir}') + os.makedirs(save_dir, exist_ok=True) + for i, mask in enumerate(masks): + save_path = os.path.join(save_dir, f'{i}.png') + mask.save(save_path) + sample = { + "global_prompt": global_prompt, + "mask_prompts": mask_prompts, + "seed": seed, + } + with open(os.path.join(save_dir, f"prompts.json"), 'w', encoding='utf-8') as f: + json.dump(sample, f, ensure_ascii=False, indent=4) + +def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + # Font settings + font = ImageFont.truetype("dinglieciweifont20250217-2.ttf", font_size) # Adjust as needed + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + if mask is None: + continue + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + if mask_bbox is None: + continue + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + return result + +config = { + "max_num_painter_layers": 8, + "max_num_model_cache": 1, +} + +model_dict = {} + +def load_model(model_type='qwen-image'): + global model_dict + model_key = f"{model_type}" + if model_key in model_dict: + return model_dict[model_key] + pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + ) + pipe.load_lora(pipe.dit, "models/train/Qwen-Image-EliGen_lora/step-20000.safetensors") + model_dict[model_key] = pipe + return pipe + +load_model('qwen-image') + +with gr.Blocks() as app: + gr.Markdown( + """## EliGen: Entity-Level Controllable Text-to-Image Model + 1. On the left, input the **global prompt** for the overall image, such as "a person stands by the river." + 2. On the right, input the **local prompt** for each entity, such as "person," and draw the corresponding mask in the **Entity Mask Painter**. Generally, solid rectangular masks yield better results. + 3. Click the **Generate** button to create the image. By selecting different **random seeds**, you can generate diverse images. + 4. **You can directly click the "Load Example" button on any sample at the bottom to load example inputs.** + """ + ) + + loading_status = gr.Textbox(label="Loading Model...", value="Loading model... Please wait...", visible=True) + main_interface = gr.Column(visible=False) + + def initialize_model(): + try: + load_model('qwen-image') + return { + loading_status: gr.update(value="Model loaded successfully!", visible=False), + main_interface: gr.update(visible=True), + } + except Exception as e: + print(f'Failed to load model with error: {e}') + return { + loading_status: gr.update(value=f"Failed to load model: {str(e)}", visible=True), + main_interface: gr.update(visible=True), + } + + app.load(initialize_model, inputs=None, outputs=[loading_status, main_interface]) + + with main_interface: + with gr.Row(): + local_prompt_list = [] + canvas_list = [] + random_mask_dir = gr.State(f'{random.randint(0, 1000000):08d}') + with gr.Column(scale=382, min_width=100): + model_type = gr.State('qwen-image') + with gr.Accordion(label="Global prompt"): + prompt = gr.Textbox(label="Global Prompt", lines=3) + negative_prompt = gr.Textbox(label="Negative prompt", value="", lines=3) + with gr.Accordion(label="Inference Options", open=True): + seed = gr.Number(minimum=0, maximum=10**9, value=42, interactive=True, label="Random seed", show_label=True) + num_inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, interactive=True, label="Inference steps") + cfg_scale = gr.Slider(minimum=2.0, maximum=10.0, value=4.0, step=0.1, interactive=True, label="Classifier-free guidance scale") + embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=3.5, step=0.1, interactive=True, label="Embedded guidance scale") + height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height") + width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width") + with gr.Accordion(label="Inpaint Input Image", open=False): + input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil") + background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight", visible=False) + + with gr.Column(): + reset_input_button = gr.Button(value="Reset Inpaint Input") + send_input_to_painter = gr.Button(value="Set as painter's background") + @gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click) + def reset_input_image(input_image): + return None + + with gr.Column(scale=618, min_width=100): + with gr.Accordion(label="Entity Painter"): + for painter_layer_id in range(config["max_num_painter_layers"]): + with gr.Tab(label=f"Entity {painter_layer_id}"): + local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}") + canvas = gr.ImageEditor( + canvas_size=(1024, 1024), + sources=None, + layers=False, + interactive=True, + image_mode="RGBA", + brush=gr.Brush( + default_size=50, + default_color="#000000", + colors=["#000000"], + ), + label="Entity Mask Painter", + key=f"canvas_{painter_layer_id}", + width=width, + height=height, + ) + @gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear], show_progress="hidden") + def resize_canvas(height, width, canvas): + if canvas is None or canvas["background"] is None: + return np.ones((height, width, 3), dtype=np.uint8) * 255 + h, w = canvas["background"].shape[:2] + if h != height or width != w: + return np.ones((height, width, 3), dtype=np.uint8) * 255 + else: + return canvas + local_prompt_list.append(local_prompt) + canvas_list.append(canvas) + with gr.Accordion(label="Results"): + run_button = gr.Button(value="Generate", variant="primary") + output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil") + with gr.Row(): + with gr.Column(): + output_to_painter_button = gr.Button(value="Set as painter's background") + with gr.Column(): + return_with_mask = gr.Checkbox(value=False, interactive=True, label="show result with mask painting") + output_to_input_button = gr.Button(value="Set as input image", visible=False, interactive=False) + real_output = gr.State(None) + mask_out = gr.State(None) + + @gr.on( + inputs=[model_type, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + local_prompt_list + canvas_list, + outputs=[output_image, real_output, mask_out], + triggers=run_button.click + ) + def generate_image(model_type, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()): + pipe = load_model(model_type) + input_params = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "cfg_scale": cfg_scale, + "num_inference_steps": num_inference_steps, + "height": height, + "width": width, + "progress_bar_cmd": progress.tqdm, + } + if isinstance(pipe, FluxImagePipeline): + input_params["embedded_guidance"] = embedded_guidance + if input_image is not None: + input_params["input_image"] = input_image.resize((width, height)).convert("RGB") + input_params["enable_eligen_inpaint"] = True + + local_prompt_list, canvas_list = ( + args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]], + args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]], + ) + local_prompts, masks = [], [] + for local_prompt, canvas in zip(local_prompt_list, canvas_list): + if isinstance(local_prompt, str) and len(local_prompt) > 0: + local_prompts.append(local_prompt) + masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB")) + entity_prompts = None if len(local_prompts) == 0 else local_prompts + entity_masks = None if len(masks) == 0 or entity_prompts is None else masks + input_params.update({ + "eligen_entity_prompts": entity_prompts, + "eligen_entity_masks": entity_masks, + }) + torch.manual_seed(seed) + save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir) + image = pipe(**input_params) + masks = [mask.resize(image.size) for mask in masks] + image_with_mask = visualize_masks(image, masks, local_prompts) + + real_output = gr.State(image) + mask_out = gr.State(image_with_mask) + + if return_with_mask: + return image_with_mask, real_output, mask_out + return image, real_output, mask_out + + @gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click) + def send_input_to_painter_background(input_image, *canvas_list): + if input_image is None: + return tuple(canvas_list) + for canvas in canvas_list: + h, w = canvas["background"].shape[:2] + canvas["background"] = input_image.resize((w, h)) + return tuple(canvas_list) + @gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click) + def send_output_to_painter_background(real_output, *canvas_list): + if real_output is None: + return tuple(canvas_list) + for canvas in canvas_list: + h, w = canvas["background"].shape[:2] + canvas["background"] = real_output.value.resize((w, h)) + return tuple(canvas_list) + @gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden") + def show_output(return_with_mask, real_output, mask_out): + if return_with_mask: + return mask_out.value + else: + return real_output.value + @gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click) + def send_output_to_pipe_input(real_output): + return real_output.value + + with gr.Column(): + gr.Markdown("## Examples") + for i in range(0, len(examples), 2): + with gr.Row(): + if i < len(examples): + example = examples[i] + with gr.Column(): + example_image = gr.Image( + value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png", + label=example["description"], + interactive=False, + width=1024, + height=512 + ) + load_example_button = gr.Button(value=f"Load Example {example['example_id']}") + load_example_button.click( + load_example, + inputs=[load_example_button], + outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list + ) + + if i + 1 < len(examples): + example = examples[i + 1] + with gr.Column(): + example_image = gr.Image( + value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png", + label=example["description"], + interactive=False, + width=1024, + height=512 + ) + load_example_button = gr.Button(value=f"Load Example {example['example_id']}") + load_example_button.click( + load_example, + inputs=[load_example_button], + outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list + ) +app.config["show_progress"] = "hidden" +app.launch(share=False) diff --git a/examples/qwen_image/model_inference/Qwen-Image-EliGen.py b/examples/qwen_image/model_inference/Qwen-Image-EliGen.py index ef06eef..76bee7a 100644 --- a/examples/qwen_image/model_inference/Qwen-Image-EliGen.py +++ b/examples/qwen_image/model_inference/Qwen-Image-EliGen.py @@ -62,6 +62,26 @@ def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_r return result +def example(pipe, seeds, example_id, global_prompt, entity_prompts): + dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/example_{example_id}/*.png") + masks = [Image.open(f"./data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))] + negative_prompt = "" + for seed in seeds: + # generate image + image = pipe( + prompt=global_prompt, + cfg_scale=4.0, + negative_prompt=negative_prompt, + num_inference_steps=30, + seed=seed, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + ) + image.save(f"eligen_example_{example_id}_{seed}.png") + visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png") + pipe = QwenImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, @@ -73,17 +93,41 @@ pipe = QwenImagePipeline.from_pretrained( ], tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), ) -example_id = 1 -global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a long dress, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea\n" -dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/example_{example_id}/*.png") -entity_prompts = ["cliff", "sea", "red moon", "sailing boat", "a seated beautiful woman wearing red dress", "yellow long dress"] -masks = [Image.open(f"./data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))] +pipe.load_lora(pipe.dit, "models/train/Qwen-Image-EliGen_lora/step-20000.safetensors") -for seed in range(20): - image = pipe(global_prompt, seed=seed, num_inference_steps=40, eligen_entity_prompts=entity_prompts, eligen_entity_masks=masks, cfg_scale=4.0, height=1024, width=1024) - image.save(f"workdirs/qwen_image/eligen_{seed}.jpg") - visualize_masks(image, masks, entity_prompts, f"workdirs/qwen_image/eligen_{seed}_mask.png") +# example 1 +global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n" +entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"] +example(pipe, [0], 1, global_prompt, entity_prompts) - image1 = pipe(global_prompt, seed=seed, num_inference_steps=40, height=1024, width=1024, cfg_scale=4.0) - image1.save(f"workdirs/qwen_image/qwenimage_{seed}.jpg") +# example 2 +global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render." +entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "blue belt"] +example(pipe, [0], 2, global_prompt, entity_prompts) + +# example 3 +global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning," +entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"] +example(pipe, [27], 3, global_prompt, entity_prompts) + +# example 4 +global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'" +entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"] +example(pipe, [21], 4, global_prompt, entity_prompts) + +# example 5 +global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere." +entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"] +example(pipe, [0], 5, global_prompt, entity_prompts) + +# example 6 +global_prompt = "Snow White and the 6 Dwarfs." +entity_prompts = ["Dwarf 1", "Dwarf 2", "Dwarf 3", "Snow White", "Dwarf 4", "Dwarf 5", "Dwarf 6"] +example(pipe, [8], 6, global_prompt, entity_prompts) + +# example 7, same prompt with different seeds +seeds = range(5, 9) +global_prompt = "A beautiful asia woman wearing white dress, holding a mirror, with a forest background;" +entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"] +example(pipe, seeds, 7, global_prompt, entity_prompts)