mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
@@ -312,7 +312,107 @@ flux_series = [
|
|||||||
"model_hash": "0629116fce1472503a66992f96f3eb1a",
|
"model_hash": "0629116fce1472503a66992f96f3eb1a",
|
||||||
"model_name": "flux_value_controller",
|
"model_name": "flux_value_controller",
|
||||||
"model_class": "diffsynth.models.flux_value_control.SingleValueEncoder",
|
"model_class": "diffsynth.models.flux_value_control.SingleValueEncoder",
|
||||||
}
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "52357cb26250681367488a8954c271e8",
|
||||||
|
"model_name": "flux_controlnet",
|
||||||
|
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
||||||
|
"extra_kwargs": {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "78d18b9101345ff695f312e7e62538c0",
|
||||||
|
"model_name": "flux_controlnet",
|
||||||
|
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
||||||
|
"extra_kwargs": {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors")
|
||||||
|
"model_hash": "b001c89139b5f053c715fe772362dd2a",
|
||||||
|
"model_name": "flux_controlnet",
|
||||||
|
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
||||||
|
"extra_kwargs": {"num_single_blocks": 0},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin")
|
||||||
|
"model_hash": "c07c0f04f5ff55e86b4e937c7a40d481",
|
||||||
|
"model_name": "infiniteyou_image_projector",
|
||||||
|
"model_class": "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_infiniteyou.FluxInfiniteYouImageProjectorStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors")
|
||||||
|
"model_hash": "7f9583eb8ba86642abb9a21a4b2c9e16",
|
||||||
|
"model_name": "flux_controlnet",
|
||||||
|
"model_class": "diffsynth.models.flux_controlnet.FluxControlNet",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter",
|
||||||
|
"extra_kwargs": {"num_joint_blocks": 4, "num_single_blocks": 10},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "77c2e4dd2440269eb33bfaa0d004f6ab",
|
||||||
|
"model_name": "flux_lora_encoder",
|
||||||
|
"model_class": "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "30143afb2dea73d1ac580e0787628f8c",
|
||||||
|
"model_name": "flux_lora_patcher",
|
||||||
|
"model_class": "diffsynth.models.flux_lora_patcher.FluxLoraPatcher",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors")
|
||||||
|
"model_hash": "2bd19e845116e4f875a0a048e27fc219",
|
||||||
|
"model_name": "nexus_gen_llm",
|
||||||
|
"model_class": "diffsynth.models.nexus_gen.NexusGenAutoregressiveModel",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen.NexusGenAutoregressiveModelStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
|
||||||
|
"model_hash": "63c969fd37cce769a90aa781fbff5f81",
|
||||||
|
"model_name": "nexus_gen_editing_adapter",
|
||||||
|
"model_class": "diffsynth.models.nexus_gen_projector.NexusGenImageEmbeddingMerger",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenMergerStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin")
|
||||||
|
"model_hash": "63c969fd37cce769a90aa781fbff5f81",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
|
||||||
|
"model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
|
||||||
|
"model_name": "nexus_gen_generation_adapter",
|
||||||
|
"model_class": "diffsynth.models.nexus_gen_projector.NexusGenAdapter",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenAdapterStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin")
|
||||||
|
"model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d",
|
||||||
|
"model_name": "flux_dit",
|
||||||
|
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin")
|
||||||
|
"model_hash": "4daaa66cc656a8fe369908693dad0a35",
|
||||||
|
"model_name": "flux_ipadapter",
|
||||||
|
"model_class": "diffsynth.models.flux_ipadapter.FluxIpAdapter",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.FluxIpAdapterStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "04d8c1e20a1f1b25f7434f111992a33f",
|
||||||
|
"model_name": "siglip_vision_model",
|
||||||
|
"model_class": "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.SiglipStateDictConverter",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series
|
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series
|
||||||
|
|||||||
@@ -1,9 +1,62 @@
|
|||||||
import torch
|
import torch
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
|
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
|
||||||
from .utils import hash_state_dict_keys, init_weights_on_device
|
# from .utils import hash_state_dict_keys, init_weights_on_device
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||||
|
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||||
|
keys_str = keys_str.encode(encoding="UTF-8")
|
||||||
|
return hashlib.md5(keys_str).hexdigest()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
|
||||||
|
|
||||||
|
old_register_parameter = torch.nn.Module.register_parameter
|
||||||
|
if include_buffers:
|
||||||
|
old_register_buffer = torch.nn.Module.register_buffer
|
||||||
|
|
||||||
|
def register_empty_parameter(module, name, param):
|
||||||
|
old_register_parameter(module, name, param)
|
||||||
|
if param is not None:
|
||||||
|
param_cls = type(module._parameters[name])
|
||||||
|
kwargs = module._parameters[name].__dict__
|
||||||
|
kwargs["requires_grad"] = param.requires_grad
|
||||||
|
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
||||||
|
|
||||||
|
def register_empty_buffer(module, name, buffer, persistent=True):
|
||||||
|
old_register_buffer(module, name, buffer, persistent=persistent)
|
||||||
|
if buffer is not None:
|
||||||
|
module._buffers[name] = module._buffers[name].to(device)
|
||||||
|
|
||||||
|
def patch_tensor_constructor(fn):
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
kwargs["device"] = device
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
if include_buffers:
|
||||||
|
tensor_constructors_to_patch = {
|
||||||
|
torch_function_name: getattr(torch, torch_function_name)
|
||||||
|
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
tensor_constructors_to_patch = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
torch.nn.Module.register_parameter = register_empty_parameter
|
||||||
|
if include_buffers:
|
||||||
|
torch.nn.Module.register_buffer = register_empty_buffer
|
||||||
|
for torch_function_name in tensor_constructors_to_patch.keys():
|
||||||
|
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
torch.nn.Module.register_parameter = old_register_parameter
|
||||||
|
if include_buffers:
|
||||||
|
torch.nn.Module.register_buffer = old_register_buffer
|
||||||
|
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
||||||
|
setattr(torch, torch_function_name, old_torch_function)
|
||||||
|
|
||||||
class FluxControlNet(torch.nn.Module):
|
class FluxControlNet(torch.nn.Module):
|
||||||
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
|
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
|
||||||
@@ -102,9 +155,9 @@ class FluxControlNet(torch.nn.Module):
|
|||||||
return controlnet_res_stack, controlnet_single_res_stack
|
return controlnet_res_stack, controlnet_single_res_stack
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
# @staticmethod
|
||||||
def state_dict_converter():
|
# def state_dict_converter():
|
||||||
return FluxControlNetStateDictConverter()
|
# return FluxControlNetStateDictConverter()
|
||||||
|
|
||||||
def quantize(self):
|
def quantize(self):
|
||||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||||
|
|||||||
@@ -5,34 +5,21 @@ import torch
|
|||||||
|
|
||||||
class SiglipVisionModelSO400M(SiglipVisionModel):
|
class SiglipVisionModelSO400M(SiglipVisionModel):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
config = SiglipVisionConfig(**{
|
config = SiglipVisionConfig(
|
||||||
"architectures": [
|
hidden_size=1152,
|
||||||
"SiglipModel"
|
image_size=384,
|
||||||
],
|
intermediate_size=4304,
|
||||||
"initializer_factor": 1.0,
|
model_type="siglip_vision_model",
|
||||||
"model_type": "siglip",
|
num_attention_heads=16,
|
||||||
"text_config": {
|
num_hidden_layers=27,
|
||||||
"hidden_size": 1152,
|
patch_size=14,
|
||||||
"intermediate_size": 4304,
|
architectures=["SiglipModel"],
|
||||||
"model_type": "siglip_text_model",
|
initializer_factor=1.0,
|
||||||
"num_attention_heads": 16,
|
torch_dtype="float32",
|
||||||
"num_hidden_layers": 27
|
transformers_version="4.37.0.dev0"
|
||||||
},
|
)
|
||||||
"torch_dtype": "float32",
|
|
||||||
"transformers_version": "4.37.0.dev0",
|
|
||||||
"vision_config": {
|
|
||||||
"hidden_size": 1152,
|
|
||||||
"image_size": 384,
|
|
||||||
"intermediate_size": 4304,
|
|
||||||
"model_type": "siglip_vision_model",
|
|
||||||
"num_attention_heads": 16,
|
|
||||||
"num_hidden_layers": 27,
|
|
||||||
"patch_size": 14
|
|
||||||
}
|
|
||||||
})
|
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
|
|
||||||
class MLPProjModel(torch.nn.Module):
|
class MLPProjModel(torch.nn.Module):
|
||||||
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@@ -1,5 +1,415 @@
|
|||||||
import torch
|
import torch
|
||||||
from .sd_text_encoder import CLIPEncoderLayer
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
def low_version_attention(query, key, value, attn_bias=None):
|
||||||
|
scale = 1 / query.shape[-1] ** 0.5
|
||||||
|
query = query * scale
|
||||||
|
attn = torch.matmul(query, key.transpose(-2, -1))
|
||||||
|
if attn_bias is not None:
|
||||||
|
attn = attn + attn_bias
|
||||||
|
attn = attn.softmax(-1)
|
||||||
|
return attn @ value
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||||
|
super().__init__()
|
||||||
|
dim_inner = head_dim * num_heads
|
||||||
|
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
|
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||||
|
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
|
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
|
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||||
|
|
||||||
|
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||||
|
batch_size = q.shape[0]
|
||||||
|
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
||||||
|
hidden_states = hidden_states + scale * ip_hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
|
||||||
|
batch_size = encoder_hidden_states.shape[0]
|
||||||
|
|
||||||
|
q = self.to_q(hidden_states)
|
||||||
|
k = self.to_k(encoder_hidden_states)
|
||||||
|
v = self.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if qkv_preprocessor is not None:
|
||||||
|
q, k, v = qkv_preprocessor(q, k, v)
|
||||||
|
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
|
if ipadapter_kwargs is not None:
|
||||||
|
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
|
||||||
|
hidden_states = self.to_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
|
||||||
|
q = self.to_q(hidden_states)
|
||||||
|
k = self.to_k(encoder_hidden_states)
|
||||||
|
v = self.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||||
|
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||||
|
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
|
||||||
|
else:
|
||||||
|
import xformers.ops as xops
|
||||||
|
hidden_states = xops.memory_efficient_attention(q, k, v)
|
||||||
|
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
hidden_states = self.to_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||||
|
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPEncoderLayer(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
|
||||||
|
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
|
||||||
|
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
|
||||||
|
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
|
||||||
|
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
|
||||||
|
|
||||||
|
self.use_quick_gelu = use_quick_gelu
|
||||||
|
|
||||||
|
def quickGELU(self, x):
|
||||||
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, attn_mask=None):
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
|
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm2(hidden_states)
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
if self.use_quick_gelu:
|
||||||
|
hidden_states = self.quickGELU(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states = torch.nn.functional.gelu(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class SDTextEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# token_embedding
|
||||||
|
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
||||||
|
|
||||||
|
# position_embeds (This is a fixed tensor)
|
||||||
|
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||||
|
|
||||||
|
# encoders
|
||||||
|
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
||||||
|
|
||||||
|
# attn_mask
|
||||||
|
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||||
|
|
||||||
|
# final_layer_norm
|
||||||
|
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||||
|
|
||||||
|
def attention_mask(self, length):
|
||||||
|
mask = torch.empty(length, length)
|
||||||
|
mask.fill_(float("-inf"))
|
||||||
|
mask.triu_(1)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def forward(self, input_ids, clip_skip=1):
|
||||||
|
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||||
|
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||||
|
for encoder_id, encoder in enumerate(self.encoders):
|
||||||
|
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||||
|
if encoder_id + clip_skip == len(self.encoders):
|
||||||
|
break
|
||||||
|
embeds = self.final_layer_norm(embeds)
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return SDTextEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class SDTextEncoderStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
rename_dict = {
|
||||||
|
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||||
|
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
||||||
|
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||||
|
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
||||||
|
}
|
||||||
|
attn_rename_dict = {
|
||||||
|
"self_attn.q_proj": "attn.to_q",
|
||||||
|
"self_attn.k_proj": "attn.to_k",
|
||||||
|
"self_attn.v_proj": "attn.to_v",
|
||||||
|
"self_attn.out_proj": "attn.to_out",
|
||||||
|
"layer_norm1": "layer_norm1",
|
||||||
|
"layer_norm2": "layer_norm2",
|
||||||
|
"mlp.fc1": "fc1",
|
||||||
|
"mlp.fc2": "fc2",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name in rename_dict:
|
||||||
|
param = state_dict[name]
|
||||||
|
if name == "text_model.embeddings.position_embedding.weight":
|
||||||
|
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||||
|
state_dict_[rename_dict[name]] = param
|
||||||
|
elif name.startswith("text_model.encoder.layers."):
|
||||||
|
param = state_dict[name]
|
||||||
|
names = name.split(".")
|
||||||
|
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
||||||
|
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
||||||
|
state_dict_[name_] = param
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
rename_dict = {
|
||||||
|
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name in rename_dict:
|
||||||
|
param = state_dict[name]
|
||||||
|
if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
|
||||||
|
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||||
|
state_dict_[rename_dict[name]] = param
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class LoRALayerBlock(torch.nn.Module):
|
class LoRALayerBlock(torch.nn.Module):
|
||||||
@@ -63,8 +473,8 @@ class LoRAEmbedder(torch.nn.Module):
|
|||||||
lora_emb = []
|
lora_emb = []
|
||||||
for lora_pattern in self.lora_patterns:
|
for lora_pattern in self.lora_patterns:
|
||||||
name, layer_type = lora_pattern["name"], lora_pattern["type"]
|
name, layer_type = lora_pattern["name"], lora_pattern["type"]
|
||||||
lora_A = lora[name + ".lora_A.default.weight"]
|
lora_A = lora[name + ".lora_A.weight"]
|
||||||
lora_B = lora[name + ".lora_B.default.weight"]
|
lora_B = lora[name + ".lora_B.weight"]
|
||||||
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
|
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
|
||||||
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
|
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
|
||||||
lora_emb.append(lora_out)
|
lora_emb.append(lora_out)
|
||||||
|
|||||||
306
diffsynth/models/flux_lora_patcher.py
Normal file
306
diffsynth/models/flux_lora_patcher.py
Normal file
@@ -0,0 +1,306 @@
|
|||||||
|
import torch, math
|
||||||
|
from ..core.loader import load_state_dict
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
class GeneralLoRALoader:
|
||||||
|
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||||
|
self.device = device
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
def get_name_dict(self, lora_state_dict):
|
||||||
|
lora_name_dict = {}
|
||||||
|
for key in lora_state_dict:
|
||||||
|
if ".lora_B." not in key:
|
||||||
|
continue
|
||||||
|
keys = key.split(".")
|
||||||
|
if len(keys) > keys.index("lora_B") + 2:
|
||||||
|
keys.pop(keys.index("lora_B") + 1)
|
||||||
|
keys.pop(keys.index("lora_B"))
|
||||||
|
if keys[0] == "diffusion_model":
|
||||||
|
keys.pop(0)
|
||||||
|
keys.pop(-1)
|
||||||
|
target_name = ".".join(keys)
|
||||||
|
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
||||||
|
return lora_name_dict
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||||
|
updated_num = 0
|
||||||
|
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if name in lora_name_dict:
|
||||||
|
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
|
||||||
|
if len(weight_up.shape) == 4:
|
||||||
|
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||||
|
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||||
|
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||||
|
else:
|
||||||
|
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
||||||
|
state_dict = module.state_dict()
|
||||||
|
state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
|
||||||
|
module.load_state_dict(state_dict)
|
||||||
|
updated_num += 1
|
||||||
|
print(f"{updated_num} tensors are updated by LoRA.")
|
||||||
|
|
||||||
|
class FluxLoRALoader(GeneralLoRALoader):
|
||||||
|
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||||
|
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||||
|
|
||||||
|
self.diffusers_rename_dict = {
|
||||||
|
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.default.weight",
|
||||||
|
"transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
||||||
|
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
||||||
|
}
|
||||||
|
|
||||||
|
self.civitai_rename_dict = {
|
||||||
|
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
|
||||||
|
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
|
||||||
|
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
|
||||||
|
}
|
||||||
|
|
||||||
|
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||||
|
super().load(model, state_dict_lora, alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_state_dict(self,state_dict):
|
||||||
|
|
||||||
|
def guess_block_id(name,model_resource):
|
||||||
|
if model_resource == 'civitai':
|
||||||
|
names = name.split("_")
|
||||||
|
for i in names:
|
||||||
|
if i.isdigit():
|
||||||
|
return i, name.replace(f"_{i}_", "_blockid_")
|
||||||
|
if model_resource == 'diffusers':
|
||||||
|
names = name.split(".")
|
||||||
|
for i in names:
|
||||||
|
if i.isdigit():
|
||||||
|
return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def guess_resource(state_dict):
|
||||||
|
for k in state_dict:
|
||||||
|
if "lora_unet_" in k:
|
||||||
|
return 'civitai'
|
||||||
|
elif k.startswith("transformer."):
|
||||||
|
return 'diffusers'
|
||||||
|
else:
|
||||||
|
None
|
||||||
|
|
||||||
|
model_resource = guess_resource(state_dict)
|
||||||
|
if model_resource is None:
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict
|
||||||
|
def guess_alpha(state_dict):
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if ".alpha" in name:
|
||||||
|
for suffix in [".lora_down.weight", ".lora_A.weight"]:
|
||||||
|
name_ = name.replace(".alpha", suffix)
|
||||||
|
if name_ in state_dict:
|
||||||
|
lora_alpha = param.item() / state_dict[name_].shape[0]
|
||||||
|
lora_alpha = math.sqrt(lora_alpha)
|
||||||
|
return lora_alpha
|
||||||
|
|
||||||
|
return 1
|
||||||
|
|
||||||
|
alpha = guess_alpha(state_dict)
|
||||||
|
|
||||||
|
state_dict_ = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
block_id, source_name = guess_block_id(name,model_resource)
|
||||||
|
if alpha != 1:
|
||||||
|
param *= alpha
|
||||||
|
if source_name in rename_dict:
|
||||||
|
target_name = rename_dict[source_name]
|
||||||
|
target_name = target_name.replace(".blockid.", f".{block_id}.")
|
||||||
|
state_dict_[target_name] = param
|
||||||
|
else:
|
||||||
|
state_dict_[name] = param
|
||||||
|
|
||||||
|
if model_resource == 'diffusers':
|
||||||
|
for name in list(state_dict_.keys()):
|
||||||
|
if "single_blocks." in name and ".a_to_q." in name:
|
||||||
|
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
|
||||||
|
if mlp is None:
|
||||||
|
dim = 4
|
||||||
|
if 'lora_A' in name:
|
||||||
|
dim = 1
|
||||||
|
mlp = torch.zeros(dim * state_dict_[name].shape[0],
|
||||||
|
*state_dict_[name].shape[1:],
|
||||||
|
dtype=state_dict_[name].dtype)
|
||||||
|
else:
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
|
||||||
|
if 'lora_A' in name:
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_.pop(name),
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||||
|
mlp,
|
||||||
|
], dim=0)
|
||||||
|
elif 'lora_B' in name:
|
||||||
|
d, r = state_dict_[name].shape
|
||||||
|
param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device)
|
||||||
|
param[:d, :r] = state_dict_.pop(name)
|
||||||
|
param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k."))
|
||||||
|
param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v."))
|
||||||
|
param[3*d:, 3*r:] = mlp
|
||||||
|
else:
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_.pop(name),
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||||
|
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||||
|
mlp,
|
||||||
|
], dim=0)
|
||||||
|
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
|
||||||
|
state_dict_[name_] = param
|
||||||
|
for name in list(state_dict_.keys()):
|
||||||
|
for component in ["a", "b"]:
|
||||||
|
if f".{component}_to_q." in name:
|
||||||
|
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||||
|
concat_dim = 0
|
||||||
|
if 'lora_A' in name:
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||||
|
], dim=0)
|
||||||
|
elif 'lora_B' in name:
|
||||||
|
origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
|
||||||
|
d, r = origin.shape
|
||||||
|
# print(d, r)
|
||||||
|
param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device)
|
||||||
|
param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
|
||||||
|
param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")]
|
||||||
|
param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")]
|
||||||
|
else:
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||||
|
], dim=0)
|
||||||
|
state_dict_[name_] = param
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
class LoraMerger(torch.nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.weight_base = torch.nn.Parameter(torch.randn((dim,)))
|
||||||
|
self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))
|
||||||
|
self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))
|
||||||
|
self.weight_out = torch.nn.Parameter(torch.ones((dim,)))
|
||||||
|
self.bias = torch.nn.Parameter(torch.randn((dim,)))
|
||||||
|
self.activation = torch.nn.Sigmoid()
|
||||||
|
self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)
|
||||||
|
self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)
|
||||||
|
|
||||||
|
def forward(self, base_output, lora_outputs):
|
||||||
|
norm_base_output = self.norm_base(base_output)
|
||||||
|
norm_lora_outputs = self.norm_lora(lora_outputs)
|
||||||
|
gate = self.activation(
|
||||||
|
norm_base_output * self.weight_base \
|
||||||
|
+ norm_lora_outputs * self.weight_lora \
|
||||||
|
+ norm_base_output * norm_lora_outputs * self.weight_cross + self.bias
|
||||||
|
)
|
||||||
|
output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
|
||||||
|
return output
|
||||||
|
|
||||||
|
class FluxLoraPatcher(torch.nn.Module):
|
||||||
|
def __init__(self, lora_patterns=None):
|
||||||
|
super().__init__()
|
||||||
|
if lora_patterns is None:
|
||||||
|
lora_patterns = self.default_lora_patterns()
|
||||||
|
model_dict = {}
|
||||||
|
for lora_pattern in lora_patterns:
|
||||||
|
name, dim = lora_pattern["name"], lora_pattern["dim"]
|
||||||
|
model_dict[name.replace(".", "___")] = LoraMerger(dim)
|
||||||
|
self.model_dict = torch.nn.ModuleDict(model_dict)
|
||||||
|
|
||||||
|
def default_lora_patterns(self):
|
||||||
|
lora_patterns = []
|
||||||
|
lora_dict = {
|
||||||
|
"attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432,
|
||||||
|
"attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432,
|
||||||
|
}
|
||||||
|
for i in range(19):
|
||||||
|
for suffix in lora_dict:
|
||||||
|
lora_patterns.append({
|
||||||
|
"name": f"blocks.{i}.{suffix}",
|
||||||
|
"dim": lora_dict[suffix]
|
||||||
|
})
|
||||||
|
lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216}
|
||||||
|
for i in range(38):
|
||||||
|
for suffix in lora_dict:
|
||||||
|
lora_patterns.append({
|
||||||
|
"name": f"single_blocks.{i}.{suffix}",
|
||||||
|
"dim": lora_dict[suffix]
|
||||||
|
})
|
||||||
|
return lora_patterns
|
||||||
|
|
||||||
|
def forward(self, base_output, lora_outputs, name):
|
||||||
|
return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)
|
||||||
161
diffsynth/models/nexus_gen.py
Normal file
161
diffsynth/models/nexus_gen.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
class NexusGenAutoregressiveModel(torch.nn.Module):
|
||||||
|
def __init__(self, max_length=1024, max_pixels=262640):
|
||||||
|
super(NexusGenAutoregressiveModel, self).__init__()
|
||||||
|
from .nexus_gen_ar_model import Qwen2_5_VLForConditionalGeneration
|
||||||
|
from transformers import Qwen2_5_VLConfig
|
||||||
|
self.max_length = max_length
|
||||||
|
self.max_pixels = max_pixels
|
||||||
|
model_config = Qwen2_5_VLConfig(**{
|
||||||
|
"_name_or_path": "DiffSynth-Studio/Nexus-GenV2",
|
||||||
|
"architectures": [
|
||||||
|
"Qwen2_5_VLForConditionalGeneration"
|
||||||
|
],
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"auto_map": {
|
||||||
|
"AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig",
|
||||||
|
"AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel",
|
||||||
|
"AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration"
|
||||||
|
},
|
||||||
|
"bos_token_id": 151643,
|
||||||
|
"eos_token_id": 151645,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 3584,
|
||||||
|
"image_token_id": 151655,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 18944,
|
||||||
|
"max_position_embeddings": 128000,
|
||||||
|
"max_window_layers": 28,
|
||||||
|
"model_type": "qwen2_5_vl",
|
||||||
|
"num_attention_heads": 28,
|
||||||
|
"num_hidden_layers": 28,
|
||||||
|
"num_key_value_heads": 4,
|
||||||
|
"pad_token_id": 151643,
|
||||||
|
"rms_norm_eps": 1e-06,
|
||||||
|
"rope_scaling": {
|
||||||
|
"mrope_section": [
|
||||||
|
16,
|
||||||
|
24,
|
||||||
|
24
|
||||||
|
],
|
||||||
|
"rope_type": "default",
|
||||||
|
"type": "default"
|
||||||
|
},
|
||||||
|
"rope_theta": 1000000.0,
|
||||||
|
"sliding_window": 32768,
|
||||||
|
"tie_word_embeddings": False,
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"transformers_version": "4.49.0",
|
||||||
|
"use_cache": False,
|
||||||
|
"use_sliding_window": False,
|
||||||
|
"video_token_id": 151656,
|
||||||
|
"vision_config": {
|
||||||
|
"hidden_size": 1280,
|
||||||
|
"in_chans": 3,
|
||||||
|
"model_type": "qwen2_5_vl",
|
||||||
|
"spatial_patch_size": 14,
|
||||||
|
"tokens_per_second": 2,
|
||||||
|
"torch_dtype": "bfloat16"
|
||||||
|
},
|
||||||
|
"vision_end_token_id": 151653,
|
||||||
|
"vision_start_token_id": 151652,
|
||||||
|
"vision_token_id": 151654,
|
||||||
|
"vocab_size": 152064
|
||||||
|
})
|
||||||
|
self.model = Qwen2_5_VLForConditionalGeneration(model_config)
|
||||||
|
self.processor = None
|
||||||
|
|
||||||
|
|
||||||
|
def load_processor(self, path):
|
||||||
|
from .nexus_gen_ar_model import Qwen2_5_VLProcessor
|
||||||
|
self.processor = Qwen2_5_VLProcessor.from_pretrained(path)
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return NexusGenAutoregressiveModelStateDictConverter()
|
||||||
|
|
||||||
|
def bound_image(self, image, max_pixels=262640):
|
||||||
|
from qwen_vl_utils import smart_resize
|
||||||
|
resized_height, resized_width = smart_resize(
|
||||||
|
image.height,
|
||||||
|
image.width,
|
||||||
|
max_pixels=max_pixels,
|
||||||
|
)
|
||||||
|
return image.resize((resized_width, resized_height))
|
||||||
|
|
||||||
|
def get_editing_msg(self, instruction):
|
||||||
|
if '<image>' not in instruction:
|
||||||
|
instruction = '<image> ' + instruction
|
||||||
|
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is the image: <image>"}]
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def get_generation_msg(self, instruction):
|
||||||
|
instruction = "Generate an image according to the following description: {}".format(instruction)
|
||||||
|
messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is an image based on the description: <image>"}]
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def forward(self, instruction, ref_image=None, num_img_tokens=81):
|
||||||
|
"""
|
||||||
|
Generate target embeddings for the given instruction and reference image.
|
||||||
|
"""
|
||||||
|
if ref_image is not None:
|
||||||
|
messages = self.get_editing_msg(instruction)
|
||||||
|
images = [self.bound_image(ref_image)] + [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))]
|
||||||
|
output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens)
|
||||||
|
else:
|
||||||
|
messages = self.get_generation_msg(instruction)
|
||||||
|
images = [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))]
|
||||||
|
output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens)
|
||||||
|
|
||||||
|
return output_image_embeddings
|
||||||
|
|
||||||
|
def get_target_embeddings(self, images, messages, processor, model, num_img_tokens=81):
|
||||||
|
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
|
||||||
|
text = text.replace('<image>', '<|vision_start|><|image_pad|><|vision_end|>')
|
||||||
|
inputs = processor(
|
||||||
|
text=[text],
|
||||||
|
images=images,
|
||||||
|
padding=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
inputs = inputs.to(model.device)
|
||||||
|
|
||||||
|
input_embeds = model.model.embed_tokens(inputs['input_ids'])
|
||||||
|
image_embeds = model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw'])
|
||||||
|
ground_truth_image_embeds = image_embeds[-num_img_tokens:]
|
||||||
|
input_image_embeds = image_embeds[:-num_img_tokens]
|
||||||
|
|
||||||
|
image_mask = inputs['input_ids'] == model.config.image_token_id
|
||||||
|
indices = image_mask.cumsum(dim=1)
|
||||||
|
input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask)
|
||||||
|
gt_image_mask = torch.logical_and(image_mask, ~input_image_mask)
|
||||||
|
input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds)
|
||||||
|
input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds)
|
||||||
|
|
||||||
|
image_prefill_embeds = model.image_prefill_embeds(
|
||||||
|
torch.arange(81, device=model.device).long()
|
||||||
|
)
|
||||||
|
input_embeds = input_embeds.masked_scatter(gt_image_mask.unsqueeze(-1).expand_as(input_embeds), image_prefill_embeds)
|
||||||
|
|
||||||
|
position_ids, _ = model.get_rope_index(
|
||||||
|
inputs['input_ids'],
|
||||||
|
inputs['image_grid_thw'],
|
||||||
|
attention_mask=inputs['attention_mask'])
|
||||||
|
position_ids = position_ids.contiguous()
|
||||||
|
outputs = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True)
|
||||||
|
output_image_embeddings = outputs.image_embeddings[:, :-1, :]
|
||||||
|
output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]]
|
||||||
|
return output_image_embeddings, input_image_embeds, inputs['image_grid_thw']
|
||||||
|
|
||||||
|
|
||||||
|
class NexusGenAutoregressiveModelStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
state_dict = {"model." + key: value for key, value in state_dict.items()}
|
||||||
|
return state_dict
|
||||||
1143
diffsynth/models/nexus_gen_ar_model.py
Normal file
1143
diffsynth/models/nexus_gen_ar_model.py
Normal file
File diff suppressed because it is too large
Load Diff
417
diffsynth/models/nexus_gen_projector.py
Normal file
417
diffsynth/models/nexus_gen_projector.py
Normal file
@@ -0,0 +1,417 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
||||||
|
mrope_section = mrope_section * 2
|
||||||
|
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||||
|
unsqueeze_dim
|
||||||
|
)
|
||||||
|
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
||||||
|
unsqueeze_dim
|
||||||
|
)
|
||||||
|
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5_VLRotaryEmbedding(nn.Module):
|
||||||
|
def __init__(self, config, device=None):
|
||||||
|
super().__init__()
|
||||||
|
# BC: "rope_type" was originally "type"
|
||||||
|
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
||||||
|
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
||||||
|
else:
|
||||||
|
self.rope_type = "default"
|
||||||
|
self.max_seq_len_cached = config.max_position_embeddings
|
||||||
|
self.original_max_seq_len = config.max_position_embeddings
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
from transformers.modeling_rope_utils import _compute_default_rope_parameters
|
||||||
|
self.rope_init_fn = _compute_default_rope_parameters
|
||||||
|
|
||||||
|
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
self.original_inv_freq = self.inv_freq
|
||||||
|
|
||||||
|
|
||||||
|
def _dynamic_frequency_update(self, position_ids, device):
|
||||||
|
"""
|
||||||
|
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
||||||
|
1 - growing beyond the cached sequence length (allow scaling)
|
||||||
|
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
||||||
|
"""
|
||||||
|
seq_len = torch.max(position_ids) + 1
|
||||||
|
if seq_len > self.max_seq_len_cached: # growth
|
||||||
|
inv_freq, self.attention_scaling = self.rope_init_fn(
|
||||||
|
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
||||||
|
)
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
||||||
|
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
||||||
|
self.max_seq_len_cached = self.original_max_seq_len
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, x, position_ids):
|
||||||
|
if "dynamic" in self.rope_type:
|
||||||
|
self._dynamic_frequency_update(position_ids, device=x.device)
|
||||||
|
|
||||||
|
# Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for the grids
|
||||||
|
# So we expand the inv_freq to shape (3, ...)
|
||||||
|
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
||||||
|
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
||||||
|
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
||||||
|
device_type = x.device.type
|
||||||
|
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
||||||
|
with torch.autocast(device_type=device_type, enabled=False):
|
||||||
|
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
cos = emb.cos()
|
||||||
|
sin = emb.sin()
|
||||||
|
|
||||||
|
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
||||||
|
cos = cos * self.attention_scaling
|
||||||
|
sin = sin * self.attention_scaling
|
||||||
|
|
||||||
|
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
|
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||||
|
"""
|
||||||
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||||
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5_VLAttention(nn.Module):
|
||||||
|
def __init__(self, config, layer_idx: Optional[int] = None):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.num_heads = config.num_attention_heads
|
||||||
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
self.is_causal = True
|
||||||
|
self.attention_dropout = config.attention_dropout
|
||||||
|
self.rope_scaling = config.rope_scaling
|
||||||
|
|
||||||
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
||||||
|
f" and `num_heads`: {self.num_heads})."
|
||||||
|
)
|
||||||
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
||||||
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||||
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
||||||
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
|
# Fix precision issues in Qwen2-VL float16 inference
|
||||||
|
# Replace inf values with zeros in attention weights to prevent NaN propagation
|
||||||
|
if query_states.dtype == torch.float16:
|
||||||
|
attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights)
|
||||||
|
|
||||||
|
# upcast attention to fp32
|
||||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||||
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
||||||
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
|
|
||||||
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||||
|
raise ValueError(
|
||||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||||
|
f" {attn_output.size()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, -1)
|
||||||
|
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2MLP(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
self.config = config
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
self.intermediate_size = config.intermediate_size
|
||||||
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
||||||
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
||||||
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
return down_proj
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2RMSNorm(nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
"""
|
||||||
|
Qwen2RMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||||
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
def extra_repr(self):
|
||||||
|
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen2_5_VLDecoderLayer(nn.Module):
|
||||||
|
def __init__(self, config, layer_idx):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
|
self.self_attn = Qwen2_5_VLAttention(config, layer_idx)
|
||||||
|
|
||||||
|
self.mlp = Qwen2MLP(config)
|
||||||
|
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class NexusGenImageEmbeddingMerger(nn.Module):
|
||||||
|
def __init__(self, num_layers=1, out_channel=4096, expand_ratio=4, device='cpu'):
|
||||||
|
super().__init__()
|
||||||
|
from transformers import Qwen2_5_VLConfig
|
||||||
|
from transformers.activations import ACT2FN
|
||||||
|
config = Qwen2_5_VLConfig(**{
|
||||||
|
"_name_or_path": "DiffSynth-Studio/Nexus-GenV2",
|
||||||
|
"architectures": [
|
||||||
|
"Qwen2_5_VLForConditionalGeneration"
|
||||||
|
],
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"auto_map": {
|
||||||
|
"AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig",
|
||||||
|
"AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel",
|
||||||
|
"AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration"
|
||||||
|
},
|
||||||
|
"bos_token_id": 151643,
|
||||||
|
"eos_token_id": 151645,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 3584,
|
||||||
|
"image_token_id": 151655,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 18944,
|
||||||
|
"max_position_embeddings": 128000,
|
||||||
|
"max_window_layers": 28,
|
||||||
|
"model_type": "qwen2_5_vl",
|
||||||
|
"num_attention_heads": 28,
|
||||||
|
"num_hidden_layers": 28,
|
||||||
|
"num_key_value_heads": 4,
|
||||||
|
"pad_token_id": 151643,
|
||||||
|
"rms_norm_eps": 1e-06,
|
||||||
|
"rope_scaling": {
|
||||||
|
"mrope_section": [
|
||||||
|
16,
|
||||||
|
24,
|
||||||
|
24
|
||||||
|
],
|
||||||
|
"rope_type": "default",
|
||||||
|
"type": "default"
|
||||||
|
},
|
||||||
|
"rope_theta": 1000000.0,
|
||||||
|
"sliding_window": 32768,
|
||||||
|
"tie_word_embeddings": False,
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"transformers_version": "4.49.0",
|
||||||
|
"use_cache": False,
|
||||||
|
"use_sliding_window": False,
|
||||||
|
"video_token_id": 151656,
|
||||||
|
"vision_config": {
|
||||||
|
"hidden_size": 1280,
|
||||||
|
"in_chans": 3,
|
||||||
|
"model_type": "qwen2_5_vl",
|
||||||
|
"spatial_patch_size": 14,
|
||||||
|
"tokens_per_second": 2,
|
||||||
|
"torch_dtype": "bfloat16"
|
||||||
|
},
|
||||||
|
"vision_end_token_id": 151653,
|
||||||
|
"vision_start_token_id": 151652,
|
||||||
|
"vision_token_id": 151654,
|
||||||
|
"vocab_size": 152064
|
||||||
|
})
|
||||||
|
self.config = config
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.layers = nn.ModuleList([Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(num_layers)])
|
||||||
|
self.projector = nn.Sequential(Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps),
|
||||||
|
nn.Linear(config.hidden_size, out_channel * expand_ratio),
|
||||||
|
Qwen2RMSNorm(out_channel * expand_ratio, eps=config.rms_norm_eps),
|
||||||
|
ACT2FN[config.hidden_act], nn.Linear(out_channel * expand_ratio, out_channel),
|
||||||
|
Qwen2RMSNorm(out_channel, eps=config.rms_norm_eps))
|
||||||
|
self.base_grid = torch.tensor([[1, 72, 72]], device=device)
|
||||||
|
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config, device=device)
|
||||||
|
|
||||||
|
def get_position_ids(self, image_grid_thw):
|
||||||
|
"""
|
||||||
|
Generates position ids for the input embeddings grid.
|
||||||
|
modified from the qwen2_vl mrope.
|
||||||
|
"""
|
||||||
|
batch_size = image_grid_thw.shape[0]
|
||||||
|
spatial_merge_size = self.config.vision_config.spatial_merge_size
|
||||||
|
t, h, w = (
|
||||||
|
image_grid_thw[0][0],
|
||||||
|
image_grid_thw[0][1],
|
||||||
|
image_grid_thw[0][2],
|
||||||
|
)
|
||||||
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
||||||
|
t.item(),
|
||||||
|
h.item() // spatial_merge_size,
|
||||||
|
w.item() // spatial_merge_size,
|
||||||
|
)
|
||||||
|
scale_h = self.base_grid[0][1].item() / h.item()
|
||||||
|
scale_w = self.base_grid[0][2].item() / w.item()
|
||||||
|
|
||||||
|
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
|
||||||
|
expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)
|
||||||
|
time_tensor = expanded_range * self.config.vision_config.tokens_per_second
|
||||||
|
t_index = time_tensor.long().flatten().to(image_grid_thw.device)
|
||||||
|
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten().to(image_grid_thw.device) * scale_h
|
||||||
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten().to(image_grid_thw.device) * scale_w
|
||||||
|
# 3, B, L
|
||||||
|
position_ids = torch.stack([t_index, h_index, w_index]).unsqueeze(0).repeat(batch_size, 1, 1).permute(1, 0, 2)
|
||||||
|
return position_ids
|
||||||
|
|
||||||
|
def forward(self, embeds, embeds_grid, ref_embeds=None, ref_embeds_grid=None):
|
||||||
|
position_ids = self.get_position_ids(embeds_grid)
|
||||||
|
hidden_states = embeds
|
||||||
|
if ref_embeds is not None:
|
||||||
|
position_ids_ref_embeds = self.get_position_ids(ref_embeds_grid)
|
||||||
|
position_ids = torch.cat((position_ids, position_ids_ref_embeds), dim=-1)
|
||||||
|
hidden_states = torch.cat((embeds, ref_embeds), dim=1)
|
||||||
|
|
||||||
|
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||||
|
for layer in self.layers:
|
||||||
|
hidden_states = layer(hidden_states, position_embeddings)
|
||||||
|
|
||||||
|
hidden_states = self.projector(hidden_states)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return NexusGenMergerStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class NexusGenMergerStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
merger_state_dict = {key.replace("embedding_merger.", ""): value for key, value in state_dict.items() if key.startswith('embedding_merger.')}
|
||||||
|
return merger_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
class NexusGenAdapter(nn.Module):
|
||||||
|
"""
|
||||||
|
Adapter for Nexus-Gen generation decoder.
|
||||||
|
"""
|
||||||
|
def __init__(self, input_dim=3584, output_dim=4096):
|
||||||
|
super(NexusGenAdapter, self).__init__()
|
||||||
|
self.adapter = nn.Sequential(nn.Linear(input_dim, output_dim),
|
||||||
|
nn.LayerNorm(output_dim), nn.ReLU(),
|
||||||
|
nn.Linear(output_dim, output_dim),
|
||||||
|
nn.LayerNorm(output_dim))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.adapter(x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return NexusGenAdapterStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class NexusGenAdapterStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
adapter_state_dict = {key: value for key, value in state_dict.items() if key.startswith('adapter.')}
|
||||||
|
return adapter_state_dict
|
||||||
412
diffsynth/models/sd_text_encoder.py
Normal file
412
diffsynth/models/sd_text_encoder.py
Normal file
@@ -0,0 +1,412 @@
|
|||||||
|
import torch
|
||||||
|
from .attention import Attention
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
def low_version_attention(query, key, value, attn_bias=None):
|
||||||
|
scale = 1 / query.shape[-1] ** 0.5
|
||||||
|
query = query * scale
|
||||||
|
attn = torch.matmul(query, key.transpose(-2, -1))
|
||||||
|
if attn_bias is not None:
|
||||||
|
attn = attn + attn_bias
|
||||||
|
attn = attn.softmax(-1)
|
||||||
|
return attn @ value
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||||
|
super().__init__()
|
||||||
|
dim_inner = head_dim * num_heads
|
||||||
|
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
|
||||||
|
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||||
|
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
|
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||||
|
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||||
|
|
||||||
|
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||||
|
batch_size = q.shape[0]
|
||||||
|
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
||||||
|
hidden_states = hidden_states + scale * ip_hidden_states
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
|
||||||
|
batch_size = encoder_hidden_states.shape[0]
|
||||||
|
|
||||||
|
q = self.to_q(hidden_states)
|
||||||
|
k = self.to_k(encoder_hidden_states)
|
||||||
|
v = self.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if qkv_preprocessor is not None:
|
||||||
|
q, k, v = qkv_preprocessor(q, k, v)
|
||||||
|
|
||||||
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
|
if ipadapter_kwargs is not None:
|
||||||
|
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
|
||||||
|
hidden_states = self.to_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||||
|
if encoder_hidden_states is None:
|
||||||
|
encoder_hidden_states = hidden_states
|
||||||
|
|
||||||
|
q = self.to_q(hidden_states)
|
||||||
|
k = self.to_k(encoder_hidden_states)
|
||||||
|
v = self.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
|
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||||
|
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||||
|
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
|
||||||
|
else:
|
||||||
|
import xformers.ops as xops
|
||||||
|
hidden_states = xops.memory_efficient_attention(q, k, v)
|
||||||
|
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
|
||||||
|
|
||||||
|
hidden_states = hidden_states.to(q.dtype)
|
||||||
|
hidden_states = self.to_out(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||||
|
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPEncoderLayer(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
|
||||||
|
super().__init__()
|
||||||
|
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
|
||||||
|
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
|
||||||
|
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
|
||||||
|
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
|
||||||
|
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
|
||||||
|
|
||||||
|
self.use_quick_gelu = use_quick_gelu
|
||||||
|
|
||||||
|
def quickGELU(self, x):
|
||||||
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
def forward(self, hidden_states, attn_mask=None):
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
|
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.layer_norm2(hidden_states)
|
||||||
|
hidden_states = self.fc1(hidden_states)
|
||||||
|
if self.use_quick_gelu:
|
||||||
|
hidden_states = self.quickGELU(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states = torch.nn.functional.gelu(hidden_states)
|
||||||
|
hidden_states = self.fc2(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class SDTextEncoder(torch.nn.Module):
|
||||||
|
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# token_embedding
|
||||||
|
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
||||||
|
|
||||||
|
# position_embeds (This is a fixed tensor)
|
||||||
|
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||||
|
|
||||||
|
# encoders
|
||||||
|
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
||||||
|
|
||||||
|
# attn_mask
|
||||||
|
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||||
|
|
||||||
|
# final_layer_norm
|
||||||
|
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||||
|
|
||||||
|
def attention_mask(self, length):
|
||||||
|
mask = torch.empty(length, length)
|
||||||
|
mask.fill_(float("-inf"))
|
||||||
|
mask.triu_(1)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def forward(self, input_ids, clip_skip=1):
|
||||||
|
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||||
|
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||||
|
for encoder_id, encoder in enumerate(self.encoders):
|
||||||
|
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||||
|
if encoder_id + clip_skip == len(self.encoders):
|
||||||
|
break
|
||||||
|
embeds = self.final_layer_norm(embeds)
|
||||||
|
return embeds
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def state_dict_converter():
|
||||||
|
return SDTextEncoderStateDictConverter()
|
||||||
|
|
||||||
|
|
||||||
|
class SDTextEncoderStateDictConverter:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def from_diffusers(self, state_dict):
|
||||||
|
rename_dict = {
|
||||||
|
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||||
|
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
||||||
|
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||||
|
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
||||||
|
}
|
||||||
|
attn_rename_dict = {
|
||||||
|
"self_attn.q_proj": "attn.to_q",
|
||||||
|
"self_attn.k_proj": "attn.to_k",
|
||||||
|
"self_attn.v_proj": "attn.to_v",
|
||||||
|
"self_attn.out_proj": "attn.to_out",
|
||||||
|
"layer_norm1": "layer_norm1",
|
||||||
|
"layer_norm2": "layer_norm2",
|
||||||
|
"mlp.fc1": "fc1",
|
||||||
|
"mlp.fc2": "fc2",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name in rename_dict:
|
||||||
|
param = state_dict[name]
|
||||||
|
if name == "text_model.embeddings.position_embedding.weight":
|
||||||
|
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||||
|
state_dict_[rename_dict[name]] = param
|
||||||
|
elif name.startswith("text_model.encoder.layers."):
|
||||||
|
param = state_dict[name]
|
||||||
|
names = name.split(".")
|
||||||
|
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
||||||
|
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
||||||
|
state_dict_[name_] = param
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
def from_civitai(self, state_dict):
|
||||||
|
rename_dict = {
|
||||||
|
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
|
||||||
|
"cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||||
|
"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name in rename_dict:
|
||||||
|
param = state_dict[name]
|
||||||
|
if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
|
||||||
|
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||||
|
state_dict_[rename_dict[name]] = param
|
||||||
|
return state_dict_
|
||||||
@@ -16,7 +16,7 @@ from ..models.flux_text_encoder_clip import FluxTextEncoderClip
|
|||||||
from ..models.flux_text_encoder_t5 import FluxTextEncoderT5
|
from ..models.flux_text_encoder_t5 import FluxTextEncoderT5
|
||||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||||
from ..models.flux_value_control import MultiValueEncoder
|
from ..models.flux_value_control import MultiValueEncoder
|
||||||
|
from ..core.vram.layers import AutoWrappedLinear
|
||||||
|
|
||||||
class MultiControlNet(torch.nn.Module):
|
class MultiControlNet(torch.nn.Module):
|
||||||
def __init__(self, models: list[torch.nn.Module]):
|
def __init__(self, models: list[torch.nn.Module]):
|
||||||
@@ -102,8 +102,15 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
]
|
]
|
||||||
self.model_fn = model_fn_flux_image
|
self.model_fn = model_fn_flux_image
|
||||||
self.lora_loader = FluxLoRALoader
|
self.lora_loader = FluxLoRALoader
|
||||||
|
|
||||||
|
def enable_lora_magic(self):
|
||||||
|
if self.lora_patcher is not None:
|
||||||
|
for name, module in self.dit.named_modules():
|
||||||
|
if isinstance(module, AutoWrappedLinear):
|
||||||
|
merger_name = name.replace(".", "___")
|
||||||
|
if merger_name in self.lora_patcher.model_dict:
|
||||||
|
module.lora_merger = self.lora_patcher.model_dict[merger_name]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
torch_dtype: torch.dtype = torch.bfloat16,
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
|||||||
104
diffsynth/utils/state_dict_converters/flux_controlnet.py
Normal file
104
diffsynth/utils/state_dict_converters/flux_controlnet.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
import torch
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
|
||||||
|
def FluxControlNetStateDictConverter(state_dict):
|
||||||
|
global_rename_dict = {
|
||||||
|
"context_embedder": "context_embedder",
|
||||||
|
"x_embedder": "x_embedder",
|
||||||
|
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
||||||
|
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
||||||
|
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
|
||||||
|
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
||||||
|
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
||||||
|
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
||||||
|
"norm_out.linear": "final_norm_out.linear",
|
||||||
|
"proj_out": "final_proj_out",
|
||||||
|
}
|
||||||
|
rename_dict = {
|
||||||
|
"proj_out": "proj_out",
|
||||||
|
"norm1.linear": "norm1_a.linear",
|
||||||
|
"norm1_context.linear": "norm1_b.linear",
|
||||||
|
"attn.to_q": "attn.a_to_q",
|
||||||
|
"attn.to_k": "attn.a_to_k",
|
||||||
|
"attn.to_v": "attn.a_to_v",
|
||||||
|
"attn.to_out.0": "attn.a_to_out",
|
||||||
|
"attn.add_q_proj": "attn.b_to_q",
|
||||||
|
"attn.add_k_proj": "attn.b_to_k",
|
||||||
|
"attn.add_v_proj": "attn.b_to_v",
|
||||||
|
"attn.to_add_out": "attn.b_to_out",
|
||||||
|
"ff.net.0.proj": "ff_a.0",
|
||||||
|
"ff.net.2": "ff_a.2",
|
||||||
|
"ff_context.net.0.proj": "ff_b.0",
|
||||||
|
"ff_context.net.2": "ff_b.2",
|
||||||
|
"attn.norm_q": "attn.norm_q_a",
|
||||||
|
"attn.norm_k": "attn.norm_k_a",
|
||||||
|
"attn.norm_added_q": "attn.norm_q_b",
|
||||||
|
"attn.norm_added_k": "attn.norm_k_b",
|
||||||
|
}
|
||||||
|
rename_dict_single = {
|
||||||
|
"attn.to_q": "a_to_q",
|
||||||
|
"attn.to_k": "a_to_k",
|
||||||
|
"attn.to_v": "a_to_v",
|
||||||
|
"attn.norm_q": "norm_q_a",
|
||||||
|
"attn.norm_k": "norm_k_a",
|
||||||
|
"norm.linear": "norm.linear",
|
||||||
|
"proj_mlp": "proj_in_besides_attn",
|
||||||
|
"proj_out": "proj_out",
|
||||||
|
}
|
||||||
|
state_dict_ = {}
|
||||||
|
|
||||||
|
for name in state_dict:
|
||||||
|
param = state_dict[name]
|
||||||
|
if name.endswith(".weight") or name.endswith(".bias"):
|
||||||
|
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||||
|
prefix = name[:-len(suffix)]
|
||||||
|
if prefix in global_rename_dict:
|
||||||
|
state_dict_[global_rename_dict[prefix] + suffix] = param
|
||||||
|
elif prefix.startswith("transformer_blocks."):
|
||||||
|
names = prefix.split(".")
|
||||||
|
names[0] = "blocks"
|
||||||
|
middle = ".".join(names[2:])
|
||||||
|
if middle in rename_dict:
|
||||||
|
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
||||||
|
state_dict_[name_] = param
|
||||||
|
elif prefix.startswith("single_transformer_blocks."):
|
||||||
|
names = prefix.split(".")
|
||||||
|
names[0] = "single_blocks"
|
||||||
|
middle = ".".join(names[2:])
|
||||||
|
if middle in rename_dict_single:
|
||||||
|
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
||||||
|
state_dict_[name_] = param
|
||||||
|
else:
|
||||||
|
state_dict_[name] = param
|
||||||
|
else:
|
||||||
|
state_dict_[name] = param
|
||||||
|
for name in list(state_dict_.keys()):
|
||||||
|
if ".proj_in_besides_attn." in name:
|
||||||
|
name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
|
||||||
|
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
|
||||||
|
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")],
|
||||||
|
state_dict_[name],
|
||||||
|
], dim=0)
|
||||||
|
state_dict_[name_] = param
|
||||||
|
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q."))
|
||||||
|
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k."))
|
||||||
|
state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v."))
|
||||||
|
state_dict_.pop(name)
|
||||||
|
for name in list(state_dict_.keys()):
|
||||||
|
for component in ["a", "b"]:
|
||||||
|
if f".{component}_to_q." in name:
|
||||||
|
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||||
|
param = torch.concat([
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||||
|
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||||
|
], dim=0)
|
||||||
|
state_dict_[name_] = param
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||||
|
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||||
|
|
||||||
|
return state_dict_
|
||||||
@@ -1,4 +1,41 @@
|
|||||||
|
import torch
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
||||||
|
keys = []
|
||||||
|
all_keys = sorted(list(state_dict))
|
||||||
|
|
||||||
|
for key in all_keys:
|
||||||
|
value = state_dict[key]
|
||||||
|
if isinstance(key, str):
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
if with_shape:
|
||||||
|
shape = "_".join(map(str, list(value.shape)))
|
||||||
|
keys.append(key + ":" + shape)
|
||||||
|
keys.append(key)
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
||||||
|
keys.sort()
|
||||||
|
keys_str = ",".join(keys)
|
||||||
|
return keys_str
|
||||||
|
|
||||||
|
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||||
|
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||||
|
keys_str = keys_str.encode(encoding="UTF-8")
|
||||||
|
return hashlib.md5(keys_str).hexdigest()
|
||||||
|
|
||||||
def FluxDiTStateDictConverter(state_dict):
|
def FluxDiTStateDictConverter(state_dict):
|
||||||
|
model_hash = hash_state_dict_keys(state_dict, with_shape=True)
|
||||||
|
|
||||||
|
if model_hash in ["3e6c61b0f9471135fc9c6d6a98e98b6d", "63c969fd37cce769a90aa781fbff5f81"]:
|
||||||
|
dit_state_dict = {}
|
||||||
|
for key in state_dict:
|
||||||
|
if key.startswith('pipe.dit.'):
|
||||||
|
value = state_dict[key]
|
||||||
|
new_key = key.replace("pipe.dit.", "")
|
||||||
|
dit_state_dict[new_key] = value
|
||||||
|
return dit_state_dict
|
||||||
|
|
||||||
rename_dict = {
|
rename_dict = {
|
||||||
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
|
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
|
||||||
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
|
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
|
||||||
|
|||||||
@@ -0,0 +1,4 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
def FluxInfiniteYouImageProjectorStateDictConverter(state_dict):
|
||||||
|
return state_dict['image_proj']
|
||||||
34
diffsynth/utils/state_dict_converters/flux_ipadapter.py
Normal file
34
diffsynth/utils/state_dict_converters/flux_ipadapter.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
def FluxIpAdapterStateDictConverter(state_dict):
|
||||||
|
state_dict_ = {}
|
||||||
|
|
||||||
|
if "ip_adapter" in state_dict and isinstance(state_dict["ip_adapter"], dict):
|
||||||
|
for name, param in state_dict["ip_adapter"].items():
|
||||||
|
name_ = 'ipadapter_modules.' + name
|
||||||
|
state_dict_[name_] = param
|
||||||
|
|
||||||
|
if "image_proj" in state_dict:
|
||||||
|
for name, param in state_dict["image_proj"].items():
|
||||||
|
name_ = "image_proj." + name
|
||||||
|
state_dict_[name_] = param
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if key.startswith("image_proj."):
|
||||||
|
state_dict_[key] = value
|
||||||
|
elif key.startswith("ip_adapter."):
|
||||||
|
new_key = key.replace("ip_adapter.", "ipadapter_modules.")
|
||||||
|
state_dict_[new_key] = value
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
def SiglipStateDictConverter(state_dict):
|
||||||
|
new_state_dict = {}
|
||||||
|
for key in state_dict:
|
||||||
|
if key.startswith("vision_model."):
|
||||||
|
new_state_dict[key] = state_dict[key]
|
||||||
|
return new_state_dict
|
||||||
8
diffsynth/utils/state_dict_converters/nexus_gen.py
Normal file
8
diffsynth/utils/state_dict_converters/nexus_gen.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
def NexusGenAutoregressiveModelStateDictConverter(state_dict):
|
||||||
|
new_state_dict = {}
|
||||||
|
for key in state_dict:
|
||||||
|
value = state_dict[key]
|
||||||
|
new_state_dict["model." + key] = value
|
||||||
|
return new_state_dict
|
||||||
17
diffsynth/utils/state_dict_converters/nexus_gen_projector.py
Normal file
17
diffsynth/utils/state_dict_converters/nexus_gen_projector.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
def NexusGenMergerStateDictConverter(state_dict):
|
||||||
|
merger_state_dict = {}
|
||||||
|
for key in state_dict:
|
||||||
|
if key.startswith('embedding_merger.'):
|
||||||
|
value = state_dict[key]
|
||||||
|
new_key = key.replace("embedding_merger.", "")
|
||||||
|
merger_state_dict[new_key] = value
|
||||||
|
return merger_state_dict
|
||||||
|
|
||||||
|
def NexusGenAdapterStateDictConverter(state_dict):
|
||||||
|
adapter_state_dict = {}
|
||||||
|
for key in state_dict:
|
||||||
|
if key.startswith('adapter.'):
|
||||||
|
adapter_state_dict[key] = state_dict[key]
|
||||||
|
return adapter_state_dict
|
||||||
@@ -13,10 +13,8 @@ pipe = FluxImagePipeline.from_pretrained(
|
|||||||
ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors"),
|
ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
pipe.enable_lora_magic()
|
|
||||||
|
|
||||||
lora = ModelConfig(model_id="VoidOc/flux_animal_forest1", origin_file_pattern="20.safetensors")
|
lora = ModelConfig(model_id="VoidOc/flux_animal_forest1", origin_file_pattern="20.safetensors")
|
||||||
pipe.load_lora(pipe.dit, lora, hotload=True) # Use `pipe.clear_lora()` to drop the loaded LoRA.
|
pipe.load_lora(pipe.dit, lora) # Use `pipe.clear_lora()` to drop the loaded LoRA.
|
||||||
|
|
||||||
# Empty prompt can automatically activate LoRA capabilities.
|
# Empty prompt can automatically activate LoRA capabilities.
|
||||||
image = pipe(prompt="", seed=0, lora_encoder_inputs=lora)
|
image = pipe(prompt="", seed=0, lora_encoder_inputs=lora)
|
||||||
|
|||||||
@@ -18,12 +18,10 @@ pipe.enable_lora_magic()
|
|||||||
pipe.load_lora(
|
pipe.load_lora(
|
||||||
pipe.dit,
|
pipe.dit,
|
||||||
ModelConfig(model_id="cancel13/cxsk", origin_file_pattern="30.safetensors"),
|
ModelConfig(model_id="cancel13/cxsk", origin_file_pattern="30.safetensors"),
|
||||||
hotload=True,
|
|
||||||
)
|
)
|
||||||
pipe.load_lora(
|
pipe.load_lora(
|
||||||
pipe.dit,
|
pipe.dit,
|
||||||
ModelConfig(model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", origin_file_pattern="merged_lora.safetensors"),
|
ModelConfig(model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", origin_file_pattern="merged_lora.safetensors"),
|
||||||
hotload=True,
|
|
||||||
)
|
)
|
||||||
image = pipe(prompt="a cat", seed=0)
|
image = pipe(prompt="a cat", seed=0)
|
||||||
image.save("image_fused.jpg")
|
image.save("image_fused.jpg")
|
||||||
|
|||||||
Reference in New Issue
Block a user