mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
flex t2i
This commit is contained in:
@@ -276,20 +276,22 @@ class AdaLayerNormContinuous(torch.nn.Module):
|
||||
|
||||
|
||||
class FluxDiT(torch.nn.Module):
|
||||
def __init__(self, disable_guidance_embedder=False):
|
||||
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
|
||||
super().__init__()
|
||||
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
|
||||
self.time_embedder = TimestepEmbeddings(256, 3072)
|
||||
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
|
||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||
self.x_embedder = torch.nn.Linear(64, 3072)
|
||||
self.x_embedder = torch.nn.Linear(input_dim, 3072)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
|
||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
|
||||
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
||||
|
||||
self.final_norm_out = AdaLayerNormContinuous(3072)
|
||||
self.final_proj_out = torch.nn.Linear(3072, 64)
|
||||
|
||||
self.input_dim = input_dim
|
||||
|
||||
|
||||
def patchify(self, hidden_states):
|
||||
@@ -738,5 +740,7 @@ class FluxDiTStateDictConverter:
|
||||
pass
|
||||
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
|
||||
return state_dict_, {"disable_guidance_embedder": True}
|
||||
elif "double_blocks.8.img_attn.norm.key_norm.scale" not in state_dict_:
|
||||
return state_dict_, {"input_dim": 196, "num_blocks": 8}
|
||||
else:
|
||||
return state_dict_
|
||||
|
||||
Reference in New Issue
Block a user