support sd3.5

This commit is contained in:
Artiprocher
2024-11-06 19:57:01 +08:00
parent 344cbd3286
commit 39ddb7c3e3
5 changed files with 175 additions and 514 deletions

View File

@@ -1,5 +1,5 @@
import torch
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm, RMSNorm
from einops import rearrange
from .tiler import TileWorker
from .utils import init_weights_on_device
@@ -37,21 +37,6 @@ class RoPEEmbedding(torch.nn.Module):
class RMSNorm(torch.nn.Module):
def __init__(self, dim, eps):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones((dim,)))
self.eps = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
hidden_states = hidden_states.to(input_dtype) * self.weight
return hidden_states
class FluxJointAttention(torch.nn.Module):
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False):
super().__init__()