diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index c0bb673..f70f3cb 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -98,6 +98,7 @@ model_loader_configs = [ (None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"), (None, "3394f306c4cbf04334b712bf5aaed95f", ["flux_dit"], [FluxDiT], "civitai"), (None, "023f054d918a84ccf503481fd1e3379e", ["flux_dit"], [FluxDiT], "civitai"), + (None, "d02f41c13549fa5093d3521f62a5570a", ["flux_dit"], [FluxDiT], "civitai"), (None, "605c56eab23e9e2af863ad8f0813a25d", ["flux_dit"], [FluxDiT], "diffusers"), (None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"), (None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"), diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 6d3100d..31dde9c 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -276,20 +276,22 @@ class AdaLayerNormContinuous(torch.nn.Module): class FluxDiT(torch.nn.Module): - def __init__(self, disable_guidance_embedder=False): + def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19): super().__init__() self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56]) self.time_embedder = TimestepEmbeddings(256, 3072) self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072) self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072)) self.context_embedder = torch.nn.Linear(4096, 3072) - self.x_embedder = torch.nn.Linear(64, 3072) + self.x_embedder = torch.nn.Linear(input_dim, 3072) - self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)]) + self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)]) self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)]) self.final_norm_out = AdaLayerNormContinuous(3072) self.final_proj_out = torch.nn.Linear(3072, 64) + + self.input_dim = input_dim def patchify(self, hidden_states): @@ -738,5 +740,7 @@ class FluxDiTStateDictConverter: pass if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_: return state_dict_, {"disable_guidance_embedder": True} + elif "double_blocks.8.img_attn.norm.key_norm.scale" not in state_dict_: + return state_dict_, {"input_dim": 196, "num_blocks": 8} else: return state_dict_ diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index c0729fc..647ff2d 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -360,6 +360,27 @@ class FluxImagePipeline(BasePipeline): return self.infinityou_processor.prepare_infinite_you(self.image_proj_model, id_image, controlnet_image, infinityou_guidance, height, width) else: return {}, controlnet_image + + + def prepare_flex_kwargs(self, latents, flex_inpaint_image=None, flex_inpaint_mask=None, flex_control_image=None): + if self.dit.input_dim == 196: + if flex_inpaint_image is None: + flex_inpaint_image = torch.zeros_like(latents) + else: + pass # TODO + if flex_inpaint_mask is None: + flex_inpaint_mask = torch.ones_like(latents)[:, 0:1, :, :] + else: + pass # TODO + if flex_control_image is None: + flex_control_image = torch.zeros_like(latents) + else: + pass # TODO + flex_condition = torch.concat([flex_inpaint_image, flex_inpaint_mask, flex_control_image], dim=1) + flex_kwargs = {"flex_condition": flex_condition} + else: + flex_kwargs = {} + return flex_kwargs @torch.no_grad() @@ -398,6 +419,10 @@ class FluxImagePipeline(BasePipeline): # InfiniteYou infinityou_id_image=None, infinityou_guidance=1.0, + # Flex + flex_inpaint_image=None, + flex_inpaint_mask=None, + flex_control_image=None, # TeaCache tea_cache_l1_thresh=None, # Tile @@ -436,6 +461,9 @@ class FluxImagePipeline(BasePipeline): # ControlNets controlnet_kwargs_posi, controlnet_kwargs_nega, local_controlnet_kwargs = self.prepare_controlnet(controlnet_image, masks, controlnet_inpaint_mask, tiler_kwargs, enable_controlnet_on_negative) + + # Flex + flex_kwargs = self.prepare_flex_kwargs(latents, flex_inpaint_image, flex_inpaint_mask, flex_control_image) # TeaCache tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None} @@ -449,7 +477,7 @@ class FluxImagePipeline(BasePipeline): inference_callback = lambda prompt_emb_posi, controlnet_kwargs: lets_dance_flux( dit=self.dit, controlnet=self.controlnet, hidden_states=latents, timestep=timestep, - **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs + **prompt_emb_posi, **tiler_kwargs, **extra_input, **controlnet_kwargs, **ipadapter_kwargs_list_posi, **eligen_kwargs_posi, **tea_cache_kwargs, **infiniteyou_kwargs, **flex_kwargs, ) noise_pred_posi = self.control_noise_via_local_prompts( prompt_emb_posi, prompt_emb_locals, masks, mask_scales, inference_callback, @@ -466,7 +494,7 @@ class FluxImagePipeline(BasePipeline): noise_pred_nega = lets_dance_flux( dit=self.dit, controlnet=self.controlnet, hidden_states=latents, timestep=timestep, - **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, + **prompt_emb_nega, **tiler_kwargs, **extra_input, **controlnet_kwargs_nega, **ipadapter_kwargs_list_nega, **eligen_kwargs_nega, **infiniteyou_kwargs, **flex_kwargs, ) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: @@ -602,6 +630,7 @@ def lets_dance_flux( ipadapter_kwargs_list={}, id_emb=None, infinityou_guidance=None, + flex_condition=None, tea_cache: TeaCache = None, **kwargs ): @@ -652,6 +681,9 @@ def lets_dance_flux( controlnet_res_stack, controlnet_single_res_stack = controlnet( controlnet_frames, **controlnet_extra_kwargs ) + + if flex_condition is not None: + hidden_states = torch.concat([hidden_states, flex_condition], dim=1) if image_ids is None: image_ids = dit.prepare_image_ids(hidden_states) diff --git a/examples/image_synthesis/flex_text_to_image.py b/examples/image_synthesis/flex_text_to_image.py new file mode 100644 index 0000000..288d4a6 --- /dev/null +++ b/examples/image_synthesis/flex_text_to_image.py @@ -0,0 +1,22 @@ +import torch +from diffsynth import ModelManager, FluxImagePipeline, download_models + + +download_models(["FLUX.1-dev"]) +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda") +model_manager.load_models([ + "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", + "models/FLUX/FLUX.1-dev/text_encoder_2", + "models/FLUX/FLUX.1-dev/ae.safetensors", + "models/ostris/Flex.2-preview/Flex.2-preview.safetensors" +]) +pipe = FluxImagePipeline.from_model_manager(model_manager) + +prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." + +torch.manual_seed(9) +image = pipe( + prompt=prompt, + num_inference_steps=50, embedded_guidance=3.5 +) +image.save("image_1024.jpg")