mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 15:48:20 +00:00
Mova (#1337)
* support mova inference * mova media_io * add unified audio_video api & fix bug of mono audio input for ltx * support mova train * mova docs * fix bug
This commit is contained in:
@@ -99,18 +99,30 @@ def rope_apply(x, freqs, num_heads):
|
||||
return x_out.to(x.dtype)
|
||||
|
||||
|
||||
def set_to_torch_norm(models):
|
||||
for model in models:
|
||||
for module in model.modules():
|
||||
if isinstance(module, RMSNorm):
|
||||
module.use_torch_norm = True
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.use_torch_norm = False
|
||||
self.normalized_shape = (dim,)
|
||||
|
||||
def norm(self, x):
|
||||
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
dtype = x.dtype
|
||||
return self.norm(x.float()).to(dtype) * self.weight
|
||||
if self.use_torch_norm:
|
||||
return F.rms_norm(x, self.normalized_shape, self.weight, self.eps)
|
||||
else:
|
||||
return self.norm(x.float()).to(dtype) * self.weight
|
||||
|
||||
|
||||
class AttentionModule(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user