|
|
|
|
@@ -62,10 +62,11 @@ class TimestepEmbeddings(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AdaLayerNorm(torch.nn.Module):
|
|
|
|
|
def __init__(self, dim, single=False):
|
|
|
|
|
def __init__(self, dim, single=False, dual=False):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.single = single
|
|
|
|
|
self.linear = torch.nn.Linear(dim, dim * (2 if single else 6))
|
|
|
|
|
self.dual = dual
|
|
|
|
|
self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])
|
|
|
|
|
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
|
|
|
|
|
|
|
|
def forward(self, x, emb):
|
|
|
|
|
@@ -74,6 +75,12 @@ class AdaLayerNorm(torch.nn.Module):
|
|
|
|
|
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
|
|
|
|
|
x = self.norm(x) * (1 + scale) + shift
|
|
|
|
|
return x
|
|
|
|
|
elif self.dual:
|
|
|
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)
|
|
|
|
|
norm_x = self.norm(x)
|
|
|
|
|
x = norm_x * (1 + scale_msa) + shift_msa
|
|
|
|
|
norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2
|
|
|
|
|
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2
|
|
|
|
|
else:
|
|
|
|
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
|
|
|
|
|
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
|
|
|
|
@@ -138,16 +145,58 @@ class JointAttention(torch.nn.Module):
|
|
|
|
|
else:
|
|
|
|
|
hidden_states_b = self.b_to_out(hidden_states_b)
|
|
|
|
|
return hidden_states_a, hidden_states_b
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SingleAttention(torch.nn.Module):
|
|
|
|
|
def __init__(self, dim_a, num_heads, head_dim, use_rms_norm=False):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.num_heads = num_heads
|
|
|
|
|
self.head_dim = head_dim
|
|
|
|
|
|
|
|
|
|
class JointTransformerBlock(torch.nn.Module):
|
|
|
|
|
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
|
|
|
|
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
|
|
|
|
|
|
|
|
|
if use_rms_norm:
|
|
|
|
|
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
|
|
|
|
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
|
|
|
|
else:
|
|
|
|
|
self.norm_q_a = None
|
|
|
|
|
self.norm_k_a = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_qkv(self, hidden_states, to_qkv, norm_q, norm_k):
|
|
|
|
|
batch_size = hidden_states.shape[0]
|
|
|
|
|
qkv = to_qkv(hidden_states)
|
|
|
|
|
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
q, k, v = qkv.chunk(3, dim=1)
|
|
|
|
|
if norm_q is not None:
|
|
|
|
|
q = norm_q(q)
|
|
|
|
|
if norm_k is not None:
|
|
|
|
|
k = norm_k(k)
|
|
|
|
|
return q, k, v
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, hidden_states_a):
|
|
|
|
|
batch_size = hidden_states_a.shape[0]
|
|
|
|
|
q, k, v = self.process_qkv(hidden_states_a, self.a_to_qkv, self.norm_q_a, self.norm_k_a)
|
|
|
|
|
|
|
|
|
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
|
|
|
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
|
|
|
hidden_states = hidden_states.to(q.dtype)
|
|
|
|
|
hidden_states = self.a_to_out(hidden_states)
|
|
|
|
|
return hidden_states
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DualTransformerBlock(torch.nn.Module):
|
|
|
|
|
def __init__(self, dim, num_attention_heads, use_rms_norm=False):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.norm1_a = AdaLayerNorm(dim)
|
|
|
|
|
self.norm1_a = AdaLayerNorm(dim, dual=True)
|
|
|
|
|
self.norm1_b = AdaLayerNorm(dim)
|
|
|
|
|
|
|
|
|
|
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
|
|
|
|
self.attn2 = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
|
|
|
|
|
|
|
|
|
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
|
|
|
self.ff_a = torch.nn.Sequential(
|
|
|
|
|
@@ -165,7 +214,7 @@ class JointTransformerBlock(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, hidden_states_a, hidden_states_b, temb):
|
|
|
|
|
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
|
|
|
|
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a, norm_hidden_states_a_2, gate_msa_a_2 = self.norm1_a(hidden_states_a, emb=temb)
|
|
|
|
|
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
|
|
|
|
|
|
|
|
|
# Attention
|
|
|
|
|
@@ -173,6 +222,58 @@ class JointTransformerBlock(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
# Part A
|
|
|
|
|
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
|
|
|
|
hidden_states_a = hidden_states_a + gate_msa_a_2 * self.attn2(norm_hidden_states_a_2)
|
|
|
|
|
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
|
|
|
|
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
|
|
|
|
|
|
|
|
|
# Part B
|
|
|
|
|
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
|
|
|
|
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
|
|
|
|
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
|
|
|
|
|
|
|
|
|
return hidden_states_a, hidden_states_b
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class JointTransformerBlock(torch.nn.Module):
|
|
|
|
|
def __init__(self, dim, num_attention_heads, use_rms_norm=False, dual=False):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.norm1_a = AdaLayerNorm(dim, dual=dual)
|
|
|
|
|
self.norm1_b = AdaLayerNorm(dim)
|
|
|
|
|
|
|
|
|
|
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
|
|
|
|
if dual:
|
|
|
|
|
self.attn2 = SingleAttention(dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
|
|
|
|
|
|
|
|
|
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
|
|
|
self.ff_a = torch.nn.Sequential(
|
|
|
|
|
torch.nn.Linear(dim, dim*4),
|
|
|
|
|
torch.nn.GELU(approximate="tanh"),
|
|
|
|
|
torch.nn.Linear(dim*4, dim)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
|
|
|
|
self.ff_b = torch.nn.Sequential(
|
|
|
|
|
torch.nn.Linear(dim, dim*4),
|
|
|
|
|
torch.nn.GELU(approximate="tanh"),
|
|
|
|
|
torch.nn.Linear(dim*4, dim)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, hidden_states_a, hidden_states_b, temb):
|
|
|
|
|
if self.norm1_a.dual:
|
|
|
|
|
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a, norm_hidden_states_a_2, gate_msa_a_2 = self.norm1_a(hidden_states_a, emb=temb)
|
|
|
|
|
else:
|
|
|
|
|
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
|
|
|
|
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
|
|
|
|
|
|
|
|
|
# Attention
|
|
|
|
|
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
|
|
|
|
|
|
|
|
|
# Part A
|
|
|
|
|
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
|
|
|
|
if self.norm1_a.dual:
|
|
|
|
|
hidden_states_a = hidden_states_a + gate_msa_a_2 * self.attn2(norm_hidden_states_a_2)
|
|
|
|
|
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
|
|
|
|
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
|
|
|
|
|
|
|
|
|
@@ -218,13 +319,14 @@ class JointTransformerFinalBlock(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SD3DiT(torch.nn.Module):
|
|
|
|
|
def __init__(self, embed_dim=1536, num_layers=24, use_rms_norm=False):
|
|
|
|
|
def __init__(self, embed_dim=1536, num_layers=24, use_rms_norm=False, num_dual_blocks=0, pos_embed_max_size=192):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=embed_dim, pos_embed_max_size=192)
|
|
|
|
|
self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=embed_dim, pos_embed_max_size=pos_embed_max_size)
|
|
|
|
|
self.time_embedder = TimestepEmbeddings(256, embed_dim)
|
|
|
|
|
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, embed_dim), torch.nn.SiLU(), torch.nn.Linear(embed_dim, embed_dim))
|
|
|
|
|
self.context_embedder = torch.nn.Linear(4096, embed_dim)
|
|
|
|
|
self.blocks = torch.nn.ModuleList([JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm) for _ in range(num_layers-1)]
|
|
|
|
|
self.blocks = torch.nn.ModuleList([JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm, dual=True) for _ in range(num_dual_blocks)]
|
|
|
|
|
+ [JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm) for _ in range(num_layers-1-num_dual_blocks)]
|
|
|
|
|
+ [JointTransformerFinalBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm)])
|
|
|
|
|
self.norm_out = AdaLayerNorm(embed_dim, single=True)
|
|
|
|
|
self.proj_out = torch.nn.Linear(embed_dim, 64)
|
|
|
|
|
@@ -286,7 +388,17 @@ class SD3DiTStateDictConverter:
|
|
|
|
|
while num_layers > 0 and f"blocks.{num_layers-1}.ff_a.0.bias" not in state_dict:
|
|
|
|
|
num_layers -= 1
|
|
|
|
|
use_rms_norm = "blocks.0.attn.norm_q_a.weight" in state_dict
|
|
|
|
|
return {"embed_dim": embed_dim, "num_layers": num_layers, "use_rms_norm": use_rms_norm}
|
|
|
|
|
num_dual_blocks = 0
|
|
|
|
|
while f"blocks.{num_dual_blocks}.attn2.a_to_out.bias" in state_dict:
|
|
|
|
|
num_dual_blocks += 1
|
|
|
|
|
pos_embed_max_size = state_dict["pos_embedder.pos_embed"].shape[1]
|
|
|
|
|
return {
|
|
|
|
|
"embed_dim": embed_dim,
|
|
|
|
|
"num_layers": num_layers,
|
|
|
|
|
"use_rms_norm": use_rms_norm,
|
|
|
|
|
"num_dual_blocks": num_dual_blocks,
|
|
|
|
|
"pos_embed_max_size": pos_embed_max_size
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def from_diffusers(self, state_dict):
|
|
|
|
|
rename_dict = {
|
|
|
|
|
@@ -402,13 +514,21 @@ class SD3DiTStateDictConverter:
|
|
|
|
|
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_a.weight",
|
|
|
|
|
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_q.weight": f"blocks.{i}.attn.norm_q_b.weight",
|
|
|
|
|
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_b.weight",
|
|
|
|
|
|
|
|
|
|
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.ln_q.weight": f"blocks.{i}.attn2.norm_q_a.weight",
|
|
|
|
|
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.ln_k.weight": f"blocks.{i}.attn2.norm_k_a.weight",
|
|
|
|
|
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.qkv.weight": f"blocks.{i}.attn2.a_to_qkv.weight",
|
|
|
|
|
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.qkv.bias": f"blocks.{i}.attn2.a_to_qkv.bias",
|
|
|
|
|
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.proj.weight": f"blocks.{i}.attn2.a_to_out.weight",
|
|
|
|
|
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.proj.bias": f"blocks.{i}.attn2.a_to_out.bias",
|
|
|
|
|
})
|
|
|
|
|
state_dict_ = {}
|
|
|
|
|
for name in state_dict:
|
|
|
|
|
if name in rename_dict:
|
|
|
|
|
param = state_dict[name]
|
|
|
|
|
if name == "model.diffusion_model.pos_embed":
|
|
|
|
|
param = param.reshape((1, 192, 192, param.shape[-1]))
|
|
|
|
|
pos_embed_max_size = int(param.shape[1] ** 0.5 + 0.4)
|
|
|
|
|
param = param.reshape((1, pos_embed_max_size, pos_embed_max_size, param.shape[-1]))
|
|
|
|
|
if isinstance(rename_dict[name], str):
|
|
|
|
|
state_dict_[rename_dict[name]] = param
|
|
|
|
|
else:
|
|
|
|
|
|