support flux training

This commit is contained in:
Artiprocher
2024-09-06 10:37:28 +08:00
parent 4654aa0cab
commit 416b73b8c0
5 changed files with 218 additions and 46 deletions

View File

@@ -218,9 +218,7 @@ class FluxSingleTransformerBlock(torch.nn.Module):
self.dim = dim
self.norm = AdaLayerNormSingle(dim)
# self.proj_in = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), torch.nn.GELU(approximate="tanh"))
# self.attn = FluxSingleAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
self.linear = torch.nn.Linear(dim, dim * (3 + 4))
self.to_qkv_mlp = torch.nn.Linear(dim, dim * (3 + 4))
self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6)
self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6)
@@ -253,7 +251,7 @@ class FluxSingleTransformerBlock(torch.nn.Module):
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
residual = hidden_states_a
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
hidden_states_a = self.linear(norm_hidden_states)
hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
attn_output = self.process_attention(attn_output, image_rotary_emb)
@@ -295,8 +293,8 @@ class FluxDiT(torch.nn.Module):
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
self.norm_out = AdaLayerNormContinuous(3072)
self.proj_out = torch.nn.Linear(3072, 64)
self.final_norm_out = AdaLayerNormContinuous(3072)
self.final_proj_out = torch.nn.Linear(3072, 64)
def patchify(self, hidden_states):
@@ -350,6 +348,7 @@ class FluxDiT(torch.nn.Module):
hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
tiled=False, tile_size=128, tile_stride=64,
use_gradient_checkpointing=False,
**kwargs
):
if tiled:
@@ -373,16 +372,35 @@ class FluxDiT(torch.nn.Module):
hidden_states = self.patchify(hidden_states)
hidden_states = self.x_embedder(hidden_states)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.blocks:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block in self.single_blocks:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
hidden_states = self.norm_out(hidden_states, conditioning)
hidden_states = self.proj_out(hidden_states)
hidden_states = self.final_norm_out(hidden_states, conditioning)
hidden_states = self.final_proj_out(hidden_states)
hidden_states = self.unpatchify(hidden_states, height, width)
return hidden_states
@@ -399,7 +417,7 @@ class FluxDiTStateDictConverter:
pass
def from_diffusers(self, state_dict):
rename_dict = {
global_rename_dict = {
"context_embedder": "context_embedder",
"x_embedder": "x_embedder",
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
@@ -408,9 +426,11 @@ class FluxDiTStateDictConverter:
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
"norm_out.linear": "norm_out.linear",
"norm_out.linear": "final_norm_out.linear",
"proj_out": "final_proj_out",
}
rename_dict = {
"proj_out": "proj_out",
"norm1.linear": "norm1_a.linear",
"norm1_context.linear": "norm1_b.linear",
"attn.to_q": "attn.a_to_q",
@@ -442,13 +462,11 @@ class FluxDiTStateDictConverter:
}
state_dict_ = {}
for name, param in state_dict.items():
if name in rename_dict:
state_dict_[rename_dict[name]] = param
elif name.endswith(".weight") or name.endswith(".bias"):
if name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias"
prefix = name[:-len(suffix)]
if prefix in rename_dict:
state_dict_[rename_dict[prefix] + suffix] = param
if prefix in global_rename_dict:
state_dict_[global_rename_dict[prefix] + suffix] = param
elif prefix.startswith("transformer_blocks."):
names = prefix.split(".")
names[0] = "blocks"
@@ -469,7 +487,7 @@ class FluxDiTStateDictConverter:
pass
for name in list(state_dict_.keys()):
if ".proj_in_besides_attn." in name:
name_ = name.replace(".proj_in_besides_attn.", ".linear.")
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
param = torch.concat([
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
@@ -508,16 +526,16 @@ class FluxDiTStateDictConverter:
"vector_in.in_layer.weight": "pooled_text_embedder.0.weight",
"vector_in.out_layer.bias": "pooled_text_embedder.2.bias",
"vector_in.out_layer.weight": "pooled_text_embedder.2.weight",
"final_layer.linear.bias": "proj_out.bias",
"final_layer.linear.weight": "proj_out.weight",
"final_layer.linear.bias": "final_proj_out.bias",
"final_layer.linear.weight": "final_proj_out.weight",
"guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias",
"guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight",
"guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias",
"guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight",
"img_in.bias": "x_embedder.bias",
"img_in.weight": "x_embedder.weight",
"final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
"final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
"final_layer.adaLN_modulation.1.weight": "final_norm_out.linear.weight",
"final_layer.adaLN_modulation.1.bias": "final_norm_out.linear.bias",
}
suffix_rename_dict = {
"img_attn.norm.key_norm.scale": "attn.norm_k_a.weight",
@@ -545,8 +563,8 @@ class FluxDiTStateDictConverter:
"txt_mod.lin.bias": "norm1_b.linear.bias",
"txt_mod.lin.weight": "norm1_b.linear.weight",
"linear1.bias": "linear.bias",
"linear1.weight": "linear.weight",
"linear1.bias": "to_qkv_mlp.bias",
"linear1.weight": "to_qkv_mlp.weight",
"linear2.bias": "proj_out.bias",
"linear2.weight": "proj_out.weight",
"modulation.lin.bias": "norm.linear.bias",