This commit is contained in:
Artiprocher
2025-04-24 14:51:40 +08:00
parent e54c0a8468
commit b7a1ac6671
4 changed files with 64 additions and 5 deletions

View File

@@ -276,20 +276,22 @@ class AdaLayerNormContinuous(torch.nn.Module):
class FluxDiT(torch.nn.Module):
def __init__(self, disable_guidance_embedder=False):
def __init__(self, disable_guidance_embedder=False, input_dim=64, num_blocks=19):
super().__init__()
self.pos_embedder = RoPEEmbedding(3072, 10000, [16, 56, 56])
self.time_embedder = TimestepEmbeddings(256, 3072)
self.guidance_embedder = None if disable_guidance_embedder else TimestepEmbeddings(256, 3072)
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(768, 3072), torch.nn.SiLU(), torch.nn.Linear(3072, 3072))
self.context_embedder = torch.nn.Linear(4096, 3072)
self.x_embedder = torch.nn.Linear(64, 3072)
self.x_embedder = torch.nn.Linear(input_dim, 3072)
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(num_blocks)])
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
self.final_norm_out = AdaLayerNormContinuous(3072)
self.final_proj_out = torch.nn.Linear(3072, 64)
self.input_dim = input_dim
def patchify(self, hidden_states):
@@ -738,5 +740,7 @@ class FluxDiTStateDictConverter:
pass
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
return state_dict_, {"disable_guidance_embedder": True}
elif "double_blocks.8.img_attn.norm.key_norm.scale" not in state_dict_:
return state_dict_, {"input_dim": 196, "num_blocks": 8}
else:
return state_dict_