mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
150 lines
5.9 KiB
Python
150 lines
5.9 KiB
Python
import torch
|
|
from diffsynth import SDTextEncoder
|
|
from diffsynth.models.sd3_text_encoder import SD3TextEncoder1StateDictConverter
|
|
from diffsynth.models.sd_text_encoder import CLIPEncoderLayer
|
|
|
|
|
|
class LoRALayerBlock(torch.nn.Module):
|
|
def __init__(self, L, dim_in):
|
|
super().__init__()
|
|
self.x = torch.nn.Parameter(torch.randn(1, L, dim_in))
|
|
|
|
def forward(self, lora_A, lora_B):
|
|
out = self.x @ lora_A.T @ lora_B.T
|
|
return out
|
|
|
|
|
|
class LoRAEmbedder(torch.nn.Module):
|
|
def __init__(self, lora_patterns=None, L=1, out_dim=2048):
|
|
super().__init__()
|
|
if lora_patterns is None:
|
|
lora_patterns = self.default_lora_patterns()
|
|
|
|
model_dict = {}
|
|
for lora_pattern in lora_patterns:
|
|
name, dim = lora_pattern["name"], lora_pattern["dim"][0]
|
|
model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim)
|
|
self.model_dict = torch.nn.ModuleDict(model_dict)
|
|
|
|
proj_dict = {}
|
|
for lora_pattern in lora_patterns:
|
|
layer_type, dim = lora_pattern["type"], lora_pattern["dim"][1]
|
|
if layer_type not in proj_dict:
|
|
proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim, out_dim)
|
|
self.proj_dict = torch.nn.ModuleDict(proj_dict)
|
|
|
|
self.lora_patterns = lora_patterns
|
|
|
|
|
|
def default_lora_patterns(self):
|
|
lora_patterns = []
|
|
lora_dict = {
|
|
"attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432),
|
|
"attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432),
|
|
}
|
|
for i in range(19):
|
|
for suffix in lora_dict:
|
|
lora_patterns.append({
|
|
"name": f"blocks.{i}.{suffix}",
|
|
"dim": lora_dict[suffix],
|
|
"type": suffix,
|
|
})
|
|
lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)}
|
|
for i in range(38):
|
|
for suffix in lora_dict:
|
|
lora_patterns.append({
|
|
"name": f"single_blocks.{i}.{suffix}",
|
|
"dim": lora_dict[suffix],
|
|
"type": suffix,
|
|
})
|
|
return lora_patterns
|
|
|
|
def forward(self, lora):
|
|
lora_emb = []
|
|
for lora_pattern in self.lora_patterns:
|
|
name, layer_type = lora_pattern["name"], lora_pattern["type"]
|
|
lora_A = lora[name + ".lora_A.default.weight"]
|
|
lora_B = lora[name + ".lora_B.default.weight"]
|
|
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
|
|
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
|
|
lora_emb.append(lora_out)
|
|
lora_emb = torch.concat(lora_emb, dim=1)
|
|
return lora_emb
|
|
|
|
|
|
class TextEncoder(torch.nn.Module):
|
|
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
|
super().__init__()
|
|
|
|
# token_embedding
|
|
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
|
|
|
# position_embeds (This is a fixed tensor)
|
|
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
|
|
|
# encoders
|
|
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
|
|
|
# attn_mask
|
|
self.attn_mask = self.attention_mask(max_position_embeddings)
|
|
|
|
# final_layer_norm
|
|
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
|
|
|
def attention_mask(self, length):
|
|
mask = torch.empty(length, length)
|
|
mask.fill_(float("-inf"))
|
|
mask.triu_(1)
|
|
return mask
|
|
|
|
def forward(self, input_ids, clip_skip=1):
|
|
embeds = self.token_embedding(input_ids) + self.position_embeds
|
|
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
|
for encoder_id, encoder in enumerate(self.encoders):
|
|
embeds = encoder(embeds, attn_mask=attn_mask)
|
|
if encoder_id + clip_skip == len(self.encoders):
|
|
break
|
|
embeds = self.final_layer_norm(embeds)
|
|
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
|
return pooled_embeds
|
|
|
|
@staticmethod
|
|
def state_dict_converter():
|
|
return SD3TextEncoder1StateDictConverter()
|
|
|
|
|
|
class LoRAEncoder(torch.nn.Module):
|
|
def __init__(self, embed_dim=768, max_position_embeddings=304, num_encoder_layers=2, encoder_intermediate_size=3072, L=1):
|
|
super().__init__()
|
|
max_position_embeddings *= L
|
|
|
|
# Embedder
|
|
self.embedder = LoRAEmbedder(L=L, out_dim=embed_dim)
|
|
|
|
# position_embeds (This is a fixed tensor)
|
|
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
|
|
|
# encoders
|
|
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
|
|
|
# attn_mask
|
|
self.attn_mask = self.attention_mask(max_position_embeddings)
|
|
|
|
# final_layer_norm
|
|
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
|
|
|
def attention_mask(self, length):
|
|
mask = torch.empty(length, length)
|
|
mask.fill_(float("-inf"))
|
|
mask.triu_(1)
|
|
return mask
|
|
|
|
def forward(self, lora):
|
|
embeds = self.embedder(lora) + self.position_embeds
|
|
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
|
for encoder_id, encoder in enumerate(self.encoders):
|
|
embeds = encoder(embeds, attn_mask=attn_mask)
|
|
embeds = self.final_layer_norm(embeds)
|
|
embeds = embeds.mean(dim=1)
|
|
return embeds
|