From e576d71908dac3d509cdc646fba3c7380ebb6a45 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 5 Mar 2025 11:20:10 +0800 Subject: [PATCH] support dreambooth lora --- diffsynth/models/flux_dit.py | 23 ++--- examples/EntityControl/README.md | 5 ++ .../EntityControl/styled_entity_control.py | 90 +++++++++++++++++++ 3 files changed, 108 insertions(+), 10 deletions(-) create mode 100644 examples/EntityControl/styled_entity_control.py diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 7a01478..6d3100d 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -628,19 +628,22 @@ class FluxDiTStateDictConverter: else: pass for name in list(state_dict_.keys()): - if ".proj_in_besides_attn." in name: - name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.") + if "single_blocks." in name and ".a_to_q." in name: + mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None) + if mlp is None: + mlp = torch.zeros(4 * state_dict_[name].shape[0], + *state_dict_[name].shape[1:], + dtype=state_dict_[name].dtype) + else: + state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn.")) param = torch.concat([ - state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")], - state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")], - state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")], - state_dict_[name], + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, ], dim=0) + name_ = name.replace(".a_to_q.", ".to_qkv_mlp.") state_dict_[name_] = param - state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q.")) - state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k.")) - state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v.")) - state_dict_.pop(name) for name in list(state_dict_.keys()): for component in ["a", "b"]: if f".{component}_to_q." in name: diff --git a/examples/EntityControl/README.md b/examples/EntityControl/README.md index 2e0530f..79e5d56 100644 --- a/examples/EntityControl/README.md +++ b/examples/EntityControl/README.md @@ -77,6 +77,11 @@ Demonstration of the styled entity control results with EliGen and IP-Adapter, s |-|-|-|-| |![image_1_base](https://github.com/user-attachments/assets/5e2dd3ab-37d3-4f58-8e02-ee2f9b238604)|![result1](https://github.com/user-attachments/assets/0f6711a2-572a-41b3-938a-95deff6d732d)|![result2](https://github.com/user-attachments/assets/ce2e66e5-1fdf-44e8-bca7-555d805a50b1)|![result3](https://github.com/user-attachments/assets/ad2da233-2f7c-4065-ab57-b2d84dc2c0e2)| +We also provide a demo of the styled entity control results with EliGen and specific styled lora, see [./styled_entity_control.py](./styled_entity_control.py) for details. Here is the visualization of EliGen with [Lego dreambooth lora](https://huggingface.co/merve/flux-lego-lora-dreambooth). +|![image_1_base](https://github.com/user-attachments/assets/35fb60f5-48ef-4f22-95d8-f9e732a5f63f)|![result1](https://github.com/user-attachments/assets/441d700f-f0b1-40e0-8848-4db23520972c)|![result2](https://github.com/user-attachments/assets/c8fd4498-3c55-48ab-9abf-3a092a90c878)|![result3](https://github.com/user-attachments/assets/181ba2bb-62cf-41a8-9e3a-20ed8a7a672f)| +|-|-|-|-| +|![image_1_base](https://github.com/user-attachments/assets/70a3f578-8c7e-4b40-954d-8fc94d4f3ae9)|![result1](https://github.com/user-attachments/assets/65670717-6136-4594-84e5-2307fc20753d)|![result2](https://github.com/user-attachments/assets/5ec7a5bd-f2c9-4b2e-8a4e-d2655ec8036c)|![result3](https://github.com/user-attachments/assets/56f00192-9553-45a6-a971-511b9f5b1480)| + ### Entity Transfer Demonstration of the entity transfer results with EliGen and In-Context LoRA, see [./entity_transfer.py](./entity_transfer.py) for generation prompts. diff --git a/examples/EntityControl/styled_entity_control.py b/examples/EntityControl/styled_entity_control.py new file mode 100644 index 0000000..e478d04 --- /dev/null +++ b/examples/EntityControl/styled_entity_control.py @@ -0,0 +1,90 @@ +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models +from modelscope import dataset_snapshot_download +from examples.EntityControl.utils import visualize_masks +from PIL import Image +import torch + +def example(pipe, seeds, example_id, global_prompt, entity_prompts): + dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/example_{example_id}/*.png") + masks = [Image.open(f"./data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))] + negative_prompt = "worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw," + for seed in seeds: + # generate image + image = pipe( + prompt=global_prompt, + cfg_scale=3.0, + negative_prompt=negative_prompt, + num_inference_steps=50, + embedded_guidance=3.5, + seed=seed, + height=1024, + width=1024, + eligen_entity_prompts=entity_prompts, + eligen_entity_masks=masks, + ) + image.save(f"styled_eligen_example_{example_id}_{seed}.png") + visualize_masks(image, masks, entity_prompts, f"styled_entity_control_example_{example_id}_mask_{seed}.png") + +# download and load model +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) +model_manager.load_lora( + download_customized_models( + model_id="FluxLora/merve-flux-lego-lora-dreambooth", + origin_file_path="pytorch_lora_weights.safetensors", + local_dir="models/lora/merve-flux-lego-lora-dreambooth" + ), + lora_alpha=1 +) +model_manager.load_lora( + download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control" + ), + lora_alpha=1 +) +pipe = FluxImagePipeline.from_model_manager(model_manager) + +# example 1 +trigger_word = "lego set in style of TOK, " +global_prompt = "A breathtaking beauty of Raja Ampat by the late-night moonlight , one beautiful woman from behind wearing a pale blue long dress with soft glow, sitting at the top of a cliff looking towards the beach,pastell light colors, a group of small distant birds flying in far sky, a boat sailing on the sea, best quality, realistic, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, maximalist style, photorealistic, concept art, sharp focus, harmony, serenity, tranquility, soft pastell colors,ambient occlusion, cozy ambient lighting, masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning, view from above\n" +global_prompt = trigger_word + global_prompt +entity_prompts = ["cliff", "sea", "moon", "sailing boat", "a seated beautiful woman", "pale blue long dress with soft glow"] +example(pipe, [0], 1, global_prompt, entity_prompts) + +# example 2 +global_prompt = "samurai girl wearing a kimono, she's holding a sword glowing with red flame, her long hair is flowing in the wind, she is looking at a small bird perched on the back of her hand. ultra realist style. maximum image detail. maximum realistic render." +global_prompt = trigger_word + global_prompt +entity_prompts = ["flowing hair", "sword glowing with red flame", "A cute bird", "blue belt"] +example(pipe, [0], 2, global_prompt, entity_prompts) + +# example 3 +global_prompt = "Image of a neverending staircase up to a mysterious palace in the sky, The ancient palace stood majestically atop a mist-shrouded mountain, sunrise, two traditional monk walk in the stair looking at the sunrise, fog,see-through, best quality, whimsical, fantastic, splash art, intricate detailed, hyperdetailed, photorealistic, concept art, harmony, serenity, tranquility, ambient occlusion, halation, cozy ambient lighting, dynamic lighting,masterpiece, liiv1, linquivera, metix, mentixis, masterpiece, award winning," +global_prompt = trigger_word + global_prompt +entity_prompts = ["ancient palace", "stone staircase with railings", "a traditional monk", "a traditional monk"] +example(pipe, [27], 3, global_prompt, entity_prompts) + +# example 4 +global_prompt = "A beautiful girl wearing shirt and shorts in the street, holding a sign 'Entity Control'" +global_prompt = trigger_word + global_prompt +entity_prompts = ["A beautiful girl", "sign 'Entity Control'", "shorts", "shirt"] +example(pipe, [21], 4, global_prompt, entity_prompts) + +# example 5 +global_prompt = "A captivating, dramatic scene in a painting that exudes mystery and foreboding. A white sky, swirling blue clouds, and a crescent yellow moon illuminate a solitary woman standing near the water's edge. Her long dress flows in the wind, silhouetted against the eerie glow. The water mirrors the fiery sky and moonlight, amplifying the uneasy atmosphere." +global_prompt = trigger_word + global_prompt +entity_prompts = ["crescent yellow moon", "a solitary woman", "water", "swirling blue clouds"] +example(pipe, [0], 5, global_prompt, entity_prompts) + +# example 6 +global_prompt = "Snow White and the 6 Dwarfs." +global_prompt = trigger_word + global_prompt +entity_prompts = ["Dwarf 1", "Dwarf 2", "Dwarf 3", "Snow White", "Dwarf 4", "Dwarf 5", "Dwarf 6"] +example(pipe, [8], 6, global_prompt, entity_prompts) + +# example 7, same prompt with different seeds +seeds = range(5, 9) +global_prompt = "A beautiful woman wearing white dress, holding a mirror, with a warm light background;" +global_prompt = trigger_word + global_prompt +entity_prompts = ["A beautiful woman", "mirror", "necklace", "glasses", "earring", "white dress", "jewelry headpiece"] +example(pipe, seeds, 7, global_prompt, entity_prompts)