mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:39:43 +00:00
support flux-schnell
This commit is contained in:
@@ -281,11 +281,11 @@ class AdaLayerNormContinuous(torch.nn.Module):
|
||||
|
||||
|
||||
class FluxDiT(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, disable_guidance_embedder=False):
|
||||
super().__init__()
|
||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.guidance_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
||||
@@ -362,9 +362,9 @@ class FluxDiT(torch.nn.Module):
|
||||
if image_ids is None:
|
||||
image_ids = self.prepare_image_ids(hidden_states)
|
||||
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype)\
|
||||
+ self.guidance_embedder(guidance, hidden_states.dtype)\
|
||||
+ self.pooled_text_embedder(pooled_prompt_emb)
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
||||
if self.guidance_embedder is not None:
|
||||
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
|
||||
@@ -589,5 +589,7 @@ class FluxDiTStateDictConverter:
|
||||
state_dict_[rename] = param
|
||||
else:
|
||||
pass
|
||||
return state_dict_
|
||||
|
||||
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
|
||||
return state_dict_, {"disable_guidance_embedder": True}
|
||||
else:
|
||||
return state_dict_
|
||||
|
||||
Reference in New Issue
Block a user