update tensor parallel

This commit is contained in:
mi804
2025-03-25 12:38:17 +08:00
parent 3dc28f428f
commit 6d405b669c
2 changed files with 49 additions and 17 deletions

View File

@@ -183,6 +183,13 @@ class CrossAttention(nn.Module):
return self.o(x)
class GateModule(nn.Module):
def __init__(self,):
super().__init__()
def forward(self, x, gate, residual):
return x + gate * residual
class DiTBlock(nn.Module):
def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6):
super().__init__()
@@ -199,16 +206,17 @@ class DiTBlock(nn.Module):
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
approximate='tanh'), nn.Linear(ffn_dim, dim))
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
self.gate = GateModule()
def forward(self, x, context, t_mod, freqs):
# msa: multi-head self-attention mlp: multi-layer perceptron
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
input_x = modulate(self.norm1(x), shift_msa, scale_msa)
x = x + gate_msa * self.self_attn(input_x, freqs)
x = self.gate(x, gate_msa, self.self_attn(input_x, freqs))
x = x + self.cross_attn(self.norm3(x), context)
input_x = modulate(self.norm2(x), shift_mlp, scale_mlp)
x = x + gate_mlp * self.ffn(input_x)
x = self.gate(x, gate_mlp, self.ffn(input_x))
return x