diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index 2184338..ad9c24b 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -76,6 +76,7 @@ model_loader_configs = [ (None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"), (None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"), (None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"), + (None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"), (None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"), (None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"), (None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"), diff --git a/diffsynth/models/flux_dit.py b/diffsynth/models/flux_dit.py index 080806c..3355f2e 100644 --- a/diffsynth/models/flux_dit.py +++ b/diffsynth/models/flux_dit.py @@ -281,11 +281,11 @@ class AdaLayerNormContinuous(torch.nn.Module): class FluxDiT(torch.nn.Module): - def __init__(self): + def __init__(self, disable_guidance_embedder=False): super().__init__() self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56]) self.time_embedder = TimestepEmbeddings(256, 3072) - self.guidance_embedder = TimestepEmbeddings(256, 3072) + self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072) self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072)) self.context_embedder = torch.nn.Linear(4096, 3072) self.x_embedder = torch.nn.Linear(64, 3072) @@ -362,9 +362,9 @@ class FluxDiT(torch.nn.Module): if image_ids is None: image_ids = self.prepare_image_ids(hidden_states) - conditioning = self.time_embedder(timestep, hidden_states.dtype)\ - + self.guidance_embedder(guidance, hidden_states.dtype)\ - + self.pooled_text_embedder(pooled_prompt_emb) + conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb) + if self.guidance_embedder is not None: + conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype) prompt_emb = self.context_embedder(prompt_emb) image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) @@ -589,5 +589,7 @@ class FluxDiTStateDictConverter: state_dict_[rename] = param else: pass - return state_dict_ - \ No newline at end of file + if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_: + return state_dict_, {"disable_guidance_embedder": True} + else: + return state_dict_