mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
update
This commit is contained in:
@@ -413,6 +413,21 @@ flux_series = [
|
||||
"model_class": "diffsynth.models.flux_ipadapter.SiglipVisionModelSO400M",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_ipadapter.SiglipStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
||||
"model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
|
||||
"model_name": "step1x_connector",
|
||||
"model_class": "diffsynth.models.step1x_connector.Qwen2Connector",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.step1x_connector.Qwen2ConnectorStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="stepfun-ai/Step1X-Edit", origin_file_pattern="step1x-edit-i1258.safetensors"),
|
||||
"model_hash": "d30fb9e02b1dbf4e509142f05cf7dd50",
|
||||
"model_name": "flux_dit",
|
||||
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||
"extra_kwargs": {"disable_guidance_embedder": True},
|
||||
},
|
||||
]
|
||||
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series
|
||||
|
||||
@@ -111,4 +111,8 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
||||
"diffsynth.models.longcat_video_dit.RMSNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.longcat_video_dit.LayerNorm_FP32": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.flux_dit.FluxDiT": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"diffsynth.models.flux_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ import torch, glob, os
|
||||
from typing import Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from modelscope import snapshot_download
|
||||
from huggingface_hub import snapshot_download as hf_snapshot_download
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@@ -10,7 +11,7 @@ class ModelConfig:
|
||||
path: Union[str, list[str]] = None
|
||||
model_id: str = None
|
||||
origin_file_pattern: Union[str, list[str]] = None
|
||||
download_resource: str = "ModelScope"
|
||||
download_resource: str = None
|
||||
local_model_path: str = None
|
||||
skip_download: bool = None
|
||||
offload_device: Optional[Union[str, torch.device]] = None
|
||||
@@ -30,13 +31,29 @@ class ModelConfig:
|
||||
def download(self):
|
||||
origin_file_pattern = self.origin_file_pattern + ("*" if self.origin_file_pattern.endswith("/") else "")
|
||||
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
|
||||
snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||
allow_file_pattern=self.origin_file_pattern,
|
||||
ignore_file_pattern=downloaded_files,
|
||||
local_files_only=False
|
||||
)
|
||||
if self.download_resource is None:
|
||||
if os.environ.get('DIFFSYNTH_DOWNLOAD_RESOURCE') is not None:
|
||||
self.download_resource = os.environ.get('DIFFSYNTH_DOWNLOAD_RESOURCE')
|
||||
else:
|
||||
self.download_resource = "modelscope"
|
||||
if self.download_resource.lower() == "modelscope":
|
||||
snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||
allow_file_pattern=self.origin_file_pattern,
|
||||
ignore_file_pattern=downloaded_files,
|
||||
local_files_only=False
|
||||
)
|
||||
elif self.download_resource.lower() == "huggingface":
|
||||
hf_snapshot_download(
|
||||
self.model_id,
|
||||
local_dir=os.path.join(self.local_model_path, self.model_id),
|
||||
allow_patterns=self.origin_file_pattern,
|
||||
ignore_patterns=downloaded_files,
|
||||
local_files_only=False
|
||||
)
|
||||
else:
|
||||
raise ValueError("`download_resource` should be `modelscope` or `huggingface`.")
|
||||
|
||||
def require_downloading(self):
|
||||
if self.path is not None:
|
||||
|
||||
@@ -77,7 +77,7 @@ class ModelPool:
|
||||
print(f"Loaded model: {json.dumps(model_info, indent=4)}")
|
||||
loaded = True
|
||||
if not loaded:
|
||||
raise ValueError(f"Cannot detect the model type. File: {path}. Model hash: {model_hash}.")
|
||||
raise ValueError(f"Cannot detect the model type. File: {path}. Model hash: {model_hash}")
|
||||
|
||||
def fetch_model(self, model_name, index=None):
|
||||
fetched_models = []
|
||||
|
||||
@@ -143,6 +143,7 @@ class QwenImageTextEncoder(torch.nn.Module):
|
||||
})
|
||||
self.model = Qwen2_5_VLModel(config)
|
||||
self.lm_head = torch.nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
663
diffsynth/models/step1x_connector.py
Normal file
663
diffsynth/models/step1x_connector.py
Normal file
@@ -0,0 +1,663 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch, math
|
||||
import torch.nn
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from functools import partial
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
|
||||
def attention(q, k, v, attn_mask, mode="torch"):
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
x = rearrange(x, "b n s d -> b s (n d)")
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
hidden_channels=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=None,
|
||||
bias=True,
|
||||
drop=0.0,
|
||||
use_conv=False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_channels
|
||||
hidden_channels = hidden_channels or in_channels
|
||||
bias = (bias, bias)
|
||||
drop_probs = (drop, drop)
|
||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
||||
|
||||
self.fc1 = linear_layer(
|
||||
in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype
|
||||
)
|
||||
self.act = act_layer()
|
||||
self.drop1 = nn.Dropout(drop_probs[0])
|
||||
self.norm = (
|
||||
norm_layer(hidden_channels, device=device, dtype=dtype)
|
||||
if norm_layer is not None
|
||||
else nn.Identity()
|
||||
)
|
||||
self.fc2 = linear_layer(
|
||||
hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype
|
||||
)
|
||||
self.drop2 = nn.Dropout(drop_probs[1])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop1(x)
|
||||
x = self.norm(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
|
||||
|
||||
class TextProjection(nn.Module):
|
||||
"""
|
||||
Projects text embeddings. Also handles dropout for classifier-free guidance.
|
||||
|
||||
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(
|
||||
in_features=in_channels,
|
||||
out_features=hidden_size,
|
||||
bias=True,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.act_1 = act_layer()
|
||||
self.linear_2 = nn.Linear(
|
||||
in_features=hidden_size,
|
||||
out_features=hidden_size,
|
||||
bias=True,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear_1(caption)
|
||||
hidden_states = self.act_1(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
act_layer,
|
||||
frequency_embedding_size=256,
|
||||
max_period=10000,
|
||||
out_size=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
):
|
||||
factory_kwargs = {"dtype": dtype, "device": device}
|
||||
super().__init__()
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.max_period = max_period
|
||||
if out_size is None:
|
||||
out_size = hidden_size
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(
|
||||
frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
|
||||
),
|
||||
act_layer(),
|
||||
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
||||
)
|
||||
nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore
|
||||
nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
|
||||
Args:
|
||||
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
||||
dim (int): the dimension of the output.
|
||||
max_period (int): controls the minimum frequency of the embeddings.
|
||||
|
||||
Returns:
|
||||
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
|
||||
|
||||
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||
/ half
|
||||
).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(
|
||||
t, self.frequency_embedding_size, self.max_period
|
||||
).type(t.dtype) # type: ignore
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
def apply_gate(x, gate=None, tanh=False):
|
||||
"""AI is creating summary for apply_gate
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): input tensor.
|
||||
gate (torch.Tensor, optional): gate tensor. Defaults to None.
|
||||
tanh (bool, optional): whether to use tanh function. Defaults to False.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the output tensor after apply gate.
|
||||
"""
|
||||
if gate is None:
|
||||
return x
|
||||
if tanh:
|
||||
return x * gate.unsqueeze(1).tanh()
|
||||
else:
|
||||
return x * gate.unsqueeze(1)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
elementwise_affine=True,
|
||||
eps: float = 1e-6,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (nn.Parameter): Learnable scaling parameter.
|
||||
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
|
||||
"""
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
if hasattr(self, "weight"):
|
||||
output = output * self.weight
|
||||
return output
|
||||
|
||||
|
||||
def get_norm_layer(norm_layer):
|
||||
"""
|
||||
Get the normalization layer.
|
||||
|
||||
Args:
|
||||
norm_layer (str): The type of normalization layer.
|
||||
|
||||
Returns:
|
||||
norm_layer (nn.Module): The normalization layer.
|
||||
"""
|
||||
if norm_layer == "layer":
|
||||
return nn.LayerNorm
|
||||
elif norm_layer == "rms":
|
||||
return RMSNorm
|
||||
else:
|
||||
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
||||
|
||||
|
||||
def get_activation_layer(act_type):
|
||||
"""get activation layer
|
||||
|
||||
Args:
|
||||
act_type (str): the activation type
|
||||
|
||||
Returns:
|
||||
torch.nn.functional: the activation layer
|
||||
"""
|
||||
if act_type == "gelu":
|
||||
return lambda: nn.GELU()
|
||||
elif act_type == "gelu_tanh":
|
||||
return lambda: nn.GELU(approximate="tanh")
|
||||
elif act_type == "relu":
|
||||
return nn.ReLU
|
||||
elif act_type == "silu":
|
||||
return nn.SiLU
|
||||
else:
|
||||
raise ValueError(f"Unknown activation type: {act_type}")
|
||||
|
||||
class IndividualTokenRefinerBlock(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
heads_num,
|
||||
mlp_width_ratio: str = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
act_type: str = "silu",
|
||||
qk_norm: bool = False,
|
||||
qk_norm_type: str = "layer",
|
||||
qkv_bias: bool = True,
|
||||
need_CA: bool = False,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.need_CA = need_CA
|
||||
self.heads_num = heads_num
|
||||
head_dim = hidden_size // heads_num
|
||||
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
||||
|
||||
self.norm1 = nn.LayerNorm(
|
||||
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
||||
)
|
||||
self.self_attn_qkv = nn.Linear(
|
||||
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
|
||||
)
|
||||
qk_norm_layer = get_norm_layer(qk_norm_type)
|
||||
self.self_attn_q_norm = (
|
||||
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||
if qk_norm
|
||||
else nn.Identity()
|
||||
)
|
||||
self.self_attn_k_norm = (
|
||||
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||
if qk_norm
|
||||
else nn.Identity()
|
||||
)
|
||||
self.self_attn_proj = nn.Linear(
|
||||
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(
|
||||
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
||||
)
|
||||
act_layer = get_activation_layer(act_type)
|
||||
self.mlp = MLP(
|
||||
in_channels=hidden_size,
|
||||
hidden_channels=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=mlp_drop_rate,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
act_layer(),
|
||||
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
||||
)
|
||||
|
||||
if self.need_CA:
|
||||
self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size,
|
||||
heads_num=heads_num,
|
||||
mlp_width_ratio=mlp_width_ratio,
|
||||
mlp_drop_rate=mlp_drop_rate,
|
||||
act_type=act_type,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_type=qk_norm_type,
|
||||
qkv_bias=qkv_bias,
|
||||
**factory_kwargs,)
|
||||
# Zero-initialize the modulation
|
||||
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
||||
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
|
||||
attn_mask: torch.Tensor = None,
|
||||
y: torch.Tensor = None,
|
||||
):
|
||||
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
|
||||
norm_x = self.norm1(x)
|
||||
qkv = self.self_attn_qkv(norm_x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
# Apply QK-Norm if needed
|
||||
q = self.self_attn_q_norm(q).to(v)
|
||||
k = self.self_attn_k_norm(k).to(v)
|
||||
|
||||
# Self-Attention
|
||||
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
|
||||
|
||||
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
||||
|
||||
if self.need_CA:
|
||||
x = self.cross_attnblock(x, c, attn_mask, y)
|
||||
|
||||
# FFN Layer
|
||||
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
class CrossAttnBlock(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
heads_num,
|
||||
mlp_width_ratio: str = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
act_type: str = "silu",
|
||||
qk_norm: bool = False,
|
||||
qk_norm_type: str = "layer",
|
||||
qkv_bias: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.heads_num = heads_num
|
||||
head_dim = hidden_size // heads_num
|
||||
|
||||
self.norm1 = nn.LayerNorm(
|
||||
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
||||
)
|
||||
self.norm1_2 = nn.LayerNorm(
|
||||
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
||||
)
|
||||
self.self_attn_q = nn.Linear(
|
||||
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
||||
)
|
||||
self.self_attn_kv = nn.Linear(
|
||||
hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs
|
||||
)
|
||||
qk_norm_layer = get_norm_layer(qk_norm_type)
|
||||
self.self_attn_q_norm = (
|
||||
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||
if qk_norm
|
||||
else nn.Identity()
|
||||
)
|
||||
self.self_attn_k_norm = (
|
||||
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
||||
if qk_norm
|
||||
else nn.Identity()
|
||||
)
|
||||
self.self_attn_proj = nn.Linear(
|
||||
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(
|
||||
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
|
||||
)
|
||||
act_layer = get_activation_layer(act_type)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
act_layer(),
|
||||
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
||||
)
|
||||
# Zero-initialize the modulation
|
||||
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
||||
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
|
||||
attn_mask: torch.Tensor = None,
|
||||
y: torch.Tensor=None,
|
||||
|
||||
):
|
||||
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
|
||||
norm_x = self.norm1(x)
|
||||
norm_y = self.norm1_2(y)
|
||||
q = self.self_attn_q(norm_x)
|
||||
q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
|
||||
kv = self.self_attn_kv(norm_y)
|
||||
k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num)
|
||||
# Apply QK-Norm if needed
|
||||
q = self.self_attn_q_norm(q).to(v)
|
||||
k = self.self_attn_k_norm(k).to(v)
|
||||
|
||||
# Self-Attention
|
||||
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
|
||||
|
||||
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class IndividualTokenRefiner(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
heads_num,
|
||||
depth,
|
||||
mlp_width_ratio: float = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
act_type: str = "silu",
|
||||
qk_norm: bool = False,
|
||||
qk_norm_type: str = "layer",
|
||||
qkv_bias: bool = True,
|
||||
need_CA:bool=False,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.need_CA = need_CA
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
IndividualTokenRefinerBlock(
|
||||
hidden_size=hidden_size,
|
||||
heads_num=heads_num,
|
||||
mlp_width_ratio=mlp_width_ratio,
|
||||
mlp_drop_rate=mlp_drop_rate,
|
||||
act_type=act_type,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_type=qk_norm_type,
|
||||
qkv_bias=qkv_bias,
|
||||
need_CA=self.need_CA,
|
||||
**factory_kwargs,
|
||||
)
|
||||
for _ in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
c: torch.LongTensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
y:torch.Tensor=None,
|
||||
):
|
||||
self_attn_mask = None
|
||||
if mask is not None:
|
||||
batch_size = mask.shape[0]
|
||||
seq_len = mask.shape[1]
|
||||
mask = mask.to(x.device)
|
||||
# batch_size x 1 x seq_len x seq_len
|
||||
self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
|
||||
1, 1, seq_len, 1
|
||||
)
|
||||
# batch_size x 1 x seq_len x seq_len
|
||||
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
||||
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
|
||||
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
||||
# avoids self-attention weight being NaN for padding tokens
|
||||
self_attn_mask[:, :, :, 0] = True
|
||||
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, c, self_attn_mask,y)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SingleTokenRefiner(torch.nn.Module):
|
||||
"""
|
||||
A single token refiner block for llm text embedding refine.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
hidden_size,
|
||||
heads_num,
|
||||
depth,
|
||||
mlp_width_ratio: float = 4.0,
|
||||
mlp_drop_rate: float = 0.0,
|
||||
act_type: str = "silu",
|
||||
qk_norm: bool = False,
|
||||
qk_norm_type: str = "layer",
|
||||
qkv_bias: bool = True,
|
||||
need_CA:bool=False,
|
||||
attn_mode: str = "torch",
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.attn_mode = attn_mode
|
||||
self.need_CA = need_CA
|
||||
assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
|
||||
|
||||
self.input_embedder = nn.Linear(
|
||||
in_channels, hidden_size, bias=True, **factory_kwargs
|
||||
)
|
||||
if self.need_CA:
|
||||
self.input_embedder_CA = nn.Linear(
|
||||
in_channels, hidden_size, bias=True, **factory_kwargs
|
||||
)
|
||||
|
||||
act_layer = get_activation_layer(act_type)
|
||||
# Build timestep embedding layer
|
||||
self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
|
||||
# Build context embedding layer
|
||||
self.c_embedder = TextProjection(
|
||||
in_channels, hidden_size, act_layer, **factory_kwargs
|
||||
)
|
||||
|
||||
self.individual_token_refiner = IndividualTokenRefiner(
|
||||
hidden_size=hidden_size,
|
||||
heads_num=heads_num,
|
||||
depth=depth,
|
||||
mlp_width_ratio=mlp_width_ratio,
|
||||
mlp_drop_rate=mlp_drop_rate,
|
||||
act_type=act_type,
|
||||
qk_norm=qk_norm,
|
||||
qk_norm_type=qk_norm_type,
|
||||
qkv_bias=qkv_bias,
|
||||
need_CA=need_CA,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.LongTensor,
|
||||
mask: Optional[torch.LongTensor] = None,
|
||||
y: torch.LongTensor=None,
|
||||
):
|
||||
timestep_aware_representations = self.t_embedder(t)
|
||||
|
||||
if mask is None:
|
||||
context_aware_representations = x.mean(dim=1)
|
||||
else:
|
||||
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
|
||||
context_aware_representations = (x * mask_float).sum(
|
||||
dim=1
|
||||
) / mask_float.sum(dim=1)
|
||||
context_aware_representations = self.c_embedder(context_aware_representations)
|
||||
c = timestep_aware_representations + context_aware_representations
|
||||
|
||||
x = self.input_embedder(x)
|
||||
if self.need_CA:
|
||||
y = self.input_embedder_CA(y)
|
||||
x = self.individual_token_refiner(x, c, mask, y)
|
||||
else:
|
||||
x = self.individual_token_refiner(x, c, mask)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Qwen2Connector(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
# biclip_dim=1024,
|
||||
in_channels=3584,
|
||||
hidden_size=4096,
|
||||
heads_num=32,
|
||||
depth=2,
|
||||
need_CA=False,
|
||||
device=None,
|
||||
dtype=torch.bfloat16,
|
||||
):
|
||||
super().__init__()
|
||||
factory_kwargs = {"device": device, "dtype":dtype}
|
||||
|
||||
self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs)
|
||||
self.global_proj_out=nn.Linear(in_channels,768)
|
||||
|
||||
self.scale_factor = nn.Parameter(torch.zeros(1))
|
||||
with torch.no_grad():
|
||||
self.scale_factor.data += -(1 - 0.09)
|
||||
|
||||
def forward(self, x,t,mask):
|
||||
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
|
||||
x_mean = (x * mask_float).sum(
|
||||
dim=1
|
||||
) / mask_float.sum(dim=1) * (1 + self.scale_factor.to(dtype=x.dtype, device=x.device))
|
||||
|
||||
global_out=self.global_proj_out(x_mean)
|
||||
encoder_hidden_states = self.S(x,t,mask)
|
||||
return encoder_hidden_states,global_out
|
||||
194
diffsynth/models/step1x_text_encoder.py
Normal file
194
diffsynth/models/step1x_text_encoder.py
Normal file
@@ -0,0 +1,194 @@
|
||||
import torch
|
||||
from typing import Optional, Union
|
||||
from .qwen_image_text_encoder import QwenImageTextEncoder
|
||||
|
||||
|
||||
class Step1xEditEmbedder(torch.nn.Module):
|
||||
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device="cuda"):
|
||||
super().__init__()
|
||||
self.max_length = max_length
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
|
||||
- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
|
||||
- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n
|
||||
Here are examples of how to transform or refine prompts:
|
||||
- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.
|
||||
- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n
|
||||
Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
|
||||
User Prompt:'''
|
||||
|
||||
self.prefix = Qwen25VL_7b_PREFIX
|
||||
self.model = model
|
||||
self.processor = processor
|
||||
|
||||
def model_forward(
|
||||
self,
|
||||
model: QwenImageTextEncoder,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else model.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else model.config.output_hidden_states
|
||||
)
|
||||
|
||||
outputs = model.model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
return outputs.hidden_states
|
||||
|
||||
def forward(self, caption, ref_images):
|
||||
text_list = caption
|
||||
embs = torch.zeros(
|
||||
len(text_list),
|
||||
self.max_length,
|
||||
self.model.config.hidden_size,
|
||||
dtype=torch.bfloat16,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
masks = torch.zeros(
|
||||
len(text_list),
|
||||
self.max_length,
|
||||
dtype=torch.long,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
|
||||
def split_string(s):
|
||||
s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes
|
||||
result = []
|
||||
in_quotes = False
|
||||
temp = ""
|
||||
|
||||
for idx,char in enumerate(s):
|
||||
if char == '"' and idx>155:
|
||||
temp += char
|
||||
if not in_quotes:
|
||||
result.append(temp)
|
||||
temp = ""
|
||||
|
||||
in_quotes = not in_quotes
|
||||
continue
|
||||
if in_quotes:
|
||||
if char.isspace():
|
||||
pass # have space token
|
||||
|
||||
result.append("“" + char + "”")
|
||||
else:
|
||||
temp += char
|
||||
|
||||
if temp:
|
||||
result.append(temp)
|
||||
|
||||
return result
|
||||
|
||||
for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)):
|
||||
|
||||
messages = [{"role": "user", "content": []}]
|
||||
|
||||
messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"})
|
||||
|
||||
messages[0]["content"].append({"type": "image", "image": imgs})
|
||||
|
||||
# 再添加 text
|
||||
messages[0]["content"].append({"type": "text", "text": f"{txt}"})
|
||||
|
||||
# Preparation for inference
|
||||
text = self.processor.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True, add_vision_id=True
|
||||
)
|
||||
|
||||
image_inputs = [imgs]
|
||||
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
old_inputs_ids = inputs.input_ids
|
||||
text_split_list = split_string(text)
|
||||
|
||||
token_list = []
|
||||
for text_each in text_split_list:
|
||||
txt_inputs = self.processor(
|
||||
text=text_each,
|
||||
images=None,
|
||||
videos=None,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
token_each = txt_inputs.input_ids
|
||||
if token_each[0][0] == 2073 and token_each[0][-1] == 854:
|
||||
token_each = token_each[:, 1:-1]
|
||||
token_list.append(token_each)
|
||||
else:
|
||||
token_list.append(token_each)
|
||||
|
||||
new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
|
||||
|
||||
new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
|
||||
|
||||
idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0]
|
||||
idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0]
|
||||
inputs.input_ids = (
|
||||
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
|
||||
.unsqueeze(0)
|
||||
.to("cuda")
|
||||
)
|
||||
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
|
||||
outputs = self.model_forward(
|
||||
self.model,
|
||||
input_ids=inputs.input_ids,
|
||||
attention_mask=inputs.attention_mask,
|
||||
pixel_values=inputs.pixel_values.to("cuda"),
|
||||
image_grid_thw=inputs.image_grid_thw.to("cuda"),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
emb = outputs[-1]
|
||||
|
||||
embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][
|
||||
: self.max_length
|
||||
]
|
||||
|
||||
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
|
||||
(min(self.max_length, emb.shape[1] - 217)),
|
||||
dtype=torch.long,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
|
||||
return embs, masks
|
||||
@@ -16,6 +16,7 @@ from ..models.flux_text_encoder_clip import FluxTextEncoderClip
|
||||
from ..models.flux_text_encoder_t5 import FluxTextEncoderT5
|
||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
from ..models.flux_value_control import MultiValueEncoder
|
||||
from ..models.step1x_text_encoder import Step1xEditEmbedder
|
||||
from ..core.vram.layers import AutoWrappedLinear
|
||||
|
||||
class MultiControlNet(torch.nn.Module):
|
||||
@@ -103,7 +104,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
self.model_fn = model_fn_flux_image
|
||||
self.lora_loader = FluxLoRALoader
|
||||
|
||||
def enable_lora_magic(self):
|
||||
def enable_lora_merger(self):
|
||||
if self.lora_patcher is not None:
|
||||
for name, module in self.dit.named_modules():
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
@@ -118,7 +119,8 @@ class FluxImagePipeline(BasePipeline):
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"),
|
||||
tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"),
|
||||
nexus_gen_processor_config: ModelConfig = None,
|
||||
nexus_gen_processor_config: ModelConfig = ModelConfig(model_id="DiffSynth-Studio/Nexus-GenV2", origin_file_pattern="processor/"),
|
||||
step1x_processor_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen2.5-VL-7B-Instruct", origin_file_pattern=""),
|
||||
vram_limit: float = None,
|
||||
):
|
||||
# Initialize pipeline
|
||||
@@ -144,7 +146,12 @@ class FluxImagePipeline(BasePipeline):
|
||||
if controlnets is not None: pipe.controlnet = MultiControlNet(controlnets)
|
||||
pipe.ipadapter = model_pool.fetch_model("flux_ipadapter")
|
||||
pipe.ipadapter_image_encoder = model_pool.fetch_model("siglip_vision_model")
|
||||
pipe.qwenvl = model_pool.fetch_model("qwenvl")
|
||||
qwenvl = model_pool.fetch_model("qwen_image_text_encoder")
|
||||
if qwenvl is not None:
|
||||
from transformers import AutoProcessor
|
||||
step1x_processor_config.download_if_necessary()
|
||||
processor = AutoProcessor.from_pretrained(step1x_processor_config.path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28)
|
||||
pipe.qwenvl = Step1xEditEmbedder(qwenvl, processor)
|
||||
pipe.step1x_connector = model_pool.fetch_model("step1x_connector")
|
||||
pipe.image_proj_model = model_pool.fetch_model("infiniteyou_image_projector")
|
||||
if pipe.image_proj_model is not None:
|
||||
@@ -154,7 +161,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
pipe.nexus_gen = model_pool.fetch_model("nexus_gen_llm")
|
||||
pipe.nexus_gen_generation_adapter = model_pool.fetch_model("nexus_gen_generation_adapter")
|
||||
pipe.nexus_gen_editing_adapter = model_pool.fetch_model("nexus_gen_editing_adapter")
|
||||
if nexus_gen_processor_config is not None and pipe.nexus_gen is not None:
|
||||
if pipe.nexus_gen is not None:
|
||||
nexus_gen_processor_config.download_if_necessary()
|
||||
pipe.nexus_gen.load_processor(nexus_gen_processor_config.path)
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import torch
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
|
||||
def FluxControlNetStateDictConverter(state_dict):
|
||||
global_rename_dict = {
|
||||
|
||||
@@ -1,39 +1,17 @@
|
||||
import torch
|
||||
import hashlib
|
||||
|
||||
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
||||
keys = []
|
||||
all_keys = sorted(list(state_dict))
|
||||
|
||||
for key in all_keys:
|
||||
value = state_dict[key]
|
||||
if isinstance(key, str):
|
||||
if isinstance(value, torch.Tensor):
|
||||
if with_shape:
|
||||
shape = "_".join(map(str, list(value.shape)))
|
||||
keys.append(key + ":" + shape)
|
||||
keys.append(key)
|
||||
elif isinstance(value, dict):
|
||||
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
||||
keys.sort()
|
||||
keys_str = ",".join(keys)
|
||||
return keys_str
|
||||
|
||||
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||
keys_str = keys_str.encode(encoding="UTF-8")
|
||||
return hashlib.md5(keys_str).hexdigest()
|
||||
|
||||
def FluxDiTStateDictConverter(state_dict):
|
||||
model_hash = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
|
||||
if model_hash in ["3e6c61b0f9471135fc9c6d6a98e98b6d", "63c969fd37cce769a90aa781fbff5f81"]:
|
||||
is_nexus_gen = sum([key.startswith("pipe.dit.") for key in state_dict]) > 0
|
||||
if is_nexus_gen:
|
||||
dit_state_dict = {}
|
||||
for key in state_dict:
|
||||
if key.startswith('pipe.dit.'):
|
||||
value = state_dict[key]
|
||||
param = state_dict[key]
|
||||
new_key = key.replace("pipe.dit.", "")
|
||||
dit_state_dict[new_key] = value
|
||||
if new_key.startswith("final_norm_out.linear."):
|
||||
param = torch.concat([param[3072:], param[:3072]], dim=0)
|
||||
dit_state_dict[new_key] = param
|
||||
return dit_state_dict
|
||||
|
||||
rename_dict = {
|
||||
|
||||
@@ -1,4 +1,2 @@
|
||||
import torch
|
||||
|
||||
def FluxInfiniteYouImageProjectorStateDictConverter(state_dict):
|
||||
return state_dict['image_proj']
|
||||
@@ -1,5 +1,3 @@
|
||||
import torch
|
||||
|
||||
def FluxIpAdapterStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import torch
|
||||
|
||||
def NexusGenAutoregressiveModelStateDictConverter(state_dict):
|
||||
new_state_dict = {}
|
||||
for key in state_dict:
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import torch
|
||||
|
||||
def NexusGenMergerStateDictConverter(state_dict):
|
||||
merger_state_dict = {}
|
||||
for key in state_dict:
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
def Qwen2ConnectorStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("connector."):
|
||||
name_ = name[len("connector."):]
|
||||
state_dict_[name_] = state_dict[name]
|
||||
return state_dict_
|
||||
Reference in New Issue
Block a user