From 7b756a518ed61f37ac5e22049b14ae8f69dbdce0 Mon Sep 17 00:00:00 2001 From: yjy415 <2471352175@qq.com> Date: Tue, 18 Nov 2025 20:19:37 +0800 Subject: [PATCH 1/4] flux --- diffsynth/configs/model_configs.py | 53 +- diffsynth/models/flux_controlnet.py | 61 ++- diffsynth/models/flux_lora_encoder.py | 483 +++++++++++++++++- diffsynth/models/flux_lora_patcher.py | 60 +++ diffsynth/models/sd_text_encoder.py | 412 +++++++++++++++ diffsynth/pipelines/flux_image.py | 118 ++++- .../state_dict_converters/flux_controlnet.py | 104 ++++ .../state_dict_converters/flux_infiniteyou.py | 4 + .../FLUX.1-dev-LoRA-Encoder.py | 4 +- 9 files changed, 1286 insertions(+), 13 deletions(-) create mode 100644 diffsynth/models/flux_lora_patcher.py create mode 100644 diffsynth/models/sd_text_encoder.py create mode 100644 diffsynth/utils/state_dict_converters/flux_controlnet.py create mode 100644 diffsynth/utils/state_dict_converters/flux_infiniteyou.py diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 1fade41..3aab5c7 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -312,7 +312,58 @@ flux_series = [ "model_hash": "0629116fce1472503a66992f96f3eb1a", "model_name": "flux_value_controller", "model_class": "diffsynth.models.flux_value_control.SingleValueEncoder", - } + }, + { + # Example: ModelConfig(model_id="alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", origin_file_pattern="diffusion_pytorch_model.safetensors") + "model_hash": "52357cb26250681367488a8954c271e8", + "model_name": "flux_controlnet", + "model_class": "diffsynth.models.flux_controlnet.FluxControlNet", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter", + "extra_kwargs": {"num_joint_blocks": 6, "num_single_blocks": 0, "additional_input_dim": 4}, + }, + { + # Example: ModelConfig(model_id="InstantX/FLUX.1-dev-Controlnet-Union-alpha", origin_file_pattern="diffusion_pytorch_model.safetensors") + "model_hash": "78d18b9101345ff695f312e7e62538c0", + "model_name": "flux_controlnet", + "model_class": "diffsynth.models.flux_controlnet.FluxControlNet", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter", + "extra_kwargs": {"num_mode": 10, "mode_dict": {"canny": 0, "tile": 1, "depth": 2, "blur": 3, "pose": 4, "gray": 5, "lq": 6}}, + }, + { + # Example: ModelConfig(model_id="jasperai/Flux.1-dev-Controlnet-Upscaler", origin_file_pattern="diffusion_pytorch_model.safetensors") + "model_hash": "b001c89139b5f053c715fe772362dd2a", + "model_name": "flux_controlnet", + "model_class": "diffsynth.models.flux_controlnet.FluxControlNet", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter", + "extra_kwargs": {"num_single_blocks": 0}, + }, + { + # Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/image_proj_model.bin") + "model_hash": "c07c0f04f5ff55e86b4e937c7a40d481", + "model_name": "infiniteyou_image_projector", + "model_class": "diffsynth.models.flux_infiniteyou.InfiniteYouImageProjector", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_infiniteyou.FluxInfiniteYouImageProjectorStateDictConverter", + }, + { + # Example: ModelConfig(model_id="ByteDance/InfiniteYou", origin_file_pattern="infu_flux_v1.0/aes_stage2/InfuseNetModel/*.safetensors") + "model_hash": "7f9583eb8ba86642abb9a21a4b2c9e16", + "model_name": "flux_controlnet", + "model_class": "diffsynth.models.flux_controlnet.FluxControlNet", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_controlnet.FluxControlNetStateDictConverter", + "extra_kwargs": {"num_joint_blocks": 4, "num_single_blocks": 10}, + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors") + "model_hash": "77c2e4dd2440269eb33bfaa0d004f6ab", + "model_name": "flux_lora_encoder", + "model_class": "diffsynth.models.flux_lora_encoder.FluxLoRAEncoder", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev", origin_file_pattern="model.safetensors") + "model_hash": "30143afb2dea73d1ac580e0787628f8c", + "model_name": "flux_lora_patcher", + "model_class": "diffsynth.models.flux_lora_patcher.FluxLoraPatcher", + }, ] MODEL_CONFIGS = qwen_image_series + wan_series + flux_series diff --git a/diffsynth/models/flux_controlnet.py b/diffsynth/models/flux_controlnet.py index 85fccd7..7fb1138 100644 --- a/diffsynth/models/flux_controlnet.py +++ b/diffsynth/models/flux_controlnet.py @@ -1,9 +1,62 @@ import torch from einops import rearrange, repeat from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm -from .utils import hash_state_dict_keys, init_weights_on_device +# from .utils import hash_state_dict_keys, init_weights_on_device +from contextlib import contextmanager +def hash_state_dict_keys(state_dict, with_shape=True): + keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape) + keys_str = keys_str.encode(encoding="UTF-8") + return hashlib.md5(keys_str).hexdigest() +@contextmanager +def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False): + + old_register_parameter = torch.nn.Module.register_parameter + if include_buffers: + old_register_buffer = torch.nn.Module.register_buffer + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) + + def register_empty_buffer(module, name, buffer, persistent=True): + old_register_buffer(module, name, buffer, persistent=persistent) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(device) + + def patch_tensor_constructor(fn): + def wrapper(*args, **kwargs): + kwargs["device"] = device + return fn(*args, **kwargs) + + return wrapper + + if include_buffers: + tensor_constructors_to_patch = { + torch_function_name: getattr(torch, torch_function_name) + for torch_function_name in ["empty", "zeros", "ones", "full"] + } + else: + tensor_constructors_to_patch = {} + + try: + torch.nn.Module.register_parameter = register_empty_parameter + if include_buffers: + torch.nn.Module.register_buffer = register_empty_buffer + for torch_function_name in tensor_constructors_to_patch.keys(): + setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) + yield + finally: + torch.nn.Module.register_parameter = old_register_parameter + if include_buffers: + torch.nn.Module.register_buffer = old_register_buffer + for torch_function_name, old_torch_function in tensor_constructors_to_patch.items(): + setattr(torch, torch_function_name, old_torch_function) class FluxControlNet(torch.nn.Module): def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0): @@ -102,9 +155,9 @@ class FluxControlNet(torch.nn.Module): return controlnet_res_stack, controlnet_single_res_stack - @staticmethod - def state_dict_converter(): - return FluxControlNetStateDictConverter() + # @staticmethod + # def state_dict_converter(): + # return FluxControlNetStateDictConverter() def quantize(self): def cast_to(weight, dtype=None, device=None, copy=False): diff --git a/diffsynth/models/flux_lora_encoder.py b/diffsynth/models/flux_lora_encoder.py index 695640a..2be5dbb 100644 --- a/diffsynth/models/flux_lora_encoder.py +++ b/diffsynth/models/flux_lora_encoder.py @@ -1,5 +1,415 @@ import torch -from .sd_text_encoder import CLIPEncoderLayer +from einops import rearrange + + +def low_version_attention(query, key, value, attn_bias=None): + scale = 1 / query.shape[-1] ** 0.5 + query = query * scale + attn = torch.matmul(query, key.transpose(-2, -1)) + if attn_bias is not None: + attn = attn + attn_bias + attn = attn.softmax(-1) + return attn @ value + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) + + def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0): + batch_size = q.shape[0] + ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v) + hidden_states = hidden_states + scale * ip_hidden_states + return hidden_states + + def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if qkv_preprocessor is not None: + q, k, v = qkv_preprocessor(q, k, v) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + if ipadapter_kwargs is not None: + hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads) + k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads) + v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads) + + if attn_mask is not None: + hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask) + else: + import xformers.ops as xops + hidden_states = xops.memory_efficient_attention(q, k, v) + hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads) + + hidden_states = hidden_states.to(q.dtype) + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor) + + + + + +class CLIPEncoderLayer(torch.nn.Module): + def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True): + super().__init__() + self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True) + self.layer_norm1 = torch.nn.LayerNorm(embed_dim) + self.layer_norm2 = torch.nn.LayerNorm(embed_dim) + self.fc1 = torch.nn.Linear(embed_dim, intermediate_size) + self.fc2 = torch.nn.Linear(intermediate_size, embed_dim) + + self.use_quick_gelu = use_quick_gelu + + def quickGELU(self, x): + return x * torch.sigmoid(1.702 * x) + + def forward(self, hidden_states, attn_mask=None): + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.attn(hidden_states, attn_mask=attn_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.fc1(hidden_states) + if self.use_quick_gelu: + hidden_states = self.quickGELU(hidden_states) + else: + hidden_states = torch.nn.functional.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class SDTextEncoder(torch.nn.Module): + def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072): + super().__init__() + + # token_embedding + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim) + + # position_embeds (This is a fixed tensor) + self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim)) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)]) + + # attn_mask + self.attn_mask = self.attention_mask(max_position_embeddings) + + # final_layer_norm + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def attention_mask(self, length): + mask = torch.empty(length, length) + mask.fill_(float("-inf")) + mask.triu_(1) + return mask + + def forward(self, input_ids, clip_skip=1): + embeds = self.token_embedding(input_ids) + self.position_embeds + attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds, attn_mask=attn_mask) + if encoder_id + clip_skip == len(self.encoders): + break + embeds = self.final_layer_norm(embeds) + return embeds + + @staticmethod + def state_dict_converter(): + return SDTextEncoderStateDictConverter() + + +class SDTextEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.embeddings.position_embedding.weight": "position_embeds", + "text_model.final_layer_norm.weight": "final_layer_norm.weight", + "text_model.final_layer_norm.bias": "final_layer_norm.bias" + } + attn_rename_dict = { + "self_attn.q_proj": "attn.to_q", + "self_attn.k_proj": "attn.to_k", + "self_attn.v_proj": "attn.to_v", + "self_attn.out_proj": "attn.to_out", + "layer_norm1": "layer_norm1", + "layer_norm2": "layer_norm2", + "mlp.fc1": "fc1", + "mlp.fc2": "fc2", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + elif name.startswith("text_model.encoder.layers."): + param = state_dict[name] + names = name.split(".") + layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] + name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) + state_dict_[name_] = param + return state_dict_ + + def from_civitai(self, state_dict): + rename_dict = { + "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight", + "cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias", + "cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight", + "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds" + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + return state_dict_ + class LoRALayerBlock(torch.nn.Module): @@ -58,13 +468,80 @@ class LoRAEmbedder(torch.nn.Module): "type": suffix, }) return lora_patterns + + def get_lora_param_pair(self, lora, name, dim, device, dtype): + key_A = name + ".lora_A.default.weight" + key_B = name + ".lora_B.default.weight" + if key_A in lora and key_B in lora: + return lora[key_A], lora[key_B] + if "to_qkv" in name: + base_name = name.replace("to_qkv", "") + suffixes = ["to_q", "to_k", "to_v"] + + found_As = [] + found_Bs = [] + + all_found = True + for suffix in suffixes: + sub_name = base_name + suffix + k_A = sub_name + ".lora_A.default.weight" + k_B = sub_name + ".lora_B.default.weight" + + if k_A in lora and k_B in lora: + found_As.append(lora[k_A]) + found_Bs.append(lora[k_B]) + else: + all_found = False + break + if all_found: + pass + + rank = 16 + for k, v in lora.items(): + if "lora_A" in k: + rank = v.shape[0] + device = v.device + dtype = v.dtype + break + + lora_A = torch.zeros((rank, dim[0]), device=device, dtype=dtype) + lora_B = torch.zeros((dim[1], rank), device=device, dtype=dtype) + + return lora_A, lora_B + def forward(self, lora): lora_emb = [] + device = None + dtype = None + for v in lora.values(): + device = v.device + dtype = v.dtype + break + for lora_pattern in self.lora_patterns: name, layer_type = lora_pattern["name"], lora_pattern["type"] - lora_A = lora[name + ".lora_A.default.weight"] - lora_B = lora[name + ".lora_B.default.weight"] + dim = lora_pattern["dim"] + + lora_A, lora_B = self.get_lora_param_pair(lora, name, dim, device, dtype) + + if "to_qkv" in name and (lora_A is None or (torch.equal(lora_A, torch.zeros_like(lora_A)))): + base_name = name.replace("to_qkv", "") + try: + q_name = base_name + "to_q" + k_name = base_name + "to_k" + v_name = base_name + "to_v" + + real_A = lora[q_name + ".lora_A.default.weight"] + B_q = lora[q_name + ".lora_B.default.weight"] + B_k = lora[k_name + ".lora_B.default.weight"] + B_v = lora[v_name + ".lora_B.default.weight"] + real_B = torch.cat([B_q, B_k, B_v], dim=0) + + lora_A, lora_B = real_A, real_B + except KeyError: + pass + lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B) lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out) lora_emb.append(lora_out) diff --git a/diffsynth/models/flux_lora_patcher.py b/diffsynth/models/flux_lora_patcher.py new file mode 100644 index 0000000..a249feb --- /dev/null +++ b/diffsynth/models/flux_lora_patcher.py @@ -0,0 +1,60 @@ +import torch + +class LoraMerger(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.weight_base = torch.nn.Parameter(torch.randn((dim,))) + self.weight_lora = torch.nn.Parameter(torch.randn((dim,))) + self.weight_cross = torch.nn.Parameter(torch.randn((dim,))) + self.weight_out = torch.nn.Parameter(torch.ones((dim,))) + self.bias = torch.nn.Parameter(torch.randn((dim,))) + self.activation = torch.nn.Sigmoid() + self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5) + self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5) + + def forward(self, base_output, lora_outputs): + norm_base_output = self.norm_base(base_output) + norm_lora_outputs = self.norm_lora(lora_outputs) + gate = self.activation( + norm_base_output * self.weight_base \ + + norm_lora_outputs * self.weight_lora \ + + norm_base_output * norm_lora_outputs * self.weight_cross + self.bias + ) + output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0) + return output + +class FluxLoraPatcher(torch.nn.Module): + def __init__(self, lora_patterns=None): + super().__init__() + if lora_patterns is None: + lora_patterns = self.default_lora_patterns() + model_dict = {} + for lora_pattern in lora_patterns: + name, dim = lora_pattern["name"], lora_pattern["dim"] + model_dict[name.replace(".", "___")] = LoraMerger(dim) + self.model_dict = torch.nn.ModuleDict(model_dict) + + def default_lora_patterns(self): + lora_patterns = [] + lora_dict = { + "attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432, + "attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432, + } + for i in range(19): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216} + for i in range(38): + for suffix in lora_dict: + lora_patterns.append({ + "name": f"single_blocks.{i}.{suffix}", + "dim": lora_dict[suffix] + }) + return lora_patterns + + def forward(self, base_output, lora_outputs, name): + return self.model_dict[name.replace(".", "___")](base_output, lora_outputs) + diff --git a/diffsynth/models/sd_text_encoder.py b/diffsynth/models/sd_text_encoder.py new file mode 100644 index 0000000..b0a1171 --- /dev/null +++ b/diffsynth/models/sd_text_encoder.py @@ -0,0 +1,412 @@ +import torch +from .attention import Attention +from einops import rearrange + + +def low_version_attention(query, key, value, attn_bias=None): + scale = 1 / query.shape[-1] ** 0.5 + query = query * scale + attn = torch.matmul(query, key.transpose(-2, -1)) + if attn_bias is not None: + attn = attn + attn_bias + attn = attn.softmax(-1) + return attn @ value + + +class Attention(torch.nn.Module): + + def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): + super().__init__() + dim_inner = head_dim * num_heads + kv_dim = kv_dim if kv_dim is not None else q_dim + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) + self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) + self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) + + def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0): + batch_size = q.shape[0] + ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v) + hidden_states = hidden_states + scale * ip_hidden_states + return hidden_states + + def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + batch_size = encoder_hidden_states.shape[0] + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if qkv_preprocessor is not None: + q, k, v = qkv_preprocessor(q, k, v) + + hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) + if ipadapter_kwargs is not None: + hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) + hidden_states = hidden_states.to(q.dtype) + + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + + q = self.to_q(hidden_states) + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads) + k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads) + v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads) + + if attn_mask is not None: + hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask) + else: + import xformers.ops as xops + hidden_states = xops.memory_efficient_attention(q, k, v) + hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads) + + hidden_states = hidden_states.to(q.dtype) + hidden_states = self.to_out(hidden_states) + + return hidden_states + + def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None): + return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor) + + + + + +class CLIPEncoderLayer(torch.nn.Module): + def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True): + super().__init__() + self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True) + self.layer_norm1 = torch.nn.LayerNorm(embed_dim) + self.layer_norm2 = torch.nn.LayerNorm(embed_dim) + self.fc1 = torch.nn.Linear(embed_dim, intermediate_size) + self.fc2 = torch.nn.Linear(intermediate_size, embed_dim) + + self.use_quick_gelu = use_quick_gelu + + def quickGELU(self, x): + return x * torch.sigmoid(1.702 * x) + + def forward(self, hidden_states, attn_mask=None): + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.attn(hidden_states, attn_mask=attn_mask) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.fc1(hidden_states) + if self.use_quick_gelu: + hidden_states = self.quickGELU(hidden_states) + else: + hidden_states = torch.nn.functional.gelu(hidden_states) + hidden_states = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class SDTextEncoder(torch.nn.Module): + def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072): + super().__init__() + + # token_embedding + self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim) + + # position_embeds (This is a fixed tensor) + self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim)) + + # encoders + self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)]) + + # attn_mask + self.attn_mask = self.attention_mask(max_position_embeddings) + + # final_layer_norm + self.final_layer_norm = torch.nn.LayerNorm(embed_dim) + + def attention_mask(self, length): + mask = torch.empty(length, length) + mask.fill_(float("-inf")) + mask.triu_(1) + return mask + + def forward(self, input_ids, clip_skip=1): + embeds = self.token_embedding(input_ids) + self.position_embeds + attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) + for encoder_id, encoder in enumerate(self.encoders): + embeds = encoder(embeds, attn_mask=attn_mask) + if encoder_id + clip_skip == len(self.encoders): + break + embeds = self.final_layer_norm(embeds) + return embeds + + @staticmethod + def state_dict_converter(): + return SDTextEncoderStateDictConverter() + + +class SDTextEncoderStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + rename_dict = { + "text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "text_model.embeddings.position_embedding.weight": "position_embeds", + "text_model.final_layer_norm.weight": "final_layer_norm.weight", + "text_model.final_layer_norm.bias": "final_layer_norm.bias" + } + attn_rename_dict = { + "self_attn.q_proj": "attn.to_q", + "self_attn.k_proj": "attn.to_k", + "self_attn.v_proj": "attn.to_v", + "self_attn.out_proj": "attn.to_out", + "layer_norm1": "layer_norm1", + "layer_norm2": "layer_norm2", + "mlp.fc1": "fc1", + "mlp.fc2": "fc2", + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + elif name.startswith("text_model.encoder.layers."): + param = state_dict[name] + names = name.split(".") + layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1] + name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail]) + state_dict_[name_] = param + return state_dict_ + + def from_civitai(self, state_dict): + rename_dict = { + "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias", + "cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight", + "cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias", + "cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight", + "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds" + } + state_dict_ = {} + for name in state_dict: + if name in rename_dict: + param = state_dict[name] + if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": + param = param.reshape((1, param.shape[0], param.shape[1])) + state_dict_[rename_dict[name]] = param + return state_dict_ diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 3c0151d..78f0cb1 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -102,8 +102,122 @@ class FluxImagePipeline(BasePipeline): ] self.model_fn = model_fn_flux_image self.lora_loader = FluxLoRALoader - - + + def enable_lora_magic(self): + pass + + # def load_lora(self, model, lora_config, alpha=1, hotload=False): + # if isinstance(lora_config, str): + # path = lora_config + # else: + # lora_config.download_if_necessary() + # path = lora_config.path + + # state_dict = load_state_dict(path, torch_dtype=self.torch_dtype, device="cpu") + # loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device) + # state_dict = loader.convert_state_dict(state_dict) + # loaded_count = 0 + # for key in tqdm(state_dict, desc="Applying LoRA"): + # if ".lora_A." in key: + # layer_name = key.split(".lora_A.")[0] + # module = model + # try: + # parts = layer_name.split(".") + # for part in parts: + # if part.isdigit(): + # module = module[int(part)] + # else: + # module = getattr(module, part) + # except AttributeError: + # continue + + # w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype) + # w_b_key = key.replace("lora_A", "lora_B") + # if w_b_key not in state_dict: continue + # w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype) + # delta_w = torch.mm(w_b, w_a) + # module.weight.data += delta_w * alpha + # loaded_count += 1 + + + def load_lora(self, model, lora_config, alpha=1.0, hotload=False): + if isinstance(lora_config, str): + path = lora_config + else: + lora_config.download_if_necessary() + path = lora_config.path + + state_dict = load_state_dict(path, torch_dtype=self.torch_dtype, device="cpu") + loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device) + state_dict = loader.convert_state_dict(state_dict) + + print(f"Merging LoRA weights from {path}...") + loaded_count = 0 + + # [新增] 键名映射表,处理 FW2 Loader 与 DiT 模型名称不一致的情况 + # 针对 Single Blocks 常见的命名差异进行修正 + key_mapping = { + ".linear1.": ".to_qkv_mlp.", # 常见差异点 1 + ".linear2.": ".proj_out.", # 常见差异点 2 + ".modulation.lin.": ".norm.linear." # 常见差异点 3 + } + + for key in tqdm(state_dict, desc="Applying LoRA"): + if ".lora_A." in key: + layer_name = key.split(".lora_A.")[0] + + # [新增] 尝试应用键名修正 + target_layer_name = layer_name + for src, dst in key_mapping.items(): + if src in target_layer_name: + target_layer_name = target_layer_name.replace(src, dst) + + # 在模型中查找层 + module = model + try: + parts = target_layer_name.split(".") + for part in parts: + if part.isdigit(): + module = module[int(part)] + else: + module = getattr(module, part) + except AttributeError: + # 如果修正后还是找不到,尝试原始名称(作为保底) + try: + module = model + parts = layer_name.split(".") + for part in parts: + if part.isdigit(): + module = module[int(part)] + else: + module = getattr(module, part) + except AttributeError: + # 确实找不到,跳过并打印警告(可选) + # print(f"Warning: Could not find layer for {layer_name}") + continue + + # 获取 LoRA 参数并计算增量 + try: + w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype) + w_b_key = key.replace("lora_A", "lora_B") + if w_b_key not in state_dict: continue + w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype) + + # 检查形状是否匹配 (非常重要,防止 broadcasting 错误掩盖问题) + # Linear weight: (out, in). B@A: (out, in) + delta_w = torch.mm(w_b, w_a) + if delta_w.shape != module.weight.shape: + # 形状不匹配通常意味着 QKV 融合/分离状态不一致 + # 简单跳过或尝试转置(视具体情况,这里保守跳过) + continue + + module.weight.data += delta_w * alpha + loaded_count += 1 + except Exception as e: + continue + + print(f"Applied LoRA to {loaded_count} layers.") + @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, diff --git a/diffsynth/utils/state_dict_converters/flux_controlnet.py b/diffsynth/utils/state_dict_converters/flux_controlnet.py new file mode 100644 index 0000000..926590c --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_controlnet.py @@ -0,0 +1,104 @@ +import torch +import hashlib +import json + +def FluxControlNetStateDictConverter(state_dict): + global_rename_dict = { + "context_embedder": "context_embedder", + "x_embedder": "x_embedder", + "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0", + "time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2", + "time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0", + "time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2", + "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0", + "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2", + "norm_out.linear": "final_norm_out.linear", + "proj_out": "final_proj_out", + } + rename_dict = { + "proj_out": "proj_out", + "norm1.linear": "norm1_a.linear", + "norm1_context.linear": "norm1_b.linear", + "attn.to_q": "attn.a_to_q", + "attn.to_k": "attn.a_to_k", + "attn.to_v": "attn.a_to_v", + "attn.to_out.0": "attn.a_to_out", + "attn.add_q_proj": "attn.b_to_q", + "attn.add_k_proj": "attn.b_to_k", + "attn.add_v_proj": "attn.b_to_v", + "attn.to_add_out": "attn.b_to_out", + "ff.net.0.proj": "ff_a.0", + "ff.net.2": "ff_a.2", + "ff_context.net.0.proj": "ff_b.0", + "ff_context.net.2": "ff_b.2", + "attn.norm_q": "attn.norm_q_a", + "attn.norm_k": "attn.norm_k_a", + "attn.norm_added_q": "attn.norm_q_b", + "attn.norm_added_k": "attn.norm_k_b", + } + rename_dict_single = { + "attn.to_q": "a_to_q", + "attn.to_k": "a_to_k", + "attn.to_v": "a_to_v", + "attn.norm_q": "norm_q_a", + "attn.norm_k": "norm_k_a", + "norm.linear": "norm.linear", + "proj_mlp": "proj_in_besides_attn", + "proj_out": "proj_out", + } + state_dict_ = {} + + for name in state_dict: + param = state_dict[name] + if name.endswith(".weight") or name.endswith(".bias"): + suffix = ".weight" if name.endswith(".weight") else ".bias" + prefix = name[:-len(suffix)] + if prefix in global_rename_dict: + state_dict_[global_rename_dict[prefix] + suffix] = param + elif prefix.startswith("transformer_blocks."): + names = prefix.split(".") + names[0] = "blocks" + middle = ".".join(names[2:]) + if middle in rename_dict: + name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]]) + state_dict_[name_] = param + elif prefix.startswith("single_transformer_blocks."): + names = prefix.split(".") + names[0] = "single_blocks" + middle = ".".join(names[2:]) + if middle in rename_dict_single: + name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]]) + state_dict_[name_] = param + else: + state_dict_[name] = param + else: + state_dict_[name] = param + for name in list(state_dict_.keys()): + if ".proj_in_besides_attn." in name: + name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.") + param = torch.concat([ + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")], + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")], + state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_v.")], + state_dict_[name], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_q.")) + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_k.")) + state_dict_.pop(name.replace(".proj_in_besides_attn.", f".a_to_v.")) + state_dict_.pop(name) + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) + + return state_dict_ \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/flux_infiniteyou.py b/diffsynth/utils/state_dict_converters/flux_infiniteyou.py new file mode 100644 index 0000000..826f996 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_infiniteyou.py @@ -0,0 +1,4 @@ +import torch + +def FluxInfiniteYouImageProjectorStateDictConverter(state_dict): + return state_dict['image_proj'] \ No newline at end of file diff --git a/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py b/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py index 9e3d74b..75f1bc8 100644 --- a/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py +++ b/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py @@ -13,10 +13,8 @@ pipe = FluxImagePipeline.from_pretrained( ModelConfig(model_id="DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev", origin_file_pattern="model.safetensors"), ], ) -pipe.enable_lora_magic() - lora = ModelConfig(model_id="VoidOc/flux_animal_forest1", origin_file_pattern="20.safetensors") -pipe.load_lora(pipe.dit, lora, hotload=True) # Use `pipe.clear_lora()` to drop the loaded LoRA. +pipe.load_lora(pipe.dit, lora) # Use `pipe.clear_lora()` to drop the loaded LoRA. # Empty prompt can automatically activate LoRA capabilities. image = pipe(prompt="", seed=0, lora_encoder_inputs=lora) From 3f9e9cad9d6d8e2fb3b0ed1e4e1da518861f18ac Mon Sep 17 00:00:00 2001 From: yjy415 <2471352175@qq.com> Date: Tue, 18 Nov 2025 20:37:14 +0800 Subject: [PATCH 2/4] fix:flux --- diffsynth/pipelines/flux_image.py | 100 +++--------------------------- 1 file changed, 10 insertions(+), 90 deletions(-) diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 78f0cb1..9ac0373 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -106,41 +106,7 @@ class FluxImagePipeline(BasePipeline): def enable_lora_magic(self): pass - # def load_lora(self, model, lora_config, alpha=1, hotload=False): - # if isinstance(lora_config, str): - # path = lora_config - # else: - # lora_config.download_if_necessary() - # path = lora_config.path - - # state_dict = load_state_dict(path, torch_dtype=self.torch_dtype, device="cpu") - # loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device) - # state_dict = loader.convert_state_dict(state_dict) - # loaded_count = 0 - # for key in tqdm(state_dict, desc="Applying LoRA"): - # if ".lora_A." in key: - # layer_name = key.split(".lora_A.")[0] - # module = model - # try: - # parts = layer_name.split(".") - # for part in parts: - # if part.isdigit(): - # module = module[int(part)] - # else: - # module = getattr(module, part) - # except AttributeError: - # continue - - # w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype) - # w_b_key = key.replace("lora_A", "lora_B") - # if w_b_key not in state_dict: continue - # w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype) - # delta_w = torch.mm(w_b, w_a) - # module.weight.data += delta_w * alpha - # loaded_count += 1 - - - def load_lora(self, model, lora_config, alpha=1.0, hotload=False): + def load_lora(self, model, lora_config, alpha=1, hotload=False): if isinstance(lora_config, str): path = lora_config else: @@ -150,74 +116,28 @@ class FluxImagePipeline(BasePipeline): state_dict = load_state_dict(path, torch_dtype=self.torch_dtype, device="cpu") loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device) state_dict = loader.convert_state_dict(state_dict) - - print(f"Merging LoRA weights from {path}...") loaded_count = 0 - - # [新增] 键名映射表,处理 FW2 Loader 与 DiT 模型名称不一致的情况 - # 针对 Single Blocks 常见的命名差异进行修正 - key_mapping = { - ".linear1.": ".to_qkv_mlp.", # 常见差异点 1 - ".linear2.": ".proj_out.", # 常见差异点 2 - ".modulation.lin.": ".norm.linear." # 常见差异点 3 - } - for key in tqdm(state_dict, desc="Applying LoRA"): if ".lora_A." in key: layer_name = key.split(".lora_A.")[0] - - # [新增] 尝试应用键名修正 - target_layer_name = layer_name - for src, dst in key_mapping.items(): - if src in target_layer_name: - target_layer_name = target_layer_name.replace(src, dst) - - # 在模型中查找层 module = model try: - parts = target_layer_name.split(".") + parts = layer_name.split(".") for part in parts: if part.isdigit(): module = module[int(part)] else: module = getattr(module, part) except AttributeError: - # 如果修正后还是找不到,尝试原始名称(作为保底) - try: - module = model - parts = layer_name.split(".") - for part in parts: - if part.isdigit(): - module = module[int(part)] - else: - module = getattr(module, part) - except AttributeError: - # 确实找不到,跳过并打印警告(可选) - # print(f"Warning: Could not find layer for {layer_name}") - continue - - # 获取 LoRA 参数并计算增量 - try: - w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype) - w_b_key = key.replace("lora_A", "lora_B") - if w_b_key not in state_dict: continue - w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype) - - # 检查形状是否匹配 (非常重要,防止 broadcasting 错误掩盖问题) - # Linear weight: (out, in). B@A: (out, in) - delta_w = torch.mm(w_b, w_a) - if delta_w.shape != module.weight.shape: - # 形状不匹配通常意味着 QKV 融合/分离状态不一致 - # 简单跳过或尝试转置(视具体情况,这里保守跳过) - continue - - module.weight.data += delta_w * alpha - loaded_count += 1 - except Exception as e: continue - - print(f"Applied LoRA to {loaded_count} layers.") - + + w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype) + w_b_key = key.replace("lora_A", "lora_B") + if w_b_key not in state_dict: continue + w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype) + delta_w = torch.mm(w_b, w_a) + module.weight.data += delta_w * alpha + loaded_count += 1 @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, From 2d23c897c20c60a877b834b689f80e786b5d7495 Mon Sep 17 00:00:00 2001 From: yjy415 <2471352175@qq.com> Date: Tue, 18 Nov 2025 21:29:35 +0800 Subject: [PATCH 3/4] add: LoRA Encoder --- diffsynth/models/flux_lora_encoder.py | 71 +------------------ diffsynth/pipelines/flux_image.py | 32 --------- .../model_inference/FLUX.1-dev-LoRA-Fusion.py | 2 - 3 files changed, 2 insertions(+), 103 deletions(-) diff --git a/diffsynth/models/flux_lora_encoder.py b/diffsynth/models/flux_lora_encoder.py index 2be5dbb..13589b0 100644 --- a/diffsynth/models/flux_lora_encoder.py +++ b/diffsynth/models/flux_lora_encoder.py @@ -468,80 +468,13 @@ class LoRAEmbedder(torch.nn.Module): "type": suffix, }) return lora_patterns - - def get_lora_param_pair(self, lora, name, dim, device, dtype): - key_A = name + ".lora_A.default.weight" - key_B = name + ".lora_B.default.weight" - if key_A in lora and key_B in lora: - return lora[key_A], lora[key_B] - if "to_qkv" in name: - base_name = name.replace("to_qkv", "") - suffixes = ["to_q", "to_k", "to_v"] - - found_As = [] - found_Bs = [] - - all_found = True - for suffix in suffixes: - sub_name = base_name + suffix - k_A = sub_name + ".lora_A.default.weight" - k_B = sub_name + ".lora_B.default.weight" - - if k_A in lora and k_B in lora: - found_As.append(lora[k_A]) - found_Bs.append(lora[k_B]) - else: - all_found = False - break - if all_found: - pass - - rank = 16 - for k, v in lora.items(): - if "lora_A" in k: - rank = v.shape[0] - device = v.device - dtype = v.dtype - break - - lora_A = torch.zeros((rank, dim[0]), device=device, dtype=dtype) - lora_B = torch.zeros((dim[1], rank), device=device, dtype=dtype) - - return lora_A, lora_B - def forward(self, lora): lora_emb = [] - device = None - dtype = None - for v in lora.values(): - device = v.device - dtype = v.dtype - break - for lora_pattern in self.lora_patterns: name, layer_type = lora_pattern["name"], lora_pattern["type"] - dim = lora_pattern["dim"] - - lora_A, lora_B = self.get_lora_param_pair(lora, name, dim, device, dtype) - - if "to_qkv" in name and (lora_A is None or (torch.equal(lora_A, torch.zeros_like(lora_A)))): - base_name = name.replace("to_qkv", "") - try: - q_name = base_name + "to_q" - k_name = base_name + "to_k" - v_name = base_name + "to_v" - - real_A = lora[q_name + ".lora_A.default.weight"] - B_q = lora[q_name + ".lora_B.default.weight"] - B_k = lora[k_name + ".lora_B.default.weight"] - B_v = lora[v_name + ".lora_B.default.weight"] - real_B = torch.cat([B_q, B_k, B_v], dim=0) - - lora_A, lora_B = real_A, real_B - except KeyError: - pass - + lora_A = lora[name + ".lora_A.weight"] + lora_B = lora[name + ".lora_B.weight"] lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B) lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out) lora_emb.append(lora_out) diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index 9ac0373..dcd9e8e 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -106,38 +106,6 @@ class FluxImagePipeline(BasePipeline): def enable_lora_magic(self): pass - def load_lora(self, model, lora_config, alpha=1, hotload=False): - if isinstance(lora_config, str): - path = lora_config - else: - lora_config.download_if_necessary() - path = lora_config.path - - state_dict = load_state_dict(path, torch_dtype=self.torch_dtype, device="cpu") - loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device) - state_dict = loader.convert_state_dict(state_dict) - loaded_count = 0 - for key in tqdm(state_dict, desc="Applying LoRA"): - if ".lora_A." in key: - layer_name = key.split(".lora_A.")[0] - module = model - try: - parts = layer_name.split(".") - for part in parts: - if part.isdigit(): - module = module[int(part)] - else: - module = getattr(module, part) - except AttributeError: - continue - - w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype) - w_b_key = key.replace("lora_A", "lora_B") - if w_b_key not in state_dict: continue - w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype) - delta_w = torch.mm(w_b, w_a) - module.weight.data += delta_w * alpha - loaded_count += 1 @staticmethod def from_pretrained( torch_dtype: torch.dtype = torch.bfloat16, diff --git a/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py b/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py index 5339230..9d0b189 100644 --- a/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py +++ b/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py @@ -18,12 +18,10 @@ pipe.enable_lora_magic() pipe.load_lora( pipe.dit, ModelConfig(model_id="cancel13/cxsk", origin_file_pattern="30.safetensors"), - hotload=True, ) pipe.load_lora( pipe.dit, ModelConfig(model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", origin_file_pattern="merged_lora.safetensors"), - hotload=True, ) image = pipe(prompt="a cat", seed=0) image.save("image_fused.jpg") From c119ce7e6481c8dccdec54d6736e65bea781aca7 Mon Sep 17 00:00:00 2001 From: yjy415 <2471352175@qq.com> Date: Wed, 19 Nov 2025 15:14:18 +0800 Subject: [PATCH 4/4] Add: FLUX --- diffsynth/configs/model_configs.py | 49 + diffsynth/models/flux_ipadapter.py | 39 +- diffsynth/models/flux_lora_patcher.py | 250 +++- diffsynth/models/nexus_gen.py | 161 +++ diffsynth/models/nexus_gen_ar_model.py | 1143 +++++++++++++++++ diffsynth/models/nexus_gen_projector.py | 417 ++++++ diffsynth/pipelines/flux_image.py | 9 +- .../utils/state_dict_converters/flux_dit.py | 37 + .../state_dict_converters/flux_ipadapter.py | 34 + .../utils/state_dict_converters/nexus_gen.py | 8 + .../nexus_gen_projector.py | 17 + 11 files changed, 2134 insertions(+), 30 deletions(-) create mode 100644 diffsynth/models/nexus_gen.py create mode 100644 diffsynth/models/nexus_gen_ar_model.py create mode 100644 diffsynth/models/nexus_gen_projector.py create mode 100644 diffsynth/utils/state_dict_converters/flux_ipadapter.py create mode 100644 diffsynth/utils/state_dict_converters/nexus_gen.py create mode 100644 diffsynth/utils/state_dict_converters/nexus_gen_projector.py diff --git a/diffsynth/configs/model_configs.py b/diffsynth/configs/model_configs.py index 3aab5c7..f89cffb 100644 --- a/diffsynth/configs/model_configs.py +++ b/diffsynth/configs/model_configs.py @@ -364,6 +364,55 @@ flux_series = [ "model_name": "flux_lora_patcher", "model_class": "diffsynth.models.flux_lora_patcher.FluxLoraPatcher", }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="model*.safetensors") + "model_hash": "2bd19e845116e4f875a0a048e27fc219", + "model_name": "nexus_gen_llm", + "model_class": "diffsynth.models.nexus_gen.NexusGenAutoregressiveModel", + "state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen.NexusGenAutoregressiveModelStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin") + "model_hash": "63c969fd37cce769a90aa781fbff5f81", + "model_name": "nexus_gen_editing_adapter", + "model_class": "diffsynth.models.nexus_gen_projector.NexusGenImageEmbeddingMerger", + "state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenMergerStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="edit_decoder.bin") + "model_hash": "63c969fd37cce769a90aa781fbff5f81", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin") + "model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d", + "model_name": "nexus_gen_generation_adapter", + "model_class": "diffsynth.models.nexus_gen_projector.NexusGenAdapter", + "state_dict_converter": "diffsynth.utils.state_dict_converters.nexus_gen_projector.NexusGenAdapterStateDictConverter", + }, + { + # Example: ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="generation_decoder.bin") + "model_hash": "3e6c61b0f9471135fc9c6d6a98e98b6d", + "model_name": "flux_dit", + "model_class": "diffsynth.models.flux_dit.FluxDiT", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter", + }, + { + # Example: ModelConfig(model_id="InstantX/FLUX.1-dev-IP-Adapter", origin_file_pattern="ip-adapter.bin") + "model_hash": "4daaa66cc656a8fe369908693dad0a35", + "model_name": "flux_ipadapter", + "model_class": "diffsynth.models.flux_ipadapter.FluxIpAdapter", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.FluxIpAdapterStateDictConverter", + }, + { + # Example: ModelConfig(model_id="google/siglip-so400m-patch14-384", origin_file_pattern="model.safetensors") + "model_hash": "04d8c1e20a1f1b25f7434f111992a33f", + "model_name": "siglip_vision_model", + "model_class": "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M", + "state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.SiglipStateDictConverter", + }, ] MODEL_CONFIGS = qwen_image_series + wan_series + flux_series diff --git a/diffsynth/models/flux_ipadapter.py b/diffsynth/models/flux_ipadapter.py index 798c45a..31176fc 100644 --- a/diffsynth/models/flux_ipadapter.py +++ b/diffsynth/models/flux_ipadapter.py @@ -5,34 +5,21 @@ import torch class SiglipVisionModelSO400M(SiglipVisionModel): def __init__(self): - config = SiglipVisionConfig(**{ - "architectures": [ - "SiglipModel" - ], - "initializer_factor": 1.0, - "model_type": "siglip", - "text_config": { - "hidden_size": 1152, - "intermediate_size": 4304, - "model_type": "siglip_text_model", - "num_attention_heads": 16, - "num_hidden_layers": 27 - }, - "torch_dtype": "float32", - "transformers_version": "4.37.0.dev0", - "vision_config": { - "hidden_size": 1152, - "image_size": 384, - "intermediate_size": 4304, - "model_type": "siglip_vision_model", - "num_attention_heads": 16, - "num_hidden_layers": 27, - "patch_size": 14 - } - }) + config = SiglipVisionConfig( + hidden_size=1152, + image_size=384, + intermediate_size=4304, + model_type="siglip_vision_model", + num_attention_heads=16, + num_hidden_layers=27, + patch_size=14, + architectures=["SiglipModel"], + initializer_factor=1.0, + torch_dtype="float32", + transformers_version="4.37.0.dev0" + ) super().__init__(config) - class MLPProjModel(torch.nn.Module): def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4): super().__init__() diff --git a/diffsynth/models/flux_lora_patcher.py b/diffsynth/models/flux_lora_patcher.py index a249feb..5d8fc8c 100644 --- a/diffsynth/models/flux_lora_patcher.py +++ b/diffsynth/models/flux_lora_patcher.py @@ -1,4 +1,251 @@ -import torch +import torch, math +from ..core.loader import load_state_dict +from typing import Union + +class GeneralLoRALoader: + def __init__(self, device="cpu", torch_dtype=torch.float32): + self.device = device + self.torch_dtype = torch_dtype + + + def get_name_dict(self, lora_state_dict): + lora_name_dict = {} + for key in lora_state_dict: + if ".lora_B." not in key: + continue + keys = key.split(".") + if len(keys) > keys.index("lora_B") + 2: + keys.pop(keys.index("lora_B") + 1) + keys.pop(keys.index("lora_B")) + if keys[0] == "diffusion_model": + keys.pop(0) + keys.pop(-1) + target_name = ".".join(keys) + lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A.")) + return lora_name_dict + + + def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): + updated_num = 0 + lora_name_dict = self.get_name_dict(state_dict_lora) + for name, module in model.named_modules(): + if name in lora_name_dict: + weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype) + weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype) + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2) + weight_down = weight_down.squeeze(3).squeeze(2) + weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + weight_lora = alpha * torch.mm(weight_up, weight_down) + state_dict = module.state_dict() + state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora + module.load_state_dict(state_dict) + updated_num += 1 + print(f"{updated_num} tensors are updated by LoRA.") + +class FluxLoRALoader(GeneralLoRALoader): + def __init__(self, device="cpu", torch_dtype=torch.float32): + super().__init__(device=device, torch_dtype=torch_dtype) + + self.diffusers_rename_dict = { + "transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.default.weight", + "transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.default.weight", + "transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.default.weight", + "transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.default.weight", + "transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.default.weight", + "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.default.weight", + "transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.default.weight", + "transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.default.weight", + "transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.default.weight", + "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.default.weight", + "transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.default.weight", + "transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.default.weight", + "transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.default.weight", + "transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.default.weight", + "transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.default.weight", + "transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.default.weight", + "transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.default.weight", + } + + self.civitai_rename_dict = { + "lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight", + "lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight", + "lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight", + "lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight", + } + + def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0): + super().load(model, state_dict_lora, alpha) + + + def convert_state_dict(self,state_dict): + + def guess_block_id(name,model_resource): + if model_resource == 'civitai': + names = name.split("_") + for i in names: + if i.isdigit(): + return i, name.replace(f"_{i}_", "_blockid_") + if model_resource == 'diffusers': + names = name.split(".") + for i in names: + if i.isdigit(): + return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.") + return None, None + + def guess_resource(state_dict): + for k in state_dict: + if "lora_unet_" in k: + return 'civitai' + elif k.startswith("transformer."): + return 'diffusers' + else: + None + + model_resource = guess_resource(state_dict) + if model_resource is None: + return state_dict + + rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict + def guess_alpha(state_dict): + for name, param in state_dict.items(): + if ".alpha" in name: + for suffix in [".lora_down.weight", ".lora_A.weight"]: + name_ = name.replace(".alpha", suffix) + if name_ in state_dict: + lora_alpha = param.item() / state_dict[name_].shape[0] + lora_alpha = math.sqrt(lora_alpha) + return lora_alpha + + return 1 + + alpha = guess_alpha(state_dict) + + state_dict_ = {} + for name, param in state_dict.items(): + block_id, source_name = guess_block_id(name,model_resource) + if alpha != 1: + param *= alpha + if source_name in rename_dict: + target_name = rename_dict[source_name] + target_name = target_name.replace(".blockid.", f".{block_id}.") + state_dict_[target_name] = param + else: + state_dict_[name] = param + + if model_resource == 'diffusers': + for name in list(state_dict_.keys()): + if "single_blocks." in name and ".a_to_q." in name: + mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None) + if mlp is None: + dim = 4 + if 'lora_A' in name: + dim = 1 + mlp = torch.zeros(dim * state_dict_[name].shape[0], + *state_dict_[name].shape[1:], + dtype=state_dict_[name].dtype) + else: + state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn.")) + if 'lora_A' in name: + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + elif 'lora_B' in name: + d, r = state_dict_[name].shape + param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device) + param[:d, :r] = state_dict_.pop(name) + param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")) + param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")) + param[3*d:, 3*r:] = mlp + else: + param = torch.concat([ + state_dict_.pop(name), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")), + state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")), + mlp, + ], dim=0) + name_ = name.replace(".a_to_q.", ".to_qkv_mlp.") + state_dict_[name_] = param + for name in list(state_dict_.keys()): + for component in ["a", "b"]: + if f".{component}_to_q." in name: + name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.") + concat_dim = 0 + if 'lora_A' in name: + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + elif 'lora_B' in name: + origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] + d, r = origin.shape + # print(d, r) + param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device) + param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")] + param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")] + param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")] + else: + param = torch.concat([ + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")], + state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")], + ], dim=0) + state_dict_[name_] = param + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k.")) + state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v.")) + return state_dict_ + class LoraMerger(torch.nn.Module): def __init__(self, dim): @@ -57,4 +304,3 @@ class FluxLoraPatcher(torch.nn.Module): def forward(self, base_output, lora_outputs, name): return self.model_dict[name.replace(".", "___")](base_output, lora_outputs) - diff --git a/diffsynth/models/nexus_gen.py b/diffsynth/models/nexus_gen.py new file mode 100644 index 0000000..0110398 --- /dev/null +++ b/diffsynth/models/nexus_gen.py @@ -0,0 +1,161 @@ +import torch +from PIL import Image + + +class NexusGenAutoregressiveModel(torch.nn.Module): + def __init__(self, max_length=1024, max_pixels=262640): + super(NexusGenAutoregressiveModel, self).__init__() + from .nexus_gen_ar_model import Qwen2_5_VLForConditionalGeneration + from transformers import Qwen2_5_VLConfig + self.max_length = max_length + self.max_pixels = max_pixels + model_config = Qwen2_5_VLConfig(**{ + "_name_or_path": "DiffSynth-Studio/Nexus-GenV2", + "architectures": [ + "Qwen2_5_VLForConditionalGeneration" + ], + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig", + "AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel", + "AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration" + }, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "image_token_id": 151655, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 128000, + "max_window_layers": 28, + "model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "pad_token_id": 151643, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.49.0", + "use_cache": False, + "use_sliding_window": False, + "video_token_id": 151656, + "vision_config": { + "hidden_size": 1280, + "in_chans": 3, + "model_type": "qwen2_5_vl", + "spatial_patch_size": 14, + "tokens_per_second": 2, + "torch_dtype": "bfloat16" + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654, + "vocab_size": 152064 + }) + self.model = Qwen2_5_VLForConditionalGeneration(model_config) + self.processor = None + + + def load_processor(self, path): + from .nexus_gen_ar_model import Qwen2_5_VLProcessor + self.processor = Qwen2_5_VLProcessor.from_pretrained(path) + + + @staticmethod + def state_dict_converter(): + return NexusGenAutoregressiveModelStateDictConverter() + + def bound_image(self, image, max_pixels=262640): + from qwen_vl_utils import smart_resize + resized_height, resized_width = smart_resize( + image.height, + image.width, + max_pixels=max_pixels, + ) + return image.resize((resized_width, resized_height)) + + def get_editing_msg(self, instruction): + if '' not in instruction: + instruction = ' ' + instruction + messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is the image: "}] + return messages + + def get_generation_msg(self, instruction): + instruction = "Generate an image according to the following description: {}".format(instruction) + messages = [{"role":"user", "content":instruction}, {"role":"assistant", "content":"Here is an image based on the description: "}] + return messages + + def forward(self, instruction, ref_image=None, num_img_tokens=81): + """ + Generate target embeddings for the given instruction and reference image. + """ + if ref_image is not None: + messages = self.get_editing_msg(instruction) + images = [self.bound_image(ref_image)] + [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))] + output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens) + else: + messages = self.get_generation_msg(instruction) + images = [Image.new(mode='RGB', size=(252, 252), color=(255, 255, 255))] + output_image_embeddings = self.get_target_embeddings(images, messages, self.processor, self.model, num_img_tokens) + + return output_image_embeddings + + def get_target_embeddings(self, images, messages, processor, model, num_img_tokens=81): + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) + text = text.replace('', '<|vision_start|><|image_pad|><|vision_end|>') + inputs = processor( + text=[text], + images=images, + padding=True, + return_tensors="pt", + ) + inputs = inputs.to(model.device) + + input_embeds = model.model.embed_tokens(inputs['input_ids']) + image_embeds = model.visual(inputs['pixel_values'], grid_thw=inputs['image_grid_thw']) + ground_truth_image_embeds = image_embeds[-num_img_tokens:] + input_image_embeds = image_embeds[:-num_img_tokens] + + image_mask = inputs['input_ids'] == model.config.image_token_id + indices = image_mask.cumsum(dim=1) + input_image_mask = torch.logical_and(indices <= (image_embeds.shape[0] - ground_truth_image_embeds.shape[0]), image_mask) + gt_image_mask = torch.logical_and(image_mask, ~input_image_mask) + input_image_mask = input_image_mask.unsqueeze(-1).expand_as(input_embeds) + input_embeds = input_embeds.masked_scatter(input_image_mask, input_image_embeds) + + image_prefill_embeds = model.image_prefill_embeds( + torch.arange(81, device=model.device).long() + ) + input_embeds = input_embeds.masked_scatter(gt_image_mask.unsqueeze(-1).expand_as(input_embeds), image_prefill_embeds) + + position_ids, _ = model.get_rope_index( + inputs['input_ids'], + inputs['image_grid_thw'], + attention_mask=inputs['attention_mask']) + position_ids = position_ids.contiguous() + outputs = model(inputs_embeds=input_embeds, position_ids=position_ids, attention_mask=inputs['attention_mask'], return_dict=True) + output_image_embeddings = outputs.image_embeddings[:, :-1, :] + output_image_embeddings = output_image_embeddings[gt_image_mask[:, 1:]] + return output_image_embeddings, input_image_embeds, inputs['image_grid_thw'] + + +class NexusGenAutoregressiveModelStateDictConverter: + def __init__(self): + pass + + def from_civitai(self, state_dict): + state_dict = {"model." + key: value for key, value in state_dict.items()} + return state_dict diff --git a/diffsynth/models/nexus_gen_ar_model.py b/diffsynth/models/nexus_gen_ar_model.py new file mode 100644 index 0000000..d5a2973 --- /dev/null +++ b/diffsynth/models/nexus_gen_ar_model.py @@ -0,0 +1,1143 @@ +import os +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss + +from transformers.cache_utils import Cache +from transformers.generation import GenerationMixin, LogitsProcessorList, StoppingCriteriaList, GenerationConfig, GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput +from transformers.utils import add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from transformers.modeling_outputs import ModelOutput +from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VisionTransformerPretrainedModel, + Qwen2_5_VLModel, + Qwen2_5_VLPreTrainedModel, + QWEN2_5_VL_INPUTS_DOCSTRING, + ) + +from transformers.feature_extraction_utils import BatchFeature +from transformers.image_utils import ImageInput, VideoInput +from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, VideosKwargs +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput + +GenerateNonBeamOutput = Union[GenerateDecoderOnlyOutput, GenerateEncoderDecoderOutput] + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "Qwen2_5_VLConfig" + + +@dataclass +class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput): + """ + Base class for Qwen2_5_VL causal language model (or autoregressive) outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*): + The rope index difference between sequence length and multimodal rope. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + image_embeddings: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + rope_deltas: Optional[torch.LongTensor] = None + + +class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + config_class = Qwen2_5_VLConfig + _no_split_modules = ["Qwen2_5_VLDecoderLayer", "Qwen2_5_VLVisionBlock"] + + def __init__(self, config): + super().__init__(config) + self.visual = Qwen2_5_VisionTransformerPretrainedModel._from_config(config.vision_config) + self.model = Qwen2_5_VLModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.vision_head = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.rope_deltas = None # cache rope_deltas here + self.image_prefill_embeds = nn.Embedding(81, config.hidden_size) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def get_rope_index( + self, + input_ids: Optional[torch.LongTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculate the 3D rope index based on image and video's temporal, height and width in LLM. + + Explanation: + Each embedding sequence contains vision embedding and text embedding or just contains text embedding. + + For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs. + Examples: + input_ids: [T T T T T], here T is for text. + temporal position_ids: [0, 1, 2, 3, 4] + height position_ids: [0, 1, 2, 3, 4] + width position_ids: [0, 1, 2, 3, 4] + + For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part + and 1D rotary position embedding for text part. + Examples: + Temporal (Time): 3 patches, representing different segments of the video in time. + Height: 2 patches, dividing each frame vertically. + Width: 2 patches, dividing each frame horizontally. + We also have some important parameters: + fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second. + tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity. + temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames. + interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs. + input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision. + vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100] + vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1] + vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] + text temporal position_ids: [101, 102, 103, 104, 105] + text height position_ids: [101, 102, 103, 104, 105] + text width position_ids: [101, 102, 103, 104, 105] + Here we calculate the text start position_ids as the max vision position_ids plus 1. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*): + The temporal, height and width of feature shape of each image in LLM. + video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*): + The temporal, height and width of feature shape of each video in LLM. + second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*): + The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + Returns: + position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`) + mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`) + """ + spatial_merge_size = self.config.vision_config.spatial_merge_size + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + mrope_position_deltas = [] + if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + second_per_grid_t = 0 + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + if second_per_grid_ts is not None: + second_per_grid_t = second_per_grid_ts[video_index] + else: + second_per_grid_t = 1.0 + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + + time_tensor = expanded_range * second_per_grid_t * self.config.vision_config.tokens_per_second + + time_tensor_long = time_tensor.long() + t_index = time_tensor_long.flatten() + + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) + mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device) + max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + @add_start_docstrings_to_model_forward(QWEN2_5_VL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Qwen2_5_VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, + image_embeddings: Optional[torch.Tensor] = None, + token_loss_weight: Optional[float] = 0.1, + img_loss_weight: Optional[float] = 1.0, + ) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + # test feature + inputs_embeds = self.model.embed_tokens(input_ids) + # for image encoding and training + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + # position_ids [3, B, L] + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + image_embeds = self.vision_head(hidden_states) + + loss = None + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + # prepare labels for logits + logits_labels = labels.clone().detach() + image_tokens = (labels == self.config.image_token_id) + logits_labels[image_tokens] = -100 + + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = logits_labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) * token_loss_weight + + shift_image_tokens_2d = (labels[..., 1:].contiguous() == self.config.image_token_id) # (B, L-1) + shifted_image_embeds = image_embeds[:, :-1, :].contiguous() # (B, L-1, D) + masked_image_embeds = shifted_image_embeds[shift_image_tokens_2d] # (num_image_tokens, D) + + mse_loss_fct = nn.MSELoss() + mse_loss_fct = mse_loss_fct.to(shift_logits.device) + if image_embeddings is None: + image_embeddings = torch.zeros_like(masked_image_embeds) + img_loss = mse_loss_fct(masked_image_embeds, image_embeddings) + + cos_sim = torch.cosine_similarity( + masked_image_embeds, + image_embeddings, + dim=-1 + ) + cos_loss = (1 - cos_sim).mean() + img_loss = 0.5 * img_loss + 0.5 * cos_loss + # fix nan for empty image tokens + if image_embeddings.size(0) == 0: + img_loss = img_loss.nan_to_num(0.0) + # combine the loss + loss = loss + img_loss_weight * img_loss + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5_VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + image_embeddings=image_embeds, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + + + def _sample( + self, + input_ids: torch.LongTensor, + logits_processor: LogitsProcessorList, + stopping_criteria: StoppingCriteriaList, + generation_config: GenerationConfig, + synced_gpus: bool, + streamer: Optional["BaseStreamer"], + **model_kwargs, + ) -> Union[GenerateNonBeamOutput, torch.LongTensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + + Parameters: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + The sequence used as a prompt for the generation. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + stopping_criteria (`StoppingCriteriaList`): + An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`] + used to tell if the generation loop should stop. + generation_config ([`~generation.GenerationConfig`]): + The generation configuration to be used as parametrization of the decoding method. + synced_gpus (`bool`): + Whether to continue running the while loop until max_length (needed to avoid deadlocking with + `FullyShardedDataParallel` and DeepSpeed ZeRO Stage 3). + streamer (`BaseStreamer`, *optional*): + Streamer object that will be used to stream the generated sequences. Generated tokens are passed + through `streamer.put(token_ids)` and the streamer is responsible for any further processing. + model_kwargs: + Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is + an encoder-decoder model the kwargs should include `encoder_outputs`. + + Return: + [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`: + A `torch.LongTensor` containing the generated tokens (default behaviour) or a + [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and + `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if + `model.config.is_encoder_decoder=True`. + """ + # init values + pad_token_id = generation_config._pad_token_tensor + output_attentions = generation_config.output_attentions + output_hidden_states = generation_config.output_hidden_states + output_scores = generation_config.output_scores + output_logits = generation_config.output_logits + return_dict_in_generate = generation_config.return_dict_in_generate + max_length = generation_config.max_length + has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria) + do_sample = generation_config.do_sample + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + raw_logits = () if (return_dict_in_generate and output_logits) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + batch_size, cur_len = input_ids.shape + this_peer_finished = False + unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) + model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) + + model_forward = self.__call__ + if isinstance(model_kwargs.get("past_key_values"), Cache): + is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache + is_compileable = is_compileable and not self.generation_config.disable_compile + if is_compileable and ( + self.device.type == "cuda" or generation_config.compile_config._compile_all_devices + ): + os.environ["TOKENIZERS_PARALLELISM"] = "0" + model_forward = self.get_compiled_call(generation_config.compile_config) + + is_prefill = True + is_sampling_img = input_ids[:, -1] == self.config.vision_start_token_id + generation_image_grid_thw = model_kwargs.pop("generation_image_grid_thw", self.get_default_image_grid_thw()) + num_img_tokens = self.get_num_image_tokens(generation_image_grid_thw) + output_image_embeddings = [] + while self._has_unfinished_sequences( + this_peer_finished, synced_gpus, device=input_ids.device, cur_len=cur_len, max_length=max_length + ): + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + + # prepare prefilled embeds + model_inputs.update(self.prepare_prefilled_image_embeds(len(output_image_embeddings), num_img_tokens, is_sampling_img, **model_kwargs)) + + # parse position_ids from model_kwargs + model_inputs.update(self.prepare_image_position_ids(input_ids, generation_image_grid_thw, is_sampling_img, **model_kwargs)) + + # prepare variable output controls (note: some models won't accept all output controls) + model_inputs.update({"output_attentions": output_attentions} if output_attentions else {}) + model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {}) + + if is_prefill: + outputs = self(**model_inputs, return_dict=True) + is_prefill = False + else: + outputs = model_forward(**model_inputs, return_dict=True) + + # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping + model_kwargs = self._update_model_kwargs_for_generation( + outputs, + model_kwargs, + is_encoder_decoder=self.config.is_encoder_decoder, + ) + # TODO: support batch image sampling + if bool(is_sampling_img) and len(output_image_embeddings) < num_img_tokens: + output_image_embeddings.append(outputs.image_embeddings[:, -1, :].unsqueeze(1)) + + if synced_gpus and this_peer_finished: + continue + # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration + # (the clone itself is always small) + next_token_logits = outputs.logits[:, -1, :].clone().float() + next_token_logits = next_token_logits.to(input_ids.device) + + # do not sample token + next_token_logits[:, self.config.vision_end_token_id] = -float('inf') + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_logits: + raw_logits += (next_token_logits,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # token selection + if do_sample: + probs = nn.functional.softmax(next_token_scores, dim=-1) + # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + # while not bool(is_sampling_img) and torch.any(next_tokens == self.config.vision_end_token_id): + # probs[:, self.config.vision_end_token_id] = 0 + # next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(next_token_scores, dim=-1) + + # finished sentences should have their next token be a padding token + if has_eos_stopping_criteria: + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + + #TODO: support batch image sample + if num_img_tokens is not None: + cur_img_tokens = (input_ids == self.config.vision_start_token_id).flip(dims=[1]).float().argmax(dim=1) + # check whether is sampling images + is_end_img = torch.logical_and(cur_img_tokens == num_img_tokens, is_sampling_img) + is_sampling_img = torch.logical_and(is_sampling_img, cur_img_tokens < num_img_tokens) + next_tokens[is_sampling_img] = self.config.image_token_id + # check whether to end sampling images + next_tokens[is_end_img] = self.config.vision_end_token_id + else: + # check whether to end sampling images + is_sampling_img = torch.logical_and(is_sampling_img, (next_tokens != self.config.vision_end_token_id)) + # replace the next token with the image token if is sampling image + next_tokens[is_sampling_img] = self.config.image_token_id + # check whether to start sampling images + is_sampling_img = torch.logical_or(is_sampling_img, (next_tokens == self.config.vision_start_token_id)) + + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + + if streamer is not None: + streamer.put(next_tokens.cpu()) + + unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) + this_peer_finished = unfinished_sequences.max() == 0 + cur_len += 1 + + # This is needed to properly delete outputs.logits which may be very large for first iteration + # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration + del outputs + + if streamer is not None: + streamer.end() + + # output the image embeddings + output_image_embeddings = torch.cat(output_image_embeddings, dim=1) if len(output_image_embeddings) > 0 else None + + if return_dict_in_generate: + return GenerateDecoderOnlyAll2AllOutput( + sequences=input_ids, + scores=scores, + logits=raw_logits, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + past_key_values=model_kwargs.get("past_key_values"), + output_image_embeddings=output_image_embeddings, + ) + else: + return input_ids + + + def prepare_prefilled_image_embeds(self, cur_image_tokens, num_img_tokens, is_sampling_img, **model_kwargs): + if cur_image_tokens == 0 or cur_image_tokens > num_img_tokens or not bool(is_sampling_img): + return {} + # TODO: support batch image sample + image_idx = torch.tensor([cur_image_tokens-1]).to(self.device).long().unsqueeze(0) + inputs_embeds = self.image_prefill_embeds(image_idx) + return {"inputs_embeds": inputs_embeds} + + + def get_default_image_grid_thw(self,): + return torch.tensor([[1, 18, 18]]).to(self.device) + + + def get_num_image_tokens(self, image_grid_thw): + return int(torch.prod(image_grid_thw, dim=1).sum() // 4) + + + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + num_img_tokens = model_kwargs.pop("generation_image_grid_thw", None) + super()._validate_model_kwargs(model_kwargs) + model_kwargs["generation_image_grid_thw"] = num_img_tokens + + def prepare_image_position_ids(self, input_ids, generation_image_grid_thw, is_sampling_img, **model_kwargs): + # Overwritten -- prepare position_ids for image tokens + cur_img_tokens = int((input_ids == self.config.vision_start_token_id).flip(dims=[1]).float().argmax(dim=1)) + # TODO: support batch image sample + if cur_img_tokens > 0 and bool(is_sampling_img): + image_grid_thw = generation_image_grid_thw + if model_kwargs.get('image_grid_thw') is not None: + image_grid_thw = torch.cat([model_kwargs.get('image_grid_thw'), image_grid_thw]) + remaining_img_tokens = self.get_num_image_tokens(generation_image_grid_thw) - cur_img_tokens + padding_ids = input_ids.new_full((1, remaining_img_tokens), fill_value=self.config.image_token_id) + padded_ids = torch.cat([input_ids, padding_ids], dim=1) + position_ids, _ = self.get_rope_index(padded_ids, image_grid_thw, None, None) + if model_kwargs.get("use_cache", True): + position_ids = position_ids[:, :, input_ids.shape[1] - 1].unsqueeze(-1) + else: + position_ids = position_ids[:, :, :input_ids.shape[1]] + return {"position_ids": position_ids} + return {} + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + image_embeddings=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + + model_inputs = super().prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + position_ids=position_ids, + pixel_values=pixel_values, + pixel_values_videos=pixel_values_videos, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + second_per_grid_ts=second_per_grid_ts, + use_cache=use_cache, + **kwargs, + ) + + # Qwen2-5-VL position_ids are prepared with rope_deltas in forward + model_inputs["position_ids"] = None + + if cache_position[0] != 0: + model_inputs["pixel_values"] = None + model_inputs["pixel_values_videos"] = None + return model_inputs + + def _get_image_nums_and_video_nums( + self, + input_ids: Optional[torch.LongTensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get the number of images and videos for each sample to calculate the separation length of the sample tensor. + These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. + + Returns: + image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`) + video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`) + """ + image_token_id = self.config.image_token_id + video_token_id = self.config.video_token_id + vision_start_token_id = self.config.vision_start_token_id + + vision_start_mask = input_ids == vision_start_token_id + vision_first_mask = torch.roll(vision_start_mask, shifts=1, dims=1) + image_mask = input_ids == image_token_id + video_mask = input_ids == video_token_id + image_nums = torch.sum(vision_first_mask & image_mask, dim=1) + video_nums = torch.sum(vision_first_mask & video_mask, dim=1) + + return image_nums, video_nums + + def _expand_inputs_for_generation( + self, + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[torch.LongTensor] = None, + **model_kwargs, + ) -> Tuple[torch.LongTensor, Dict[str, Any]]: + # Overwritten -- Support for expanding tensors without a batch size dimension + # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t + # pixel_values.shape[0] is sum(seqlen_images for samples) + # image_grid_thw.shape[0] is sum(num_images for samples) + + if expand_size == 1: + return input_ids, model_kwargs + + visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"] + + def _expand_dict_for_generation_visual(dict_to_expand): + image_grid_thw = model_kwargs.get("image_grid_thw", None) + video_grid_thw = model_kwargs.get("video_grid_thw", None) + image_nums, video_nums = self._get_image_nums_and_video_nums(input_ids) + + def _repeat_interleave_samples(x, lengths, repeat_times): + samples = torch.split(x, lengths) + repeat_args = [repeat_times] + [1] * (x.dim() - 1) + result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0) + return result + + for key in dict_to_expand: + if key == "pixel_values": + # split images into samples + samples = torch.split(image_grid_thw, list(image_nums)) + # compute the sequence length of images for each sample + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "image_grid_thw": + # get the num of images for each sample + lengths = list(image_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "pixel_values_videos": + samples = torch.split(video_grid_thw, list(video_nums)) + lengths = [torch.prod(sample, dim=1).sum() for sample in samples] + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "video_grid_thw": + lengths = list(video_nums) + dict_to_expand[key] = _repeat_interleave_samples( + dict_to_expand[key], lengths=lengths, repeat_times=expand_size + ) + elif key == "second_per_grid_ts": + if not isinstance(dict_to_expand[key], list): + raise TypeError( + f"Expected value for key '{key}' to be a list, but got {type(dict_to_expand[key])} instead." + ) + tensor = torch.tensor(dict_to_expand[key]) + lengths = list(video_nums) + tensor = _repeat_interleave_samples(tensor, lengths=lengths, repeat_times=expand_size) + dict_to_expand[key] = tensor.tolist() + return dict_to_expand + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if ( + key != "cache_position" + and dict_to_expand[key] is not None + and isinstance(dict_to_expand[key], torch.Tensor) + and key not in visual_keys + ): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + # input_ids is required for expanding visual inputs + # If input_ids is unavailable, visual inputs will not be used; therefore, there is no need to expand visual inputs. + if input_ids is not None and input_ids.numel() != 0: + model_kwargs = _expand_dict_for_generation_visual(model_kwargs) + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + if is_encoder_decoder: + if model_kwargs.get("encoder_outputs") is None: + raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") + model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"]) + + return input_ids, model_kwargs + + +__all__ = ["Qwen2_5_VLForConditionalGeneration", "Qwen2_5_VLModel", "Qwen2_5_VLPreTrainedModel"] + + + +class Qwen2_5_VLVideosProcessorKwargs(VideosKwargs, total=False): + fps: Union[List[float], float] + + +class Qwen2_5_VLProcessorKwargs(ProcessingKwargs, total=False): + videos_kwargs: Qwen2_5_VLVideosProcessorKwargs + _defaults = { + "text_kwargs": { + "padding": False, + }, + "videos_kwargs": {"fps": 2.0}, + } + + +class Qwen2_5_VLProcessor(ProcessorMixin): + r""" + Constructs a Qwen2.5-VL processor which wraps a Qwen2.5-VL image processor and a Qwen2 tokenizer into a single processor. + [`Qwen2_5_VLProcessor`] offers all the functionalities of [`Qwen2VLImageProcessor`] and [`Qwen2TokenizerFast`]. See the + [`~Qwen2_5_VLProcessor.__call__`] and [`~Qwen2_5_VLProcessor.decode`] for more information. + Args: + image_processor ([`Qwen2VLImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`Qwen2TokenizerFast`], *optional*): + The tokenizer is a required input. + chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages + in a chat into a tokenizable string. + """ + + attributes = ["image_processor", "tokenizer"] + valid_kwargs = ["chat_template"] + + image_processor_class = "AutoImageProcessor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs): + self.image_token = "<|image_pad|>" if not hasattr(tokenizer, "image_token") else tokenizer.image_token + self.video_token = "<|video_pad|>" if not hasattr(tokenizer, "video_token") else tokenizer.video_token + super().__init__(image_processor, tokenizer, chat_template=chat_template) + + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + videos: VideoInput = None, + **kwargs: Unpack[Qwen2_5_VLProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text` + and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode + the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to + Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`. + + Args: + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + videos (`np.ndarray`, `torch.Tensor`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of videos to be prepared. Each video can be a 4D NumPy array or PyTorch + tensor, or a nested list of 3D frames. Both channels-first and channels-last formats are supported. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when + `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not + `None`). + - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`. + - **pixel_values_videos** -- Pixel values of videos to be fed to a model. Returned when `videos` is not `None`. + - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`. + - **video_grid_thw** -- List of video 3D grid in LLM. Returned when `videos` is not `None`. + - **second_per_grid_ts** -- List of video seconds per time grid. Returned when `videos` is not `None`. + """ + output_kwargs = self._merge_kwargs( + Qwen2_5_VLProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + if images is not None: + image_inputs = self.image_processor(images=images, videos=None, **output_kwargs["images_kwargs"]) + image_grid_thw = image_inputs["image_grid_thw"] + else: + image_inputs = {} + image_grid_thw = None + + if videos is not None: + videos_inputs = self.image_processor(images=None, videos=videos, **output_kwargs["images_kwargs"]) + video_grid_thw = videos_inputs["video_grid_thw"] + + fps = output_kwargs["videos_kwargs"].pop("fps", 2.0) + if isinstance(fps, (int, float)): + second_per_grid_ts = [self.image_processor.temporal_patch_size / fps] * len(video_grid_thw) + elif hasattr(fps, "__len__") and len(fps) == len(video_grid_thw): + second_per_grid_ts = [self.image_processor.temporal_patch_size / tmp for tmp in fps] + else: + raise ValueError( + f"The length of fps ({len(fps) if hasattr(fps, '__len__') else fps}) must be equal to the length of video_grid_thw ({len(video_grid_thw)}) or fps should be a single number." + ) + videos_inputs.update({"second_per_grid_ts": second_per_grid_ts}) + + else: + videos_inputs = {} + video_grid_thw = None + + if not isinstance(text, list): + text = [text] + + if image_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.image_token in text[i]: + text[i] = text[i].replace( + self.image_token, + "<|placeholder|>" * (image_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.image_token) + + if video_grid_thw is not None: + merge_length = self.image_processor.merge_size**2 + index = 0 + for i in range(len(text)): + while self.video_token in text[i]: + text[i] = text[i].replace( + self.video_token, + "<|placeholder|>" * (video_grid_thw[index].prod() // merge_length), + 1, + ) + index += 1 + text[i] = text[i].replace("<|placeholder|>", self.video_token) + + text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"]) + + return BatchFeature(data={**text_inputs, **image_inputs, **videos_inputs}) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def batch_decode_all2all(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + decoded = self.tokenizer.batch_decode(*args, **kwargs) + pattern = r'<\|vision_start\|>.*?<\|vision_end\|>' + decoded_with_image_tag = [re.sub(pattern, '', d, flags=re.DOTALL) for d in decoded] + decoded_with_image_tag = [re.sub(r'<\|im_end\|>', '', d) for d in decoded_with_image_tag] + return decoded_with_image_tag + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + def post_process_image_text_to_text( + self, generated_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False, **kwargs + ): + """ + Post-process the output of the model to decode the text. + + Args: + generated_outputs (`torch.Tensor` or `np.ndarray`): + The output of the model `generate` function. The output is expected to be a tensor of shape `(batch_size, sequence_length)` + or `(sequence_length,)`. + skip_special_tokens (`bool`, *optional*, defaults to `True`): + Whether or not to remove special tokens in the output. Argument passed to the tokenizer's `batch_decode` method. + Clean_up_tokenization_spaces (`bool`, *optional*, defaults to `False`): + Whether or not to clean up the tokenization spaces. Argument passed to the tokenizer's `batch_decode` method. + **kwargs: + Additional arguments to be passed to the tokenizer's `batch_decode method`. + + Returns: + `List[str]`: The decoded text. + """ + return self.tokenizer.batch_decode( + generated_outputs, + skip_special_tokens=skip_special_tokens, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + @property + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + names_from_processor = list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + return names_from_processor + ["second_per_grid_ts"] + + +__all__ = ["Qwen2_5_VLProcessor"] diff --git a/diffsynth/models/nexus_gen_projector.py b/diffsynth/models/nexus_gen_projector.py new file mode 100644 index 0000000..d69b3e1 --- /dev/null +++ b/diffsynth/models/nexus_gen_projector.py @@ -0,0 +1,417 @@ +import math +import torch +import torch.nn as nn +from typing import Optional, Tuple + + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1): + mrope_section = mrope_section * 2 + cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze( + unsqueeze_dim + ) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class Qwen2_5_VLRotaryEmbedding(nn.Module): + def __init__(self, config, device=None): + super().__init__() + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + from transformers.modeling_rope_utils import _compute_default_rope_parameters + self.rope_init_fn = _compute_default_rope_parameters + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + + def _dynamic_frequency_update(self, position_ids, device): + """ + dynamic RoPE layers should recompute `inv_freq` in the following situations: + 1 - growing beyond the cached sequence length (allow scaling) + 2 - the current sequence length is in the original scale (avoid losing precision with small sequences) + """ + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_seq_len_cached: # growth + inv_freq, self.attention_scaling = self.rope_init_fn( + self.config, device, seq_len=seq_len, **self.rope_kwargs + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + self.max_seq_len_cached = seq_len + + if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset + self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + self.max_seq_len_cached = self.original_max_seq_len + + + @torch.no_grad() + def forward(self, x, position_ids): + if "dynamic" in self.rope_type: + self._dynamic_frequency_update(position_ids, device=x.device) + + # Core RoPE block. In contrast to other models, Qwen2_5_VL has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + # Force float32 (see https://github.com/huggingface/transformers/pull/29285) + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + + # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention + cos = cos * self.attention_scaling + sin = sin * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Qwen2_5_VLAttention(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.rope_scaling = config.rope_scaling + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + # Fix precision issues in Qwen2-VL float16 inference + # Replace inf values with zeros in attention weights to prevent NaN propagation + if query_states.dtype == torch.float16: + attn_weights = torch.where(torch.isinf(attn_weights), torch.zeros_like(attn_weights), attn_weights) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output + + +class Qwen2MLP(nn.Module): + def __init__(self, config): + super().__init__() + from transformers.activations import ACT2FN + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class Qwen2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Qwen2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class Qwen2_5_VLDecoderLayer(nn.Module): + def __init__(self, config, layer_idx): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = Qwen2_5_VLAttention(config, layer_idx) + + self.mlp = Qwen2MLP(config) + self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states = self.self_attn( + hidden_states=hidden_states, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class NexusGenImageEmbeddingMerger(nn.Module): + def __init__(self, num_layers=1, out_channel=4096, expand_ratio=4, device='cpu'): + super().__init__() + from transformers import Qwen2_5_VLConfig + from transformers.activations import ACT2FN + config = Qwen2_5_VLConfig(**{ + "_name_or_path": "DiffSynth-Studio/Nexus-GenV2", + "architectures": [ + "Qwen2_5_VLForConditionalGeneration" + ], + "attention_dropout": 0.0, + "auto_map": { + "AutoConfig": "configuration_qwen2_5_vl.Qwen2_5_VLConfig", + "AutoModel": "modeling_qwen2_5_vl.Qwen2_5_VLModel", + "AutoModelForCausalLM": "modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration" + }, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 3584, + "image_token_id": 151655, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 128000, + "max_window_layers": 28, + "model_type": "qwen2_5_vl", + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "pad_token_id": 151643, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "sliding_window": 32768, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.49.0", + "use_cache": False, + "use_sliding_window": False, + "video_token_id": 151656, + "vision_config": { + "hidden_size": 1280, + "in_chans": 3, + "model_type": "qwen2_5_vl", + "spatial_patch_size": 14, + "tokens_per_second": 2, + "torch_dtype": "bfloat16" + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654, + "vocab_size": 152064 + }) + self.config = config + self.num_layers = num_layers + self.layers = nn.ModuleList([Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(num_layers)]) + self.projector = nn.Sequential(Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps), + nn.Linear(config.hidden_size, out_channel * expand_ratio), + Qwen2RMSNorm(out_channel * expand_ratio, eps=config.rms_norm_eps), + ACT2FN[config.hidden_act], nn.Linear(out_channel * expand_ratio, out_channel), + Qwen2RMSNorm(out_channel, eps=config.rms_norm_eps)) + self.base_grid = torch.tensor([[1, 72, 72]], device=device) + self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config, device=device) + + def get_position_ids(self, image_grid_thw): + """ + Generates position ids for the input embeddings grid. + modified from the qwen2_vl mrope. + """ + batch_size = image_grid_thw.shape[0] + spatial_merge_size = self.config.vision_config.spatial_merge_size + t, h, w = ( + image_grid_thw[0][0], + image_grid_thw[0][1], + image_grid_thw[0][2], + ) + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + scale_h = self.base_grid[0][1].item() / h.item() + scale_w = self.base_grid[0][2].item() / w.item() + + range_tensor = torch.arange(llm_grid_t).view(-1, 1) + expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w) + time_tensor = expanded_range * self.config.vision_config.tokens_per_second + t_index = time_tensor.long().flatten().to(image_grid_thw.device) + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten().to(image_grid_thw.device) * scale_h + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten().to(image_grid_thw.device) * scale_w + # 3, B, L + position_ids = torch.stack([t_index, h_index, w_index]).unsqueeze(0).repeat(batch_size, 1, 1).permute(1, 0, 2) + return position_ids + + def forward(self, embeds, embeds_grid, ref_embeds=None, ref_embeds_grid=None): + position_ids = self.get_position_ids(embeds_grid) + hidden_states = embeds + if ref_embeds is not None: + position_ids_ref_embeds = self.get_position_ids(ref_embeds_grid) + position_ids = torch.cat((position_ids, position_ids_ref_embeds), dim=-1) + hidden_states = torch.cat((embeds, ref_embeds), dim=1) + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + for layer in self.layers: + hidden_states = layer(hidden_states, position_embeddings) + + hidden_states = self.projector(hidden_states) + return hidden_states + + @staticmethod + def state_dict_converter(): + return NexusGenMergerStateDictConverter() + + +class NexusGenMergerStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + merger_state_dict = {key.replace("embedding_merger.", ""): value for key, value in state_dict.items() if key.startswith('embedding_merger.')} + return merger_state_dict + + +class NexusGenAdapter(nn.Module): + """ + Adapter for Nexus-Gen generation decoder. + """ + def __init__(self, input_dim=3584, output_dim=4096): + super(NexusGenAdapter, self).__init__() + self.adapter = nn.Sequential(nn.Linear(input_dim, output_dim), + nn.LayerNorm(output_dim), nn.ReLU(), + nn.Linear(output_dim, output_dim), + nn.LayerNorm(output_dim)) + + def forward(self, x): + return self.adapter(x) + + @staticmethod + def state_dict_converter(): + return NexusGenAdapterStateDictConverter() + + +class NexusGenAdapterStateDictConverter: + def __init__(self): + pass + + def from_diffusers(self, state_dict): + return state_dict + + def from_civitai(self, state_dict): + adapter_state_dict = {key: value for key, value in state_dict.items() if key.startswith('adapter.')} + return adapter_state_dict diff --git a/diffsynth/pipelines/flux_image.py b/diffsynth/pipelines/flux_image.py index dcd9e8e..b200220 100644 --- a/diffsynth/pipelines/flux_image.py +++ b/diffsynth/pipelines/flux_image.py @@ -16,7 +16,7 @@ from ..models.flux_text_encoder_clip import FluxTextEncoderClip from ..models.flux_text_encoder_t5 import FluxTextEncoderT5 from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder from ..models.flux_value_control import MultiValueEncoder - +from ..core.vram.layers import AutoWrappedLinear class MultiControlNet(torch.nn.Module): def __init__(self, models: list[torch.nn.Module]): @@ -104,7 +104,12 @@ class FluxImagePipeline(BasePipeline): self.lora_loader = FluxLoRALoader def enable_lora_magic(self): - pass + if self.lora_patcher is not None: + for name, module in self.dit.named_modules(): + if isinstance(module, AutoWrappedLinear): + merger_name = name.replace(".", "___") + if merger_name in self.lora_patcher.model_dict: + module.lora_merger = self.lora_patcher.model_dict[merger_name] @staticmethod def from_pretrained( diff --git a/diffsynth/utils/state_dict_converters/flux_dit.py b/diffsynth/utils/state_dict_converters/flux_dit.py index 3250a22..a60c3a9 100644 --- a/diffsynth/utils/state_dict_converters/flux_dit.py +++ b/diffsynth/utils/state_dict_converters/flux_dit.py @@ -1,4 +1,41 @@ +import torch +import hashlib + +def convert_state_dict_keys_to_single_str(state_dict, with_shape=True): + keys = [] + all_keys = sorted(list(state_dict)) + + for key in all_keys: + value = state_dict[key] + if isinstance(key, str): + if isinstance(value, torch.Tensor): + if with_shape: + shape = "_".join(map(str, list(value.shape))) + keys.append(key + ":" + shape) + keys.append(key) + elif isinstance(value, dict): + keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape)) + keys.sort() + keys_str = ",".join(keys) + return keys_str + +def hash_state_dict_keys(state_dict, with_shape=True): + keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape) + keys_str = keys_str.encode(encoding="UTF-8") + return hashlib.md5(keys_str).hexdigest() + def FluxDiTStateDictConverter(state_dict): + model_hash = hash_state_dict_keys(state_dict, with_shape=True) + + if model_hash in ["3e6c61b0f9471135fc9c6d6a98e98b6d", "63c969fd37cce769a90aa781fbff5f81"]: + dit_state_dict = {} + for key in state_dict: + if key.startswith('pipe.dit.'): + value = state_dict[key] + new_key = key.replace("pipe.dit.", "") + dit_state_dict[new_key] = value + return dit_state_dict + rename_dict = { "time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias", "time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight", diff --git a/diffsynth/utils/state_dict_converters/flux_ipadapter.py b/diffsynth/utils/state_dict_converters/flux_ipadapter.py new file mode 100644 index 0000000..9fde053 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/flux_ipadapter.py @@ -0,0 +1,34 @@ +import torch + +def FluxIpAdapterStateDictConverter(state_dict): + state_dict_ = {} + + if "ip_adapter" in state_dict and isinstance(state_dict["ip_adapter"], dict): + for name, param in state_dict["ip_adapter"].items(): + name_ = 'ipadapter_modules.' + name + state_dict_[name_] = param + + if "image_proj" in state_dict: + for name, param in state_dict["image_proj"].items(): + name_ = "image_proj." + name + state_dict_[name_] = param + return state_dict_ + + for key, value in state_dict.items(): + if key.startswith("image_proj."): + state_dict_[key] = value + elif key.startswith("ip_adapter."): + new_key = key.replace("ip_adapter.", "ipadapter_modules.") + state_dict_[new_key] = value + else: + pass + + return state_dict_ + + +def SiglipStateDictConverter(state_dict): + new_state_dict = {} + for key in state_dict: + if key.startswith("vision_model."): + new_state_dict[key] = state_dict[key] + return new_state_dict \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/nexus_gen.py b/diffsynth/utils/state_dict_converters/nexus_gen.py new file mode 100644 index 0000000..f01ec69 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/nexus_gen.py @@ -0,0 +1,8 @@ +import torch + +def NexusGenAutoregressiveModelStateDictConverter(state_dict): + new_state_dict = {} + for key in state_dict: + value = state_dict[key] + new_state_dict["model." + key] = value + return new_state_dict \ No newline at end of file diff --git a/diffsynth/utils/state_dict_converters/nexus_gen_projector.py b/diffsynth/utils/state_dict_converters/nexus_gen_projector.py new file mode 100644 index 0000000..5f24a85 --- /dev/null +++ b/diffsynth/utils/state_dict_converters/nexus_gen_projector.py @@ -0,0 +1,17 @@ +import torch + +def NexusGenMergerStateDictConverter(state_dict): + merger_state_dict = {} + for key in state_dict: + if key.startswith('embedding_merger.'): + value = state_dict[key] + new_key = key.replace("embedding_merger.", "") + merger_state_dict[new_key] = value + return merger_state_dict + +def NexusGenAdapterStateDictConverter(state_dict): + adapter_state_dict = {} + for key in state_dict: + if key.startswith('adapter.'): + adapter_state_dict[key] = state_dict[key] + return adapter_state_dict \ No newline at end of file