mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 07:18:14 +00:00
flux
This commit is contained in:
@@ -1,9 +1,38 @@
|
||||
from .svd_image_encoder import SVDImageEncoder
|
||||
from .sd3_dit import RMSNorm
|
||||
from transformers import CLIPImageProcessor
|
||||
from .general_modules import RMSNorm
|
||||
from transformers import SiglipVisionModel, SiglipVisionConfig
|
||||
import torch
|
||||
|
||||
|
||||
class SiglipVisionModelSO400M(SiglipVisionModel):
|
||||
def __init__(self):
|
||||
config = SiglipVisionConfig(**{
|
||||
"architectures": [
|
||||
"SiglipModel"
|
||||
],
|
||||
"initializer_factor": 1.0,
|
||||
"model_type": "siglip",
|
||||
"text_config": {
|
||||
"hidden_size": 1152,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_text_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27
|
||||
},
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.37.0.dev0",
|
||||
"vision_config": {
|
||||
"hidden_size": 1152,
|
||||
"image_size": 384,
|
||||
"intermediate_size": 4304,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 14
|
||||
}
|
||||
})
|
||||
super().__init__(config)
|
||||
|
||||
|
||||
class MLPProjModel(torch.nn.Module):
|
||||
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
||||
super().__init__()
|
||||
|
||||
@@ -106,7 +106,7 @@ class TileWorker:
|
||||
return model_output
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
class ConvAttention(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||
super().__init__()
|
||||
@@ -115,10 +115,10 @@ class Attention(torch.nn.Module):
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||
self.to_q = torch.nn.Conv2d(q_dim, dim_inner, kernel_size=(1, 1), bias=bias_q)
|
||||
self.to_k = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
|
||||
self.to_v = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
|
||||
self.to_out = torch.nn.Conv2d(dim_inner, q_dim, kernel_size=(1, 1), bias=bias_out)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||
if encoder_hidden_states is None:
|
||||
@@ -126,9 +126,14 @@ class Attention(torch.nn.Module):
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(encoder_hidden_states)
|
||||
v = self.to_v(encoder_hidden_states)
|
||||
conv_input = rearrange(hidden_states, "B L C -> B C L 1")
|
||||
q = self.to_q(conv_input)
|
||||
q = rearrange(q[:, :, :, 0], "B C L -> B L C")
|
||||
conv_input = rearrange(encoder_hidden_states, "B L C -> B C L 1")
|
||||
k = self.to_k(conv_input)
|
||||
v = self.to_v(conv_input)
|
||||
k = rearrange(k[:, :, :, 0], "B C L -> B L C")
|
||||
v = rearrange(v[:, :, :, 0], "B C L -> B L C")
|
||||
|
||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
@@ -138,7 +143,9 @@ class Attention(torch.nn.Module):
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
conv_input = rearrange(hidden_states, "B L C -> B C L 1")
|
||||
hidden_states = self.to_out(conv_input)
|
||||
hidden_states = rearrange(hidden_states[:, :, :, 0], "B C L -> B L C")
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -152,7 +159,7 @@ class VAEAttentionBlock(torch.nn.Module):
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
|
||||
self.transformer_blocks = torch.nn.ModuleList([
|
||||
Attention(
|
||||
ConvAttention(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
@@ -236,7 +243,7 @@ class DownSampler(torch.nn.Module):
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
|
||||
class SD3VAEDecoder(torch.nn.Module):
|
||||
class FluxVAEDecoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.3611
|
||||
@@ -308,7 +315,7 @@ class SD3VAEDecoder(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SD3VAEEncoder(torch.nn.Module):
|
||||
class FluxVAEEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.3611
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import torch
|
||||
from diffsynth.models.svd_unet import TemporalTimesteps
|
||||
from .general_modules import TemporalTimesteps
|
||||
|
||||
|
||||
class MultiValueEncoder(torch.nn.Module):
|
||||
def __init__(self, encoders=()):
|
||||
super().__init__()
|
||||
if not isinstance(encoders, list):
|
||||
encoders = [encoders]
|
||||
self.encoders = torch.nn.ModuleList(encoders)
|
||||
|
||||
def __call__(self, values, dtype):
|
||||
|
||||
Reference in New Issue
Block a user