support flux-schnell

This commit is contained in:
Artiprocher
2024-09-14 11:16:03 +08:00
parent c21ed1e478
commit cde1f81df6
2 changed files with 10 additions and 7 deletions

View File

@@ -76,6 +76,7 @@ model_loader_configs = [
(None, "1aafa3cc91716fb6b300cc1cd51b85a3", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "diffusers"),
(None, "21ea55f476dfc4fd135587abb59dfe5d", ["flux_vae_encoder", "flux_vae_decoder"], [FluxVAEEncoder, FluxVAEDecoder], "civitai"),
(None, "a29710fea6dddb0314663ee823598e50", ["flux_dit"], [FluxDiT], "civitai"),
(None, "57b02550baab820169365b3ee3afa2c9", ["flux_dit"], [FluxDiT], "civitai"),
(None, "280189ee084bca10f70907bf6ce1649d", ["cog_vae_encoder", "cog_vae_decoder"], [CogVAEEncoder, CogVAEDecoder], "diffusers"),
(None, "9b9313d104ac4df27991352fec013fd4", ["rife"], [IFNet], "civitai"),
(None, "6b7116078c4170bfbeaedc8fe71f6649", ["esrgan"], [RRDBNet], "civitai"),

View File

@@ -281,11 +281,11 @@ class AdaLayerNormContinuous(torch.nn.Module):
class FluxDiT(torch.nn.Module):
def __init__(self):
def __init__(self, disable_guidance_embedder=False):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.time_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
self.context_embedder = torch.nn.Linear(4096, 3072)
self.x_embedder = torch.nn.Linear(64, 3072)
@@ -362,9 +362,9 @@ class FluxDiT(torch.nn.Module):
if image_ids is None:
image_ids = self.prepare_image_ids(hidden_states)
conditioning = self.time_embedder(timestep, hidden_states.dtype)\
+ self.guidance_embedder(guidance, hidden_states.dtype)\
+ self.pooled_text_embedder(pooled_prompt_emb)
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
if self.guidance_embedder is not None:
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
prompt_emb = self.context_embedder(prompt_emb)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
@@ -589,5 +589,7 @@ class FluxDiTStateDictConverter:
state_dict_[rename] = param
else:
pass
return state_dict_
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
return state_dict_, {"disable_guidance_embedder": True}
else:
return state_dict_