import torch, math def get_timestep_embedding( timesteps: torch.Tensor, embedding_dim: int, flip_sin_to_cos: bool = False, downscale_freq_shift: float = 1, scale: float = 1, max_period: int = 10000, computation_device = None, align_dtype_to_timestep = False, ): assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange( start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device ) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) if align_dtype_to_timestep: emb = emb.to(timesteps.dtype) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings emb = scale * emb # concat sine and cosine embeddings emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # flip sine and cosine embeddings if flip_sin_to_cos: emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) # zero pad if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb class TemporalTimesteps(torch.nn.Module): def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None, scale=1, align_dtype_to_timestep=False): super().__init__() self.num_channels = num_channels self.flip_sin_to_cos = flip_sin_to_cos self.downscale_freq_shift = downscale_freq_shift self.computation_device = computation_device self.scale = scale self.align_dtype_to_timestep = align_dtype_to_timestep def forward(self, timesteps): t_emb = get_timestep_embedding( timesteps, self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift, computation_device=self.computation_device, scale=self.scale, align_dtype_to_timestep=self.align_dtype_to_timestep, ) return t_emb class DiffusersCompatibleTimestepProj(torch.nn.Module): def __init__(self, dim_in, dim_out): super().__init__() self.linear_1 = torch.nn.Linear(dim_in, dim_out) self.act = torch.nn.SiLU() self.linear_2 = torch.nn.Linear(dim_out, dim_out) def forward(self, x): x = self.linear_1(x) x = self.act(x) x = self.linear_2(x) return x class TimestepEmbeddings(torch.nn.Module): def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False, use_additional_t_cond=False): super().__init__() self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep) if diffusers_compatible_format: self.timestep_embedder = DiffusersCompatibleTimestepProj(dim_in, dim_out) else: self.timestep_embedder = torch.nn.Sequential( torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out) ) self.use_additional_t_cond = use_additional_t_cond if use_additional_t_cond: self.addition_t_embedding = torch.nn.Embedding(2, dim_out) def forward(self, timestep, dtype, addition_t_cond=None): time_emb = self.time_proj(timestep).to(dtype) time_emb = self.timestep_embedder(time_emb) if addition_t_cond is not None: addition_t_emb = self.addition_t_embedding(addition_t_cond) addition_t_emb = addition_t_emb.to(dtype=dtype) time_emb = time_emb + addition_t_emb return time_emb class RMSNorm(torch.nn.Module): def __init__(self, dim, eps, elementwise_affine=True): super().__init__() self.eps = eps if elementwise_affine: self.weight = torch.nn.Parameter(torch.ones((dim,))) else: self.weight = None def forward(self, hidden_states): input_dtype = hidden_states.dtype variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.eps) hidden_states = hidden_states.to(input_dtype) if self.weight is not None: hidden_states = hidden_states * self.weight return hidden_states class AdaLayerNorm(torch.nn.Module): def __init__(self, dim, single=False, dual=False): super().__init__() self.single = single self.dual = dual self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual]) self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) def forward(self, x, emb): emb = self.linear(torch.nn.functional.silu(emb)) if self.single: scale, shift = emb.unsqueeze(1).chunk(2, dim=2) x = self.norm(x) * (1 + scale) + shift return x elif self.dual: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2) norm_x = self.norm(x) x = norm_x * (1 + scale_msa) + shift_msa norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2 return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2 else: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2) x = self.norm(x) * (1 + scale_msa) + shift_msa return x, gate_msa, shift_mlp, scale_mlp, gate_mlp