support flux ipadapter

This commit is contained in:
root
2024-11-26 18:08:50 +08:00
parent 5fc9e53eec
commit 4f40683fd8
6 changed files with 133 additions and 19 deletions

View File

@@ -6,16 +6,21 @@ from .tiler import TileWorker
class RMSNorm(torch.nn.Module):
def __init__(self, dim, eps):
def __init__(self, dim, eps, elementwise_affine=True):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones((dim,)))
self.eps = eps
if elementwise_affine:
self.weight = torch.nn.Parameter(torch.ones((dim,)))
else:
self.weight = None
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
hidden_states = hidden_states.to(input_dtype) * self.weight
hidden_states = hidden_states.to(input_dtype)
if self.weight is not None:
hidden_states = hidden_states * self.weight
return hidden_states