From b7a1ac66716f3a562fe2e7bae4414e699770b1ec Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 24 Apr 2025 14:51:40 +0800 Subject: [PATCH 1/4] flex t2i --- diffsynth/configs/model_config.py | 1 + diffsynth/models/flux_dit.py | 10 ++++-- diffsynth/pipelines/flux_image.py | 36 +++++++++++++++++-- .../image_synthesis/flex_text_to_image.py | 22 ++++++++++++ 4 files changed, 64 insertions(+), 5 deletions(-) create mode 100644 examples/image_synthesis/flex_text_to_image.py diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index c0bb673..f70f3cb 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -98,6 +98,7 @@ model_loader_configs = [ (None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"), (None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"), (None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"), + (None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"), (None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"), (None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"), (None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"), diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 6d3100d..31dde9c 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -276,20 +276,22 @@ class AdaLayerNormContinuous(torch.nn.Module): class FluxDiT(torch.nn.Module): - def __init__(self, disable_guidance_embedder=False): + def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19): super().__init__() self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56]) self.time_embedder = TimestepEmbeddings(256, 3072) self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072) self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072)) self.context_embedder = torch.nn.Linear(4096, 3072) - self.x_embedder = torch.nn.Linear(64, 3072) + self.x_embedder = torch.nn.Linear(input_dim, 3072) - self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)]) + self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)]) self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)]) self.final_norm_out = AdaLayerNormContinuous(3072) self.final_proj_out = torch.nn.Linear(3072, 64) + + self.input_dim = input_dim def patchify(self, hidden_states): @@ -738,5 +740,7 @@ class FluxDiTStateDictConverter: pass if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_: return state_dict_, {"disable_guidance_embedder": True} + elif "double_blocks.8.img_attn.norm.key_norm.scale" not in state_dict_: + return state_dict_, {"input_dim": 196, "num_blocks": 8} else: return state_dict_ diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index c0729fc..647ff2d 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -360,6 +360,27 @@ class FluxImagePipeline(BasePipeline): return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width) else: return {}, controlnet_image + + + def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None): + if self.dit.input_dim == 196: + if flex_inpaint_image is None: + flex_inpaint_image = torch.zeros_like(latents) + else: + pass # TODO + if flex_inpaint_mask is None: + flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :] + else: + pass # TODO + if flex_control_image is None: + flex_control_image = torch.zeros_like(latents) + else: + pass # TODO + flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1) + flex_kwargs = {"flex_condition": flex_condition} + else: + flex_kwargs = {} + return flex_kwargs @torch.no_grad() @@ -398,6 +419,10 @@ class FluxImagePipeline(BasePipeline): # InfiniteYou infinityou_id_image=None, infinityou_guidance=1.0, + # Flex + flex_inpaint_image=None, + flex_inpaint_mask=None, + flex_control_image=None, # TeaCache tea_cache_l1_thresh=None, # Tile @@ -436,6 +461,9 @@ class FluxImagePipeline(BasePipeline): # ControlNets controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative) + + # Flex + flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image) # TeaCache tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None} @@ -449,7 +477,7 @@ class FluxImagePipeline(BasePipeline): inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux( dit=self.dit, controlnet=self.controlnet, hidden_states=latents, timestep=timestep, - **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs + **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **flex_kwargs, ) noise_pred_posi = self.control_noise_via_local_prompts( prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback, @@ -466,7 +494,7 @@ class FluxImagePipeline(BasePipeline): noise_pred_nega = lets_dance_flux( dit=self.dit, controlnet=self.controlnet, hidden_states=latents, timestep=timestep, - **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, + **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **flex_kwargs, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: @@ -602,6 +630,7 @@ def lets_dance_flux( ipadapter_kwargs_list={}, id_emb=None, infinityou_guidance=None, + flex_condition=None, tea_cache: TeaCache = None, **kwargs ): @@ -652,6 +681,9 @@ def lets_dance_flux( controlnet_res_stack, controlnet_single_res_stack = controlnet( controlnet_frames, **controlnet_extra_kwargs ) + + if flex_condition is not None: + hidden_states = torch.concat([hidden_states, flex_condition], dim=1) if image_ids is None: image_ids = dit.prepare_image_ids(hidden_states) diff --git a/examples/image_synthesis/flex_text_to_image.py b/examples/image_synthesis/flex_text_to_image.py new file mode 100644 index 0000000..288d4a6 --- /dev/null +++ b/examples/image_synthesis/flex_text_to_image.py @@ -0,0 +1,22 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline, download_models + + +download_models(["FLUX.1-dev"]) +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") +model_manager.load_models([ + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + "models/ostris/Flex.2-preview/Flex.2-preview.safetensors" +]) +pipe = FluxImagePipeline.from_model_manager(model_manager) + +prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." + +torch.manual_seed(9) +image = pipe( + prompt=prompt, + num_inference_steps=50, embedded_guidance=3.5 +) +image.save("image_1024.jpg") From ccf24c363fcac755ffb24ede784af2f437b9c9cf Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 24 Apr 2025 19:18:54 +0800 Subject: [PATCH 2/4] flex control --- diffsynth/pipelines/flux_image.py | 14 +++++--- .../image_synthesis/flex_text_to_image.py | 34 +++++++++++++++---- 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 647ff2d..5fa807d 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -362,20 +362,24 @@ class FluxImagePipeline(BasePipeline): return {}, controlnet_image - def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None): + def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, tiled=False, tile_size=64, tile_stride=32): if self.dit.input_dim == 196: if flex_inpaint_image is None: flex_inpaint_image = torch.zeros_like(latents) else: - pass # TODO + flex_inpaint_image = self.preprocess_image(flex_inpaint_image).to(device=self.device, dtype=self.torch_dtype) + flex_inpaint_image = self.encode_image(flex_inpaint_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) if flex_inpaint_mask is None: flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :] else: - pass # TODO + flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2])) + flex_inpaint_mask = self.preprocess_image(flex_inpaint_mask).to(device=self.device, dtype=self.torch_dtype) + flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2 if flex_control_image is None: flex_control_image = torch.zeros_like(latents) else: - pass # TODO + flex_control_image = self.preprocess_image(flex_control_image).to(device=self.device, dtype=self.torch_dtype) + flex_control_image = self.encode_image(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1) flex_kwargs = {"flex_condition": flex_condition} else: @@ -463,7 +467,7 @@ class FluxImagePipeline(BasePipeline): controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative) # Flex - flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image) + flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, **tiler_kwargs) # TeaCache tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None} diff --git a/examples/image_synthesis/flex_text_to_image.py b/examples/image_synthesis/flex_text_to_image.py index 288d4a6..5a56923 100644 --- a/examples/image_synthesis/flex_text_to_image.py +++ b/examples/image_synthesis/flex_text_to_image.py @@ -1,5 +1,8 @@ import torch from diffsynth import ModelManager, FluxImagePipeline, download_models +from diffsynth.controlnets.processors import Annotator +import numpy as np +from PIL import Image download_models(["FLUX.1-dev"]) @@ -12,11 +15,30 @@ model_manager.load_models([ ]) pipe = FluxImagePipeline.from_model_manager(model_manager) -prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." - -torch.manual_seed(9) image = pipe( - prompt=prompt, - num_inference_steps=50, embedded_guidance=3.5 + prompt="portrait of a beautiful Asian girl, long hair, red t-shirt, sunshine, beach", + num_inference_steps=50, embedded_guidance=3.5, + seed=0 ) -image.save("image_1024.jpg") +image.save("image_1.jpg") + +mask = np.zeros((1024, 1024, 3), dtype=np.uint8) +mask[0:300, 300:800] = 255 +mask = Image.fromarray(mask) +mask.save("image_mask.jpg") + +inpaint_image = np.array(image) +inpaint_image[0:300, 300:800] = 0 +inpaint_image = Image.fromarray(inpaint_image) +inpaint_image.save("image_inpaint.jpg") + +control_image = Annotator("canny")(image) +control_image.save("image_control.jpg") + +image = pipe( + prompt="portrait of a beautiful Asian girl, long hair, yellow t-shirt, sunshine, beach", + num_inference_steps=50, embedded_guidance=3.5, + flex_control_image=control_image, + seed=4 +) +image.save("image_2.jpg") From 419ace37f3429ea99d12d562ea060fe6f04c9156 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 25 Apr 2025 11:32:13 +0800 Subject: [PATCH 3/4] flex full support --- diffsynth/pipelines/flux_image.py | 19 +++++++++++++++---- .../image_synthesis/flex_text_to_image.py | 19 ++++++++++++------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 5fa807d..c17e182 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -362,7 +362,7 @@ class FluxImagePipeline(BasePipeline): return {}, controlnet_image - def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, tiled=False, tile_size=64, tile_stride=32): + def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, flex_control_strength=0.5, flex_control_stop=0.5, tiled=False, tile_size=64, tile_stride=32): if self.dit.input_dim == 196: if flex_inpaint_image is None: flex_inpaint_image = torch.zeros_like(latents) @@ -375,13 +375,16 @@ class FluxImagePipeline(BasePipeline): flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2])) flex_inpaint_mask = self.preprocess_image(flex_inpaint_mask).to(device=self.device, dtype=self.torch_dtype) flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2 + flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask) if flex_control_image is None: flex_control_image = torch.zeros_like(latents) else: flex_control_image = self.preprocess_image(flex_control_image).to(device=self.device, dtype=self.torch_dtype) - flex_control_image = self.encode_image(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + flex_control_image = self.encode_image(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1) - flex_kwargs = {"flex_condition": flex_condition} + flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1) + flex_control_stop_timestep = self.scheduler.timesteps[int(flex_control_stop * (len(self.scheduler.timesteps) - 1))] + flex_kwargs = {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep} else: flex_kwargs = {} return flex_kwargs @@ -427,6 +430,8 @@ class FluxImagePipeline(BasePipeline): flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, + flex_control_strength=0.5, + flex_control_stop=0.5, # TeaCache tea_cache_l1_thresh=None, # Tile @@ -635,6 +640,8 @@ def lets_dance_flux( id_emb=None, infinityou_guidance=None, flex_condition=None, + flex_uncondition=None, + flex_control_stop_timestep=None, tea_cache: TeaCache = None, **kwargs ): @@ -686,8 +693,12 @@ def lets_dance_flux( controlnet_frames, **controlnet_extra_kwargs ) + # Flex if flex_condition is not None: - hidden_states = torch.concat([hidden_states, flex_condition], dim=1) + if timestep.tolist()[0] >= flex_control_stop_timestep: + hidden_states = torch.concat([hidden_states, flex_condition], dim=1) + else: + hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1) if image_ids is None: image_ids = dit.prepare_image_ids(hidden_states) diff --git a/examples/image_synthesis/flex_text_to_image.py b/examples/image_synthesis/flex_text_to_image.py index 5a56923..3770764 100644 --- a/examples/image_synthesis/flex_text_to_image.py +++ b/examples/image_synthesis/flex_text_to_image.py @@ -23,22 +23,27 @@ image = pipe( image.save("image_1.jpg") mask = np.zeros((1024, 1024, 3), dtype=np.uint8) -mask[0:300, 300:800] = 255 +mask[200:400, 400:700] = 255 mask = Image.fromarray(mask) mask.save("image_mask.jpg") -inpaint_image = np.array(image) -inpaint_image[0:300, 300:800] = 0 -inpaint_image = Image.fromarray(inpaint_image) -inpaint_image.save("image_inpaint.jpg") +inpaint_image = image + +image = pipe( + prompt="portrait of a beautiful Asian girl with sunglasses, long hair, red t-shirt, sunshine, beach", + num_inference_steps=50, embedded_guidance=3.5, + flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask, + seed=4 +) +image.save("image_2.jpg") control_image = Annotator("canny")(image) control_image.save("image_control.jpg") image = pipe( - prompt="portrait of a beautiful Asian girl, long hair, yellow t-shirt, sunshine, beach", + prompt="portrait of a beautiful Asian girl with sunglasses, long hair, yellow t-shirt, sunshine, beach", num_inference_steps=50, embedded_guidance=3.5, flex_control_image=control_image, seed=4 ) -image.save("image_2.jpg") +image.save("image_3.jpg") From cc6306136cd129ee04635edf40cbbe730d5bce86 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 25 Apr 2025 12:23:29 +0800 Subject: [PATCH 4/4] flex full support --- diffsynth/models/flux_dit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 31dde9c..ea5ce21 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -740,7 +740,7 @@ class FluxDiTStateDictConverter: pass if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_: return state_dict_, {"disable_guidance_embedder": True} - elif "double_blocks.8.img_attn.norm.key_norm.scale" not in state_dict_: + elif "blocks.8.attn.norm_k_a.weight" not in state_dict_: return state_dict_, {"input_dim": 196, "num_blocks": 8} else: return state_dict_