mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
DiffSynth-Studio 2.0 major update
This commit is contained in:
139
diffsynth/models/general_modules.py
Normal file
139
diffsynth/models/general_modules.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import torch, math
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = False,
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
computation_device = None,
|
||||
align_dtype_to_timestep = False,
|
||||
):
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent).to(timesteps.device)
|
||||
if align_dtype_to_timestep:
|
||||
emb = emb.to(timesteps.dtype)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
emb = scale * emb
|
||||
|
||||
# concat sine and cosine embeddings
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
|
||||
# flip sine and cosine embeddings
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
|
||||
# zero pad
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
class TemporalTimesteps(torch.nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None, scale=1, align_dtype_to_timestep=False):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.computation_device = computation_device
|
||||
self.scale = scale
|
||||
self.align_dtype_to_timestep = align_dtype_to_timestep
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
computation_device=self.computation_device,
|
||||
scale=self.scale,
|
||||
align_dtype_to_timestep=self.align_dtype_to_timestep,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class DiffusersCompatibleTimestepProj(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.linear_1 = torch.nn.Linear(dim_in, dim_out)
|
||||
self.act = torch.nn.SiLU()
|
||||
self.linear_2 = torch.nn.Linear(dim_out, dim_out)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear_1(x)
|
||||
x = self.act(x)
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbeddings(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False):
|
||||
super().__init__()
|
||||
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
|
||||
if diffusers_compatible_format:
|
||||
self.timestep_embedder = DiffusersCompatibleTimestepProj(dim_in, dim_out)
|
||||
else:
|
||||
self.timestep_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
time_emb = self.time_proj(timestep).to(dtype)
|
||||
time_emb = self.timestep_embedder(time_emb)
|
||||
return time_emb
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim, eps, elementwise_affine=True):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = torch.nn.Parameter(torch.ones((dim,)))
|
||||
else:
|
||||
self.weight = None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
if self.weight is not None:
|
||||
hidden_states = hidden_states * self.weight
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AdaLayerNorm(torch.nn.Module):
|
||||
def __init__(self, dim, single=False, dual=False):
|
||||
super().__init__()
|
||||
self.single = single
|
||||
self.dual = dual
|
||||
self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])
|
||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, emb):
|
||||
emb = self.linear(torch.nn.functional.silu(emb))
|
||||
if self.single:
|
||||
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
elif self.dual:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)
|
||||
norm_x = self.norm(x)
|
||||
x = norm_x * (1 + scale_msa) + shift_msa
|
||||
norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
|
||||
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
Reference in New Issue
Block a user