support step1x

This commit is contained in:
Artiprocher
2025-04-27 20:37:43 +08:00
parent 4f01b37a2a
commit 109a0a0d49
5 changed files with 947 additions and 10 deletions

168
diffsynth/models/qwenvl.py Normal file
View File

@@ -0,0 +1,168 @@
import torch
class Qwen25VL_7b_Embedder(torch.nn.Module):
def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"):
super(Qwen25VL_7b_Embedder, self).__init__()
self.max_length = max_length
self.dtype = dtype
self.device = device
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
torch_dtype=dtype,
).to(torch.cuda.current_device())
self.model.requires_grad_(False)
self.processor = AutoProcessor.from_pretrained(
model_path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28
)
Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n
Here are examples of how to transform or refine prompts:
- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.
- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.\n
Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:
User Prompt:'''
self.prefix = Qwen25VL_7b_PREFIX
@staticmethod
def from_pretrained(path, torch_dtype=torch.bfloat16, device="cuda"):
return Qwen25VL_7b_Embedder(path, dtype=torch_dtype, device=device)
def forward(self, caption, ref_images):
text_list = caption
embs = torch.zeros(
len(text_list),
self.max_length,
self.model.config.hidden_size,
dtype=torch.bfloat16,
device=torch.cuda.current_device(),
)
hidden_states = torch.zeros(
len(text_list),
self.max_length,
self.model.config.hidden_size,
dtype=torch.bfloat16,
device=torch.cuda.current_device(),
)
masks = torch.zeros(
len(text_list),
self.max_length,
dtype=torch.long,
device=torch.cuda.current_device(),
)
input_ids_list = []
attention_mask_list = []
emb_list = []
def split_string(s):
s = s.replace("", '"').replace("", '"').replace("'", '''"''') # use english quotes
result = []
in_quotes = False
temp = ""
for idx,char in enumerate(s):
if char == '"' and idx>155:
temp += char
if not in_quotes:
result.append(temp)
temp = ""
in_quotes = not in_quotes
continue
if in_quotes:
if char.isspace():
pass # have space token
result.append("" + char + "")
else:
temp += char
if temp:
result.append(temp)
return result
for idx, (txt, imgs) in enumerate(zip(text_list, ref_images)):
messages = [{"role": "user", "content": []}]
messages[0]["content"].append({"type": "text", "text": f"{self.prefix}"})
messages[0]["content"].append({"type": "image", "image": imgs})
# 再添加 text
messages[0]["content"].append({"type": "text", "text": f"{txt}"})
# Preparation for inference
text = self.processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, add_vision_id=True
)
image_inputs = [imgs]
inputs = self.processor(
text=[text],
images=image_inputs,
padding=True,
return_tensors="pt",
)
old_inputs_ids = inputs.input_ids
text_split_list = split_string(text)
token_list = []
for text_each in text_split_list:
txt_inputs = self.processor(
text=text_each,
images=None,
videos=None,
padding=True,
return_tensors="pt",
)
token_each = txt_inputs.input_ids
if token_each[0][0] == 2073 and token_each[0][-1] == 854:
token_each = token_each[:, 1:-1]
token_list.append(token_each)
else:
token_list.append(token_each)
new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
idx1 = (old_inputs_ids == 151653).nonzero(as_tuple=True)[1][0]
idx2 = (new_txt_ids == 151653).nonzero(as_tuple=True)[1][0]
inputs.input_ids = (
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
.unsqueeze(0)
.to("cuda")
)
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
outputs = self.model(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
pixel_values=inputs.pixel_values.to("cuda"),
image_grid_thw=inputs.image_grid_thw.to("cuda"),
output_hidden_states=True,
)
emb = outputs["hidden_states"][-1]
embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][
: self.max_length
]
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
(min(self.max_length, emb.shape[1] - 217)),
dtype=torch.long,
device=torch.cuda.current_device(),
)
return embs, masks

View File

