From 1b3c204d209f7db1d0f181c992e41f0efa8f3302 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Fri, 20 Jun 2025 14:49:09 +0800 Subject: [PATCH 1/2] flux_ipadapter_refactor --- diffsynth/pipelines/flux_image_new.py | 28 +++++++++++++++++++++++++- examples/flux/flux_ipadapter.py | 29 +++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 examples/flux/flux_ipadapter.py diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py index bab84b8..3a1d382 100644 --- a/diffsynth/pipelines/flux_image_new.py +++ b/diffsynth/pipelines/flux_image_new.py @@ -43,6 +43,7 @@ class FluxImagePipeline(BasePipeline): FluxImageUnit_InputImageEmbedder(), FluxImageUnit_ImageIDs(), FluxImageUnit_EmbeddedGuidanceEmbedder(), + FluxImageUnit_IPAdapter(), ] self.model_fn = model_fn_flux_image @@ -98,7 +99,9 @@ class FluxImagePipeline(BasePipeline): pipe.vae_decoder = model_manager.fetch_model("flux_vae_decoder") pipe.vae_encoder = model_manager.fetch_model("flux_vae_encoder") pipe.prompter.fetch_models(pipe.text_encoder_1, pipe.text_encoder_2) - + pipe.ipadapter = model_manager.fetch_model("flux_ipadapter") + pipe.ipadapter_image_encoder = model_manager.fetch_model("siglip_vision_model") + return pipe @@ -294,6 +297,29 @@ class FluxImageUnit_EmbeddedGuidanceEmbedder(PipelineUnit): return {"guidance": guidance} +class FluxImageUnit_IPAdapter(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("ipadapter_image_encoder", "ipadapter") + ) + + def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega): + ipadapter_images, ipadapter_scale = inputs_shared.get("ipadapter_images", None), inputs_shared.get("ipadapter_scale", 1.0) + if ipadapter_images is None: + return inputs_shared, inputs_posi, inputs_nega + + pipe.load_models_to_device(self.onload_model_names) + images = [image.convert("RGB").resize((384, 384), resample=3) for image in ipadapter_images] + images = [pipe.preprocess_image(image).to(device=pipe.device, dtype=pipe.torch_dtype) for image in images] + ipadapter_images = torch.cat(images, dim=0) + ipadapter_image_encoding = pipe.ipadapter_image_encoder(ipadapter_images).pooler_output + + inputs_posi.update({"ipadapter_kwargs_list": pipe.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)}) + if inputs_shared.get("cfg_scale", 1.0) != 1.0: + inputs_nega.update({"ipadapter_kwargs_list": pipe.ipadapter(torch.zeros_like(ipadapter_image_encoding))}) + return inputs_shared, inputs_posi, inputs_nega + class TeaCache: def __init__(self, num_inference_steps, rel_l1_thresh): diff --git a/examples/flux/flux_ipadapter.py b/examples/flux/flux_ipadapter.py new file mode 100644 index 0000000..6214e4d --- /dev/null +++ b/examples/flux/flux_ipadapter.py @@ -0,0 +1,29 @@ +import torch +from PIL import Image +from diffsynth import save_video, VideoData, download_models +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig +from modelscope import dataset_snapshot_download + +#TODO: repalce the local path with model_id +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin"), + ModelConfig(path="models/IpAdapter/InstantX/FLUX.1-dev-IP-Adapter/image_encoder") + ], +) + +seed = 42 +origin_prompt = "a rabbit in a garden, colorful flowers" +image = pipe(prompt=origin_prompt, height=1280, width=960, seed=seed) +image.save("style image.jpg") + +torch.manual_seed(seed) +image = pipe(prompt="A piggy", height=1280, width=960, seed=seed, + ipadapter_images=[image], ipadapter_scale=0.7) +image.save("A piggy.jpg") From 6d5f8b74239de653ccda3bf9e48a05b809db2a1d Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Fri, 20 Jun 2025 16:53:41 +0800 Subject: [PATCH 2/2] flux_eligen_refactor --- diffsynth/lora/flux_lora.py | 13 +++ diffsynth/pipelines/flux_image_new.py | 57 +++++++++- examples/flux/flux_entity_control.py | 147 ++++++++++++++++++++++++++ 3 files changed, 215 insertions(+), 2 deletions(-) create mode 100644 diffsynth/lora/flux_lora.py create mode 100644 examples/flux/flux_entity_control.py diff --git a/diffsynth/lora/flux_lora.py b/diffsynth/lora/flux_lora.py new file mode 100644 index 0000000..899160f --- /dev/null +++ b/diffsynth/lora/flux_lora.py @@ -0,0 +1,13 @@ +import torch +from diffsynth.lora import GeneralLoRALoader +from diffsynth.models.lora import FluxLoRAFromCivitai + + +class FluxLoRALoader(GeneralLoRALoader): + def __init__(self, device="cpu", torch_dtype=torch.float32): + super().__init__(device=device, torch_dtype=torch_dtype) + self.loader = FluxLoRAFromCivitai() + + def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): + lora_prefix, model_resource = self.loader.match(model, state_dict_lora) + self.loader.load(model, state_dict_lora, lora_prefix, alpha=alpha, model_resource=model_resource) \ No newline at end of file diff --git a/diffsynth/pipelines/flux_image_new.py b/diffsynth/pipelines/flux_image_new.py index 3a1d382..b27a7d3 100644 --- a/diffsynth/pipelines/flux_image_new.py +++ b/diffsynth/pipelines/flux_image_new.py @@ -14,9 +14,10 @@ from typing_extensions import Literal from ..schedulers import FlowMatchScheduler from ..prompters import FluxPrompter -from ..models import ModelManager, SD3TextEncoder1, FluxTextEncoder2, FluxDiT, FluxVAEEncoder, FluxVAEDecoder +from ..models import ModelManager, load_state_dict, SD3TextEncoder1, FluxTextEncoder2, FluxDiT, FluxVAEEncoder, FluxVAEDecoder from ..models.tiler import FastTileWorker from .wan_video_new import BasePipeline, ModelConfig, PipelineUnitRunner, PipelineUnit +from ..lora.flux_lora import FluxLoRALoader @@ -44,12 +45,15 @@ class FluxImagePipeline(BasePipeline): FluxImageUnit_ImageIDs(), FluxImageUnit_EmbeddedGuidanceEmbedder(), FluxImageUnit_IPAdapter(), + FluxImageUnit_EntityControl(), ] self.model_fn = model_fn_flux_image def load_lora(self, module, path, alpha=1): - pass + loader = FluxLoRALoader(torch_dtype=self.torch_dtype, device=self.device) + lora = load_state_dict(path, torch_dtype=self.torch_dtype, device=self.device) + loader.load(module, lora, alpha=alpha) def training_loss(self, **inputs): @@ -321,6 +325,55 @@ class FluxImageUnit_IPAdapter(PipelineUnit): return inputs_shared, inputs_posi, inputs_nega +class FluxImageUnit_EntityControl(PipelineUnit): + def __init__(self): + super().__init__( + take_over=True, + onload_model_names=("text_encoder_1", "text_encoder_2") + ) + + def preprocess_masks(self, pipe, masks, height, width, dim): + out_masks = [] + for mask in masks: + mask = pipe.preprocess_image(mask.resize((width, height), resample=Image.NEAREST)).mean(dim=1, keepdim=True) > 0 + mask = mask.repeat(1, dim, 1, 1).to(device=pipe.device, dtype=pipe.torch_dtype) + out_masks.append(mask) + return out_masks + + def prepare_entity_inputs(self, pipe, entity_prompts, entity_masks, width, height, t5_sequence_length=512): + entity_masks = self.preprocess_masks(pipe, entity_masks, height//8, width//8, 1) + entity_masks = torch.cat(entity_masks, dim=0).unsqueeze(0) # b, n_mask, c, h, w + + prompt_emb, _, _ = pipe.prompter.encode_prompt( + entity_prompts, device=pipe.device, t5_sequence_length=t5_sequence_length + ) + return prompt_emb.unsqueeze(0), entity_masks + + def prepare_eligen(self, pipe, prompt_emb_nega, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length, enable_eligen_on_negative, cfg_scale): + entity_prompt_emb_posi, entity_masks_posi = self.prepare_entity_inputs(pipe, eligen_entity_prompts, eligen_entity_masks, width, height, t5_sequence_length) + if enable_eligen_on_negative and cfg_scale != 1.0: + entity_prompt_emb_nega = prompt_emb_nega['prompt_emb'].unsqueeze(1).repeat(1, entity_masks_posi.shape[1], 1, 1) + entity_masks_nega = entity_masks_posi + else: + entity_prompt_emb_nega, entity_masks_nega = None, None + eligen_kwargs_posi = {"entity_prompt_emb": entity_prompt_emb_posi, "entity_masks": entity_masks_posi} + eligen_kwargs_nega = {"entity_prompt_emb": entity_prompt_emb_nega, "entity_masks": entity_masks_nega} + return eligen_kwargs_posi, eligen_kwargs_nega + + def process(self, pipe: FluxImagePipeline, inputs_shared, inputs_posi, inputs_nega): + eligen_entity_prompts, eligen_entity_masks = inputs_shared.get("eligen_entity_prompts", None), inputs_shared.get("eligen_entity_masks", None) + if eligen_entity_prompts is None or eligen_entity_masks is None: + return inputs_shared, inputs_posi, inputs_nega + pipe.load_models_to_device(self.onload_model_names) + eligen_kwargs_posi, eligen_kwargs_nega = self.prepare_eligen(pipe, inputs_nega, + eligen_entity_prompts, eligen_entity_masks, inputs_shared["width"], inputs_shared["height"], + inputs_shared["t5_sequence_length"], inputs_shared["eligen_enable_on_negative"], inputs_shared["cfg_scale"]) + inputs_posi.update(eligen_kwargs_posi) + if inputs_shared.get("cfg_scale", 1.0) != 1.0: + inputs_nega.update(eligen_kwargs_nega) + return inputs_shared, inputs_posi, inputs_nega + + class TeaCache: def __init__(self, num_inference_steps, rel_l1_thresh): self.num_inference_steps = num_inference_steps diff --git a/examples/flux/flux_entity_control.py b/examples/flux/flux_entity_control.py new file mode 100644 index 0000000..380cd07 --- /dev/null +++ b/examples/flux/flux_entity_control.py @@ -0,0 +1,147 @@ +import random +import torch +from PIL import Image, ImageDraw, ImageFont +from diffsynth import download_customized_models +from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig +from modelscope import dataset_snapshot_download + + +def visualize_masks(image, masks, mask_prompts, output_path, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + + # Font settings + try: + font = ImageFont.truetype("arial", font_size) # Adjust as needed + except IOError: + font = ImageFont.load_default(font_size) + + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + + # Save or display the resulting image + result.save(output_path) + + return result + +def example(pipe, seeds, example_id, global_prompt, entity_prompts): + dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/example_{example_id}/*.png") + masks = [Image.open(f"./data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))] + negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," + for seed in seeds: + # generate image + image = pipe( + prompt=global_prompt, + cfg_scale=3.0, + negative_prompt=negative_prompt, + num_inference_steps=50, + embedded_guidance=3.5, + seed=seed, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + ) + image.save(f"eligen_example_{example_id}_{seed}.png") + visualize_masks(image, masks, entity_prompts, f"eligen_example_{example_id}_mask_{seed}.png") + + +pipe = FluxImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/"), + ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"), + ], +) + +download_from_modelscope = True +if download_from_modelscope: + model_id = "DiffSynth-Studio/Eligen" + downloading_priority = ["ModelScope"] +else: + model_id = "modelscope/EliGen" + downloading_priority = ["HuggingFace"] +EliGen_path = download_customized_models( + model_id=model_id, + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control", + downloading_priority=downloading_priority)[0] +pipe.load_lora(pipe.dit, EliGen_path, alpha=1) + +# example 1 +global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n" +entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"] +example(pipe, [0], 1, global_prompt, entity_prompts) + +# example 2 +global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render." +entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "blue belt"] +example(pipe, [0], 2, global_prompt, entity_prompts) + +# example 3 +global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning," +entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"] +example(pipe, [27], 3, global_prompt, entity_prompts) + +# example 4 +global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'" +entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"] +example(pipe, [21], 4, global_prompt, entity_prompts) + +# example 5 +global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere." +entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"] +example(pipe, [0], 5, global_prompt, entity_prompts) + +# example 6 +global_prompt = "Snow White and the 6 Dwarfs." +entity_prompts = ["Dwarf 1", "Dwarf 2", "Dwarf 3", "Snow White", "Dwarf 4", "Dwarf 5", "Dwarf 6"] +example(pipe, [8], 6, global_prompt, entity_prompts) + +# example 7, same prompt with different seeds +seeds = range(5, 9) +global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;" +entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"] +example(pipe, seeds, 7, global_prompt, entity_prompts)