mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
support flux ipadapter
This commit is contained in:
@@ -6,16 +6,21 @@ from .tiler import TileWorker
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim, eps):
|
||||
def __init__(self, dim, eps, elementwise_affine=True):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones((dim,)))
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = torch.nn.Parameter(torch.ones((dim,)))
|
||||
else:
|
||||
self.weight = None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
hidden_states = hidden_states.to(input_dtype) * self.weight
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
if self.weight is not None:
|
||||
hidden_states = hidden_states * self.weight
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user