support flux-schnell

This commit is contained in:
Artiprocher
2024-09-14 11:16:03 +08:00
parent c21ed1e478
commit cde1f81df6
2 changed files with 10 additions and 7 deletions

View File

@@ -281,11 +281,11 @@ class AdaLayerNormContinuous(torch.nn.Module):
class FluxDiT(torch.nn.Module):
def __init__(self):
def __init__(self, disable_guidance_embedder=False):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.time_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
self.context_embedder = torch.nn.Linear(4096, 3072)
self.x_embedder = torch.nn.Linear(64, 3072)
@@ -362,9 +362,9 @@ class FluxDiT(torch.nn.Module):
if image_ids is None:
image_ids = self.prepare_image_ids(hidden_states)
conditioning = self.time_embedder(timestep, hidden_states.dtype)\
+ self.guidance_embedder(guidance, hidden_states.dtype)\
+ self.pooled_text_embedder(pooled_prompt_emb)
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
if self.guidance_embedder is not None:
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
prompt_emb = self.context_embedder(prompt_emb)
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
@@ -589,5 +589,7 @@ class FluxDiTStateDictConverter:
state_dict_[rename] = param
else:
pass
return state_dict_
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
return state_dict_, {"disable_guidance_embedder": True}
else:
return state_dict_