diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index c0bb673..f70f3cb 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -98,6 +98,7 @@ model_loader_configs = [ (None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"), (None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"), (None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"), + (None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"), (None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"), (None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"), (None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"), diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 6d3100d..ea5ce21 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -276,20 +276,22 @@ class AdaLayerNormContinuous(torch.nn.Module): class FluxDiT(torch.nn.Module): - def __init__(self, disable_guidance_embedder=False): + def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19): super().__init__() self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56]) self.time_embedder = TimestepEmbeddings(256, 3072) self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072) self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072)) self.context_embedder = torch.nn.Linear(4096, 3072) - self.x_embedder = torch.nn.Linear(64, 3072) + self.x_embedder = torch.nn.Linear(input_dim, 3072) - self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)]) + self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)]) self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)]) self.final_norm_out = AdaLayerNormContinuous(3072) self.final_proj_out = torch.nn.Linear(3072, 64) + + self.input_dim = input_dim def patchify(self, hidden_states): @@ -738,5 +740,7 @@ class FluxDiTStateDictConverter: pass if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_: return state_dict_, {"disable_guidance_embedder": True} + elif "blocks.8.attn.norm_k_a.weight" not in state_dict_: + return state_dict_, {"input_dim": 196, "num_blocks": 8} else: return state_dict_ diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index c0729fc..c17e182 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -360,6 +360,34 @@ class FluxImagePipeline(BasePipeline): return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width) else: return {}, controlnet_image + + + def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None, flex_control_strength=0.5, flex_control_stop=0.5, tiled=False, tile_size=64, tile_stride=32): + if self.dit.input_dim == 196: + if flex_inpaint_image is None: + flex_inpaint_image = torch.zeros_like(latents) + else: + flex_inpaint_image = self.preprocess_image(flex_inpaint_image).to(device=self.device, dtype=self.torch_dtype) + flex_inpaint_image = self.encode_image(flex_inpaint_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + if flex_inpaint_mask is None: + flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :] + else: + flex_inpaint_mask = flex_inpaint_mask.resize((latents.shape[3], latents.shape[2])) + flex_inpaint_mask = self.preprocess_image(flex_inpaint_mask).to(device=self.device, dtype=self.torch_dtype) + flex_inpaint_mask = (flex_inpaint_mask[:, 0:1, :, :] + 1) / 2 + flex_inpaint_image = flex_inpaint_image * (1 - flex_inpaint_mask) + if flex_control_image is None: + flex_control_image = torch.zeros_like(latents) + else: + flex_control_image = self.preprocess_image(flex_control_image).to(device=self.device, dtype=self.torch_dtype) + flex_control_image = self.encode_image(flex_control_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) * flex_control_strength + flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1) + flex_uncondition = torch.concat([flex_inpaint_image, flex_inpaint_mask, torch.zeros_like(flex_control_image)], dim=1) + flex_control_stop_timestep = self.scheduler.timesteps[int(flex_control_stop * (len(self.scheduler.timesteps) - 1))] + flex_kwargs = {"flex_condition": flex_condition, "flex_uncondition": flex_uncondition, "flex_control_stop_timestep": flex_control_stop_timestep} + else: + flex_kwargs = {} + return flex_kwargs @torch.no_grad() @@ -398,6 +426,12 @@ class FluxImagePipeline(BasePipeline): # InfiniteYou infinityou_id_image=None, infinityou_guidance=1.0, + # Flex + flex_inpaint_image=None, + flex_inpaint_mask=None, + flex_control_image=None, + flex_control_strength=0.5, + flex_control_stop=0.5, # TeaCache tea_cache_l1_thresh=None, # Tile @@ -436,6 +470,9 @@ class FluxImagePipeline(BasePipeline): # ControlNets controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative) + + # Flex + flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image, **tiler_kwargs) # TeaCache tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None} @@ -449,7 +486,7 @@ class FluxImagePipeline(BasePipeline): inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux( dit=self.dit, controlnet=self.controlnet, hidden_states=latents, timestep=timestep, - **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs + **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **flex_kwargs, ) noise_pred_posi = self.control_noise_via_local_prompts( prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback, @@ -466,7 +503,7 @@ class FluxImagePipeline(BasePipeline): noise_pred_nega = lets_dance_flux( dit=self.dit, controlnet=self.controlnet, hidden_states=latents, timestep=timestep, - **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, + **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **flex_kwargs, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: @@ -602,6 +639,9 @@ def lets_dance_flux( ipadapter_kwargs_list={}, id_emb=None, infinityou_guidance=None, + flex_condition=None, + flex_uncondition=None, + flex_control_stop_timestep=None, tea_cache: TeaCache = None, **kwargs ): @@ -652,6 +692,13 @@ def lets_dance_flux( controlnet_res_stack, controlnet_single_res_stack = controlnet( controlnet_frames, **controlnet_extra_kwargs ) + + # Flex + if flex_condition is not None: + if timestep.tolist()[0] >= flex_control_stop_timestep: + hidden_states = torch.concat([hidden_states, flex_condition], dim=1) + else: + hidden_states = torch.concat([hidden_states, flex_uncondition], dim=1) if image_ids is None: image_ids = dit.prepare_image_ids(hidden_states) diff --git a/examples/image_synthesis/flex_text_to_image.py b/examples/image_synthesis/flex_text_to_image.py new file mode 100644 index 0000000..3770764 --- /dev/null +++ b/examples/image_synthesis/flex_text_to_image.py @@ -0,0 +1,49 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline, download_models +from diffsynth.controlnets.processors import Annotator +import numpy as np +from PIL import Image + + +download_models(["FLUX.1-dev"]) +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") +model_manager.load_models([ + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + "models/ostris/Flex.2-preview/Flex.2-preview.safetensors" +]) +pipe = FluxImagePipeline.from_model_manager(model_manager) + +image = pipe( + prompt="portrait of a beautiful Asian girl, long hair, red t-shirt, sunshine, beach", + num_inference_steps=50, embedded_guidance=3.5, + seed=0 +) +image.save("image_1.jpg") + +mask = np.zeros((1024, 1024, 3), dtype=np.uint8) +mask[200:400, 400:700] = 255 +mask = Image.fromarray(mask) +mask.save("image_mask.jpg") + +inpaint_image = image + +image = pipe( + prompt="portrait of a beautiful Asian girl with sunglasses, long hair, red t-shirt, sunshine, beach", + num_inference_steps=50, embedded_guidance=3.5, + flex_inpaint_image=inpaint_image, flex_inpaint_mask=mask, + seed=4 +) +image.save("image_2.jpg") + +control_image = Annotator("canny")(image) +control_image.save("image_control.jpg") + +image = pipe( + prompt="portrait of a beautiful Asian girl with sunglasses, long hair, yellow t-shirt, sunshine, beach", + num_inference_steps=50, embedded_guidance=3.5, + flex_control_image=control_image, + seed=4 +) +image.save("image_3.jpg")