mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-25 02:38:10 +00:00
support flux training
This commit is contained in:
@@ -218,9 +218,7 @@ class FluxSingleTransformerBlock(torch.nn.Module):
|
||||
self.dim = dim
|
||||
|
||||
self.norm = AdaLayerNormSingle(dim)
|
||||
# self.proj_in = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), torch.nn.GELU(approximate="tanh"))
|
||||
# self.attn = FluxSingleAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
|
||||
self.linear = torch.nn.Linear(dim, dim * (3 + 4))
|
||||
self.to_qkv_mlp = torch.nn.Linear(dim, dim * (3 + 4))
|
||||
self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6)
|
||||
self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6)
|
||||
|
||||
@@ -253,7 +251,7 @@ class FluxSingleTransformerBlock(torch.nn.Module):
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
|
||||
residual = hidden_states_a
|
||||
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
|
||||
hidden_states_a = self.linear(norm_hidden_states)
|
||||
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
|
||||
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
|
||||
|
||||
attn_output = self.process_attention(attn_output, image_rotary_emb)
|
||||
@@ -295,8 +293,8 @@ class FluxDiT(torch.nn.Module):
|
||||
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
|
||||
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
|
||||
|
||||
self.norm_out = AdaLayerNormContinuous(3072)
|
||||
self.proj_out = torch.nn.Linear(3072, 64)
|
||||
self.final_norm_out = AdaLayerNormContinuous(3072)
|
||||
self.final_proj_out = torch.nn.Linear(3072, 64)
|
||||
|
||||
|
||||
def patchify(self, hidden_states):
|
||||
@@ -350,6 +348,7 @@ class FluxDiT(torch.nn.Module):
|
||||
hidden_states,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
|
||||
tiled=False, tile_size=128, tile_stride=64,
|
||||
use_gradient_checkpointing=False,
|
||||
**kwargs
|
||||
):
|
||||
if tiled:
|
||||
@@ -373,16 +372,35 @@ class FluxDiT(torch.nn.Module):
|
||||
hidden_states = self.patchify(hidden_states)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.blocks:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
for block in self.single_blocks:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
|
||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, conditioning)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = self.final_norm_out(hidden_states, conditioning)
|
||||
hidden_states = self.final_proj_out(hidden_states)
|
||||
hidden_states = self.unpatchify(hidden_states, height, width)
|
||||
|
||||
return hidden_states
|
||||
@@ -399,7 +417,7 @@ class FluxDiTStateDictConverter:
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
global_rename_dict = {
|
||||
"context_embedder": "context_embedder",
|
||||
"x_embedder": "x_embedder",
|
||||
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
||||
@@ -408,9 +426,11 @@ class FluxDiTStateDictConverter:
|
||||
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
||||
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
||||
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
||||
"norm_out.linear": "norm_out.linear",
|
||||
"norm_out.linear": "final_norm_out.linear",
|
||||
"proj_out": "final_proj_out",
|
||||
}
|
||||
rename_dict = {
|
||||
"proj_out": "proj_out",
|
||||
|
||||
"norm1.linear": "norm1_a.linear",
|
||||
"norm1_context.linear": "norm1_b.linear",
|
||||
"attn.to_q": "attn.a_to_q",
|
||||
@@ -442,13 +462,11 @@ class FluxDiTStateDictConverter:
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif name.endswith(".weight") or name.endswith(".bias"):
|
||||
if name.endswith(".weight") or name.endswith(".bias"):
|
||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||
prefix = name[:-len(suffix)]
|
||||
if prefix in rename_dict:
|
||||
state_dict_[rename_dict[prefix] + suffix] = param
|
||||
if prefix in global_rename_dict:
|
||||
state_dict_[global_rename_dict[prefix] + suffix] = param
|
||||
elif prefix.startswith("transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "blocks"
|
||||
@@ -469,7 +487,7 @@ class FluxDiTStateDictConverter:
|
||||
pass
|
||||
for name in list(state_dict_.keys()):
|
||||
if ".proj_in_besides_attn." in name:
|
||||
name_ = name.replace(".proj_in_besides_attn.", ".linear.")
|
||||
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
|
||||
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
|
||||
@@ -508,16 +526,16 @@ class FluxDiTStateDictConverter:
|
||||
"vector_in.in_layer.weight": "pooled_text_embedder.0.weight",
|
||||
"vector_in.out_layer.bias": "pooled_text_embedder.2.bias",
|
||||
"vector_in.out_layer.weight": "pooled_text_embedder.2.weight",
|
||||
"final_layer.linear.bias": "proj_out.bias",
|
||||
"final_layer.linear.weight": "proj_out.weight",
|
||||
"final_layer.linear.bias": "final_proj_out.bias",
|
||||
"final_layer.linear.weight": "final_proj_out.weight",
|
||||
"guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias",
|
||||
"guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight",
|
||||
"guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias",
|
||||
"guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight",
|
||||
"img_in.bias": "x_embedder.bias",
|
||||
"img_in.weight": "x_embedder.weight",
|
||||
"final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
|
||||
"final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
|
||||
"final_layer.adaLN_modulation.1.weight": "final_norm_out.linear.weight",
|
||||
"final_layer.adaLN_modulation.1.bias": "final_norm_out.linear.bias",
|
||||
}
|
||||
suffix_rename_dict = {
|
||||
"img_attn.norm.key_norm.scale": "attn.norm_k_a.weight",
|
||||
@@ -545,8 +563,8 @@ class FluxDiTStateDictConverter:
|
||||
"txt_mod.lin.bias": "norm1_b.linear.bias",
|
||||
"txt_mod.lin.weight": "norm1_b.linear.weight",
|
||||
|
||||
"linear1.bias": "linear.bias",
|
||||
"linear1.weight": "linear.weight",
|
||||
"linear1.bias": "to_qkv_mlp.bias",
|
||||
"linear1.weight": "to_qkv_mlp.weight",
|
||||
"linear2.bias": "proj_out.bias",
|
||||
"linear2.weight": "proj_out.weight",
|
||||
"modulation.lin.bias": "norm.linear.bias",
|
||||
|
||||
Reference in New Issue
Block a user