@@ -0,0 +1,683 @@
from typing import Optional
import torch, math
import torch.nn
from einops import rearrange
from torch import nn
from functools import partial
from einops import rearrange
def attention(q, k, v, attn_mask, mode="torch"):
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
x = rearrange(x, "b n s d -> b s (n d)")
return x
class MLP(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_channels,
hidden_channels=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.0,
use_conv=False,
device=None,
dtype=None,
):
super().__init__()
out_features = out_features or in_channels
hidden_channels = hidden_channels or in_channels
bias = (bias, bias)
drop_probs = (drop, drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(
in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype
)
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = (
norm_layer(hidden_channels, device=device, dtype=dtype)
if norm_layer is not None
else nn.Identity()
)
self.fc2 = linear_layer(
hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype
)
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class TextProjection(nn.Module):
"""
Projects text embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.linear_1 = nn.Linear(
in_features=in_channels,
out_features=hidden_size,
bias=True,
**factory_kwargs,
)
self.act_1 = act_layer()
self.linear_2 = nn.Linear(
in_features=hidden_size,
out_features=hidden_size,
bias=True,
**factory_kwargs,
)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(
self,
hidden_size,
act_layer,
frequency_embedding_size=256,
max_period=10000,
out_size=None,
dtype=None,
device=None,
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.frequency_embedding_size = frequency_embedding_size
self.max_period = max_period
if out_size is None:
out_size = hidden_size
self.mlp = nn.Sequential(
nn.Linear(
frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
),
act_layer(),
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
)
nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore
nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
Args:
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
dim (int): the dimension of the output.
max_period (int): controls the minimum frequency of the embeddings.
Returns:
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(
t, self.frequency_embedding_size, self.max_period
).type(self.mlp[0].weight.dtype) # type: ignore
t_emb = self.mlp(t_freq)
return t_emb
def apply_gate(x, gate=None, tanh=False):
"""AI is creating summary for apply_gate
Args:
x (torch.Tensor): input tensor.
gate (torch.Tensor, optional): gate tensor. Defaults to None.
tanh (bool, optional): whether to use tanh function. Defaults to False.
Returns:
torch.Tensor: the output tensor after apply gate.
"""
if gate is None:
return x
if tanh:
return x * gate.unsqueeze(1).tanh()
else:
return x * gate.unsqueeze(1)
class RMSNorm(nn.Module):
def __init__(
self,
dim: int,
elementwise_affine=True,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
def get_norm_layer(norm_layer):
"""
Get the normalization layer.
Args:
norm_layer (str): The type of normalization layer.
Returns:
norm_layer (nn.Module): The normalization layer.
"""
if norm_layer == "layer":
return nn.LayerNorm
elif norm_layer == "rms":
return RMSNorm
else:
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
def get_activation_layer(act_type):
"""get activation layer
Args:
act_type (str): the activation type
Returns:
torch.nn.functional: the activation layer
"""
if act_type == "gelu":
return lambda: nn.GELU()
elif act_type == "gelu_tanh":
return lambda: nn.GELU(approximate="tanh")
elif act_type == "relu":
return nn.ReLU
elif act_type == "silu":
return nn.SiLU
else:
raise ValueError(f"Unknown activation type: {act_type}")
class IndividualTokenRefinerBlock(torch.nn.Module):
def __init__(
self,
hidden_size,
heads_num,
mlp_width_ratio: str = 4.0,
mlp_drop_rate: float = 0.0,
act_type: str = "silu",
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
need_CA: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.need_CA = need_CA
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
)
self.self_attn_qkv = nn.Linear(
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.self_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.self_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.self_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
)
act_layer = get_activation_layer(act_type)
self.mlp = MLP(
in_channels=hidden_size,
hidden_channels=mlp_hidden_dim,
act_layer=act_layer,
drop=mlp_drop_rate,
**factory_kwargs,
)
self.adaLN_modulation = nn.Sequential(
act_layer(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
)
if self.need_CA:
self.cross_attnblock=CrossAttnBlock(hidden_size=hidden_size,
heads_num=heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_drop_rate=mlp_drop_rate,
act_type=act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
**factory_kwargs,)
# Zero-initialize the modulation
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)
def forward(
self,
x: torch.Tensor,
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
attn_mask: torch.Tensor = None,
y: torch.Tensor = None,
):
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
qkv = self.self_attn_qkv(norm_x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
# Apply QK-Norm if needed
q = self.self_attn_q_norm(q).to(v)
k = self.self_attn_k_norm(k).to(v)
# Self-Attention
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
if self.need_CA:
x = self.cross_attnblock(x, c, attn_mask, y)
# FFN Layer
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
return x
class CrossAttnBlock(torch.nn.Module):
def __init__(
self,
hidden_size,
heads_num,
mlp_width_ratio: str = 4.0,
mlp_drop_rate: float = 0.0,
act_type: str = "silu",
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.heads_num = heads_num
head_dim = hidden_size // heads_num
self.norm1 = nn.LayerNorm(
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
)
self.norm1_2 = nn.LayerNorm(
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
)
self.self_attn_q = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.self_attn_kv = nn.Linear(
hidden_size, hidden_size*2, bias=qkv_bias, **factory_kwargs
)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.self_attn_q_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.self_attn_k_norm = (
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm
else nn.Identity()
)
self.self_attn_proj = nn.Linear(
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
)
self.norm2 = nn.LayerNorm(
hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
)
act_layer = get_activation_layer(act_type)
self.adaLN_modulation = nn.Sequential(
act_layer(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
)
# Zero-initialize the modulation
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)
def forward(
self,
x: torch.Tensor,
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
attn_mask: torch.Tensor = None,
y: torch.Tensor=None,
):
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
norm_x = self.norm1(x)
norm_y = self.norm1_2(y)
q = self.self_attn_q(norm_x)
q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
kv = self.self_attn_kv(norm_y)
k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num)
# Apply QK-Norm if needed
q = self.self_attn_q_norm(q).to(v)
k = self.self_attn_k_norm(k).to(v)
# Self-Attention
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
return x
class IndividualTokenRefiner(torch.nn.Module):
def __init__(
self,
hidden_size,
heads_num,
depth,
mlp_width_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
act_type: str = "silu",
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
need_CA:bool=False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.need_CA = need_CA
self.blocks = nn.ModuleList(
[
IndividualTokenRefinerBlock(
hidden_size=hidden_size,
heads_num=heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_drop_rate=mlp_drop_rate,
act_type=act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
need_CA=self.need_CA,
**factory_kwargs,
)
for _ in range(depth)
]
)
def forward(
self,
x: torch.Tensor,
c: torch.LongTensor,
mask: Optional[torch.Tensor] = None,
y:torch.Tensor=None,
):
self_attn_mask = None
if mask is not None:
batch_size = mask.shape[0]
seq_len = mask.shape[1]
mask = mask.to(x.device)
# batch_size x 1 x seq_len x seq_len
self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
1, 1, seq_len, 1
)
# batch_size x 1 x seq_len x seq_len
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
# avoids self-attention weight being NaN for padding tokens
self_attn_mask[:, :, :, 0] = True
for block in self.blocks:
x = block(x, c, self_attn_mask,y)
return x
class SingleTokenRefiner(torch.nn.Module):
"""
A single token refiner block for llm text embedding refine.
"""
def __init__(
self,
in_channels,
hidden_size,
heads_num,
depth,
mlp_width_ratio: float = 4.0,
mlp_drop_rate: float = 0.0,
act_type: str = "silu",
qk_norm: bool = False,
qk_norm_type: str = "layer",
qkv_bias: bool = True,
need_CA:bool=False,
attn_mode: str = "torch",
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.attn_mode = attn_mode
self.need_CA = need_CA
assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
self.input_embedder = nn.Linear(
in_channels, hidden_size, bias=True, **factory_kwargs
)
if self.need_CA:
self.input_embedder_CA = nn.Linear(
in_channels, hidden_size, bias=True, **factory_kwargs
)
act_layer = get_activation_layer(act_type)
# Build timestep embedding layer
self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
# Build context embedding layer
self.c_embedder = TextProjection(
in_channels, hidden_size, act_layer, **factory_kwargs
)
self.individual_token_refiner = IndividualTokenRefiner(
hidden_size=hidden_size,
heads_num=heads_num,
depth=depth,
mlp_width_ratio=mlp_width_ratio,
mlp_drop_rate=mlp_drop_rate,
act_type=act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
need_CA=need_CA,
**factory_kwargs,
)
def forward(
self,
x: torch.Tensor,
t: torch.LongTensor,
mask: Optional[torch.LongTensor] = None,
y: torch.LongTensor=None,
):
timestep_aware_representations = self.t_embedder(t)
if mask is None:
context_aware_representations = x.mean(dim=1)
else:
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
context_aware_representations = (x * mask_float).sum(
dim=1
) / mask_float.sum(dim=1)
context_aware_representations = self.c_embedder(context_aware_representations)
c = timestep_aware_representations + context_aware_representations
x = self.input_embedder(x)
if self.need_CA:
y = self.input_embedder_CA(y)
x = self.individual_token_refiner(x, c, mask, y)
else:
x = self.individual_token_refiner(x, c, mask)
return x
class Qwen2Connector(torch.nn.Module):
def __init__(
self,
# biclip_dim=1024,
in_channels=3584,
hidden_size=4096,
heads_num=32,
depth=2,
need_CA=False,
device=None,
dtype=torch.bfloat16,
):
super().__init__()
factory_kwargs = {"device": device, "dtype":dtype}
self.S =SingleTokenRefiner(in_channels=in_channels,hidden_size=hidden_size,heads_num=heads_num,depth=depth,need_CA=need_CA,**factory_kwargs)
self.global_proj_out=nn.Linear(in_channels,768)
self.scale_factor = nn.Parameter(torch.zeros(1))
with torch.no_grad():
self.scale_factor.data += -(1 - 0.09)
def forward(self, x,t,mask):
mask_float = mask.unsqueeze(-1) # [b, s1, 1]
x_mean = (x * mask_float).sum(
dim=1
) / mask_float.sum(dim=1) * (1 + self.scale_factor)
global_out=self.global_proj_out(x_mean)
encoder_hidden_states = self.S(x,t,mask)
return encoder_hidden_states,global_out
@staticmethod
def state_dict_converter():
return Qwen2ConnectorStateDictConverter()
class Qwen2ConnectorStateDictConverter:
def __init__(self):
pass
def from_diffusers(self, state_dict):
return state_dict
def from_civitai(self, state_dict):
state_dict_ = {}
for name, param in state_dict.items():
if name.startswith("connector."):
name_ = name[len("connector."):]
state_dict_[name_] = param
return state_dict_