mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
DiffSynth-Studio 2.0 major update
This commit is contained in:
@@ -1 +0,0 @@
|
||||
from .model_manager import *
|
||||
@@ -1,89 +0,0 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def low_version_attention(query, key, value, attn_bias=None):
|
||||
scale = 1 / query.shape[-1] ** 0.5
|
||||
query = query * scale
|
||||
attn = torch.matmul(query, key.transpose(-2, -1))
|
||||
if attn_bias is not None:
|
||||
attn = attn + attn_bias
|
||||
attn = attn.softmax(-1)
|
||||
return attn @ value
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||
super().__init__()
|
||||
dim_inner = head_dim * num_heads
|
||||
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||
|
||||
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||
batch_size = q.shape[0]
|
||||
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
return hidden_states
|
||||
|
||||
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(encoder_hidden_states)
|
||||
v = self.to_v(encoder_hidden_states)
|
||||
|
||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if qkv_preprocessor is not None:
|
||||
q, k, v = qkv_preprocessor(q, k, v)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
if ipadapter_kwargs is not None:
|
||||
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(encoder_hidden_states)
|
||||
v = self.to_v(encoder_hidden_states)
|
||||
|
||||
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||
|
||||
if attn_mask is not None:
|
||||
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
|
||||
else:
|
||||
import xformers.ops as xops
|
||||
hidden_states = xops.memory_efficient_attention(q, k, v)
|
||||
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
|
||||
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
|
||||
@@ -1,408 +0,0 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from .sd3_dit import TimestepEmbeddings
|
||||
from .attention import Attention
|
||||
from .utils import load_state_dict_from_folder
|
||||
from .tiler import TileWorker2Dto3D
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
class CogPatchify(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out, patch_size) -> None:
|
||||
super().__init__()
|
||||
self.proj = torch.nn.Conv3d(dim_in, dim_out, kernel_size=(1, patch_size, patch_size), stride=(1, patch_size, patch_size))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = rearrange(hidden_states, "B C T H W -> B (T H W) C")
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class CogAdaLayerNorm(torch.nn.Module):
|
||||
def __init__(self, dim, dim_cond, single=False):
|
||||
super().__init__()
|
||||
self.single = single
|
||||
self.linear = torch.nn.Linear(dim_cond, dim * (2 if single else 6))
|
||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=True, eps=1e-5)
|
||||
|
||||
|
||||
def forward(self, hidden_states, prompt_emb, emb):
|
||||
emb = self.linear(torch.nn.functional.silu(emb))
|
||||
if self.single:
|
||||
shift, scale = emb.unsqueeze(1).chunk(2, dim=2)
|
||||
hidden_states = self.norm(hidden_states) * (1 + scale) + shift
|
||||
return hidden_states
|
||||
else:
|
||||
shift_a, scale_a, gate_a, shift_b, scale_b, gate_b = emb.unsqueeze(1).chunk(6, dim=2)
|
||||
hidden_states = self.norm(hidden_states) * (1 + scale_a) + shift_a
|
||||
prompt_emb = self.norm(prompt_emb) * (1 + scale_b) + shift_b
|
||||
return hidden_states, prompt_emb, gate_a, gate_b
|
||||
|
||||
|
||||
|
||||
class CogDiTBlock(torch.nn.Module):
|
||||
def __init__(self, dim, dim_cond, num_heads):
|
||||
super().__init__()
|
||||
self.norm1 = CogAdaLayerNorm(dim, dim_cond)
|
||||
self.attn1 = Attention(q_dim=dim, num_heads=48, head_dim=dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
|
||||
self.norm_q = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
|
||||
self.norm_k = torch.nn.LayerNorm((dim//num_heads,), eps=1e-06, elementwise_affine=True)
|
||||
|
||||
self.norm2 = CogAdaLayerNorm(dim, dim_cond)
|
||||
self.ff = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
|
||||
def apply_rotary_emb(self, x, freqs_cis):
|
||||
cos, sin = freqs_cis # [S, D]
|
||||
cos = cos[None, None]
|
||||
sin = sin[None, None]
|
||||
cos, sin = cos.to(x.device), sin.to(x.device)
|
||||
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
return out
|
||||
|
||||
|
||||
def process_qkv(self, q, k, v, image_rotary_emb, text_seq_length):
|
||||
q = self.norm_q(q)
|
||||
k = self.norm_k(k)
|
||||
q[:, :, text_seq_length:] = self.apply_rotary_emb(q[:, :, text_seq_length:], image_rotary_emb)
|
||||
k[:, :, text_seq_length:] = self.apply_rotary_emb(k[:, :, text_seq_length:], image_rotary_emb)
|
||||
return q, k, v
|
||||
|
||||
|
||||
def forward(self, hidden_states, prompt_emb, time_emb, image_rotary_emb):
|
||||
# Attention
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm1(
|
||||
hidden_states, prompt_emb, time_emb
|
||||
)
|
||||
attention_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
attention_io = self.attn1(
|
||||
attention_io,
|
||||
qkv_preprocessor=lambda q, k, v: self.process_qkv(q, k, v, image_rotary_emb, prompt_emb.shape[1])
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + gate_a * attention_io[:, prompt_emb.shape[1]:]
|
||||
prompt_emb = prompt_emb + gate_b * attention_io[:, :prompt_emb.shape[1]]
|
||||
|
||||
# Feed forward
|
||||
norm_hidden_states, norm_encoder_hidden_states, gate_a, gate_b = self.norm2(
|
||||
hidden_states, prompt_emb, time_emb
|
||||
)
|
||||
ff_io = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
||||
ff_io = self.ff(ff_io)
|
||||
|
||||
hidden_states = hidden_states + gate_a * ff_io[:, prompt_emb.shape[1]:]
|
||||
prompt_emb = prompt_emb + gate_b * ff_io[:, :prompt_emb.shape[1]]
|
||||
|
||||
return hidden_states, prompt_emb
|
||||
|
||||
|
||||
|
||||
class CogDiT(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.patchify = CogPatchify(16, 3072, 2)
|
||||
self.time_embedder = TimestepEmbeddings(3072, 512)
|
||||
self.context_embedder = torch.nn.Linear(4096, 3072)
|
||||
self.blocks = torch.nn.ModuleList([CogDiTBlock(3072, 512, 48) for _ in range(42)])
|
||||
self.norm_final = torch.nn.LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
|
||||
self.norm_out = CogAdaLayerNorm(3072, 512, single=True)
|
||||
self.proj_out = torch.nn.Linear(3072, 64, bias=True)
|
||||
|
||||
|
||||
def get_resize_crop_region_for_grid(self, src, tgt_width, tgt_height):
|
||||
tw = tgt_width
|
||||
th = tgt_height
|
||||
h, w = src
|
||||
r = h / w
|
||||
if r > (th / tw):
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h))
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
def get_3d_rotary_pos_embed(
|
||||
self, embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
|
||||
):
|
||||
start, stop = crops_coords
|
||||
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
||||
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
|
||||
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
|
||||
|
||||
# Compute dimensions for each axis
|
||||
dim_t = embed_dim // 4
|
||||
dim_h = embed_dim // 8 * 3
|
||||
dim_w = embed_dim // 8 * 3
|
||||
|
||||
# Temporal frequencies
|
||||
freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
|
||||
grid_t = torch.from_numpy(grid_t).float()
|
||||
freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
|
||||
freqs_t = freqs_t.repeat_interleave(2, dim=-1)
|
||||
|
||||
# Spatial frequencies for height and width
|
||||
freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
|
||||
freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
|
||||
grid_h = torch.from_numpy(grid_h).float()
|
||||
grid_w = torch.from_numpy(grid_w).float()
|
||||
freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
|
||||
freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
|
||||
freqs_h = freqs_h.repeat_interleave(2, dim=-1)
|
||||
freqs_w = freqs_w.repeat_interleave(2, dim=-1)
|
||||
|
||||
# Broadcast and concatenate tensors along specified dimension
|
||||
def broadcast(tensors, dim=-1):
|
||||
num_tensors = len(tensors)
|
||||
shape_lens = {len(t.shape) for t in tensors}
|
||||
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
||||
shape_len = list(shape_lens)[0]
|
||||
dim = (dim + shape_len) if dim < 0 else dim
|
||||
dims = list(zip(*(list(t.shape) for t in tensors)))
|
||||
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
||||
assert all(
|
||||
[*(len(set(t[1])) <= 2 for t in expandable_dims)]
|
||||
), "invalid dimensions for broadcastable concatenation"
|
||||
max_dims = [(t[0], max(t[1])) for t in expandable_dims]
|
||||
expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
|
||||
expanded_dims.insert(dim, (dim, dims[dim]))
|
||||
expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
|
||||
tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
|
||||
return torch.cat(tensors, dim=dim)
|
||||
|
||||
freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
|
||||
|
||||
t, h, w, d = freqs.shape
|
||||
freqs = freqs.view(t * h * w, d)
|
||||
|
||||
# Generate sine and cosine components
|
||||
sin = freqs.sin()
|
||||
cos = freqs.cos()
|
||||
|
||||
if use_real:
|
||||
return cos, sin
|
||||
else:
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def prepare_rotary_positional_embeddings(
|
||||
self,
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
device: torch.device,
|
||||
):
|
||||
grid_height = height // 2
|
||||
grid_width = width // 2
|
||||
base_size_width = 720 // (8 * 2)
|
||||
base_size_height = 480 // (8 * 2)
|
||||
|
||||
grid_crops_coords = self.get_resize_crop_region_for_grid(
|
||||
(grid_height, grid_width), base_size_width, base_size_height
|
||||
)
|
||||
freqs_cos, freqs_sin = self.get_3d_rotary_pos_embed(
|
||||
embed_dim=64,
|
||||
crops_coords=grid_crops_coords,
|
||||
grid_size=(grid_height, grid_width),
|
||||
temporal_size=num_frames,
|
||||
use_real=True,
|
||||
)
|
||||
|
||||
freqs_cos = freqs_cos.to(device=device)
|
||||
freqs_sin = freqs_sin.to(device=device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
def unpatchify(self, hidden_states, height, width):
|
||||
hidden_states = rearrange(hidden_states, "B (T H W) (C P Q) -> B C T (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def build_mask(self, T, H, W, dtype, device, is_bound):
|
||||
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
|
||||
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
|
||||
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
|
||||
border_width = (H + W) // 4
|
||||
pad = torch.ones_like(h) * border_width
|
||||
mask = torch.stack([
|
||||
pad if is_bound[0] else t + 1,
|
||||
pad if is_bound[1] else T - t,
|
||||
pad if is_bound[2] else h + 1,
|
||||
pad if is_bound[3] else H - h,
|
||||
pad if is_bound[4] else w + 1,
|
||||
pad if is_bound[5] else W - w
|
||||
]).min(dim=0).values
|
||||
mask = mask.clip(1, border_width)
|
||||
mask = (mask / border_width).to(dtype=dtype, device=device)
|
||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
||||
return mask
|
||||
|
||||
|
||||
def tiled_forward(self, hidden_states, timestep, prompt_emb, tile_size=(60, 90), tile_stride=(30, 45)):
|
||||
B, C, T, H, W = hidden_states.shape
|
||||
value = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
weight = torch.zeros((B, C, T, H, W), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for h in range(0, H, tile_stride):
|
||||
for w in range(0, W, tile_stride):
|
||||
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
|
||||
continue
|
||||
h_, w_ = h + tile_size, w + tile_size
|
||||
if h_ > H: h, h_ = max(H - tile_size, 0), H
|
||||
if w_ > W: w, w_ = max(W - tile_size, 0), W
|
||||
tasks.append((h, h_, w, w_))
|
||||
|
||||
# Run
|
||||
for hl, hr, wl, wr in tasks:
|
||||
mask = self.build_mask(
|
||||
value.shape[2], (hr-hl), (wr-wl),
|
||||
hidden_states.dtype, hidden_states.device,
|
||||
is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W)
|
||||
)
|
||||
model_output = self.forward(hidden_states[:, :, :, hl:hr, wl:wr], timestep, prompt_emb)
|
||||
value[:, :, :, hl:hr, wl:wr] += model_output * mask
|
||||
weight[:, :, :, hl:hr, wl:wr] += mask
|
||||
value = value / weight
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def forward(self, hidden_states, timestep, prompt_emb, image_rotary_emb=None, tiled=False, tile_size=90, tile_stride=30, use_gradient_checkpointing=False):
|
||||
if tiled:
|
||||
return TileWorker2Dto3D().tiled_forward(
|
||||
forward_fn=lambda x: self.forward(x, timestep, prompt_emb),
|
||||
model_input=hidden_states,
|
||||
tile_size=tile_size, tile_stride=tile_stride,
|
||||
tile_device=hidden_states.device, tile_dtype=hidden_states.dtype,
|
||||
computation_device=self.context_embedder.weight.device, computation_dtype=self.context_embedder.weight.dtype
|
||||
)
|
||||
num_frames, height, width = hidden_states.shape[-3:]
|
||||
if image_rotary_emb is None:
|
||||
image_rotary_emb = self.prepare_rotary_positional_embeddings(height, width, num_frames, device=self.context_embedder.weight.device)
|
||||
hidden_states = self.patchify(hidden_states)
|
||||
time_emb = self.time_embedder(timestep, dtype=hidden_states.dtype)
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, time_emb, image_rotary_emb,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, time_emb, image_rotary_emb)
|
||||
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
hidden_states = self.norm_final(hidden_states)
|
||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
||||
hidden_states = self.norm_out(hidden_states, prompt_emb, time_emb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = self.unpatchify(hidden_states, height, width)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return CogDiTStateDictConverter()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(file_path, torch_dtype=torch.bfloat16):
|
||||
model = CogDiT().to(torch_dtype)
|
||||
state_dict = load_state_dict_from_folder(file_path, torch_dtype=torch_dtype)
|
||||
state_dict = CogDiT.state_dict_converter().from_diffusers(state_dict)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
class CogDiTStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"patch_embed.proj.weight": "patchify.proj.weight",
|
||||
"patch_embed.proj.bias": "patchify.proj.bias",
|
||||
"patch_embed.text_proj.weight": "context_embedder.weight",
|
||||
"patch_embed.text_proj.bias": "context_embedder.bias",
|
||||
"time_embedding.linear_1.weight": "time_embedder.timestep_embedder.0.weight",
|
||||
"time_embedding.linear_1.bias": "time_embedder.timestep_embedder.0.bias",
|
||||
"time_embedding.linear_2.weight": "time_embedder.timestep_embedder.2.weight",
|
||||
"time_embedding.linear_2.bias": "time_embedder.timestep_embedder.2.bias",
|
||||
|
||||
"norm_final.weight": "norm_final.weight",
|
||||
"norm_final.bias": "norm_final.bias",
|
||||
"norm_out.linear.weight": "norm_out.linear.weight",
|
||||
"norm_out.linear.bias": "norm_out.linear.bias",
|
||||
"norm_out.norm.weight": "norm_out.norm.weight",
|
||||
"norm_out.norm.bias": "norm_out.norm.bias",
|
||||
"proj_out.weight": "proj_out.weight",
|
||||
"proj_out.bias": "proj_out.bias",
|
||||
}
|
||||
suffix_dict = {
|
||||
"norm1.linear.weight": "norm1.linear.weight",
|
||||
"norm1.linear.bias": "norm1.linear.bias",
|
||||
"norm1.norm.weight": "norm1.norm.weight",
|
||||
"norm1.norm.bias": "norm1.norm.bias",
|
||||
"attn1.norm_q.weight": "norm_q.weight",
|
||||
"attn1.norm_q.bias": "norm_q.bias",
|
||||
"attn1.norm_k.weight": "norm_k.weight",
|
||||
"attn1.norm_k.bias": "norm_k.bias",
|
||||
"attn1.to_q.weight": "attn1.to_q.weight",
|
||||
"attn1.to_q.bias": "attn1.to_q.bias",
|
||||
"attn1.to_k.weight": "attn1.to_k.weight",
|
||||
"attn1.to_k.bias": "attn1.to_k.bias",
|
||||
"attn1.to_v.weight": "attn1.to_v.weight",
|
||||
"attn1.to_v.bias": "attn1.to_v.bias",
|
||||
"attn1.to_out.0.weight": "attn1.to_out.weight",
|
||||
"attn1.to_out.0.bias": "attn1.to_out.bias",
|
||||
"norm2.linear.weight": "norm2.linear.weight",
|
||||
"norm2.linear.bias": "norm2.linear.bias",
|
||||
"norm2.norm.weight": "norm2.norm.weight",
|
||||
"norm2.norm.bias": "norm2.norm.bias",
|
||||
"ff.net.0.proj.weight": "ff.0.weight",
|
||||
"ff.net.0.proj.bias": "ff.0.bias",
|
||||
"ff.net.2.weight": "ff.2.weight",
|
||||
"ff.net.2.bias": "ff.2.bias",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
if name == "patch_embed.proj.weight":
|
||||
param = param.unsqueeze(2)
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
names = name.split(".")
|
||||
if names[0] == "transformer_blocks":
|
||||
suffix = ".".join(names[2:])
|
||||
state_dict_[f"blocks.{names[1]}." + suffix_dict[suffix]] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
@@ -1,518 +0,0 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from .tiler import TileWorker2Dto3D
|
||||
|
||||
|
||||
|
||||
class Downsample3D(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 2,
|
||||
padding: int = 0,
|
||||
compress_time: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, x: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
|
||||
if self.compress_time:
|
||||
batch_size, channels, frames, height, width = x.shape
|
||||
|
||||
# (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
|
||||
x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
|
||||
|
||||
if x.shape[-1] % 2 == 1:
|
||||
x_first, x_rest = x[..., 0], x[..., 1:]
|
||||
if x_rest.shape[-1] > 0:
|
||||
# (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
|
||||
x_rest = torch.nn.functional.avg_pool1d(x_rest, kernel_size=2, stride=2)
|
||||
|
||||
x = torch.cat([x_first[..., None], x_rest], dim=-1)
|
||||
# (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
|
||||
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||
else:
|
||||
# (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
|
||||
x = torch.nn.functional.avg_pool1d(x, kernel_size=2, stride=2)
|
||||
# (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
|
||||
x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
|
||||
|
||||
# Pad the tensor
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
batch_size, channels, frames, height, width = x.shape
|
||||
# (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
|
||||
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
|
||||
x = self.conv(x)
|
||||
# (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
|
||||
x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class Upsample3D(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
padding: int = 1,
|
||||
compress_time: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.compress_time = compress_time
|
||||
|
||||
def forward(self, inputs: torch.Tensor, xq: torch.Tensor) -> torch.Tensor:
|
||||
if self.compress_time:
|
||||
if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
|
||||
# split first frame
|
||||
x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
|
||||
|
||||
x_first = torch.nn.functional.interpolate(x_first, scale_factor=2.0)
|
||||
x_rest = torch.nn.functional.interpolate(x_rest, scale_factor=2.0)
|
||||
x_first = x_first[:, :, None, :, :]
|
||||
inputs = torch.cat([x_first, x_rest], dim=2)
|
||||
elif inputs.shape[2] > 1:
|
||||
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
|
||||
else:
|
||||
inputs = inputs.squeeze(2)
|
||||
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
|
||||
inputs = inputs[:, :, None, :, :]
|
||||
else:
|
||||
# only interpolate 2D
|
||||
b, c, t, h, w = inputs.shape
|
||||
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
inputs = torch.nn.functional.interpolate(inputs, scale_factor=2.0)
|
||||
inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
|
||||
|
||||
b, c, t, h, w = inputs.shape
|
||||
inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
||||
inputs = self.conv(inputs)
|
||||
inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
|
||||
class CogVideoXSpatialNorm3D(torch.nn.Module):
|
||||
def __init__(self, f_channels, zq_channels, groups):
|
||||
super().__init__()
|
||||
self.norm_layer = torch.nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
|
||||
self.conv_y = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
self.conv_b = torch.nn.Conv3d(zq_channels, f_channels, kernel_size=1, stride=1)
|
||||
|
||||
|
||||
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
|
||||
if f.shape[2] > 1 and f.shape[2] % 2 == 1:
|
||||
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
|
||||
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
|
||||
z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
|
||||
z_first = torch.nn.functional.interpolate(z_first, size=f_first_size)
|
||||
z_rest = torch.nn.functional.interpolate(z_rest, size=f_rest_size)
|
||||
zq = torch.cat([z_first, z_rest], dim=2)
|
||||
else:
|
||||
zq = torch.nn.functional.interpolate(zq, size=f.shape[-3:])
|
||||
|
||||
norm_f = self.norm_layer(f)
|
||||
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
|
||||
return new_f
|
||||
|
||||
|
||||
|
||||
class Resnet3DBlock(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, spatial_norm_dim, groups, eps=1e-6, use_conv_shortcut=False):
|
||||
super().__init__()
|
||||
self.nonlinearity = torch.nn.SiLU()
|
||||
if spatial_norm_dim is None:
|
||||
self.norm1 = torch.nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
|
||||
self.norm2 = torch.nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
|
||||
else:
|
||||
self.norm1 = CogVideoXSpatialNorm3D(in_channels, spatial_norm_dim, groups)
|
||||
self.norm2 = CogVideoXSpatialNorm3D(out_channels, spatial_norm_dim, groups)
|
||||
|
||||
self.conv1 = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
|
||||
|
||||
self.conv2 = CachedConv3d(out_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
|
||||
|
||||
if in_channels != out_channels:
|
||||
if use_conv_shortcut:
|
||||
self.conv_shortcut = CachedConv3d(in_channels, out_channels, kernel_size=3, padding=(0, 1, 1))
|
||||
else:
|
||||
self.conv_shortcut = torch.nn.Conv3d(in_channels, out_channels, kernel_size=1)
|
||||
else:
|
||||
self.conv_shortcut = lambda x: x
|
||||
|
||||
|
||||
def forward(self, hidden_states, zq):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm1(hidden_states, zq) if isinstance(self.norm1, CogVideoXSpatialNorm3D) else self.norm1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
hidden_states = self.norm2(hidden_states, zq) if isinstance(self.norm2, CogVideoXSpatialNorm3D) else self.norm2(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
hidden_states = hidden_states + self.conv_shortcut(residual)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class CachedConv3d(torch.nn.Conv3d):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||
super().__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
|
||||
self.cached_tensor = None
|
||||
|
||||
|
||||
def clear_cache(self):
|
||||
self.cached_tensor = None
|
||||
|
||||
|
||||
def forward(self, input: torch.Tensor, use_cache = True) -> torch.Tensor:
|
||||
if use_cache:
|
||||
if self.cached_tensor is None:
|
||||
self.cached_tensor = torch.concat([input[:, :, :1]] * 2, dim=2)
|
||||
input = torch.concat([self.cached_tensor, input], dim=2)
|
||||
self.cached_tensor = input[:, :, -2:]
|
||||
return super().forward(input)
|
||||
|
||||
|
||||
|
||||
class CogVAEDecoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.7
|
||||
self.conv_in = CachedConv3d(16, 512, kernel_size=3, stride=1, padding=(0, 1, 1))
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
Resnet3DBlock(512, 512, 16, 32),
|
||||
Resnet3DBlock(512, 512, 16, 32),
|
||||
Resnet3DBlock(512, 512, 16, 32),
|
||||
Resnet3DBlock(512, 512, 16, 32),
|
||||
Resnet3DBlock(512, 512, 16, 32),
|
||||
Resnet3DBlock(512, 512, 16, 32),
|
||||
Upsample3D(512, 512, compress_time=True),
|
||||
Resnet3DBlock(512, 256, 16, 32),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Upsample3D(256, 256, compress_time=True),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Resnet3DBlock(256, 256, 16, 32),
|
||||
Upsample3D(256, 256, compress_time=False),
|
||||
Resnet3DBlock(256, 128, 16, 32),
|
||||
Resnet3DBlock(128, 128, 16, 32),
|
||||
Resnet3DBlock(128, 128, 16, 32),
|
||||
Resnet3DBlock(128, 128, 16, 32),
|
||||
])
|
||||
|
||||
self.norm_out = CogVideoXSpatialNorm3D(128, 16, 32)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = CachedConv3d(128, 3, kernel_size=3, stride=1, padding=(0, 1, 1))
|
||||
|
||||
|
||||
def forward(self, sample):
|
||||
sample = sample / self.scaling_factor
|
||||
hidden_states = self.conv_in(sample)
|
||||
|
||||
for block in self.blocks:
|
||||
hidden_states = block(hidden_states, sample)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, sample)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def decode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
|
||||
if tiled:
|
||||
B, C, T, H, W = sample.shape
|
||||
return TileWorker2Dto3D().tiled_forward(
|
||||
forward_fn=lambda x: self.decode_small_video(x),
|
||||
model_input=sample,
|
||||
tile_size=tile_size, tile_stride=tile_stride,
|
||||
tile_device=sample.device, tile_dtype=sample.dtype,
|
||||
computation_device=sample.device, computation_dtype=sample.dtype,
|
||||
scales=(3/16, (T//2*8+T%2)/T, 8, 8),
|
||||
progress_bar=progress_bar
|
||||
)
|
||||
else:
|
||||
return self.decode_small_video(sample)
|
||||
|
||||
|
||||
def decode_small_video(self, sample):
|
||||
B, C, T, H, W = sample.shape
|
||||
computation_device = self.conv_in.weight.device
|
||||
computation_dtype = self.conv_in.weight.dtype
|
||||
value = []
|
||||
for i in range(T//2):
|
||||
tl = i*2 + T%2 - (T%2 and i==0)
|
||||
tr = i*2 + 2 + T%2
|
||||
model_input = sample[:, :, tl: tr, :, :].to(dtype=computation_dtype, device=computation_device)
|
||||
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
|
||||
value.append(model_output)
|
||||
value = torch.concat(value, dim=2)
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, CachedConv3d):
|
||||
module.clear_cache()
|
||||
return value
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return CogVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class CogVAEEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.7
|
||||
self.conv_in = CachedConv3d(3, 128, kernel_size=3, stride=1, padding=(0, 1, 1))
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
Resnet3DBlock(128, 128, None, 32),
|
||||
Resnet3DBlock(128, 128, None, 32),
|
||||
Resnet3DBlock(128, 128, None, 32),
|
||||
Downsample3D(128, 128, compress_time=True),
|
||||
Resnet3DBlock(128, 256, None, 32),
|
||||
Resnet3DBlock(256, 256, None, 32),
|
||||
Resnet3DBlock(256, 256, None, 32),
|
||||
Downsample3D(256, 256, compress_time=True),
|
||||
Resnet3DBlock(256, 256, None, 32),
|
||||
Resnet3DBlock(256, 256, None, 32),
|
||||
Resnet3DBlock(256, 256, None, 32),
|
||||
Downsample3D(256, 256, compress_time=False),
|
||||
Resnet3DBlock(256, 512, None, 32),
|
||||
Resnet3DBlock(512, 512, None, 32),
|
||||
Resnet3DBlock(512, 512, None, 32),
|
||||
Resnet3DBlock(512, 512, None, 32),
|
||||
Resnet3DBlock(512, 512, None, 32),
|
||||
])
|
||||
|
||||
self.norm_out = torch.nn.GroupNorm(32, 512, eps=1e-06, affine=True)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = CachedConv3d(512, 32, kernel_size=3, stride=1, padding=(0, 1, 1))
|
||||
|
||||
|
||||
def forward(self, sample):
|
||||
hidden_states = self.conv_in(sample)
|
||||
|
||||
for block in self.blocks:
|
||||
hidden_states = block(hidden_states, sample)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)[:, :16]
|
||||
hidden_states = hidden_states * self.scaling_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def encode_video(self, sample, tiled=True, tile_size=(60, 90), tile_stride=(30, 45), progress_bar=lambda x:x):
|
||||
if tiled:
|
||||
B, C, T, H, W = sample.shape
|
||||
return TileWorker2Dto3D().tiled_forward(
|
||||
forward_fn=lambda x: self.encode_small_video(x),
|
||||
model_input=sample,
|
||||
tile_size=(i * 8 for i in tile_size), tile_stride=(i * 8 for i in tile_stride),
|
||||
tile_device=sample.device, tile_dtype=sample.dtype,
|
||||
computation_device=sample.device, computation_dtype=sample.dtype,
|
||||
scales=(16/3, (T//4+T%2)/T, 1/8, 1/8),
|
||||
progress_bar=progress_bar
|
||||
)
|
||||
else:
|
||||
return self.encode_small_video(sample)
|
||||
|
||||
|
||||
def encode_small_video(self, sample):
|
||||
B, C, T, H, W = sample.shape
|
||||
computation_device = self.conv_in.weight.device
|
||||
computation_dtype = self.conv_in.weight.dtype
|
||||
value = []
|
||||
for i in range(T//8):
|
||||
t = i*8 + T%2 - (T%2 and i==0)
|
||||
t_ = i*8 + 8 + T%2
|
||||
model_input = sample[:, :, t: t_, :, :].to(dtype=computation_dtype, device=computation_device)
|
||||
model_output = self.forward(model_input).to(dtype=sample.dtype, device=sample.device)
|
||||
value.append(model_output)
|
||||
value = torch.concat(value, dim=2)
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, CachedConv3d):
|
||||
module.clear_cache()
|
||||
return value
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return CogVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class CogVAEEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"encoder.conv_in.conv.weight": "conv_in.weight",
|
||||
"encoder.conv_in.conv.bias": "conv_in.bias",
|
||||
"encoder.down_blocks.0.downsamplers.0.conv.weight": "blocks.3.conv.weight",
|
||||
"encoder.down_blocks.0.downsamplers.0.conv.bias": "blocks.3.conv.bias",
|
||||
"encoder.down_blocks.1.downsamplers.0.conv.weight": "blocks.7.conv.weight",
|
||||
"encoder.down_blocks.1.downsamplers.0.conv.bias": "blocks.7.conv.bias",
|
||||
"encoder.down_blocks.2.downsamplers.0.conv.weight": "blocks.11.conv.weight",
|
||||
"encoder.down_blocks.2.downsamplers.0.conv.bias": "blocks.11.conv.bias",
|
||||
"encoder.norm_out.weight": "norm_out.weight",
|
||||
"encoder.norm_out.bias": "norm_out.bias",
|
||||
"encoder.conv_out.conv.weight": "conv_out.weight",
|
||||
"encoder.conv_out.conv.bias": "conv_out.bias",
|
||||
}
|
||||
prefix_dict = {
|
||||
"encoder.down_blocks.0.resnets.0.": "blocks.0.",
|
||||
"encoder.down_blocks.0.resnets.1.": "blocks.1.",
|
||||
"encoder.down_blocks.0.resnets.2.": "blocks.2.",
|
||||
"encoder.down_blocks.1.resnets.0.": "blocks.4.",
|
||||
"encoder.down_blocks.1.resnets.1.": "blocks.5.",
|
||||
"encoder.down_blocks.1.resnets.2.": "blocks.6.",
|
||||
"encoder.down_blocks.2.resnets.0.": "blocks.8.",
|
||||
"encoder.down_blocks.2.resnets.1.": "blocks.9.",
|
||||
"encoder.down_blocks.2.resnets.2.": "blocks.10.",
|
||||
"encoder.down_blocks.3.resnets.0.": "blocks.12.",
|
||||
"encoder.down_blocks.3.resnets.1.": "blocks.13.",
|
||||
"encoder.down_blocks.3.resnets.2.": "blocks.14.",
|
||||
"encoder.mid_block.resnets.0.": "blocks.15.",
|
||||
"encoder.mid_block.resnets.1.": "blocks.16.",
|
||||
}
|
||||
suffix_dict = {
|
||||
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
|
||||
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
|
||||
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
|
||||
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
|
||||
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
|
||||
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
|
||||
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
|
||||
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
|
||||
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
|
||||
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
|
||||
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
|
||||
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
|
||||
"conv1.conv.weight": "conv1.weight",
|
||||
"conv1.conv.bias": "conv1.bias",
|
||||
"conv2.conv.weight": "conv2.weight",
|
||||
"conv2.conv.bias": "conv2.bias",
|
||||
"conv_shortcut.weight": "conv_shortcut.weight",
|
||||
"conv_shortcut.bias": "conv_shortcut.bias",
|
||||
"norm1.weight": "norm1.weight",
|
||||
"norm1.bias": "norm1.bias",
|
||||
"norm2.weight": "norm2.weight",
|
||||
"norm2.bias": "norm2.bias",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
for prefix in prefix_dict:
|
||||
if name.startswith(prefix):
|
||||
suffix = name[len(prefix):]
|
||||
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
|
||||
|
||||
|
||||
class CogVAEDecoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"decoder.conv_in.conv.weight": "conv_in.weight",
|
||||
"decoder.conv_in.conv.bias": "conv_in.bias",
|
||||
"decoder.up_blocks.0.upsamplers.0.conv.weight": "blocks.6.conv.weight",
|
||||
"decoder.up_blocks.0.upsamplers.0.conv.bias": "blocks.6.conv.bias",
|
||||
"decoder.up_blocks.1.upsamplers.0.conv.weight": "blocks.11.conv.weight",
|
||||
"decoder.up_blocks.1.upsamplers.0.conv.bias": "blocks.11.conv.bias",
|
||||
"decoder.up_blocks.2.upsamplers.0.conv.weight": "blocks.16.conv.weight",
|
||||
"decoder.up_blocks.2.upsamplers.0.conv.bias": "blocks.16.conv.bias",
|
||||
"decoder.norm_out.norm_layer.weight": "norm_out.norm_layer.weight",
|
||||
"decoder.norm_out.norm_layer.bias": "norm_out.norm_layer.bias",
|
||||
"decoder.norm_out.conv_y.conv.weight": "norm_out.conv_y.weight",
|
||||
"decoder.norm_out.conv_y.conv.bias": "norm_out.conv_y.bias",
|
||||
"decoder.norm_out.conv_b.conv.weight": "norm_out.conv_b.weight",
|
||||
"decoder.norm_out.conv_b.conv.bias": "norm_out.conv_b.bias",
|
||||
"decoder.conv_out.conv.weight": "conv_out.weight",
|
||||
"decoder.conv_out.conv.bias": "conv_out.bias"
|
||||
}
|
||||
prefix_dict = {
|
||||
"decoder.mid_block.resnets.0.": "blocks.0.",
|
||||
"decoder.mid_block.resnets.1.": "blocks.1.",
|
||||
"decoder.up_blocks.0.resnets.0.": "blocks.2.",
|
||||
"decoder.up_blocks.0.resnets.1.": "blocks.3.",
|
||||
"decoder.up_blocks.0.resnets.2.": "blocks.4.",
|
||||
"decoder.up_blocks.0.resnets.3.": "blocks.5.",
|
||||
"decoder.up_blocks.1.resnets.0.": "blocks.7.",
|
||||
"decoder.up_blocks.1.resnets.1.": "blocks.8.",
|
||||
"decoder.up_blocks.1.resnets.2.": "blocks.9.",
|
||||
"decoder.up_blocks.1.resnets.3.": "blocks.10.",
|
||||
"decoder.up_blocks.2.resnets.0.": "blocks.12.",
|
||||
"decoder.up_blocks.2.resnets.1.": "blocks.13.",
|
||||
"decoder.up_blocks.2.resnets.2.": "blocks.14.",
|
||||
"decoder.up_blocks.2.resnets.3.": "blocks.15.",
|
||||
"decoder.up_blocks.3.resnets.0.": "blocks.17.",
|
||||
"decoder.up_blocks.3.resnets.1.": "blocks.18.",
|
||||
"decoder.up_blocks.3.resnets.2.": "blocks.19.",
|
||||
"decoder.up_blocks.3.resnets.3.": "blocks.20.",
|
||||
}
|
||||
suffix_dict = {
|
||||
"norm1.norm_layer.weight": "norm1.norm_layer.weight",
|
||||
"norm1.norm_layer.bias": "norm1.norm_layer.bias",
|
||||
"norm1.conv_y.conv.weight": "norm1.conv_y.weight",
|
||||
"norm1.conv_y.conv.bias": "norm1.conv_y.bias",
|
||||
"norm1.conv_b.conv.weight": "norm1.conv_b.weight",
|
||||
"norm1.conv_b.conv.bias": "norm1.conv_b.bias",
|
||||
"norm2.norm_layer.weight": "norm2.norm_layer.weight",
|
||||
"norm2.norm_layer.bias": "norm2.norm_layer.bias",
|
||||
"norm2.conv_y.conv.weight": "norm2.conv_y.weight",
|
||||
"norm2.conv_y.conv.bias": "norm2.conv_y.bias",
|
||||
"norm2.conv_b.conv.weight": "norm2.conv_b.weight",
|
||||
"norm2.conv_b.conv.bias": "norm2.conv_b.bias",
|
||||
"conv1.conv.weight": "conv1.weight",
|
||||
"conv1.conv.bias": "conv1.bias",
|
||||
"conv2.conv.weight": "conv2.weight",
|
||||
"conv2.conv.bias": "conv2.bias",
|
||||
"conv_shortcut.weight": "conv_shortcut.weight",
|
||||
"conv_shortcut.bias": "conv_shortcut.bias",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
for prefix in prefix_dict:
|
||||
if name.startswith(prefix):
|
||||
suffix = name[len(prefix):]
|
||||
state_dict_[prefix_dict[prefix] + suffix_dict[suffix]] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
from huggingface_hub import hf_hub_download
|
||||
from modelscope import snapshot_download
|
||||
import os, shutil
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
from typing import List
|
||||
from ..configs.model_config import preset_models_on_huggingface, preset_models_on_modelscope, Preset_model_id
|
||||
|
||||
|
||||
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
file_name = os.path.basename(origin_file_path)
|
||||
if file_name in os.listdir(local_dir):
|
||||
print(f" {file_name} has been already in {local_dir}.")
|
||||
else:
|
||||
print(f" Start downloading {os.path.join(local_dir, file_name)}")
|
||||
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
||||
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
||||
if downloaded_file_path != target_file_path:
|
||||
shutil.move(downloaded_file_path, target_file_path)
|
||||
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
||||
|
||||
|
||||
def download_from_huggingface(model_id, origin_file_path, local_dir):
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
file_name = os.path.basename(origin_file_path)
|
||||
if file_name in os.listdir(local_dir):
|
||||
print(f" {file_name} has been already in {local_dir}.")
|
||||
else:
|
||||
print(f" Start downloading {os.path.join(local_dir, file_name)}")
|
||||
hf_hub_download(model_id, origin_file_path, local_dir=local_dir)
|
||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
||||
target_file_path = os.path.join(local_dir, file_name)
|
||||
if downloaded_file_path != target_file_path:
|
||||
shutil.move(downloaded_file_path, target_file_path)
|
||||
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
||||
|
||||
|
||||
Preset_model_website: TypeAlias = Literal[
|
||||
"HuggingFace",
|
||||
"ModelScope",
|
||||
]
|
||||
website_to_preset_models = {
|
||||
"HuggingFace": preset_models_on_huggingface,
|
||||
"ModelScope": preset_models_on_modelscope,
|
||||
}
|
||||
website_to_download_fn = {
|
||||
"HuggingFace": download_from_huggingface,
|
||||
"ModelScope": download_from_modelscope,
|
||||
}
|
||||
|
||||
|
||||
def download_customized_models(
|
||||
model_id,
|
||||
origin_file_path,
|
||||
local_dir,
|
||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||
):
|
||||
downloaded_files = []
|
||||
for website in downloading_priority:
|
||||
# Check if the file is downloaded.
|
||||
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
||||
if file_to_download in downloaded_files:
|
||||
continue
|
||||
# Download
|
||||
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||
downloaded_files.append(file_to_download)
|
||||
return downloaded_files
|
||||
|
||||
|
||||
def download_models(
|
||||
model_id_list: List[Preset_model_id] = [],
|
||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||
):
|
||||
print(f"Downloading models: {model_id_list}")
|
||||
downloaded_files = []
|
||||
load_files = []
|
||||
|
||||
for model_id in model_id_list:
|
||||
for website in downloading_priority:
|
||||
if model_id in website_to_preset_models[website]:
|
||||
|
||||
# Parse model metadata
|
||||
model_metadata = website_to_preset_models[website][model_id]
|
||||
if isinstance(model_metadata, list):
|
||||
file_data = model_metadata
|
||||
else:
|
||||
file_data = model_metadata.get("file_list", [])
|
||||
|
||||
# Try downloading the model from this website.
|
||||
model_files = []
|
||||
for model_id, origin_file_path, local_dir in file_data:
|
||||
# Check if the file is downloaded.
|
||||
file_to_download = os.path.join(local_dir, os.path.basename(origin_file_path))
|
||||
if file_to_download in downloaded_files:
|
||||
continue
|
||||
# Download
|
||||
website_to_download_fn[website](model_id, origin_file_path, local_dir)
|
||||
if os.path.basename(origin_file_path) in os.listdir(local_dir):
|
||||
downloaded_files.append(file_to_download)
|
||||
model_files.append(file_to_download)
|
||||
|
||||
# If the model is successfully downloaded, break.
|
||||
if len(model_files) > 0:
|
||||
if isinstance(model_metadata, dict) and "load_path" in model_metadata:
|
||||
model_files = model_metadata["load_path"]
|
||||
load_files.extend(model_files)
|
||||
break
|
||||
|
||||
return load_files
|
||||
1057
diffsynth/models/flux2_dit.py
Normal file
1057
diffsynth/models/flux2_dit.py
Normal file
File diff suppressed because it is too large
Load Diff
58
diffsynth/models/flux2_text_encoder.py
Normal file
58
diffsynth/models/flux2_text_encoder.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from transformers import Mistral3ForConditionalGeneration, Mistral3Config
|
||||
|
||||
|
||||
class Flux2TextEncoder(Mistral3ForConditionalGeneration):
|
||||
def __init__(self):
|
||||
config = Mistral3Config(**{
|
||||
"architectures": [
|
||||
"Mistral3ForConditionalGeneration"
|
||||
],
|
||||
"dtype": "bfloat16",
|
||||
"image_token_index": 10,
|
||||
"model_type": "mistral3",
|
||||
"multimodal_projector_bias": False,
|
||||
"projector_hidden_act": "gelu",
|
||||
"spatial_merge_size": 2,
|
||||
"text_config": {
|
||||
"attention_dropout": 0.0,
|
||||
"dtype": "bfloat16",
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 5120,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 32768,
|
||||
"max_position_embeddings": 131072,
|
||||
"model_type": "mistral",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 40,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_theta": 1000000000.0,
|
||||
"sliding_window": None,
|
||||
"use_cache": True,
|
||||
"vocab_size": 131072
|
||||
},
|
||||
"transformers_version": "4.57.1",
|
||||
"vision_config": {
|
||||
"attention_dropout": 0.0,
|
||||
"dtype": "bfloat16",
|
||||
"head_dim": 64,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 1024,
|
||||
"image_size": 1540,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"model_type": "pixtral",
|
||||
"num_attention_heads": 16,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 24,
|
||||
"patch_size": 14,
|
||||
"rope_theta": 10000.0
|
||||
},
|
||||
"vision_feature_layer": -1
|
||||
})
|
||||
super().__init__(config)
|
||||
|
||||
def forward(self, input_ids = None, pixel_values = None, attention_mask = None, position_ids = None, past_key_values = None, inputs_embeds = None, labels = None, use_cache = None, output_attentions = None, output_hidden_states = None, return_dict = None, cache_position = None, logits_to_keep = 0, image_sizes = None, **kwargs):
|
||||
return super().forward(input_ids, pixel_values, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, logits_to_keep, image_sizes, **kwargs)
|
||||
|
||||
2322
diffsynth/models/flux2_vae.py
Normal file
2322
diffsynth/models/flux2_vae.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,62 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from .flux_dit import RoPEEmbedding, TimestepEmbeddings, FluxJointTransformerBlock, FluxSingleTransformerBlock, RMSNorm
|
||||
from .utils import hash_state_dict_keys, init_weights_on_device
|
||||
# from .utils import hash_state_dict_keys, init_weights_on_device
|
||||
from contextlib import contextmanager
|
||||
|
||||
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||
keys_str = keys_str.encode(encoding="UTF-8")
|
||||
return hashlib.md5(keys_str).hexdigest()
|
||||
|
||||
@contextmanager
|
||||
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
|
||||
|
||||
old_register_parameter = torch.nn.Module.register_parameter
|
||||
if include_buffers:
|
||||
old_register_buffer = torch.nn.Module.register_buffer
|
||||
|
||||
def register_empty_parameter(module, name, param):
|
||||
old_register_parameter(module, name, param)
|
||||
if param is not None:
|
||||
param_cls = type(module._parameters[name])
|
||||
kwargs = module._parameters[name].__dict__
|
||||
kwargs["requires_grad"] = param.requires_grad
|
||||
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
||||
|
||||
def register_empty_buffer(module, name, buffer, persistent=True):
|
||||
old_register_buffer(module, name, buffer, persistent=persistent)
|
||||
if buffer is not None:
|
||||
module._buffers[name] = module._buffers[name].to(device)
|
||||
|
||||
def patch_tensor_constructor(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
kwargs["device"] = device
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
if include_buffers:
|
||||
tensor_constructors_to_patch = {
|
||||
torch_function_name: getattr(torch, torch_function_name)
|
||||
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
||||
}
|
||||
else:
|
||||
tensor_constructors_to_patch = {}
|
||||
|
||||
try:
|
||||
torch.nn.Module.register_parameter = register_empty_parameter
|
||||
if include_buffers:
|
||||
torch.nn.Module.register_buffer = register_empty_buffer
|
||||
for torch_function_name in tensor_constructors_to_patch.keys():
|
||||
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
||||
yield
|
||||
finally:
|
||||
torch.nn.Module.register_parameter = old_register_parameter
|
||||
if include_buffers:
|
||||
torch.nn.Module.register_buffer = old_register_buffer
|
||||
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
||||
setattr(torch, torch_function_name, old_torch_function)
|
||||
|
||||
class FluxControlNet(torch.nn.Module):
|
||||
def __init__(self, disable_guidance_embedder=False, num_joint_blocks=5, num_single_blocks=10, num_mode=0, mode_dict={}, additional_input_dim=0):
|
||||
@@ -102,9 +155,9 @@ class FluxControlNet(torch.nn.Module):
|
||||
return controlnet_res_stack, controlnet_single_res_stack
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxControlNetStateDictConverter()
|
||||
# @staticmethod
|
||||
# def state_dict_converter():
|
||||
# return FluxControlNetStateDictConverter()
|
||||
|
||||
def quantize(self):
|
||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import torch
|
||||
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm, RMSNorm
|
||||
from .general_modules import TimestepEmbeddings, AdaLayerNorm, RMSNorm
|
||||
from einops import rearrange
|
||||
from .tiler import TileWorker
|
||||
from .utils import init_weights_on_device, hash_state_dict_keys
|
||||
|
||||
|
||||
def interact_with_ipadapter(hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||
batch_size, num_tokens = hidden_states.shape[0:2]
|
||||
@@ -269,7 +268,7 @@ class AdaLayerNormContinuous(torch.nn.Module):
|
||||
|
||||
def forward(self, x, conditioning):
|
||||
emb = self.linear(self.silu(conditioning))
|
||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
shift, scale = torch.chunk(emb, 2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None] + shift[:, None]
|
||||
return x
|
||||
|
||||
@@ -321,25 +320,6 @@ class FluxDiT(torch.nn.Module):
|
||||
return latent_image_ids
|
||||
|
||||
|
||||
def tiled_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
|
||||
tile_size=128, tile_stride=64,
|
||||
**kwargs
|
||||
):
|
||||
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None),
|
||||
hidden_states,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def construct_mask(self, entity_masks, prompt_seq_len, image_seq_len):
|
||||
N = len(entity_masks)
|
||||
batch_size = entity_masks[0].shape[0]
|
||||
@@ -411,338 +391,5 @@ class FluxDiT(torch.nn.Module):
|
||||
use_gradient_checkpointing=False,
|
||||
**kwargs
|
||||
):
|
||||
if tiled:
|
||||
return self.tiled_forward(
|
||||
hidden_states,
|
||||
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids,
|
||||
tile_size=tile_size, tile_stride=tile_stride,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if image_ids is None:
|
||||
image_ids = self.prepare_image_ids(hidden_states)
|
||||
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
||||
if self.guidance_embedder is not None:
|
||||
guidance = guidance * 1000
|
||||
conditioning = conditioning + self.guidance_embedder(guidance, hidden_states.dtype)
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
hidden_states = self.patchify(hidden_states)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
if entity_prompt_emb is not None and entity_masks is not None:
|
||||
prompt_emb, image_rotary_emb, attention_mask = self.process_entity_masks(hidden_states, prompt_emb, entity_prompt_emb, entity_masks, text_ids, image_ids)
|
||||
else:
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
image_rotary_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1))
|
||||
attention_mask = None
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
|
||||
|
||||
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
|
||||
for block in self.single_blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb, attention_mask)
|
||||
hidden_states = hidden_states[:, prompt_emb.shape[1]:]
|
||||
|
||||
hidden_states = self.final_norm_out(hidden_states, conditioning)
|
||||
hidden_states = self.final_proj_out(hidden_states)
|
||||
hidden_states = self.unpatchify(hidden_states, height, width)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def quantize(self):
|
||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight)
|
||||
return r
|
||||
|
||||
def cast_weight(s, input=None, dtype=None, device=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
return weight
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
bias = None
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
bias = cast_to(s.bias, bias_dtype, device)
|
||||
return weight, bias
|
||||
|
||||
class quantized_layer:
|
||||
class Linear(torch.nn.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self,input,**kwargs):
|
||||
weight,bias= cast_bias_weight(self,input)
|
||||
return torch.nn.functional.linear(input,weight,bias)
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def forward(self,hidden_states,**kwargs):
|
||||
weight= cast_weight(self.module,hidden_states)
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
||||
hidden_states = hidden_states.to(input_dtype) * weight
|
||||
return hidden_states
|
||||
|
||||
def replace_layer(model):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
with init_weights_on_device():
|
||||
new_layer = quantized_layer.Linear(module.in_features,module.out_features)
|
||||
new_layer.weight = module.weight
|
||||
if module.bias is not None:
|
||||
new_layer.bias = module.bias
|
||||
# del module
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, RMSNorm):
|
||||
if hasattr(module,"quantized"):
|
||||
continue
|
||||
module.quantized= True
|
||||
new_layer = quantized_layer.RMSNorm(module)
|
||||
setattr(model, name, new_layer)
|
||||
else:
|
||||
replace_layer(module)
|
||||
|
||||
replace_layer(self)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxDiTStateDictConverter()
|
||||
|
||||
|
||||
class FluxDiTStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
global_rename_dict = {
|
||||
"context_embedder": "context_embedder",
|
||||
"x_embedder": "x_embedder",
|
||||
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
||||
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
||||
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
|
||||
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
||||
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
||||
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
||||
"norm_out.linear": "final_norm_out.linear",
|
||||
"proj_out": "final_proj_out",
|
||||
}
|
||||
rename_dict = {
|
||||
"proj_out": "proj_out",
|
||||
"norm1.linear": "norm1_a.linear",
|
||||
"norm1_context.linear": "norm1_b.linear",
|
||||
"attn.to_q": "attn.a_to_q",
|
||||
"attn.to_k": "attn.a_to_k",
|
||||
"attn.to_v": "attn.a_to_v",
|
||||
"attn.to_out.0": "attn.a_to_out",
|
||||
"attn.add_q_proj": "attn.b_to_q",
|
||||
"attn.add_k_proj": "attn.b_to_k",
|
||||
"attn.add_v_proj": "attn.b_to_v",
|
||||
"attn.to_add_out": "attn.b_to_out",
|
||||
"ff.net.0.proj": "ff_a.0",
|
||||
"ff.net.2": "ff_a.2",
|
||||
"ff_context.net.0.proj": "ff_b.0",
|
||||
"ff_context.net.2": "ff_b.2",
|
||||
"attn.norm_q": "attn.norm_q_a",
|
||||
"attn.norm_k": "attn.norm_k_a",
|
||||
"attn.norm_added_q": "attn.norm_q_b",
|
||||
"attn.norm_added_k": "attn.norm_k_b",
|
||||
}
|
||||
rename_dict_single = {
|
||||
"attn.to_q": "a_to_q",
|
||||
"attn.to_k": "a_to_k",
|
||||
"attn.to_v": "a_to_v",
|
||||
"attn.norm_q": "norm_q_a",
|
||||
"attn.norm_k": "norm_k_a",
|
||||
"norm.linear": "norm.linear",
|
||||
"proj_mlp": "proj_in_besides_attn",
|
||||
"proj_out": "proj_out",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.endswith(".weight") or name.endswith(".bias"):
|
||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||
prefix = name[:-len(suffix)]
|
||||
if prefix in global_rename_dict:
|
||||
state_dict_[global_rename_dict[prefix] + suffix] = param
|
||||
elif prefix.startswith("transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict:
|
||||
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
elif prefix.startswith("single_transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "single_blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict_single:
|
||||
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
for name in list(state_dict_.keys()):
|
||||
if "single_blocks." in name and ".a_to_q." in name:
|
||||
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
|
||||
if mlp is None:
|
||||
mlp = torch.zeros(4 * state_dict_[name].shape[0],
|
||||
*state_dict_[name].shape[1:],
|
||||
dtype=state_dict_[name].dtype)
|
||||
else:
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
|
||||
param = torch.concat([
|
||||
state_dict_.pop(name),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||
mlp,
|
||||
], dim=0)
|
||||
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
|
||||
state_dict_[name_] = param
|
||||
for name in list(state_dict_.keys()):
|
||||
for component in ["a", "b"]:
|
||||
if f".{component}_to_q." in name:
|
||||
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||
], dim=0)
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
if hash_state_dict_keys(state_dict, with_shape=True) in ["3e6c61b0f9471135fc9c6d6a98e98b6d", "63c969fd37cce769a90aa781fbff5f81"]:
|
||||
dit_state_dict = {key.replace("pipe.dit.", ""): value for key, value in state_dict.items() if key.startswith('pipe.dit.')}
|
||||
return dit_state_dict
|
||||
rename_dict = {
|
||||
"time_in.in_layer.bias": "time_embedder.timestep_embedder.0.bias",
|
||||
"time_in.in_layer.weight": "time_embedder.timestep_embedder.0.weight",
|
||||
"time_in.out_layer.bias": "time_embedder.timestep_embedder.2.bias",
|
||||
"time_in.out_layer.weight": "time_embedder.timestep_embedder.2.weight",
|
||||
"txt_in.bias": "context_embedder.bias",
|
||||
"txt_in.weight": "context_embedder.weight",
|
||||
"vector_in.in_layer.bias": "pooled_text_embedder.0.bias",
|
||||
"vector_in.in_layer.weight": "pooled_text_embedder.0.weight",
|
||||
"vector_in.out_layer.bias": "pooled_text_embedder.2.bias",
|
||||
"vector_in.out_layer.weight": "pooled_text_embedder.2.weight",
|
||||
"final_layer.linear.bias": "final_proj_out.bias",
|
||||
"final_layer.linear.weight": "final_proj_out.weight",
|
||||
"guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias",
|
||||
"guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight",
|
||||
"guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias",
|
||||
"guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight",
|
||||
"img_in.bias": "x_embedder.bias",
|
||||
"img_in.weight": "x_embedder.weight",
|
||||
"final_layer.adaLN_modulation.1.weight": "final_norm_out.linear.weight",
|
||||
"final_layer.adaLN_modulation.1.bias": "final_norm_out.linear.bias",
|
||||
}
|
||||
suffix_rename_dict = {
|
||||
"img_attn.norm.key_norm.scale": "attn.norm_k_a.weight",
|
||||
"img_attn.norm.query_norm.scale": "attn.norm_q_a.weight",
|
||||
"img_attn.proj.bias": "attn.a_to_out.bias",
|
||||
"img_attn.proj.weight": "attn.a_to_out.weight",
|
||||
"img_attn.qkv.bias": "attn.a_to_qkv.bias",
|
||||
"img_attn.qkv.weight": "attn.a_to_qkv.weight",
|
||||
"img_mlp.0.bias": "ff_a.0.bias",
|
||||
"img_mlp.0.weight": "ff_a.0.weight",
|
||||
"img_mlp.2.bias": "ff_a.2.bias",
|
||||
"img_mlp.2.weight": "ff_a.2.weight",
|
||||
"img_mod.lin.bias": "norm1_a.linear.bias",
|
||||
"img_mod.lin.weight": "norm1_a.linear.weight",
|
||||
"txt_attn.norm.key_norm.scale": "attn.norm_k_b.weight",
|
||||
"txt_attn.norm.query_norm.scale": "attn.norm_q_b.weight",
|
||||
"txt_attn.proj.bias": "attn.b_to_out.bias",
|
||||
"txt_attn.proj.weight": "attn.b_to_out.weight",
|
||||
"txt_attn.qkv.bias": "attn.b_to_qkv.bias",
|
||||
"txt_attn.qkv.weight": "attn.b_to_qkv.weight",
|
||||
"txt_mlp.0.bias": "ff_b.0.bias",
|
||||
"txt_mlp.0.weight": "ff_b.0.weight",
|
||||
"txt_mlp.2.bias": "ff_b.2.bias",
|
||||
"txt_mlp.2.weight": "ff_b.2.weight",
|
||||
"txt_mod.lin.bias": "norm1_b.linear.bias",
|
||||
"txt_mod.lin.weight": "norm1_b.linear.weight",
|
||||
|
||||
"linear1.bias": "to_qkv_mlp.bias",
|
||||
"linear1.weight": "to_qkv_mlp.weight",
|
||||
"linear2.bias": "proj_out.bias",
|
||||
"linear2.weight": "proj_out.weight",
|
||||
"modulation.lin.bias": "norm.linear.bias",
|
||||
"modulation.lin.weight": "norm.linear.weight",
|
||||
"norm.key_norm.scale": "norm_k_a.weight",
|
||||
"norm.query_norm.scale": "norm_q_a.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("model.diffusion_model."):
|
||||
name = name[len("model.diffusion_model."):]
|
||||
names = name.split(".")
|
||||
if name in rename_dict:
|
||||
rename = rename_dict[name]
|
||||
if name.startswith("final_layer.adaLN_modulation.1."):
|
||||
param = torch.concat([param[3072:], param[:3072]], dim=0)
|
||||
state_dict_[rename] = param
|
||||
elif names[0] == "double_blocks":
|
||||
rename = f"blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
||||
state_dict_[rename] = param
|
||||
elif names[0] == "single_blocks":
|
||||
if ".".join(names[2:]) in suffix_rename_dict:
|
||||
rename = f"single_blocks.{names[1]}." + suffix_rename_dict[".".join(names[2:])]
|
||||
state_dict_[rename] = param
|
||||
else:
|
||||
pass
|
||||
if "guidance_embedder.timestep_embedder.0.weight" not in state_dict_:
|
||||
return state_dict_, {"disable_guidance_embedder": True}
|
||||
elif "blocks.8.attn.norm_k_a.weight" not in state_dict_:
|
||||
return state_dict_, {"input_dim": 196, "num_blocks": 8}
|
||||
else:
|
||||
return state_dict_
|
||||
# (Deprecated) The real forward is in `pipelines.flux_image`.
|
||||
return None
|
||||
|
||||
@@ -1,9 +1,25 @@
|
||||
from .svd_image_encoder import SVDImageEncoder
|
||||
from .sd3_dit import RMSNorm
|
||||
from transformers import CLIPImageProcessor
|
||||
from .general_modules import RMSNorm
|
||||
from transformers import SiglipVisionModel, SiglipVisionConfig
|
||||
import torch
|
||||
|
||||
|
||||
class SiglipVisionModelSO400M(SiglipVisionModel):
|
||||
def __init__(self):
|
||||
config = SiglipVisionConfig(
|
||||
hidden_size=1152,
|
||||
image_size=384,
|
||||
intermediate_size=4304,
|
||||
model_type="siglip_vision_model",
|
||||
num_attention_heads=16,
|
||||
num_hidden_layers=27,
|
||||
patch_size=14,
|
||||
architectures=["SiglipModel"],
|
||||
initializer_factor=1.0,
|
||||
torch_dtype="float32",
|
||||
transformers_version="4.37.0.dev0"
|
||||
)
|
||||
super().__init__(config)
|
||||
|
||||
class MLPProjModel(torch.nn.Module):
|
||||
def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, num_tokens=4):
|
||||
super().__init__()
|
||||
|
||||
@@ -1,5 +1,415 @@
|
||||
import torch
|
||||
from .sd_text_encoder import CLIPEncoderLayer
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def low_version_attention(query, key, value, attn_bias=None):
|
||||
scale = 1 / query.shape[-1] ** 0.5
|
||||
query = query * scale
|
||||
attn = torch.matmul(query, key.transpose(-2, -1))
|
||||
if attn_bias is not None:
|
||||
attn = attn + attn_bias
|
||||
attn = attn.softmax(-1)
|
||||
return attn @ value
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||
super().__init__()
|
||||
dim_inner = head_dim * num_heads
|
||||
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||
|
||||
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||
batch_size = q.shape[0]
|
||||
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
return hidden_states
|
||||
|
||||
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(encoder_hidden_states)
|
||||
v = self.to_v(encoder_hidden_states)
|
||||
|
||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if qkv_preprocessor is not None:
|
||||
q, k, v = qkv_preprocessor(q, k, v)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
if ipadapter_kwargs is not None:
|
||||
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(encoder_hidden_states)
|
||||
v = self.to_v(encoder_hidden_states)
|
||||
|
||||
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||
|
||||
if attn_mask is not None:
|
||||
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
|
||||
else:
|
||||
import xformers.ops as xops
|
||||
hidden_states = xops.memory_efficient_attention(q, k, v)
|
||||
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
|
||||
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class CLIPEncoderLayer(torch.nn.Module):
|
||||
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
|
||||
super().__init__()
|
||||
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
|
||||
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
|
||||
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
|
||||
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
|
||||
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
|
||||
|
||||
self.use_quick_gelu = use_quick_gelu
|
||||
|
||||
def quickGELU(self, x):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
def forward(self, hidden_states, attn_mask=None):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
if self.use_quick_gelu:
|
||||
hidden_states = self.quickGELU(hidden_states)
|
||||
else:
|
||||
hidden_states = torch.nn.functional.gelu(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SDTextEncoder(torch.nn.Module):
|
||||
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
||||
super().__init__()
|
||||
|
||||
# token_embedding
|
||||
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
||||
|
||||
# position_embeds (This is a fixed tensor)
|
||||
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||
|
||||
# encoders
|
||||
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
||||
|
||||
# attn_mask
|
||||
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||
|
||||
# final_layer_norm
|
||||
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||
|
||||
def attention_mask(self, length):
|
||||
mask = torch.empty(length, length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1)
|
||||
return mask
|
||||
|
||||
def forward(self, input_ids, clip_skip=1):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
break
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
return embeds
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SDTextEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
||||
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
||||
}
|
||||
attn_rename_dict = {
|
||||
"self_attn.q_proj": "attn.to_q",
|
||||
"self_attn.k_proj": "attn.to_k",
|
||||
"self_attn.v_proj": "attn.to_v",
|
||||
"self_attn.out_proj": "attn.to_out",
|
||||
"layer_norm1": "layer_norm1",
|
||||
"layer_norm2": "layer_norm2",
|
||||
"mlp.fc1": "fc1",
|
||||
"mlp.fc2": "fc2",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif name.startswith("text_model.encoder.layers."):
|
||||
param = state_dict[name]
|
||||
names = name.split(".")
|
||||
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
||||
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"cond_stage_model.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.bias": "encoders.11.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm1.weight": "encoders.11.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.bias": "encoders.11.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.layer_norm2.weight": "encoders.11.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.bias": "encoders.11.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc1.weight": "encoders.11.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.bias": "encoders.11.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.mlp.fc2.weight": "encoders.11.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.bias": "encoders.11.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.k_proj.weight": "encoders.11.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.bias": "encoders.11.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.q_proj.weight": "encoders.11.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.bias": "encoders.11.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.11.self_attn.v_proj.weight": "encoders.11.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
|
||||
"cond_stage_model.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
|
||||
"cond_stage_model.transformer.text_model.final_layer_norm.bias": "final_layer_norm.bias",
|
||||
"cond_stage_model.transformer.text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||
"cond_stage_model.transformer.text_model.embeddings.position_embedding.weight": "position_embeds"
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "cond_stage_model.transformer.text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
|
||||
class LoRALayerBlock(torch.nn.Module):
|
||||
@@ -63,8 +473,8 @@ class LoRAEmbedder(torch.nn.Module):
|
||||
lora_emb = []
|
||||
for lora_pattern in self.lora_patterns:
|
||||
name, layer_type = lora_pattern["name"], lora_pattern["type"]
|
||||
lora_A = lora[name + ".lora_A.default.weight"]
|
||||
lora_B = lora[name + ".lora_B.default.weight"]
|
||||
lora_A = lora[name + ".lora_A.weight"]
|
||||
lora_B = lora[name + ".lora_B.weight"]
|
||||
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
|
||||
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
|
||||
lora_emb.append(lora_out)
|
||||
|
||||
306
diffsynth/models/flux_lora_patcher.py
Normal file
306
diffsynth/models/flux_lora_patcher.py
Normal file
@@ -0,0 +1,306 @@
|
||||
import torch, math
|
||||
from ..core.loader import load_state_dict
|
||||
from typing import Union
|
||||
|
||||
class GeneralLoRALoader:
|
||||
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
|
||||
def get_name_dict(self, lora_state_dict):
|
||||
lora_name_dict = {}
|
||||
for key in lora_state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
keys = key.split(".")
|
||||
if len(keys) > keys.index("lora_B") + 2:
|
||||
keys.pop(keys.index("lora_B") + 1)
|
||||
keys.pop(keys.index("lora_B"))
|
||||
if keys[0] == "diffusion_model":
|
||||
keys.pop(0)
|
||||
keys.pop(-1)
|
||||
target_name = ".".join(keys)
|
||||
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
||||
return lora_name_dict
|
||||
|
||||
|
||||
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||
updated_num = 0
|
||||
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||
for name, module in model.named_modules():
|
||||
if name in lora_name_dict:
|
||||
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
|
||||
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
||||
state_dict = module.state_dict()
|
||||
state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
|
||||
module.load_state_dict(state_dict)
|
||||
updated_num += 1
|
||||
print(f"{updated_num} tensors are updated by LoRA.")
|
||||
|
||||
class FluxLoRALoader(GeneralLoRALoader):
|
||||
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
|
||||
self.diffusers_rename_dict = {
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_A.weight":"single_blocks.blockid.a_to_k.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_k.lora_B.weight":"single_blocks.blockid.a_to_k.lora_B.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_A.weight":"single_blocks.blockid.a_to_q.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_q.lora_B.weight":"single_blocks.blockid.a_to_q.lora_B.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_A.weight":"single_blocks.blockid.a_to_v.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.attn.to_v.lora_B.weight":"single_blocks.blockid.a_to_v.lora_B.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.norm.linear.lora_A.weight":"single_blocks.blockid.norm.linear.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.norm.linear.lora_B.weight":"single_blocks.blockid.norm.linear.lora_B.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_A.weight":"single_blocks.blockid.proj_in_besides_attn.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.proj_mlp.lora_B.weight":"single_blocks.blockid.proj_in_besides_attn.lora_B.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.proj_out.lora_A.weight":"single_blocks.blockid.proj_out.lora_A.default.weight",
|
||||
"transformer.single_transformer_blocks.blockid.proj_out.lora_B.weight":"single_blocks.blockid.proj_out.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_A.weight":"blocks.blockid.attn.b_to_k.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_k_proj.lora_B.weight":"blocks.blockid.attn.b_to_k.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_A.weight":"blocks.blockid.attn.b_to_q.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_q_proj.lora_B.weight":"blocks.blockid.attn.b_to_q.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_A.weight":"blocks.blockid.attn.b_to_v.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.add_v_proj.lora_B.weight":"blocks.blockid.attn.b_to_v.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_A.weight":"blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_add_out.lora_B.weight":"blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_k.lora_A.weight":"blocks.blockid.attn.a_to_k.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_k.lora_B.weight":"blocks.blockid.attn.a_to_k.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_A.weight":"blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_out.0.lora_B.weight":"blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_q.lora_A.weight":"blocks.blockid.attn.a_to_q.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_q.lora_B.weight":"blocks.blockid.attn.a_to_q.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_v.lora_A.weight":"blocks.blockid.attn.a_to_v.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.attn.to_v.lora_B.weight":"blocks.blockid.attn.a_to_v.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_A.weight":"blocks.blockid.ff_a.0.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff.net.0.proj.lora_B.weight":"blocks.blockid.ff_a.0.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff.net.2.lora_A.weight":"blocks.blockid.ff_a.2.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff.net.2.lora_B.weight":"blocks.blockid.ff_a.2.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_A.weight":"blocks.blockid.ff_b.0.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff_context.net.0.proj.lora_B.weight":"blocks.blockid.ff_b.0.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_A.weight":"blocks.blockid.ff_b.2.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.ff_context.net.2.lora_B.weight":"blocks.blockid.ff_b.2.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.norm1.linear.lora_A.weight":"blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.norm1.linear.lora_B.weight":"blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
||||
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_A.weight":"blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
||||
"transformer.transformer_blocks.blockid.norm1_context.linear.lora_B.weight":"blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
||||
}
|
||||
|
||||
self.civitai_rename_dict = {
|
||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
|
||||
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
|
||||
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
|
||||
}
|
||||
|
||||
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
||||
super().load(model, state_dict_lora, alpha)
|
||||
|
||||
|
||||
def convert_state_dict(self,state_dict):
|
||||
|
||||
def guess_block_id(name,model_resource):
|
||||
if model_resource == 'civitai':
|
||||
names = name.split("_")
|
||||
for i in names:
|
||||
if i.isdigit():
|
||||
return i, name.replace(f"_{i}_", "_blockid_")
|
||||
if model_resource == 'diffusers':
|
||||
names = name.split(".")
|
||||
for i in names:
|
||||
if i.isdigit():
|
||||
return i, name.replace(f"transformer_blocks.{i}.", "transformer_blocks.blockid.")
|
||||
return None, None
|
||||
|
||||
def guess_resource(state_dict):
|
||||
for k in state_dict:
|
||||
if "lora_unet_" in k:
|
||||
return 'civitai'
|
||||
elif k.startswith("transformer."):
|
||||
return 'diffusers'
|
||||
else:
|
||||
None
|
||||
|
||||
model_resource = guess_resource(state_dict)
|
||||
if model_resource is None:
|
||||
return state_dict
|
||||
|
||||
rename_dict = self.diffusers_rename_dict if model_resource == 'diffusers' else self.civitai_rename_dict
|
||||
def guess_alpha(state_dict):
|
||||
for name, param in state_dict.items():
|
||||
if ".alpha" in name:
|
||||
for suffix in [".lora_down.weight", ".lora_A.weight"]:
|
||||
name_ = name.replace(".alpha", suffix)
|
||||
if name_ in state_dict:
|
||||
lora_alpha = param.item() / state_dict[name_].shape[0]
|
||||
lora_alpha = math.sqrt(lora_alpha)
|
||||
return lora_alpha
|
||||
|
||||
return 1
|
||||
|
||||
alpha = guess_alpha(state_dict)
|
||||
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
block_id, source_name = guess_block_id(name,model_resource)
|
||||
if alpha != 1:
|
||||
param *= alpha
|
||||
if source_name in rename_dict:
|
||||
target_name = rename_dict[source_name]
|
||||
target_name = target_name.replace(".blockid.", f".{block_id}.")
|
||||
state_dict_[target_name] = param
|
||||
else:
|
||||
state_dict_[name] = param
|
||||
|
||||
if model_resource == 'diffusers':
|
||||
for name in list(state_dict_.keys()):
|
||||
if "single_blocks." in name and ".a_to_q." in name:
|
||||
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
|
||||
if mlp is None:
|
||||
dim = 4
|
||||
if 'lora_A' in name:
|
||||
dim = 1
|
||||
mlp = torch.zeros(dim * state_dict_[name].shape[0],
|
||||
*state_dict_[name].shape[1:],
|
||||
dtype=state_dict_[name].dtype)
|
||||
else:
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
|
||||
if 'lora_A' in name:
|
||||
param = torch.concat([
|
||||
state_dict_.pop(name),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||
mlp,
|
||||
], dim=0)
|
||||
elif 'lora_B' in name:
|
||||
d, r = state_dict_[name].shape
|
||||
param = torch.zeros((3*d+mlp.shape[0], 3*r+mlp.shape[1]), dtype=state_dict_[name].dtype, device=state_dict_[name].device)
|
||||
param[:d, :r] = state_dict_.pop(name)
|
||||
param[d:2*d, r:2*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_k."))
|
||||
param[2*d:3*d, 2*r:3*r] = state_dict_.pop(name.replace(".a_to_q.", ".a_to_v."))
|
||||
param[3*d:, 3*r:] = mlp
|
||||
else:
|
||||
param = torch.concat([
|
||||
state_dict_.pop(name),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||
mlp,
|
||||
], dim=0)
|
||||
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
|
||||
state_dict_[name_] = param
|
||||
for name in list(state_dict_.keys()):
|
||||
for component in ["a", "b"]:
|
||||
if f".{component}_to_q." in name:
|
||||
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||
concat_dim = 0
|
||||
if 'lora_A' in name:
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||
], dim=0)
|
||||
elif 'lora_B' in name:
|
||||
origin = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
|
||||
d, r = origin.shape
|
||||
# print(d, r)
|
||||
param = torch.zeros((3*d, 3*r), dtype=origin.dtype, device=origin.device)
|
||||
param[:d, :r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")]
|
||||
param[d:2*d, r:2*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")]
|
||||
param[2*d:3*d, 2*r:3*r] = state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")]
|
||||
else:
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||
], dim=0)
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||
return state_dict_
|
||||
|
||||
|
||||
class LoraMerger(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.weight_base = torch.nn.Parameter(torch.randn((dim,)))
|
||||
self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))
|
||||
self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))
|
||||
self.weight_out = torch.nn.Parameter(torch.ones((dim,)))
|
||||
self.bias = torch.nn.Parameter(torch.randn((dim,)))
|
||||
self.activation = torch.nn.Sigmoid()
|
||||
self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)
|
||||
self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)
|
||||
|
||||
def forward(self, base_output, lora_outputs):
|
||||
norm_base_output = self.norm_base(base_output)
|
||||
norm_lora_outputs = self.norm_lora(lora_outputs)
|
||||
gate = self.activation(
|
||||
norm_base_output * self.weight_base \
|
||||
+ norm_lora_outputs * self.weight_lora \
|
||||
+ norm_base_output * norm_lora_outputs * self.weight_cross + self.bias
|
||||
)
|
||||
output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
|
||||
return output
|
||||
|
||||
class FluxLoraPatcher(torch.nn.Module):
|
||||
def __init__(self, lora_patterns=None):
|
||||
super().__init__()
|
||||
if lora_patterns is None:
|
||||
lora_patterns = self.default_lora_patterns()
|
||||
model_dict = {}
|
||||
for lora_pattern in lora_patterns:
|
||||
name, dim = lora_pattern["name"], lora_pattern["dim"]
|
||||
model_dict[name.replace(".", "___")] = LoraMerger(dim)
|
||||
self.model_dict = torch.nn.ModuleDict(model_dict)
|
||||
|
||||
def default_lora_patterns(self):
|
||||
lora_patterns = []
|
||||
lora_dict = {
|
||||
"attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432,
|
||||
"attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432,
|
||||
}
|
||||
for i in range(19):
|
||||
for suffix in lora_dict:
|
||||
lora_patterns.append({
|
||||
"name": f"blocks.{i}.{suffix}",
|
||||
"dim": lora_dict[suffix]
|
||||
})
|
||||
lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216}
|
||||
for i in range(38):
|
||||
for suffix in lora_dict:
|
||||
lora_patterns.append({
|
||||
"name": f"single_blocks.{i}.{suffix}",
|
||||
"dim": lora_dict[suffix]
|
||||
})
|
||||
return lora_patterns
|
||||
|
||||
def forward(self, base_output, lora_outputs, name):
|
||||
return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)
|
||||
@@ -1,32 +0,0 @@
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Config
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
|
||||
|
||||
|
||||
class FluxTextEncoder2(T5EncoderModel):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.eval()
|
||||
|
||||
def forward(self, input_ids):
|
||||
outputs = super().forward(input_ids=input_ids)
|
||||
prompt_emb = outputs.last_hidden_state
|
||||
return prompt_emb
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxTextEncoder2StateDictConverter()
|
||||
|
||||
|
||||
|
||||
class FluxTextEncoder2StateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = state_dict
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
112
diffsynth/models/flux_text_encoder_clip.py
Normal file
112
diffsynth/models/flux_text_encoder_clip.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import torch
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||
super().__init__()
|
||||
dim_inner = head_dim * num_heads
|
||||
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(encoder_hidden_states)
|
||||
v = self.to_v(encoder_hidden_states)
|
||||
|
||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CLIPEncoderLayer(torch.nn.Module):
|
||||
def __init__(self, embed_dim, intermediate_size, num_heads=12, head_dim=64, use_quick_gelu=True):
|
||||
super().__init__()
|
||||
self.attn = Attention(q_dim=embed_dim, num_heads=num_heads, head_dim=head_dim, bias_q=True, bias_kv=True, bias_out=True)
|
||||
self.layer_norm1 = torch.nn.LayerNorm(embed_dim)
|
||||
self.layer_norm2 = torch.nn.LayerNorm(embed_dim)
|
||||
self.fc1 = torch.nn.Linear(embed_dim, intermediate_size)
|
||||
self.fc2 = torch.nn.Linear(intermediate_size, embed_dim)
|
||||
|
||||
self.use_quick_gelu = use_quick_gelu
|
||||
|
||||
def quickGELU(self, x):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
def forward(self, hidden_states, attn_mask=None):
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.attn(hidden_states, attn_mask=attn_mask)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
if self.use_quick_gelu:
|
||||
hidden_states = self.quickGELU(hidden_states)
|
||||
else:
|
||||
hidden_states = torch.nn.functional.gelu(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxTextEncoderClip(torch.nn.Module):
|
||||
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
||||
super().__init__()
|
||||
|
||||
# token_embedding
|
||||
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
||||
|
||||
# position_embeds (This is a fixed tensor)
|
||||
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||
|
||||
# encoders
|
||||
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
||||
|
||||
# attn_mask
|
||||
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||
|
||||
# final_layer_norm
|
||||
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||
|
||||
def attention_mask(self, length):
|
||||
mask = torch.empty(length, length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1)
|
||||
return mask
|
||||
|
||||
def forward(self, input_ids, clip_skip=2, extra_mask=None):
|
||||
embeds = self.token_embedding(input_ids)
|
||||
embeds = embeds + self.position_embeds.to(dtype=embeds.dtype, device=input_ids.device)
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
if extra_mask is not None:
|
||||
attn_mask[:, extra_mask[0]==0] = float("-inf")
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
hidden_states = embeds
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
||||
return pooled_embeds, hidden_states
|
||||
43
diffsynth/models/flux_text_encoder_t5.py
Normal file
43
diffsynth/models/flux_text_encoder_t5.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Config
|
||||
|
||||
|
||||
class FluxTextEncoderT5(T5EncoderModel):
|
||||
def __init__(self):
|
||||
config = T5Config(**{
|
||||
"architectures": [
|
||||
"T5EncoderModel"
|
||||
],
|
||||
"classifier_dropout": 0.0,
|
||||
"d_ff": 10240,
|
||||
"d_kv": 64,
|
||||
"d_model": 4096,
|
||||
"decoder_start_token_id": 0,
|
||||
"dense_act_fn": "gelu_new",
|
||||
"dropout_rate": 0.1,
|
||||
"dtype": "bfloat16",
|
||||
"eos_token_id": 1,
|
||||
"feed_forward_proj": "gated-gelu",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": True,
|
||||
"is_gated_act": True,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "t5",
|
||||
"num_decoder_layers": 24,
|
||||
"num_heads": 64,
|
||||
"num_layers": 24,
|
||||
"output_past": True,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_max_distance": 128,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": False,
|
||||
"transformers_version": "4.57.1",
|
||||
"use_cache": True,
|
||||
"vocab_size": 32128
|
||||
})
|
||||
super().__init__(config)
|
||||
|
||||
def forward(self, input_ids):
|
||||
outputs = super().forward(input_ids=input_ids)
|
||||
prompt_emb = outputs.last_hidden_state
|
||||
return prompt_emb
|
||||
@@ -1,303 +1,451 @@
|
||||
from .sd3_vae_encoder import SD3VAEEncoder, SDVAEEncoderStateDictConverter
|
||||
from .sd3_vae_decoder import SD3VAEDecoder, SDVAEDecoderStateDictConverter
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
class FluxVAEEncoder(SD3VAEEncoder):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
class FluxVAEDecoder(SD3VAEDecoder):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return FluxVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
class FluxVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
||||
class TileWorker:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"encoder.conv_in.bias": "conv_in.bias",
|
||||
"encoder.conv_in.weight": "conv_in.weight",
|
||||
"encoder.conv_out.bias": "conv_out.bias",
|
||||
"encoder.conv_out.weight": "conv_out.weight",
|
||||
"encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
|
||||
"encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
|
||||
"encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
|
||||
"encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
|
||||
"encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
|
||||
"encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
|
||||
"encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
|
||||
"encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
|
||||
"encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
|
||||
"encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
|
||||
"encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
|
||||
"encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
|
||||
"encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
|
||||
"encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
|
||||
"encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
|
||||
"encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
|
||||
"encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
|
||||
"encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
|
||||
"encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
|
||||
"encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
|
||||
"encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
|
||||
"encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
|
||||
"encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
|
||||
"encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
|
||||
"encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
|
||||
"encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
|
||||
"encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
|
||||
"encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
|
||||
"encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
|
||||
"encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
|
||||
"encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
|
||||
"encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
|
||||
"encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
|
||||
"encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
|
||||
"encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
|
||||
"encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
|
||||
"encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
|
||||
"encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
|
||||
"encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
|
||||
"encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
|
||||
"encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
|
||||
"encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
|
||||
"encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
|
||||
"encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
|
||||
"encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
|
||||
"encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
|
||||
"encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
|
||||
"encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
|
||||
"encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
|
||||
"encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
|
||||
"encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
|
||||
"encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
|
||||
"encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
|
||||
"encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
|
||||
"encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
|
||||
"encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
|
||||
"encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
|
||||
"encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
|
||||
"encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
|
||||
"encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
|
||||
"encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
|
||||
"encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
|
||||
"encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
|
||||
"encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
|
||||
"encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
|
||||
"encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
|
||||
"encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
|
||||
"encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
|
||||
"encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
|
||||
"encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
|
||||
"encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
|
||||
"encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
|
||||
"encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
|
||||
"encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
|
||||
"encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
|
||||
"encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
|
||||
"encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
|
||||
"encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
|
||||
"encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
|
||||
"encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
|
||||
"encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
|
||||
"encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
|
||||
"encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
|
||||
"encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
|
||||
"encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
|
||||
"encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
|
||||
"encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
|
||||
"encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
|
||||
"encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
|
||||
"encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
|
||||
"encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
|
||||
"encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
|
||||
"encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
|
||||
"encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
|
||||
"encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
|
||||
"encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
|
||||
"encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
|
||||
"encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
|
||||
"encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
|
||||
"encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
|
||||
"encoder.norm_out.bias": "conv_norm_out.bias",
|
||||
"encoder.norm_out.weight": "conv_norm_out.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if "transformer_blocks" in rename_dict[name]:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
|
||||
def mask(self, height, width, border_width):
|
||||
# Create a mask with shape (height, width).
|
||||
# The centre area is filled with 1, and the border line is filled with values in range (0, 1].
|
||||
x = torch.arange(height).repeat(width, 1).T
|
||||
y = torch.arange(width).repeat(height, 1)
|
||||
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
|
||||
mask = (mask / border_width).clip(0, 1)
|
||||
return mask
|
||||
|
||||
|
||||
def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
|
||||
# Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
|
||||
batch_size, channel, _, _ = model_input.shape
|
||||
model_input = model_input.to(device=tile_device, dtype=tile_dtype)
|
||||
unfold_operator = torch.nn.Unfold(
|
||||
kernel_size=(tile_size, tile_size),
|
||||
stride=(tile_stride, tile_stride)
|
||||
)
|
||||
model_input = unfold_operator(model_input)
|
||||
model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
|
||||
|
||||
return model_input
|
||||
|
||||
|
||||
def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
|
||||
# Call y=forward_fn(x) for each tile
|
||||
tile_num = model_input.shape[-1]
|
||||
model_output_stack = []
|
||||
|
||||
for tile_id in range(0, tile_num, tile_batch_size):
|
||||
|
||||
# process input
|
||||
tile_id_ = min(tile_id + tile_batch_size, tile_num)
|
||||
x = model_input[:, :, :, :, tile_id: tile_id_]
|
||||
x = x.to(device=inference_device, dtype=inference_dtype)
|
||||
x = rearrange(x, "b c h w n -> (n b) c h w")
|
||||
|
||||
# process output
|
||||
y = forward_fn(x)
|
||||
y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
|
||||
y = y.to(device=tile_device, dtype=tile_dtype)
|
||||
model_output_stack.append(y)
|
||||
|
||||
model_output = torch.concat(model_output_stack, dim=-1)
|
||||
return model_output
|
||||
|
||||
|
||||
def io_scale(self, model_output, tile_size):
|
||||
# Determine the size modification happened in forward_fn
|
||||
# We only consider the same scale on height and width.
|
||||
io_scale = model_output.shape[2] / tile_size
|
||||
return io_scale
|
||||
|
||||
|
||||
def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
|
||||
# The reversed function of tile
|
||||
mask = self.mask(tile_size, tile_size, border_width)
|
||||
mask = mask.to(device=tile_device, dtype=tile_dtype)
|
||||
mask = rearrange(mask, "h w -> 1 1 h w 1")
|
||||
model_output = model_output * mask
|
||||
|
||||
class FluxVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
|
||||
def __init__(self):
|
||||
pass
|
||||
fold_operator = torch.nn.Fold(
|
||||
output_size=(height, width),
|
||||
kernel_size=(tile_size, tile_size),
|
||||
stride=(tile_stride, tile_stride)
|
||||
)
|
||||
mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
|
||||
model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
|
||||
model_output = fold_operator(model_output) / fold_operator(mask)
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"decoder.conv_in.bias": "conv_in.bias",
|
||||
"decoder.conv_in.weight": "conv_in.weight",
|
||||
"decoder.conv_out.bias": "conv_out.bias",
|
||||
"decoder.conv_out.weight": "conv_out.weight",
|
||||
"decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
|
||||
"decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
|
||||
"decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
|
||||
"decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
|
||||
"decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
|
||||
"decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
|
||||
"decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
|
||||
"decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
|
||||
"decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
|
||||
"decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
|
||||
"decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
|
||||
"decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
|
||||
"decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
|
||||
"decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
|
||||
"decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
|
||||
"decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
|
||||
"decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
|
||||
"decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
|
||||
"decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
|
||||
"decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
|
||||
"decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
|
||||
"decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
|
||||
"decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
|
||||
"decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
|
||||
"decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
|
||||
"decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
|
||||
"decoder.norm_out.bias": "conv_norm_out.bias",
|
||||
"decoder.norm_out.weight": "conv_norm_out.weight",
|
||||
"decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
|
||||
"decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
|
||||
"decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
|
||||
"decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
|
||||
"decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
|
||||
"decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
|
||||
"decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
|
||||
"decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
|
||||
"decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
|
||||
"decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
|
||||
"decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
|
||||
"decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
|
||||
"decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
|
||||
"decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
|
||||
"decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
|
||||
"decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
|
||||
"decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
|
||||
"decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
|
||||
"decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
|
||||
"decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
|
||||
"decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
|
||||
"decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
|
||||
"decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
|
||||
"decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
|
||||
"decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
|
||||
"decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
|
||||
"decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
|
||||
"decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
|
||||
"decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
|
||||
"decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
|
||||
"decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
|
||||
"decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
|
||||
"decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
|
||||
"decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
|
||||
"decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
|
||||
"decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
|
||||
"decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
|
||||
"decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
|
||||
"decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
|
||||
"decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
|
||||
"decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
|
||||
"decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
|
||||
"decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
|
||||
"decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
|
||||
"decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
|
||||
"decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
|
||||
"decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
|
||||
"decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
|
||||
"decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
|
||||
"decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
|
||||
"decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
|
||||
"decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
|
||||
"decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
|
||||
"decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
|
||||
"decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
|
||||
"decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
|
||||
"decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
|
||||
"decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
|
||||
"decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
|
||||
"decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
|
||||
"decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
|
||||
"decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
|
||||
"decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
|
||||
"decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
|
||||
"decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
|
||||
"decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
|
||||
"decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
|
||||
"decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
|
||||
"decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
|
||||
"decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
|
||||
"decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
|
||||
"decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
|
||||
"decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
|
||||
"decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
|
||||
"decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
|
||||
"decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
|
||||
"decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
|
||||
"decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
|
||||
"decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
|
||||
"decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
|
||||
"decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
|
||||
"decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
|
||||
"decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
|
||||
"decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
|
||||
"decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
|
||||
"decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
|
||||
"decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
|
||||
"decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
|
||||
"decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
|
||||
"decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
|
||||
"decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
|
||||
"decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
|
||||
"decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
|
||||
"decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
|
||||
"decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
|
||||
"decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
|
||||
"decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
|
||||
"decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
|
||||
"decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
|
||||
"decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
|
||||
"decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
|
||||
"decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
|
||||
"decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
|
||||
"decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
|
||||
"decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
|
||||
"decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if "transformer_blocks" in rename_dict[name]:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
return model_output
|
||||
|
||||
|
||||
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
|
||||
# Prepare
|
||||
inference_device, inference_dtype = model_input.device, model_input.dtype
|
||||
height, width = model_input.shape[2], model_input.shape[3]
|
||||
border_width = int(tile_stride*0.5) if border_width is None else border_width
|
||||
|
||||
# tile
|
||||
model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
|
||||
|
||||
# inference
|
||||
model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
|
||||
|
||||
# resize
|
||||
io_scale = self.io_scale(model_output, tile_size)
|
||||
height, width = int(height*io_scale), int(width*io_scale)
|
||||
tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
|
||||
border_width = int(border_width*io_scale)
|
||||
|
||||
# untile
|
||||
model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
|
||||
|
||||
# Done!
|
||||
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
|
||||
return model_output
|
||||
|
||||
|
||||
class ConvAttention(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||
super().__init__()
|
||||
dim_inner = head_dim * num_heads
|
||||
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.to_q = torch.nn.Conv2d(q_dim, dim_inner, kernel_size=(1, 1), bias=bias_q)
|
||||
self.to_k = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
|
||||
self.to_v = torch.nn.Conv2d(kv_dim, dim_inner, kernel_size=(1, 1), bias=bias_kv)
|
||||
self.to_out = torch.nn.Conv2d(dim_inner, q_dim, kernel_size=(1, 1), bias=bias_out)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
conv_input = rearrange(hidden_states, "B L C -> B C L 1")
|
||||
q = self.to_q(conv_input)
|
||||
q = rearrange(q[:, :, :, 0], "B C L -> B L C")
|
||||
conv_input = rearrange(encoder_hidden_states, "B L C -> B C L 1")
|
||||
k = self.to_k(conv_input)
|
||||
v = self.to_v(conv_input)
|
||||
k = rearrange(k[:, :, :, 0], "B C L -> B L C")
|
||||
v = rearrange(v[:, :, :, 0], "B C L -> B L C")
|
||||
|
||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
|
||||
conv_input = rearrange(hidden_states, "B L C -> B C L 1")
|
||||
hidden_states = self.to_out(conv_input)
|
||||
hidden_states = rearrange(hidden_states[:, :, :, 0], "B C L -> B L C")
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||
super().__init__()
|
||||
dim_inner = head_dim * num_heads
|
||||
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(encoder_hidden_states)
|
||||
v = self.to_v(encoder_hidden_states)
|
||||
|
||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class VAEAttentionBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5, use_conv_attention=True):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
|
||||
if use_conv_attention:
|
||||
self.transformer_blocks = torch.nn.ModuleList([
|
||||
ConvAttention(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
bias_q=True,
|
||||
bias_kv=True,
|
||||
bias_out=True
|
||||
)
|
||||
for d in range(num_layers)
|
||||
])
|
||||
else:
|
||||
self.transformer_blocks = torch.nn.ModuleList([
|
||||
Attention(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
bias_q=True,
|
||||
bias_kv=True,
|
||||
bias_out=True
|
||||
)
|
||||
for d in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
||||
batch, _, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
|
||||
class ResnetBlock(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, temb_channels=None, groups=32, eps=1e-5):
|
||||
super().__init__()
|
||||
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels is not None:
|
||||
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.nonlinearity = torch.nn.SiLU()
|
||||
self.conv_shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=True)
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
||||
x = hidden_states
|
||||
x = self.norm1(x)
|
||||
x = self.nonlinearity(x)
|
||||
x = self.conv1(x)
|
||||
if time_emb is not None:
|
||||
emb = self.nonlinearity(time_emb)
|
||||
emb = self.time_emb_proj(emb)[:, :, None, None]
|
||||
x = x + emb
|
||||
x = self.norm2(x)
|
||||
x = self.nonlinearity(x)
|
||||
x = self.conv2(x)
|
||||
if self.conv_shortcut is not None:
|
||||
hidden_states = self.conv_shortcut(hidden_states)
|
||||
hidden_states = hidden_states + x
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
|
||||
class UpSampler(torch.nn.Module):
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(channels, channels, 3, padding=1)
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
||||
hidden_states = torch.nn.functional.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
||||
hidden_states = self.conv(hidden_states)
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
|
||||
class DownSampler(torch.nn.Module):
|
||||
def __init__(self, channels, padding=1, extra_padding=False):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(channels, channels, 3, stride=2, padding=padding)
|
||||
self.extra_padding = extra_padding
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
||||
if self.extra_padding:
|
||||
hidden_states = torch.nn.functional.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0)
|
||||
hidden_states = self.conv(hidden_states)
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
|
||||
class FluxVAEDecoder(torch.nn.Module):
|
||||
def __init__(self, use_conv_attention=True):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# UNetMidBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
UpSampler(512),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
UpSampler(512),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
UpSampler(256),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(256, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
])
|
||||
|
||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=sample.device,
|
||||
tile_dtype=sample.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# 1. pre-process
|
||||
hidden_states = sample / self.scaling_factor + self.shift_factor
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
time_emb = None
|
||||
text_emb = None
|
||||
res_stack = None
|
||||
|
||||
# 2. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 3. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FluxVAEEncoder(torch.nn.Module):
|
||||
def __init__(self, use_conv_attention=True):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.3611
|
||||
self.shift_factor = 0.1159
|
||||
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
DownSampler(128, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(128, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
DownSampler(256, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(256, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
DownSampler(512, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
# UNetMidBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6, use_conv_attention=use_conv_attention),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
])
|
||||
|
||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=sample.device,
|
||||
tile_dtype=sample.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# 1. pre-process
|
||||
hidden_states = self.conv_in(sample)
|
||||
time_emb = None
|
||||
text_emb = None
|
||||
res_stack = None
|
||||
|
||||
# 2. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 3. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = hidden_states[:, :16]
|
||||
hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def encode_video(self, sample, batch_size=8):
|
||||
B = sample.shape[0]
|
||||
hidden_states = []
|
||||
|
||||
for i in range(0, sample.shape[2], batch_size):
|
||||
|
||||
j = min(i + batch_size, sample.shape[2])
|
||||
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
|
||||
|
||||
hidden_states_batch = self(sample_batch)
|
||||
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
|
||||
|
||||
hidden_states.append(hidden_states_batch)
|
||||
|
||||
hidden_states = torch.concat(hidden_states, dim=2)
|
||||
return hidden_states
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import torch
|
||||
from diffsynth.models.svd_unet import TemporalTimesteps
|
||||
from .general_modules import TemporalTimesteps
|
||||
|
||||
|
||||
class MultiValueEncoder(torch.nn.Module):
|
||||
def __init__(self, encoders=()):
|
||||
super().__init__()
|
||||
if not isinstance(encoders, list):
|
||||
encoders = [encoders]
|
||||
self.encoders = torch.nn.ModuleList(encoders)
|
||||
|
||||
def __call__(self, values, dtype):
|
||||
@@ -28,12 +30,6 @@ class SingleValueEncoder(torch.nn.Module):
|
||||
self.positional_embedding = torch.nn.Parameter(
|
||||
torch.randn(self.prefer_len, dim_out)
|
||||
)
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
last_linear = self.prefer_value_embedder[-1]
|
||||
torch.nn.init.zeros_(last_linear.weight)
|
||||
torch.nn.init.zeros_(last_linear.bias)
|
||||
|
||||
def forward(self, value, dtype):
|
||||
value = value * 1000
|
||||
|
||||
139
diffsynth/models/general_modules.py
Normal file
139
diffsynth/models/general_modules.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import torch, math
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = False,
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
computation_device = None,
|
||||
align_dtype_to_timestep = False,
|
||||
):
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent).to(timesteps.device)
|
||||
if align_dtype_to_timestep:
|
||||
emb = emb.to(timesteps.dtype)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
emb = scale * emb
|
||||
|
||||
# concat sine and cosine embeddings
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
|
||||
# flip sine and cosine embeddings
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
|
||||
# zero pad
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
class TemporalTimesteps(torch.nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None, scale=1, align_dtype_to_timestep=False):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.computation_device = computation_device
|
||||
self.scale = scale
|
||||
self.align_dtype_to_timestep = align_dtype_to_timestep
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
computation_device=self.computation_device,
|
||||
scale=self.scale,
|
||||
align_dtype_to_timestep=self.align_dtype_to_timestep,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class DiffusersCompatibleTimestepProj(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.linear_1 = torch.nn.Linear(dim_in, dim_out)
|
||||
self.act = torch.nn.SiLU()
|
||||
self.linear_2 = torch.nn.Linear(dim_out, dim_out)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear_1(x)
|
||||
x = self.act(x)
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbeddings(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False):
|
||||
super().__init__()
|
||||
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
|
||||
if diffusers_compatible_format:
|
||||
self.timestep_embedder = DiffusersCompatibleTimestepProj(dim_in, dim_out)
|
||||
else:
|
||||
self.timestep_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
time_emb = self.time_proj(timestep).to(dtype)
|
||||
time_emb = self.timestep_embedder(time_emb)
|
||||
return time_emb
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim, eps, elementwise_affine=True):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = torch.nn.Parameter(torch.ones((dim,)))
|
||||
else:
|
||||
self.weight = None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
if self.weight is not None:
|
||||
hidden_states = hidden_states * self.weight
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AdaLayerNorm(torch.nn.Module):
|
||||
def __init__(self, dim, single=False, dual=False):
|
||||
super().__init__()
|
||||
self.single = single
|
||||
self.dual = dual
|
||||
self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])
|
||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, emb):
|
||||
emb = self.linear(torch.nn.functional.silu(emb))
|
||||
if self.single:
|
||||
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
elif self.dual:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)
|
||||
norm_x = self.norm(x)
|
||||
x = norm_x * (1 + scale_msa) + shift_msa
|
||||
norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
|
||||
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
@@ -1,451 +0,0 @@
|
||||
from .attention import Attention
|
||||
from einops import repeat, rearrange
|
||||
import math
|
||||
import torch
|
||||
|
||||
|
||||
class HunyuanDiTRotaryEmbedding(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=True):
|
||||
super().__init__()
|
||||
self.q_norm = torch.nn.LayerNorm((q_norm_shape,), elementwise_affine=True, eps=1e-06)
|
||||
self.k_norm = torch.nn.LayerNorm((k_norm_shape,), elementwise_affine=True, eps=1e-06)
|
||||
self.rotary_emb_on_k = rotary_emb_on_k
|
||||
self.k_cache, self.v_cache = [], []
|
||||
|
||||
def reshape_for_broadcast(self, freqs_cis, x):
|
||||
ndim = x.ndim
|
||||
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
||||
|
||||
def rotate_half(self, x):
|
||||
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
def apply_rotary_emb(self, xq, xk, freqs_cis):
|
||||
xk_out = None
|
||||
cos, sin = self.reshape_for_broadcast(freqs_cis, xq)
|
||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
||||
xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
|
||||
if xk is not None:
|
||||
xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
|
||||
return xq_out, xk_out
|
||||
|
||||
def forward(self, q, k, v, freqs_cis_img, to_cache=False):
|
||||
# norm
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# RoPE
|
||||
if self.rotary_emb_on_k:
|
||||
q, k = self.apply_rotary_emb(q, k, freqs_cis_img)
|
||||
else:
|
||||
q, _ = self.apply_rotary_emb(q, None, freqs_cis_img)
|
||||
|
||||
if to_cache:
|
||||
self.k_cache.append(k)
|
||||
self.v_cache.append(v)
|
||||
elif len(self.k_cache) > 0 and len(self.v_cache) > 0:
|
||||
k = torch.concat([k] + self.k_cache, dim=2)
|
||||
v = torch.concat([v] + self.v_cache, dim=2)
|
||||
self.k_cache, self.v_cache = [], []
|
||||
return q, k, v
|
||||
|
||||
|
||||
class FP32_Layernorm(torch.nn.LayerNorm):
|
||||
def forward(self, inputs):
|
||||
origin_dtype = inputs.dtype
|
||||
return torch.nn.functional.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).to(origin_dtype)
|
||||
|
||||
|
||||
class FP32_SiLU(torch.nn.SiLU):
|
||||
def forward(self, inputs):
|
||||
origin_dtype = inputs.dtype
|
||||
return torch.nn.functional.silu(inputs.float(), inplace=False).to(origin_dtype)
|
||||
|
||||
|
||||
class HunyuanDiTFinalLayer(torch.nn.Module):
|
||||
def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_size=2, out_channels=8):
|
||||
super().__init__()
|
||||
self.norm_final = torch.nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = torch.nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = torch.nn.Sequential(
|
||||
FP32_SiLU(),
|
||||
torch.nn.Linear(condition_dim, 2 * final_hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def modulate(self, x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
def forward(self, hidden_states, condition_emb):
|
||||
shift, scale = self.adaLN_modulation(condition_emb).chunk(2, dim=1)
|
||||
hidden_states = self.modulate(self.norm_final(hidden_states), shift, scale)
|
||||
hidden_states = self.linear(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanDiTBlock(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim=1408,
|
||||
condition_dim=1408,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.3637,
|
||||
text_dim=1024,
|
||||
skip_connection=False
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
||||
self.rota1 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads)
|
||||
self.attn1 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
|
||||
self.norm2 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
||||
self.rota2 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads, rotary_emb_on_k=False)
|
||||
self.attn2 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, kv_dim=text_dim, bias_q=True, bias_kv=True, bias_out=True)
|
||||
self.norm3 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
||||
self.modulation = torch.nn.Sequential(FP32_SiLU(), torch.nn.Linear(condition_dim, hidden_dim, bias=True))
|
||||
self.mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden_dim, int(hidden_dim*mlp_ratio), bias=True),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(int(hidden_dim*mlp_ratio), hidden_dim, bias=True)
|
||||
)
|
||||
if skip_connection:
|
||||
self.skip_norm = FP32_Layernorm((hidden_dim * 2,), eps=1e-6, elementwise_affine=True)
|
||||
self.skip_linear = torch.nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
|
||||
else:
|
||||
self.skip_norm, self.skip_linear = None, None
|
||||
|
||||
def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img, residual=None, to_cache=False):
|
||||
# Long Skip Connection
|
||||
if self.skip_norm is not None and self.skip_linear is not None:
|
||||
hidden_states = torch.cat([hidden_states, residual], dim=-1)
|
||||
hidden_states = self.skip_norm(hidden_states)
|
||||
hidden_states = self.skip_linear(hidden_states)
|
||||
|
||||
# Self-Attention
|
||||
shift_msa = self.modulation(condition_emb).unsqueeze(dim=1)
|
||||
attn_input = self.norm1(hidden_states) + shift_msa
|
||||
hidden_states = hidden_states + self.attn1(attn_input, qkv_preprocessor=lambda q, k, v: self.rota1(q, k, v, freq_cis_img, to_cache=to_cache))
|
||||
|
||||
# Cross-Attention
|
||||
attn_input = self.norm3(hidden_states)
|
||||
hidden_states = hidden_states + self.attn2(attn_input, text_emb, qkv_preprocessor=lambda q, k, v: self.rota2(q, k, v, freq_cis_img))
|
||||
|
||||
# FFN Layer
|
||||
mlp_input = self.norm2(hidden_states)
|
||||
hidden_states = hidden_states + self.mlp(mlp_input)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttentionPool(torch.nn.Module):
|
||||
def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None):
|
||||
super().__init__()
|
||||
self.positional_embedding = torch.nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
|
||||
self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
|
||||
self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
|
||||
self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
|
||||
self.c_proj = torch.nn.Linear(embed_dim, output_dim or embed_dim)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(1, 0, 2) # NLC -> LNC
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
||||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
|
||||
x, _ = torch.nn.functional.multi_head_attention_forward(
|
||||
query=x[:1], key=x, value=x,
|
||||
embed_dim_to_check=x.shape[-1],
|
||||
num_heads=self.num_heads,
|
||||
q_proj_weight=self.q_proj.weight,
|
||||
k_proj_weight=self.k_proj.weight,
|
||||
v_proj_weight=self.v_proj.weight,
|
||||
in_proj_weight=None,
|
||||
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||
bias_k=None,
|
||||
bias_v=None,
|
||||
add_zero_attn=False,
|
||||
dropout_p=0,
|
||||
out_proj_weight=self.c_proj.weight,
|
||||
out_proj_bias=self.c_proj.bias,
|
||||
use_separate_proj_weight=True,
|
||||
training=self.training,
|
||||
need_weights=False
|
||||
)
|
||||
return x.squeeze(0)
|
||||
|
||||
|
||||
class PatchEmbed(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=(2, 2),
|
||||
in_chans=4,
|
||||
embed_dim=1408,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
return x
|
||||
|
||||
|
||||
def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||
/ half
|
||||
).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
else:
|
||||
embedding = repeat(t, "b -> b d", d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
class TimestepEmbedder(torch.nn.Module):
|
||||
def __init__(self, hidden_size=1408, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class HunyuanDiT(torch.nn.Module):
|
||||
def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4, out_channels=8, hidden_dim=1408, text_dim=1024, t5_dim=2048, text_length=77, t5_length=256):
|
||||
super().__init__()
|
||||
|
||||
# Embedders
|
||||
self.text_emb_padding = torch.nn.Parameter(torch.randn(text_length + t5_length, text_dim, dtype=torch.float32))
|
||||
self.t5_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(t5_dim, t5_dim * 4, bias=True),
|
||||
FP32_SiLU(),
|
||||
torch.nn.Linear(t5_dim * 4, text_dim, bias=True),
|
||||
)
|
||||
self.t5_pooler = AttentionPool(t5_length, t5_dim, num_heads=8, output_dim=1024)
|
||||
self.style_embedder = torch.nn.Parameter(torch.randn(hidden_dim))
|
||||
self.patch_embedder = PatchEmbed(in_chans=in_channels)
|
||||
self.timestep_embedder = TimestepEmbedder()
|
||||
self.extra_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(256 * 6 + 1024 + hidden_dim, hidden_dim * 4),
|
||||
FP32_SiLU(),
|
||||
torch.nn.Linear(hidden_dim * 4, hidden_dim),
|
||||
)
|
||||
|
||||
# Transformer blocks
|
||||
self.num_layers_down = num_layers_down
|
||||
self.num_layers_up = num_layers_up
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[HunyuanDiTBlock(skip_connection=False) for _ in range(num_layers_down)] + \
|
||||
[HunyuanDiTBlock(skip_connection=True) for _ in range(num_layers_up)]
|
||||
)
|
||||
|
||||
# Output layers
|
||||
self.final_layer = HunyuanDiTFinalLayer()
|
||||
self.out_channels = out_channels
|
||||
|
||||
def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5):
|
||||
text_emb_mask = text_emb_mask.bool()
|
||||
text_emb_mask_t5 = text_emb_mask_t5.bool()
|
||||
text_emb_t5 = self.t5_embedder(text_emb_t5)
|
||||
text_emb = torch.cat([text_emb, text_emb_t5], dim=1)
|
||||
text_emb_mask = torch.cat([text_emb_mask, text_emb_mask_t5], dim=-1)
|
||||
text_emb = torch.where(text_emb_mask.unsqueeze(2), text_emb, self.text_emb_padding.to(text_emb))
|
||||
return text_emb
|
||||
|
||||
def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, batch_size):
|
||||
# Text embedding
|
||||
pooled_text_emb_t5 = self.t5_pooler(text_emb_t5)
|
||||
|
||||
# Timestep embedding
|
||||
timestep_emb = self.timestep_embedder(timestep)
|
||||
|
||||
# Size embedding
|
||||
size_emb = timestep_embedding(size_emb.view(-1), 256).to(dtype)
|
||||
size_emb = size_emb.view(-1, 6 * 256)
|
||||
|
||||
# Style embedding
|
||||
style_emb = repeat(self.style_embedder, "D -> B D", B=batch_size)
|
||||
|
||||
# Concatenate all extra vectors
|
||||
extra_emb = torch.cat([pooled_text_emb_t5, size_emb, style_emb], dim=1)
|
||||
condition_emb = timestep_emb + self.extra_embedder(extra_emb)
|
||||
|
||||
return condition_emb
|
||||
|
||||
def unpatchify(self, x, h, w):
|
||||
return rearrange(x, "B (H W) (P Q C) -> B C (H P) (W Q)", H=h, W=w, P=2, Q=2)
|
||||
|
||||
def build_mask(self, data, is_bound):
|
||||
_, _, H, W = data.shape
|
||||
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
|
||||
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
|
||||
border_width = (H + W) // 4
|
||||
pad = torch.ones_like(h) * border_width
|
||||
mask = torch.stack([
|
||||
pad if is_bound[0] else h + 1,
|
||||
pad if is_bound[1] else H - h,
|
||||
pad if is_bound[2] else w + 1,
|
||||
pad if is_bound[3] else W - w
|
||||
]).min(dim=0).values
|
||||
mask = mask.clip(1, border_width)
|
||||
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
|
||||
mask = rearrange(mask, "H W -> 1 H W")
|
||||
return mask
|
||||
|
||||
def tiled_block_forward(self, block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, torch_dtype, data_device, computation_device, tile_size, tile_stride):
|
||||
B, C, H, W = hidden_states.shape
|
||||
|
||||
weight = torch.zeros((1, 1, H, W), dtype=torch_dtype, device=data_device)
|
||||
values = torch.zeros((B, C, H, W), dtype=torch_dtype, device=data_device)
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for h in range(0, H, tile_stride):
|
||||
for w in range(0, W, tile_stride):
|
||||
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
|
||||
continue
|
||||
h_, w_ = h + tile_size, w + tile_size
|
||||
if h_ > H: h, h_ = H - tile_size, H
|
||||
if w_ > W: w, w_ = W - tile_size, W
|
||||
tasks.append((h, h_, w, w_))
|
||||
|
||||
# Run
|
||||
for hl, hr, wl, wr in tasks:
|
||||
hidden_states_batch = hidden_states[:, :, hl:hr, wl:wr].to(computation_device)
|
||||
hidden_states_batch = rearrange(hidden_states_batch, "B C H W -> B (H W) C")
|
||||
if residual is not None:
|
||||
residual_batch = residual[:, :, hl:hr, wl:wr].to(computation_device)
|
||||
residual_batch = rearrange(residual_batch, "B C H W -> B (H W) C")
|
||||
else:
|
||||
residual_batch = None
|
||||
|
||||
# Forward
|
||||
hidden_states_batch = block(hidden_states_batch, condition_emb, text_emb, freq_cis_img, residual_batch).to(data_device)
|
||||
hidden_states_batch = rearrange(hidden_states_batch, "B (H W) C -> B C H W", H=hr-hl)
|
||||
|
||||
mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
|
||||
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
|
||||
weight[:, :, hl:hr, wl:wr] += mask
|
||||
values /= weight
|
||||
return values
|
||||
|
||||
def forward(
|
||||
self, hidden_states, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5, timestep, size_emb, freq_cis_img,
|
||||
tiled=False, tile_size=64, tile_stride=32,
|
||||
to_cache=False,
|
||||
use_gradient_checkpointing=False,
|
||||
):
|
||||
# Embeddings
|
||||
text_emb = self.prepare_text_emb(text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5)
|
||||
condition_emb = self.prepare_extra_emb(text_emb_t5, timestep, size_emb, hidden_states.dtype, hidden_states.shape[0])
|
||||
|
||||
# Input
|
||||
height, width = hidden_states.shape[-2], hidden_states.shape[-1]
|
||||
hidden_states = self.patch_embedder(hidden_states)
|
||||
|
||||
# Blocks
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
if tiled:
|
||||
hidden_states = rearrange(hidden_states, "B (H W) C -> B C H W", H=height//2)
|
||||
residuals = []
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
residual = residuals.pop() if block_id >= self.num_layers_down else None
|
||||
hidden_states = self.tiled_block_forward(
|
||||
block, hidden_states, condition_emb, text_emb, freq_cis_img, residual,
|
||||
torch_dtype=hidden_states.dtype, data_device=hidden_states.device, computation_device=hidden_states.device,
|
||||
tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
if block_id < self.num_layers_down - 2:
|
||||
residuals.append(hidden_states)
|
||||
hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C")
|
||||
else:
|
||||
residuals = []
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
residual = residuals.pop() if block_id >= self.num_layers_down else None
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, condition_emb, text_emb, freq_cis_img, residual,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(hidden_states, condition_emb, text_emb, freq_cis_img, residual, to_cache=to_cache)
|
||||
if block_id < self.num_layers_down - 2:
|
||||
residuals.append(hidden_states)
|
||||
|
||||
# Output
|
||||
hidden_states = self.final_layer(hidden_states, condition_emb)
|
||||
hidden_states = self.unpatchify(hidden_states, height//2, width//2)
|
||||
hidden_states, _ = hidden_states.chunk(2, dim=1)
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanDiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class HunyuanDiTStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
name_ = name
|
||||
name_ = name_.replace(".default_modulation.", ".modulation.")
|
||||
name_ = name_.replace(".mlp.fc1.", ".mlp.0.")
|
||||
name_ = name_.replace(".mlp.fc2.", ".mlp.2.")
|
||||
name_ = name_.replace(".attn1.q_norm.", ".rota1.q_norm.")
|
||||
name_ = name_.replace(".attn2.q_norm.", ".rota2.q_norm.")
|
||||
name_ = name_.replace(".attn1.k_norm.", ".rota1.k_norm.")
|
||||
name_ = name_.replace(".attn2.k_norm.", ".rota2.k_norm.")
|
||||
name_ = name_.replace(".q_proj.", ".to_q.")
|
||||
name_ = name_.replace(".out_proj.", ".to_out.")
|
||||
name_ = name_.replace("text_embedding_padding", "text_emb_padding")
|
||||
name_ = name_.replace("mlp_t5.0.", "t5_embedder.0.")
|
||||
name_ = name_.replace("mlp_t5.2.", "t5_embedder.2.")
|
||||
name_ = name_.replace("pooler.", "t5_pooler.")
|
||||
name_ = name_.replace("x_embedder.", "patch_embedder.")
|
||||
name_ = name_.replace("t_embedder.", "timestep_embedder.")
|
||||
name_ = name_.replace("t5_pooler.to_q.", "t5_pooler.q_proj.")
|
||||
name_ = name_.replace("style_embedder.weight", "style_embedder")
|
||||
if ".kv_proj." in name_:
|
||||
param_k = param[:param.shape[0]//2]
|
||||
param_v = param[param.shape[0]//2:]
|
||||
state_dict_[name_.replace(".kv_proj.", ".to_k.")] = param_k
|
||||
state_dict_[name_.replace(".kv_proj.", ".to_v.")] = param_v
|
||||
elif ".Wqkv." in name_:
|
||||
param_q = param[:param.shape[0]//3]
|
||||
param_k = param[param.shape[0]//3:param.shape[0]//3*2]
|
||||
param_v = param[param.shape[0]//3*2:]
|
||||
state_dict_[name_.replace(".Wqkv.", ".to_q.")] = param_q
|
||||
state_dict_[name_.replace(".Wqkv.", ".to_k.")] = param_k
|
||||
state_dict_[name_.replace(".Wqkv.", ".to_v.")] = param_v
|
||||
elif "style_embedder" in name_:
|
||||
state_dict_[name_] = param.squeeze()
|
||||
else:
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
@@ -1,163 +0,0 @@
|
||||
from transformers import BertModel, BertConfig, T5EncoderModel, T5Config
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class HunyuanDiTCLIPTextEncoder(BertModel):
|
||||
def __init__(self):
|
||||
config = BertConfig(
|
||||
_name_or_path = "",
|
||||
architectures = ["BertModel"],
|
||||
attention_probs_dropout_prob = 0.1,
|
||||
bos_token_id = 0,
|
||||
classifier_dropout = None,
|
||||
directionality = "bidi",
|
||||
eos_token_id = 2,
|
||||
hidden_act = "gelu",
|
||||
hidden_dropout_prob = 0.1,
|
||||
hidden_size = 1024,
|
||||
initializer_range = 0.02,
|
||||
intermediate_size = 4096,
|
||||
layer_norm_eps = 1e-12,
|
||||
max_position_embeddings = 512,
|
||||
model_type = "bert",
|
||||
num_attention_heads = 16,
|
||||
num_hidden_layers = 24,
|
||||
output_past = True,
|
||||
pad_token_id = 0,
|
||||
pooler_fc_size = 768,
|
||||
pooler_num_attention_heads = 12,
|
||||
pooler_num_fc_layers = 3,
|
||||
pooler_size_per_head = 128,
|
||||
pooler_type = "first_token_transform",
|
||||
position_embedding_type = "absolute",
|
||||
torch_dtype = "float32",
|
||||
transformers_version = "4.37.2",
|
||||
type_vocab_size = 2,
|
||||
use_cache = True,
|
||||
vocab_size = 47020
|
||||
)
|
||||
super().__init__(config, add_pooling_layer=False)
|
||||
self.eval()
|
||||
|
||||
def forward(self, input_ids, attention_mask, clip_skip=1):
|
||||
input_shape = input_ids.size()
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device
|
||||
|
||||
past_key_values_length = 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
past_key_values_length=0,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
all_hidden_states = encoder_outputs.hidden_states
|
||||
prompt_emb = all_hidden_states[-clip_skip]
|
||||
if clip_skip > 1:
|
||||
mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std()
|
||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
||||
return prompt_emb
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanDiTCLIPTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class HunyuanDiTT5TextEncoder(T5EncoderModel):
|
||||
def __init__(self):
|
||||
config = T5Config(
|
||||
_name_or_path = "../HunyuanDiT/t2i/mt5",
|
||||
architectures = ["MT5ForConditionalGeneration"],
|
||||
classifier_dropout = 0.0,
|
||||
d_ff = 5120,
|
||||
d_kv = 64,
|
||||
d_model = 2048,
|
||||
decoder_start_token_id = 0,
|
||||
dense_act_fn = "gelu_new",
|
||||
dropout_rate = 0.1,
|
||||
eos_token_id = 1,
|
||||
feed_forward_proj = "gated-gelu",
|
||||
initializer_factor = 1.0,
|
||||
is_encoder_decoder = True,
|
||||
is_gated_act = True,
|
||||
layer_norm_epsilon = 1e-06,
|
||||
model_type = "t5",
|
||||
num_decoder_layers = 24,
|
||||
num_heads = 32,
|
||||
num_layers = 24,
|
||||
output_past = True,
|
||||
pad_token_id = 0,
|
||||
relative_attention_max_distance = 128,
|
||||
relative_attention_num_buckets = 32,
|
||||
tie_word_embeddings = False,
|
||||
tokenizer_class = "T5Tokenizer",
|
||||
transformers_version = "4.37.2",
|
||||
use_cache = True,
|
||||
vocab_size = 250112
|
||||
)
|
||||
super().__init__(config)
|
||||
self.eval()
|
||||
|
||||
def forward(self, input_ids, attention_mask, clip_skip=1):
|
||||
outputs = super().forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
prompt_emb = outputs.hidden_states[-clip_skip]
|
||||
if clip_skip > 1:
|
||||
mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std()
|
||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
||||
return prompt_emb
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanDiTT5TextEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class HunyuanDiTCLIPTextEncoderStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")}
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
|
||||
|
||||
class HunyuanDiTT5TextEncoderStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")}
|
||||
state_dict_["shared.weight"] = state_dict["shared.weight"]
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
@@ -1,920 +0,0 @@
|
||||
import torch
|
||||
from .sd3_dit import TimestepEmbeddings, RMSNorm
|
||||
from .utils import init_weights_on_device
|
||||
from einops import rearrange, repeat
|
||||
from tqdm import tqdm
|
||||
from typing import Union, Tuple, List
|
||||
from .utils import hash_state_dict_keys
|
||||
|
||||
|
||||
def HunyuanVideoRope(latents):
|
||||
def _to_tuple(x, dim=2):
|
||||
if isinstance(x, int):
|
||||
return (x,) * dim
|
||||
elif len(x) == dim:
|
||||
return x
|
||||
else:
|
||||
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
||||
|
||||
|
||||
def get_meshgrid_nd(start, *args, dim=2):
|
||||
"""
|
||||
Get n-D meshgrid with start, stop and num.
|
||||
|
||||
Args:
|
||||
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
|
||||
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
|
||||
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
|
||||
n-tuples.
|
||||
*args: See above.
|
||||
dim (int): Dimension of the meshgrid. Defaults to 2.
|
||||
|
||||
Returns:
|
||||
grid (np.ndarray): [dim, ...]
|
||||
"""
|
||||
if len(args) == 0:
|
||||
# start is grid_size
|
||||
num = _to_tuple(start, dim=dim)
|
||||
start = (0,) * dim
|
||||
stop = num
|
||||
elif len(args) == 1:
|
||||
# start is start, args[0] is stop, step is 1
|
||||
start = _to_tuple(start, dim=dim)
|
||||
stop = _to_tuple(args[0], dim=dim)
|
||||
num = [stop[i] - start[i] for i in range(dim)]
|
||||
elif len(args) == 2:
|
||||
# start is start, args[0] is stop, args[1] is num
|
||||
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
|
||||
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
|
||||
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
|
||||
else:
|
||||
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
||||
|
||||
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
|
||||
axis_grid = []
|
||||
for i in range(dim):
|
||||
a, b, n = start[i], stop[i], num[i]
|
||||
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
||||
axis_grid.append(g)
|
||||
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
|
||||
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int,
|
||||
pos: Union[torch.FloatTensor, int],
|
||||
theta: float = 10000.0,
|
||||
use_real: bool = False,
|
||||
theta_rescale_factor: float = 1.0,
|
||||
interpolation_factor: float = 1.0,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
|
||||
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
|
||||
|
||||
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
|
||||
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
||||
The returned tensor contains complex values in complex64 data type.
|
||||
|
||||
Args:
|
||||
dim (int): Dimension of the frequency tensor.
|
||||
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
|
||||
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (bool, optional): If True, return real part and imaginary part separately.
|
||||
Otherwise, return complex numbers.
|
||||
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
|
||||
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
|
||||
"""
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos).float()
|
||||
|
||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||
# has some connection to NTK literature
|
||||
if theta_rescale_factor != 1.0:
|
||||
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
||||
|
||||
freqs = 1.0 / (
|
||||
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
||||
) # [D/2]
|
||||
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
|
||||
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
|
||||
if use_real:
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
freqs_cis = torch.polar(
|
||||
torch.ones_like(freqs), freqs
|
||||
) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def get_nd_rotary_pos_embed(
|
||||
rope_dim_list,
|
||||
start,
|
||||
*args,
|
||||
theta=10000.0,
|
||||
use_real=False,
|
||||
theta_rescale_factor: Union[float, List[float]] = 1.0,
|
||||
interpolation_factor: Union[float, List[float]] = 1.0,
|
||||
):
|
||||
"""
|
||||
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
|
||||
|
||||
Args:
|
||||
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
|
||||
sum(rope_dim_list) should equal to head_dim of attention layer.
|
||||
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
|
||||
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
||||
*args: See above.
|
||||
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
|
||||
part and an imaginary part separately.
|
||||
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
pos_embed (torch.Tensor): [HW, D/2]
|
||||
"""
|
||||
|
||||
grid = get_meshgrid_nd(
|
||||
start, *args, dim=len(rope_dim_list)
|
||||
) # [3, W, H, D] / [2, W, H]
|
||||
|
||||
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
|
||||
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
|
||||
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
|
||||
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
|
||||
assert len(theta_rescale_factor) == len(
|
||||
rope_dim_list
|
||||
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
|
||||
|
||||
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
|
||||
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
|
||||
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
|
||||
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
|
||||
assert len(interpolation_factor) == len(
|
||||
rope_dim_list
|
||||
), "len(interpolation_factor) should equal to len(rope_dim_list)"
|
||||
|
||||
# use 1/ndim of dimensions to encode grid_axis
|
||||
embs = []
|
||||
for i in range(len(rope_dim_list)):
|
||||
emb = get_1d_rotary_pos_embed(
|
||||
rope_dim_list[i],
|
||||
grid[i].reshape(-1),
|
||||
theta,
|
||||
use_real=use_real,
|
||||
theta_rescale_factor=theta_rescale_factor[i],
|
||||
interpolation_factor=interpolation_factor[i],
|
||||
) # 2 x [WHD, rope_dim_list[i]]
|
||||
embs.append(emb)
|
||||
|
||||
if use_real:
|
||||
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
|
||||
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
|
||||
return cos, sin
|
||||
else:
|
||||
emb = torch.cat(embs, dim=1) # (WHD, D/2)
|
||||
return emb
|
||||
|
||||
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
||||
[16, 56, 56],
|
||||
[latents.shape[2], latents.shape[3] // 2, latents.shape[4] // 2],
|
||||
theta=256,
|
||||
use_real=True,
|
||||
theta_rescale_factor=1,
|
||||
)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
class PatchEmbed(torch.nn.Module):
|
||||
def __init__(self, patch_size=(1, 2, 2), in_channels=16, embed_dim=3072):
|
||||
super().__init__()
|
||||
self.proj = torch.nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class IndividualTokenRefinerBlock(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, num_heads=24):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.self_attn_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
||||
self.self_attn_proj = torch.nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden_size, hidden_size * 4),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size * 4, hidden_size)
|
||||
)
|
||||
self.adaLN_modulation = torch.nn.Sequential(
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size * 2, device="cuda", dtype=torch.bfloat16),
|
||||
)
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
|
||||
norm_x = self.norm1(x)
|
||||
qkv = self.self_attn_qkv(norm_x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
|
||||
attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
attn = rearrange(attn, "B H L D -> B L (H D)")
|
||||
|
||||
x = x + self.self_attn_proj(attn) * gate_msa.unsqueeze(1)
|
||||
x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SingleTokenRefiner(torch.nn.Module):
|
||||
def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
|
||||
super().__init__()
|
||||
self.input_embedder = torch.nn.Linear(in_channels, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
||||
self.c_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(in_channels, hidden_size),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size)
|
||||
)
|
||||
self.blocks = torch.nn.ModuleList([IndividualTokenRefinerBlock(hidden_size=hidden_size) for _ in range(depth)])
|
||||
|
||||
def forward(self, x, t, mask=None):
|
||||
timestep_aware_representations = self.t_embedder(t, dtype=torch.float32)
|
||||
|
||||
mask_float = mask.float().unsqueeze(-1)
|
||||
context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
||||
context_aware_representations = self.c_embedder(context_aware_representations)
|
||||
c = timestep_aware_representations + context_aware_representations
|
||||
|
||||
x = self.input_embedder(x)
|
||||
|
||||
mask = mask.to(device=x.device, dtype=torch.bool)
|
||||
mask = repeat(mask, "B L -> B 1 D L", D=mask.shape[-1])
|
||||
mask = mask & mask.transpose(2, 3)
|
||||
mask[:, :, :, 0] = True
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, c, mask)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ModulateDiT(torch.nn.Module):
|
||||
def __init__(self, hidden_size, factor=6):
|
||||
super().__init__()
|
||||
self.act = torch.nn.SiLU()
|
||||
self.linear = torch.nn.Linear(hidden_size, factor * hidden_size)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(self.act(x))
|
||||
|
||||
|
||||
def modulate(x, shift=None, scale=None, tr_shift=None, tr_scale=None, tr_token=None):
|
||||
if tr_shift is not None:
|
||||
x_zero = x[:, :tr_token] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1)
|
||||
x_orig = x[:, tr_token:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
x = torch.concat((x_zero, x_orig), dim=1)
|
||||
return x
|
||||
if scale is None and shift is None:
|
||||
return x
|
||||
elif shift is None:
|
||||
return x * (1 + scale.unsqueeze(1))
|
||||
elif scale is None:
|
||||
return x + shift.unsqueeze(1)
|
||||
else:
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
def reshape_for_broadcast(
|
||||
freqs_cis,
|
||||
x: torch.Tensor,
|
||||
head_first=False,
|
||||
):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
|
||||
if isinstance(freqs_cis, tuple):
|
||||
# freqs_cis: (cos, sin) in real space
|
||||
if head_first:
|
||||
assert freqs_cis[0].shape == (
|
||||
x.shape[-2],
|
||||
x.shape[-1],
|
||||
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
||||
shape = [
|
||||
d if i == ndim - 2 or i == ndim - 1 else 1
|
||||
for i, d in enumerate(x.shape)
|
||||
]
|
||||
else:
|
||||
assert freqs_cis[0].shape == (
|
||||
x.shape[1],
|
||||
x.shape[-1],
|
||||
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
||||
else:
|
||||
# freqs_cis: values in complex space
|
||||
if head_first:
|
||||
assert freqs_cis.shape == (
|
||||
x.shape[-2],
|
||||
x.shape[-1],
|
||||
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
||||
shape = [
|
||||
d if i == ndim - 2 or i == ndim - 1 else 1
|
||||
for i, d in enumerate(x.shape)
|
||||
]
|
||||
else:
|
||||
assert freqs_cis.shape == (
|
||||
x.shape[1],
|
||||
x.shape[-1],
|
||||
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x_real, x_imag = (
|
||||
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||
) # [B, S, H, D//2]
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis,
|
||||
head_first: bool = False,
|
||||
):
|
||||
xk_out = None
|
||||
if isinstance(freqs_cis, tuple):
|
||||
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
||||
# real * cos - imag * sin
|
||||
# imag * cos + real * sin
|
||||
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
||||
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
||||
else:
|
||||
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
|
||||
xq_ = torch.view_as_complex(
|
||||
xq.float().reshape(*xq.shape[:-1], -1, 2)
|
||||
) # [B, S, H, D//2]
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
|
||||
xq.device
|
||||
) # [S, D//2] --> [1, S, 1, D//2]
|
||||
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
|
||||
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
||||
xk_ = torch.view_as_complex(
|
||||
xk.float().reshape(*xk.shape[:-1], -1, 2)
|
||||
) # [B, S, H, D//2]
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
||||
|
||||
return xq_out, xk_out
|
||||
|
||||
|
||||
def attention(q, k, v):
|
||||
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = x.transpose(1, 2).flatten(2, 3)
|
||||
return x
|
||||
|
||||
|
||||
def apply_gate(x, gate, tr_gate=None, tr_token=None):
|
||||
if tr_gate is not None:
|
||||
x_zero = x[:, :tr_token] * tr_gate.unsqueeze(1)
|
||||
x_orig = x[:, tr_token:] * gate.unsqueeze(1)
|
||||
return torch.concat((x_zero, x_orig), dim=1)
|
||||
else:
|
||||
return x * gate.unsqueeze(1)
|
||||
|
||||
|
||||
class MMDoubleStreamBlockComponent(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
super().__init__()
|
||||
self.heads_num = heads_num
|
||||
|
||||
self.mod = ModulateDiT(hidden_size)
|
||||
self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
||||
self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
self.to_out = torch.nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None, token_replace_vec=None, tr_token=None):
|
||||
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
|
||||
if token_replace_vec is not None:
|
||||
assert tr_token is not None
|
||||
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = self.mod(token_replace_vec).chunk(6, dim=-1)
|
||||
else:
|
||||
tr_mod1_shift, tr_mod1_scale, tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None, None, None
|
||||
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale,
|
||||
tr_shift=tr_mod1_shift, tr_scale=tr_mod1_scale, tr_token=tr_token)
|
||||
qkv = self.to_qkv(norm_hidden_states)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
|
||||
q = self.norm_q(q)
|
||||
k = self.norm_k(k)
|
||||
|
||||
if freqs_cis is not None:
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
|
||||
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate), (tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate)
|
||||
|
||||
def process_ff(self, hidden_states, attn_output, mod, mod_tr=None, tr_token=None):
|
||||
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
|
||||
if mod_tr is not None:
|
||||
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = mod_tr
|
||||
else:
|
||||
tr_mod1_gate, tr_mod2_shift, tr_mod2_scale, tr_mod2_gate = None, None, None, None
|
||||
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod1_gate, tr_mod1_gate, tr_token)
|
||||
x = self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale, tr_shift=tr_mod2_shift, tr_scale=tr_mod2_scale, tr_token=tr_token))
|
||||
hidden_states = hidden_states + apply_gate(x, mod2_gate, tr_mod2_gate, tr_token)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MMDoubleStreamBlock(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
super().__init__()
|
||||
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis, token_replace_vec=None, tr_token=None, split_token=71):
|
||||
(q_a, k_a, v_a), mod_a, mod_tr = self.component_a(hidden_states_a, conditioning, freqs_cis, token_replace_vec, tr_token)
|
||||
(q_b, k_b, v_b), mod_b, _ = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
|
||||
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
|
||||
v_a, v_b = torch.concat([v_a, v_b[:, :split_token]], dim=1), v_b[:, split_token:].contiguous()
|
||||
attn_output_a = attention(q_a, k_a, v_a)
|
||||
attn_output_b = attention(q_b, k_b, v_b)
|
||||
attn_output_a, attn_output_b = attn_output_a[:, :-split_token].contiguous(), torch.concat([attn_output_a[:, -split_token:], attn_output_b], dim=1)
|
||||
|
||||
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a, mod_tr, tr_token)
|
||||
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
class MMSingleStreamBlockOriginal(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.heads_num = heads_num
|
||||
self.mlp_hidden_dim = hidden_size * mlp_width_ratio
|
||||
|
||||
self.linear1 = torch.nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
self.linear2 = torch.nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.q_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
self.k_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
|
||||
self.pre_norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = torch.nn.GELU(approximate="tanh")
|
||||
self.modulation = ModulateDiT(hidden_size, factor=3)
|
||||
|
||||
def forward(self, x, vec, freqs_cis=None, txt_len=256):
|
||||
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
|
||||
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
||||
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
||||
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
||||
q = torch.cat((q_a, q_b), dim=1)
|
||||
k = torch.cat((k_a, k_b), dim=1)
|
||||
|
||||
attn_output_a = attention(q[:, :-185].contiguous(), k[:, :-185].contiguous(), v[:, :-185].contiguous())
|
||||
attn_output_b = attention(q[:, -185:].contiguous(), k[:, -185:].contiguous(), v[:, -185:].contiguous())
|
||||
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
||||
|
||||
output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2))
|
||||
return x + output * mod_gate.unsqueeze(1)
|
||||
|
||||
|
||||
class MMSingleStreamBlock(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
super().__init__()
|
||||
self.heads_num = heads_num
|
||||
|
||||
self.mod = ModulateDiT(hidden_size, factor=3)
|
||||
self.norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
||||
self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
self.to_out = torch.nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
self.ff = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256, token_replace_vec=None, tr_token=None, split_token=71):
|
||||
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
|
||||
if token_replace_vec is not None:
|
||||
assert tr_token is not None
|
||||
tr_mod_shift, tr_mod_scale, tr_mod_gate = self.mod(token_replace_vec).chunk(3, dim=-1)
|
||||
else:
|
||||
tr_mod_shift, tr_mod_scale, tr_mod_gate = None, None, None
|
||||
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale,
|
||||
tr_shift=tr_mod_shift, tr_scale=tr_mod_scale, tr_token=tr_token)
|
||||
qkv = self.to_qkv(norm_hidden_states)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
|
||||
q = self.norm_q(q)
|
||||
k = self.norm_k(k)
|
||||
|
||||
q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
||||
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
||||
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
||||
|
||||
v_len = txt_len - split_token
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :split_token]], dim=1), q_b[:, split_token:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :split_token]], dim=1), k_b[:, split_token:].contiguous()
|
||||
v_a, v_b = v[:, :-v_len].contiguous(), v[:, -v_len:].contiguous()
|
||||
|
||||
attn_output_a = attention(q_a, k_a, v_a)
|
||||
attn_output_b = attention(q_b, k_b, v_b)
|
||||
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
||||
|
||||
hidden_states = hidden_states + apply_gate(self.to_out(attn_output), mod_gate, tr_mod_gate, tr_token)
|
||||
hidden_states = hidden_states + apply_gate(self.ff(norm_hidden_states), mod_gate, tr_mod_gate, tr_token)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FinalLayer(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, patch_size=(1, 2, 2), out_channels=16):
|
||||
super().__init__()
|
||||
|
||||
self.norm_final = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = torch.nn.Linear(hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels)
|
||||
|
||||
self.adaLN_modulation = torch.nn.Sequential(torch.nn.SiLU(), torch.nn.Linear(hidden_size, 2 * hidden_size))
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class HunyuanVideoDiT(torch.nn.Module):
|
||||
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40, guidance_embed=True):
|
||||
super().__init__()
|
||||
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
|
||||
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
|
||||
self.time_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
||||
self.vector_in = torch.nn.Sequential(
|
||||
torch.nn.Linear(768, hidden_size),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size)
|
||||
)
|
||||
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") if guidance_embed else None
|
||||
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
|
||||
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
|
||||
self.final_layer = FinalLayer(hidden_size)
|
||||
|
||||
# TODO: remove these parameters
|
||||
self.dtype = torch.bfloat16
|
||||
self.patch_size = [1, 2, 2]
|
||||
self.hidden_size = 3072
|
||||
self.heads_num = 24
|
||||
self.rope_dim_list = [16, 56, 56]
|
||||
|
||||
def unpatchify(self, x, T, H, W):
|
||||
x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2)
|
||||
return x
|
||||
|
||||
def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"):
|
||||
self.warm_device = warm_device
|
||||
self.cold_device = cold_device
|
||||
self.to(self.cold_device)
|
||||
|
||||
def load_models_to_device(self, loadmodel_names=[], device="cpu"):
|
||||
for model_name in loadmodel_names:
|
||||
model = getattr(self, model_name)
|
||||
if model is not None:
|
||||
model.to(device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def prepare_freqs(self, latents):
|
||||
return HunyuanVideoRope(latents)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
prompt_emb: torch.Tensor = None,
|
||||
text_mask: torch.Tensor = None,
|
||||
pooled_prompt_emb: torch.Tensor = None,
|
||||
freqs_cos: torch.Tensor = None,
|
||||
freqs_sin: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
**kwargs
|
||||
):
|
||||
B, C, T, H, W = x.shape
|
||||
|
||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb)
|
||||
if self.guidance_in is not None:
|
||||
vec += self.guidance_in(guidance * 1000, dtype=torch.float32)
|
||||
img = self.img_in(x)
|
||||
txt = self.txt_in(prompt_emb, t, text_mask)
|
||||
|
||||
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
||||
|
||||
x = torch.concat([img, txt], dim=1)
|
||||
for block in tqdm(self.single_blocks, desc="Single stream blocks"):
|
||||
x = block(x, vec, (freqs_cos, freqs_sin))
|
||||
|
||||
img = x[:, :-256]
|
||||
img = self.final_layer(img, vec)
|
||||
img = self.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
||||
return img
|
||||
|
||||
|
||||
def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
|
||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight)
|
||||
return r
|
||||
|
||||
def cast_weight(s, input=None, dtype=None, device=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
return weight
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
bias = cast_to(s.bias, bias_dtype, device) if s.bias is not None else None
|
||||
return weight, bias
|
||||
|
||||
class quantized_layer:
|
||||
class Linear(torch.nn.Linear):
|
||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def block_forward_(self, x, i, j, dtype, device):
|
||||
weight_ = cast_to(
|
||||
self.weight[j * self.block_size: (j + 1) * self.block_size, i * self.block_size: (i + 1) * self.block_size],
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
if self.bias is None or i > 0:
|
||||
bias_ = None
|
||||
else:
|
||||
bias_ = cast_to(self.bias[j * self.block_size: (j + 1) * self.block_size], dtype=dtype, device=device)
|
||||
x_ = x[..., i * self.block_size: (i + 1) * self.block_size]
|
||||
y_ = torch.nn.functional.linear(x_, weight_, bias_)
|
||||
del x_, weight_, bias_
|
||||
torch.cuda.empty_cache()
|
||||
return y_
|
||||
|
||||
def block_forward(self, x, **kwargs):
|
||||
# This feature can only reduce 2GB VRAM, so we disable it.
|
||||
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
|
||||
for i in range((self.in_features + self.block_size - 1) // self.block_size):
|
||||
for j in range((self.out_features + self.block_size - 1) // self.block_size):
|
||||
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
|
||||
return y
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def forward(self, hidden_states, **kwargs):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
if self.module.weight is not None:
|
||||
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
|
||||
hidden_states = hidden_states * weight
|
||||
return hidden_states
|
||||
|
||||
class Conv3d(torch.nn.Conv3d):
|
||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def forward(self, x):
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def forward(self, x):
|
||||
if self.weight is not None and self.bias is not None:
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
|
||||
else:
|
||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
|
||||
def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
with init_weights_on_device():
|
||||
new_layer = quantized_layer.Linear(
|
||||
module.in_features, module.out_features, bias=module.bias is not None,
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
new_layer.load_state_dict(module.state_dict(), assign=True)
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, torch.nn.Conv3d):
|
||||
with init_weights_on_device():
|
||||
new_layer = quantized_layer.Conv3d(
|
||||
module.in_channels, module.out_channels, kernel_size=module.kernel_size, stride=module.stride,
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
new_layer.load_state_dict(module.state_dict(), assign=True)
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, RMSNorm):
|
||||
new_layer = quantized_layer.RMSNorm(
|
||||
module,
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, torch.nn.LayerNorm):
|
||||
with init_weights_on_device():
|
||||
new_layer = quantized_layer.LayerNorm(
|
||||
module.normalized_shape, elementwise_affine=module.elementwise_affine, eps=module.eps,
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
new_layer.load_state_dict(module.state_dict(), assign=True)
|
||||
setattr(model, name, new_layer)
|
||||
else:
|
||||
replace_layer(module, dtype=dtype, device=device)
|
||||
|
||||
replace_layer(self, dtype=dtype, device=device)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanVideoDiTStateDictConverter()
|
||||
|
||||
|
||||
class HunyuanVideoDiTStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
origin_hash_key = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if "module" in state_dict:
|
||||
state_dict = state_dict["module"]
|
||||
direct_dict = {
|
||||
"img_in.proj": "img_in.proj",
|
||||
"time_in.mlp.0": "time_in.timestep_embedder.0",
|
||||
"time_in.mlp.2": "time_in.timestep_embedder.2",
|
||||
"vector_in.in_layer": "vector_in.0",
|
||||
"vector_in.out_layer": "vector_in.2",
|
||||
"guidance_in.mlp.0": "guidance_in.timestep_embedder.0",
|
||||
"guidance_in.mlp.2": "guidance_in.timestep_embedder.2",
|
||||
"txt_in.input_embedder": "txt_in.input_embedder",
|
||||
"txt_in.t_embedder.mlp.0": "txt_in.t_embedder.timestep_embedder.0",
|
||||
"txt_in.t_embedder.mlp.2": "txt_in.t_embedder.timestep_embedder.2",
|
||||
"txt_in.c_embedder.linear_1": "txt_in.c_embedder.0",
|
||||
"txt_in.c_embedder.linear_2": "txt_in.c_embedder.2",
|
||||
"final_layer.linear": "final_layer.linear",
|
||||
"final_layer.adaLN_modulation.1": "final_layer.adaLN_modulation.1",
|
||||
}
|
||||
txt_suffix_dict = {
|
||||
"norm1": "norm1",
|
||||
"self_attn_qkv": "self_attn_qkv",
|
||||
"self_attn_proj": "self_attn_proj",
|
||||
"norm2": "norm2",
|
||||
"mlp.fc1": "mlp.0",
|
||||
"mlp.fc2": "mlp.2",
|
||||
"adaLN_modulation.1": "adaLN_modulation.1",
|
||||
}
|
||||
double_suffix_dict = {
|
||||
"img_mod.linear": "component_a.mod.linear",
|
||||
"img_attn_qkv": "component_a.to_qkv",
|
||||
"img_attn_q_norm": "component_a.norm_q",
|
||||
"img_attn_k_norm": "component_a.norm_k",
|
||||
"img_attn_proj": "component_a.to_out",
|
||||
"img_mlp.fc1": "component_a.ff.0",
|
||||
"img_mlp.fc2": "component_a.ff.2",
|
||||
"txt_mod.linear": "component_b.mod.linear",
|
||||
"txt_attn_qkv": "component_b.to_qkv",
|
||||
"txt_attn_q_norm": "component_b.norm_q",
|
||||
"txt_attn_k_norm": "component_b.norm_k",
|
||||
"txt_attn_proj": "component_b.to_out",
|
||||
"txt_mlp.fc1": "component_b.ff.0",
|
||||
"txt_mlp.fc2": "component_b.ff.2",
|
||||
}
|
||||
single_suffix_dict = {
|
||||
"linear1": ["to_qkv", "ff.0"],
|
||||
"linear2": ["to_out", "ff.2"],
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
"modulation.linear": "mod.linear",
|
||||
}
|
||||
# single_suffix_dict = {
|
||||
# "linear1": "linear1",
|
||||
# "linear2": "linear2",
|
||||
# "q_norm": "q_norm",
|
||||
# "k_norm": "k_norm",
|
||||
# "modulation.linear": "modulation.linear",
|
||||
# }
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
names = name.split(".")
|
||||
direct_name = ".".join(names[:-1])
|
||||
if direct_name in direct_dict:
|
||||
name_ = direct_dict[direct_name] + "." + names[-1]
|
||||
state_dict_[name_] = param
|
||||
elif names[0] == "double_blocks":
|
||||
prefix = ".".join(names[:2])
|
||||
suffix = ".".join(names[2:-1])
|
||||
name_ = prefix + "." + double_suffix_dict[suffix] + "." + names[-1]
|
||||
state_dict_[name_] = param
|
||||
elif names[0] == "single_blocks":
|
||||
prefix = ".".join(names[:2])
|
||||
suffix = ".".join(names[2:-1])
|
||||
if isinstance(single_suffix_dict[suffix], list):
|
||||
if suffix == "linear1":
|
||||
name_a, name_b = single_suffix_dict[suffix]
|
||||
param_a, param_b = torch.split(param, (3072*3, 3072*4), dim=0)
|
||||
state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
|
||||
state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
|
||||
elif suffix == "linear2":
|
||||
if names[-1] == "weight":
|
||||
name_a, name_b = single_suffix_dict[suffix]
|
||||
param_a, param_b = torch.split(param, (3072*1, 3072*4), dim=-1)
|
||||
state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
|
||||
state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
|
||||
else:
|
||||
name_a, name_b = single_suffix_dict[suffix]
|
||||
state_dict_[prefix + "." + name_a + "." + names[-1]] = param
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
name_ = prefix + "." + single_suffix_dict[suffix] + "." + names[-1]
|
||||
state_dict_[name_] = param
|
||||
elif names[0] == "txt_in":
|
||||
prefix = ".".join(names[:4]).replace(".individual_token_refiner.", ".")
|
||||
suffix = ".".join(names[4:-1])
|
||||
name_ = prefix + "." + txt_suffix_dict[suffix] + "." + names[-1]
|
||||
state_dict_[name_] = param
|
||||
else:
|
||||
pass
|
||||
|
||||
return state_dict_
|
||||
@@ -1,68 +0,0 @@
|
||||
from transformers import LlamaModel, LlamaConfig, DynamicCache, LlavaForConditionalGeneration
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
|
||||
|
||||
class HunyuanVideoLLMEncoder(LlamaModel):
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.auto_offload = False
|
||||
|
||||
def enable_auto_offload(self, **kwargs):
|
||||
self.auto_offload = True
|
||||
|
||||
def forward(self, input_ids, attention_mask, hidden_state_skip_layer=2):
|
||||
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
|
||||
inputs_embeds = embed_tokens(input_ids)
|
||||
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, None, False)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
rotary_emb = deepcopy(self.rotary_emb).to(input_ids.device) if self.auto_offload else self.rotary_emb
|
||||
position_embeddings = rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
for layer_id, decoder_layer in enumerate(self.layers):
|
||||
if self.auto_offload:
|
||||
decoder_layer = deepcopy(decoder_layer).to(hidden_states.device)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
if layer_id + hidden_state_skip_layer + 1 >= len(self.layers):
|
||||
break
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoMLLMEncoder(LlavaForConditionalGeneration):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.auto_offload = False
|
||||
|
||||
def enable_auto_offload(self, **kwargs):
|
||||
self.auto_offload = True
|
||||
|
||||
# TODO: implement the low VRAM inference for MLLM.
|
||||
def forward(self, input_ids, pixel_values, attention_mask, hidden_state_skip_layer=2):
|
||||
outputs = super().forward(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
pixel_values=pixel_values)
|
||||
hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
|
||||
return hidden_state
|
||||
@@ -1,507 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from einops import repeat
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
|
||||
def __init__(self, in_channel, out_channel, kernel_size, stride=1, dilation=1, pad_mode='replicate', **kwargs):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
self.time_causal_padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0
|
||||
) # W, H, T
|
||||
self.conv = nn.Conv3d(in_channel, out_channel, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class UpsampleCausal3D(nn.Module):
|
||||
|
||||
def __init__(self, channels, use_conv=False, out_channels=None, kernel_size=None, bias=True, upsample_factor=(2, 2, 2)):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.upsample_factor = upsample_factor
|
||||
self.conv = None
|
||||
if use_conv:
|
||||
kernel_size = 3 if kernel_size is None else kernel_size
|
||||
self.conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
||||
dtype = hidden_states.dtype
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
||||
if hidden_states.shape[0] >= 64:
|
||||
hidden_states = hidden_states.contiguous()
|
||||
|
||||
# interpolate
|
||||
B, C, T, H, W = hidden_states.shape
|
||||
first_h, other_h = hidden_states.split((1, T - 1), dim=2)
|
||||
if T > 1:
|
||||
other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
|
||||
first_h = F.interpolate(first_h.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest").unsqueeze(2)
|
||||
hidden_states = torch.cat((first_h, other_h), dim=2) if T > 1 else first_h
|
||||
|
||||
# If the input is bfloat16, we cast back to bfloat16
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
if self.conv:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ResnetBlockCausal3D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels=None, dropout=0.0, groups=32, eps=1e-6, conv_shortcut_bias=True):
|
||||
super().__init__()
|
||||
self.pre_norm = True
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
|
||||
|
||||
self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
||||
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, stride=1)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
self.conv_shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, bias=conv_shortcut_bias)
|
||||
|
||||
def forward(self, input_tensor):
|
||||
hidden_states = input_tensor
|
||||
# conv1
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
# conv2
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
# shortcut
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = (self.conv_shortcut(input_tensor))
|
||||
# shortcut and scale
|
||||
output_tensor = input_tensor + hidden_states
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def prepare_causal_attention_mask(n_frame, n_hw, dtype, device, batch_size=None):
|
||||
seq_len = n_frame * n_hw
|
||||
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
|
||||
for i in range(seq_len):
|
||||
i_frame = i // n_hw
|
||||
mask[i, :(i_frame + 1) * n_hw] = 0
|
||||
if batch_size is not None:
|
||||
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
return mask
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_groups=32,
|
||||
dropout=0.0,
|
||||
eps=1e-6,
|
||||
bias=True,
|
||||
residual_connection=True):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.residual_connection = residual_connection
|
||||
dim_inner = head_dim * num_heads
|
||||
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
self.to_q = nn.Linear(in_channels, dim_inner, bias=bias)
|
||||
self.to_k = nn.Linear(in_channels, dim_inner, bias=bias)
|
||||
self.to_v = nn.Linear(in_channels, dim_inner, bias=bias)
|
||||
self.to_out = nn.Sequential(nn.Linear(dim_inner, in_channels, bias=bias), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, input_tensor, attn_mask=None):
|
||||
hidden_states = self.group_norm(input_tensor.transpose(1, 2)).transpose(1, 2)
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(hidden_states)
|
||||
v = self.to_v(hidden_states)
|
||||
|
||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.view(batch_size, self.num_heads, -1, attn_mask.shape[-1])
|
||||
hidden_states = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
if self.residual_connection:
|
||||
output_tensor = input_tensor + hidden_states
|
||||
return output_tensor
|
||||
|
||||
|
||||
class UNetMidBlockCausal3D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, dropout=0.0, num_layers=1, eps=1e-6, num_groups=32, attention_head_dim=None):
|
||||
super().__init__()
|
||||
resnets = [
|
||||
ResnetBlockCausal3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
groups=num_groups,
|
||||
eps=eps,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
attention_head_dim = attention_head_dim or in_channels
|
||||
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
Attention(
|
||||
in_channels,
|
||||
num_heads=in_channels // attention_head_dim,
|
||||
head_dim=attention_head_dim,
|
||||
num_groups=num_groups,
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
bias=True,
|
||||
residual_connection=True,
|
||||
))
|
||||
|
||||
resnets.append(
|
||||
ResnetBlockCausal3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
groups=num_groups,
|
||||
eps=eps,
|
||||
))
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.resnets[0](hidden_states)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
B, C, T, H, W = hidden_states.shape
|
||||
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
|
||||
attn_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B)
|
||||
hidden_states = attn(hidden_states, attn_mask=attn_mask)
|
||||
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpDecoderBlockCausal3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
eps=1e-6,
|
||||
num_groups=32,
|
||||
add_upsample=True,
|
||||
upsample_scale_factor=(2, 2, 2),
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
for i in range(num_layers):
|
||||
cur_in_channel = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlockCausal3D(
|
||||
in_channels=cur_in_channel,
|
||||
out_channels=out_channels,
|
||||
groups=num_groups,
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
))
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.upsamplers = None
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList([
|
||||
UpsampleCausal3D(
|
||||
out_channels,
|
||||
use_conv=True,
|
||||
out_channels=out_channels,
|
||||
upsample_factor=upsample_scale_factor,
|
||||
)
|
||||
])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DecoderCausal3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=16,
|
||||
out_channels=3,
|
||||
eps=1e-6,
|
||||
dropout=0.0,
|
||||
block_out_channels=[128, 256, 512, 512],
|
||||
layers_per_block=2,
|
||||
num_groups=32,
|
||||
time_compression_ratio=4,
|
||||
spatial_compression_ratio=8,
|
||||
gradient_checkpointing=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlockCausal3D(
|
||||
in_channels=block_out_channels[-1],
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
num_groups=num_groups,
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
|
||||
num_time_upsample_layers = int(np.log2(time_compression_ratio))
|
||||
|
||||
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
|
||||
add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
|
||||
|
||||
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
|
||||
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
|
||||
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
|
||||
|
||||
up_block = UpDecoderBlockCausal3D(
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
dropout=dropout,
|
||||
num_layers=layers_per_block + 1,
|
||||
eps=eps,
|
||||
num_groups=num_groups,
|
||||
add_upsample=bool(add_spatial_upsample or add_time_upsample),
|
||||
upsample_scale_factor=upsample_scale_factor,
|
||||
)
|
||||
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups, eps=eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
|
||||
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
# middle
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block),
|
||||
hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
# middle
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = up_block(hidden_states)
|
||||
# post-process
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoVAEDecoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=16,
|
||||
out_channels=3,
|
||||
eps=1e-6,
|
||||
dropout=0.0,
|
||||
block_out_channels=[128, 256, 512, 512],
|
||||
layers_per_block=2,
|
||||
num_groups=32,
|
||||
time_compression_ratio=4,
|
||||
spatial_compression_ratio=8,
|
||||
gradient_checkpointing=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder = DecoderCausal3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
eps=eps,
|
||||
dropout=dropout,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
num_groups=num_groups,
|
||||
time_compression_ratio=time_compression_ratio,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
self.post_quant_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
self.scaling_factor = 0.476986
|
||||
|
||||
|
||||
def forward(self, latents):
|
||||
latents = latents / self.scaling_factor
|
||||
latents = self.post_quant_conv(latents)
|
||||
dec = self.decoder(latents)
|
||||
return dec
|
||||
|
||||
|
||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||
x = torch.ones((length,))
|
||||
if not left_bound:
|
||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
||||
if not right_bound:
|
||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
||||
return x
|
||||
|
||||
|
||||
def build_mask(self, data, is_bound, border_width):
|
||||
_, _, T, H, W = data.shape
|
||||
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
||||
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
|
||||
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
|
||||
|
||||
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
|
||||
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
|
||||
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
|
||||
|
||||
mask = torch.stack([t, h, w]).min(dim=0).values
|
||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
||||
return mask
|
||||
|
||||
|
||||
def tile_forward(self, hidden_states, tile_size, tile_stride):
|
||||
B, C, T, H, W = hidden_states.shape
|
||||
size_t, size_h, size_w = tile_size
|
||||
stride_t, stride_h, stride_w = tile_stride
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for t in range(0, T, stride_t):
|
||||
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
|
||||
for h in range(0, H, stride_h):
|
||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||
for w in range(0, W, stride_w):
|
||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||
t_, h_, w_ = t + size_t, h + size_h, w + size_w
|
||||
tasks.append((t, t_, h, h_, w, w_))
|
||||
|
||||
# Run
|
||||
torch_dtype = self.post_quant_conv.weight.dtype
|
||||
data_device = hidden_states.device
|
||||
computation_device = self.post_quant_conv.weight.device
|
||||
|
||||
weight = torch.zeros((1, 1, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
||||
values = torch.zeros((B, 3, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
||||
|
||||
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
|
||||
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
||||
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
||||
if t > 0:
|
||||
hidden_states_batch = hidden_states_batch[:, :, 1:]
|
||||
|
||||
mask = self.build_mask(
|
||||
hidden_states_batch,
|
||||
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
|
||||
border_width=((size_t - stride_t) * 4, (size_h - stride_h) * 8, (size_w - stride_w) * 8)
|
||||
).to(dtype=torch_dtype, device=data_device)
|
||||
|
||||
target_t = 0 if t==0 else t * 4 + 1
|
||||
target_h = h * 8
|
||||
target_w = w * 8
|
||||
values[
|
||||
:,
|
||||
:,
|
||||
target_t: target_t + hidden_states_batch.shape[2],
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += hidden_states_batch * mask
|
||||
weight[
|
||||
:,
|
||||
:,
|
||||
target_t: target_t + hidden_states_batch.shape[2],
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
return values / weight
|
||||
|
||||
|
||||
def decode_video(self, latents, tile_size=(17, 32, 32), tile_stride=(12, 24, 24)):
|
||||
latents = latents.to(self.post_quant_conv.weight.dtype)
|
||||
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanVideoVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
class HunyuanVideoVAEDecoderStateDictConverter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith('decoder.') or name.startswith('post_quant_conv.'):
|
||||
state_dict_[name] = state_dict[name]
|
||||
return state_dict_
|
||||
@@ -1,307 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from .hunyuan_video_vae_decoder import CausalConv3d, ResnetBlockCausal3D, UNetMidBlockCausal3D
|
||||
|
||||
|
||||
class DownsampleCausal3D(nn.Module):
|
||||
|
||||
def __init__(self, channels, out_channels, kernel_size=3, bias=True, stride=2):
|
||||
super().__init__()
|
||||
self.conv = CausalConv3d(channels, out_channels, kernel_size, stride=stride, bias=bias)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DownEncoderBlockCausal3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
eps=1e-6,
|
||||
num_groups=32,
|
||||
add_downsample=True,
|
||||
downsample_stride=2,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
resnets = []
|
||||
for i in range(num_layers):
|
||||
cur_in_channel = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlockCausal3D(
|
||||
in_channels=cur_in_channel,
|
||||
out_channels=out_channels,
|
||||
groups=num_groups,
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
))
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.downsamplers = None
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList([DownsampleCausal3D(
|
||||
out_channels,
|
||||
out_channels,
|
||||
stride=downsample_stride,
|
||||
)])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class EncoderCausal3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 16,
|
||||
eps=1e-6,
|
||||
dropout=0.0,
|
||||
block_out_channels=[128, 256, 512, 512],
|
||||
layers_per_block=2,
|
||||
num_groups=32,
|
||||
time_compression_ratio: int = 4,
|
||||
spatial_compression_ratio: int = 8,
|
||||
gradient_checkpointing=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
|
||||
num_time_downsample_layers = int(np.log2(time_compression_ratio))
|
||||
|
||||
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
|
||||
add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
|
||||
|
||||
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
|
||||
downsample_stride_T = (2,) if add_time_downsample else (1,)
|
||||
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
|
||||
down_block = DownEncoderBlockCausal3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
dropout=dropout,
|
||||
num_layers=layers_per_block,
|
||||
eps=eps,
|
||||
num_groups=num_groups,
|
||||
add_downsample=bool(add_spatial_downsample or add_time_downsample),
|
||||
downsample_stride=downsample_stride,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlockCausal3D(
|
||||
in_channels=block_out_channels[-1],
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
num_groups=num_groups,
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
)
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups, eps=eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3)
|
||||
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(down_block),
|
||||
hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
# middle
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states)
|
||||
# middle
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
# post-process
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoVAEEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
eps=1e-6,
|
||||
dropout=0.0,
|
||||
block_out_channels=[128, 256, 512, 512],
|
||||
layers_per_block=2,
|
||||
num_groups=32,
|
||||
time_compression_ratio=4,
|
||||
spatial_compression_ratio=8,
|
||||
gradient_checkpointing=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = EncoderCausal3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
eps=eps,
|
||||
dropout=dropout,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
num_groups=num_groups,
|
||||
time_compression_ratio=time_compression_ratio,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
|
||||
self.scaling_factor = 0.476986
|
||||
|
||||
|
||||
def forward(self, images):
|
||||
latents = self.encoder(images)
|
||||
latents = self.quant_conv(latents)
|
||||
latents = latents[:, :16]
|
||||
latents = latents * self.scaling_factor
|
||||
return latents
|
||||
|
||||
|
||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||
x = torch.ones((length,))
|
||||
if not left_bound:
|
||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
||||
if not right_bound:
|
||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
||||
return x
|
||||
|
||||
|
||||
def build_mask(self, data, is_bound, border_width):
|
||||
_, _, T, H, W = data.shape
|
||||
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
||||
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
|
||||
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
|
||||
|
||||
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
|
||||
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
|
||||
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
|
||||
|
||||
mask = torch.stack([t, h, w]).min(dim=0).values
|
||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
||||
return mask
|
||||
|
||||
|
||||
def tile_forward(self, hidden_states, tile_size, tile_stride):
|
||||
B, C, T, H, W = hidden_states.shape
|
||||
size_t, size_h, size_w = tile_size
|
||||
stride_t, stride_h, stride_w = tile_stride
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for t in range(0, T, stride_t):
|
||||
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
|
||||
for h in range(0, H, stride_h):
|
||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||
for w in range(0, W, stride_w):
|
||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||
t_, h_, w_ = t + size_t, h + size_h, w + size_w
|
||||
tasks.append((t, t_, h, h_, w, w_))
|
||||
|
||||
# Run
|
||||
torch_dtype = self.quant_conv.weight.dtype
|
||||
data_device = hidden_states.device
|
||||
computation_device = self.quant_conv.weight.device
|
||||
|
||||
weight = torch.zeros((1, 1, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
|
||||
values = torch.zeros((B, 16, (T - 1) // 4 + 1, H // 8, W // 8), dtype=torch_dtype, device=data_device)
|
||||
|
||||
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
|
||||
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
||||
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
||||
if t > 0:
|
||||
hidden_states_batch = hidden_states_batch[:, :, 1:]
|
||||
|
||||
mask = self.build_mask(
|
||||
hidden_states_batch,
|
||||
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
|
||||
border_width=((size_t - stride_t) // 4, (size_h - stride_h) // 8, (size_w - stride_w) // 8)
|
||||
).to(dtype=torch_dtype, device=data_device)
|
||||
|
||||
target_t = 0 if t==0 else t // 4 + 1
|
||||
target_h = h // 8
|
||||
target_w = w // 8
|
||||
values[
|
||||
:,
|
||||
:,
|
||||
target_t: target_t + hidden_states_batch.shape[2],
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += hidden_states_batch * mask
|
||||
weight[
|
||||
:,
|
||||
:,
|
||||
target_t: target_t + hidden_states_batch.shape[2],
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
return values / weight
|
||||
|
||||
|
||||
def encode_video(self, latents, tile_size=(65, 256, 256), tile_stride=(48, 192, 192)):
|
||||
latents = latents.to(self.quant_conv.weight.dtype)
|
||||
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanVideoVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
class HunyuanVideoVAEEncoderStateDictConverter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith('encoder.') or name.startswith('quant_conv.'):
|
||||
state_dict_[name] = state_dict[name]
|
||||
return state_dict_
|
||||
File diff suppressed because one or more lines are too long
@@ -9,7 +9,7 @@ import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from .wan_video_dit import flash_attention
|
||||
from ..vram_management import gradient_checkpoint_forward
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
|
||||
class RMSNorm_FP32(torch.nn.Module):
|
||||
|
||||
@@ -1,402 +0,0 @@
|
||||
import torch
|
||||
from .sd_unet import SDUNet
|
||||
from .sdxl_unet import SDXLUNet
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from .sd3_dit import SD3DiT
|
||||
from .flux_dit import FluxDiT
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
from .cog_dit import CogDiT
|
||||
from .hunyuan_video_dit import HunyuanVideoDiT
|
||||
from .wan_video_dit import WanModel
|
||||
|
||||
|
||||
|
||||
class LoRAFromCivitai:
|
||||
def __init__(self):
|
||||
self.supported_model_classes = []
|
||||
self.lora_prefix = []
|
||||
self.renamed_lora_prefix = {}
|
||||
self.special_keys = {}
|
||||
|
||||
|
||||
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
||||
for key in state_dict:
|
||||
if ".lora_up" in key:
|
||||
return self.convert_state_dict_up_down(state_dict, lora_prefix, alpha)
|
||||
return self.convert_state_dict_AB(state_dict, lora_prefix, alpha)
|
||||
|
||||
|
||||
def convert_state_dict_up_down(self, state_dict, lora_prefix="lora_unet_", alpha=1.0):
|
||||
renamed_lora_prefix = self.renamed_lora_prefix.get(lora_prefix, "")
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if ".lora_up" not in key:
|
||||
continue
|
||||
if not key.startswith(lora_prefix):
|
||||
continue
|
||||
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
||||
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
target_name = key.split(".")[0].replace(lora_prefix, renamed_lora_prefix).replace("_", ".") + ".weight"
|
||||
for special_key in self.special_keys:
|
||||
target_name = target_name.replace(special_key, self.special_keys[special_key])
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
return state_dict_
|
||||
|
||||
|
||||
def convert_state_dict_AB(self, state_dict, lora_prefix="", alpha=1.0, device="cuda", torch_dtype=torch.float16):
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
if not key.startswith(lora_prefix):
|
||||
continue
|
||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
keys = key.split(".")
|
||||
keys.pop(keys.index("lora_B"))
|
||||
target_name = ".".join(keys)
|
||||
target_name = target_name[len(lora_prefix):]
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
return state_dict_
|
||||
|
||||
|
||||
def load(self, model, state_dict_lora, lora_prefix, alpha=1.0, model_resource=None):
|
||||
state_dict_model = model.state_dict()
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=alpha)
|
||||
if model_resource == "diffusers":
|
||||
state_dict_lora = model.__class__.state_dict_converter().from_diffusers(state_dict_lora)
|
||||
elif model_resource == "civitai":
|
||||
state_dict_lora = model.__class__.state_dict_converter().from_civitai(state_dict_lora)
|
||||
if isinstance(state_dict_lora, tuple):
|
||||
state_dict_lora = state_dict_lora[0]
|
||||
if len(state_dict_lora) > 0:
|
||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
||||
for name in state_dict_lora:
|
||||
fp8=False
|
||||
if state_dict_model[name].dtype == torch.float8_e4m3fn:
|
||||
state_dict_model[name]= state_dict_model[name].to(state_dict_lora[name].dtype)
|
||||
fp8=True
|
||||
state_dict_model[name] += state_dict_lora[name].to(
|
||||
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device)
|
||||
if fp8:
|
||||
state_dict_model[name] = state_dict_model[name].to(torch.float8_e4m3fn)
|
||||
model.load_state_dict(state_dict_model)
|
||||
|
||||
|
||||
def match(self, model, state_dict_lora):
|
||||
for lora_prefix, model_class in zip(self.lora_prefix, self.supported_model_classes):
|
||||
if not isinstance(model, model_class):
|
||||
continue
|
||||
state_dict_model = model.state_dict()
|
||||
for model_resource in ["diffusers", "civitai"]:
|
||||
try:
|
||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, lora_prefix=lora_prefix, alpha=1.0)
|
||||
converter_fn = model.__class__.state_dict_converter().from_diffusers if model_resource == "diffusers" \
|
||||
else model.__class__.state_dict_converter().from_civitai
|
||||
state_dict_lora_ = converter_fn(state_dict_lora_)
|
||||
if isinstance(state_dict_lora_, tuple):
|
||||
state_dict_lora_ = state_dict_lora_[0]
|
||||
if len(state_dict_lora_) == 0:
|
||||
continue
|
||||
for name in state_dict_lora_:
|
||||
if name not in state_dict_model:
|
||||
break
|
||||
else:
|
||||
return lora_prefix, model_resource
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
|
||||
class SDLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_model_classes = [SDUNet, SDTextEncoder]
|
||||
self.lora_prefix = ["lora_unet_", "lora_te_"]
|
||||
self.special_keys = {
|
||||
"down.blocks": "down_blocks",
|
||||
"up.blocks": "up_blocks",
|
||||
"mid.block": "mid_block",
|
||||
"proj.in": "proj_in",
|
||||
"proj.out": "proj_out",
|
||||
"transformer.blocks": "transformer_blocks",
|
||||
"to.q": "to_q",
|
||||
"to.k": "to_k",
|
||||
"to.v": "to_v",
|
||||
"to.out": "to_out",
|
||||
"text.model": "text_model",
|
||||
"self.attn.q.proj": "self_attn.q_proj",
|
||||
"self.attn.k.proj": "self_attn.k_proj",
|
||||
"self.attn.v.proj": "self_attn.v_proj",
|
||||
"self.attn.out.proj": "self_attn.out_proj",
|
||||
"input.blocks": "model.diffusion_model.input_blocks",
|
||||
"middle.block": "model.diffusion_model.middle_block",
|
||||
"output.blocks": "model.diffusion_model.output_blocks",
|
||||
}
|
||||
|
||||
|
||||
class SDXLLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_model_classes = [SDXLUNet, SDXLTextEncoder, SDXLTextEncoder2]
|
||||
self.lora_prefix = ["lora_unet_", "lora_te1_", "lora_te2_"]
|
||||
self.renamed_lora_prefix = {"lora_te2_": "2"}
|
||||
self.special_keys = {
|
||||
"down.blocks": "down_blocks",
|
||||
"up.blocks": "up_blocks",
|
||||
"mid.block": "mid_block",
|
||||
"proj.in": "proj_in",
|
||||
"proj.out": "proj_out",
|
||||
"transformer.blocks": "transformer_blocks",
|
||||
"to.q": "to_q",
|
||||
"to.k": "to_k",
|
||||
"to.v": "to_v",
|
||||
"to.out": "to_out",
|
||||
"text.model": "conditioner.embedders.0.transformer.text_model",
|
||||
"self.attn.q.proj": "self_attn.q_proj",
|
||||
"self.attn.k.proj": "self_attn.k_proj",
|
||||
"self.attn.v.proj": "self_attn.v_proj",
|
||||
"self.attn.out.proj": "self_attn.out_proj",
|
||||
"input.blocks": "model.diffusion_model.input_blocks",
|
||||
"middle.block": "model.diffusion_model.middle_block",
|
||||
"output.blocks": "model.diffusion_model.output_blocks",
|
||||
"2conditioner.embedders.0.transformer.text_model.encoder.layers": "text_model.encoder.layers"
|
||||
}
|
||||
|
||||
|
||||
class FluxLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_model_classes = [FluxDiT, FluxDiT]
|
||||
self.lora_prefix = ["lora_unet_", "transformer."]
|
||||
self.renamed_lora_prefix = {}
|
||||
self.special_keys = {
|
||||
"single.blocks": "single_blocks",
|
||||
"double.blocks": "double_blocks",
|
||||
"img.attn": "img_attn",
|
||||
"img.mlp": "img_mlp",
|
||||
"img.mod": "img_mod",
|
||||
"txt.attn": "txt_attn",
|
||||
"txt.mlp": "txt_mlp",
|
||||
"txt.mod": "txt_mod",
|
||||
}
|
||||
|
||||
|
||||
|
||||
class GeneralLoRAFromPeft:
|
||||
def __init__(self):
|
||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel]
|
||||
|
||||
|
||||
def get_name_dict(self, lora_state_dict):
|
||||
lora_name_dict = {}
|
||||
for key in lora_state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
keys = key.split(".")
|
||||
if len(keys) > keys.index("lora_B") + 2:
|
||||
keys.pop(keys.index("lora_B") + 1)
|
||||
keys.pop(keys.index("lora_B"))
|
||||
if keys[0] == "diffusion_model":
|
||||
keys.pop(0)
|
||||
target_name = ".".join(keys)
|
||||
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
||||
return lora_name_dict
|
||||
|
||||
|
||||
def match(self, model: torch.nn.Module, state_dict_lora):
|
||||
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||
model_name_dict = {name: None for name, _ in model.named_parameters()}
|
||||
matched_num = sum([i in model_name_dict for i in lora_name_dict])
|
||||
if matched_num == len(lora_name_dict):
|
||||
return "", ""
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def fetch_device_and_dtype(self, state_dict):
|
||||
device, dtype = None, None
|
||||
for name, param in state_dict.items():
|
||||
device, dtype = param.device, param.dtype
|
||||
break
|
||||
computation_device = device
|
||||
computation_dtype = dtype
|
||||
if computation_device == torch.device("cpu"):
|
||||
if torch.cuda.is_available():
|
||||
computation_device = torch.device("cuda")
|
||||
if computation_dtype == torch.float8_e4m3fn:
|
||||
computation_dtype = torch.float32
|
||||
return device, dtype, computation_device, computation_dtype
|
||||
|
||||
|
||||
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
||||
state_dict_model = model.state_dict()
|
||||
device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model)
|
||||
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||
for name in lora_name_dict:
|
||||
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype)
|
||||
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
||||
weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype)
|
||||
weight_patched = weight_model + weight_lora
|
||||
state_dict_model[name] = weight_patched.to(device=device, dtype=dtype)
|
||||
print(f" {len(lora_name_dict)} tensors are updated.")
|
||||
model.load_state_dict(state_dict_model)
|
||||
|
||||
|
||||
|
||||
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_model_classes = [HunyuanVideoDiT, HunyuanVideoDiT]
|
||||
self.lora_prefix = ["diffusion_model.", "transformer."]
|
||||
self.special_keys = {}
|
||||
|
||||
|
||||
class FluxLoRAConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def align_to_opensource_format(state_dict, alpha=None):
|
||||
prefix_rename_dict = {
|
||||
"single_blocks": "lora_unet_single_blocks",
|
||||
"blocks": "lora_unet_double_blocks",
|
||||
}
|
||||
middle_rename_dict = {
|
||||
"norm.linear": "modulation_lin",
|
||||
"to_qkv_mlp": "linear1",
|
||||
"proj_out": "linear2",
|
||||
|
||||
"norm1_a.linear": "img_mod_lin",
|
||||
"norm1_b.linear": "txt_mod_lin",
|
||||
"attn.a_to_qkv": "img_attn_qkv",
|
||||
"attn.b_to_qkv": "txt_attn_qkv",
|
||||
"attn.a_to_out": "img_attn_proj",
|
||||
"attn.b_to_out": "txt_attn_proj",
|
||||
"ff_a.0": "img_mlp_0",
|
||||
"ff_a.2": "img_mlp_2",
|
||||
"ff_b.0": "txt_mlp_0",
|
||||
"ff_b.2": "txt_mlp_2",
|
||||
}
|
||||
suffix_rename_dict = {
|
||||
"lora_B.weight": "lora_up.weight",
|
||||
"lora_A.weight": "lora_down.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
names = name.split(".")
|
||||
if names[-2] != "lora_A" and names[-2] != "lora_B":
|
||||
names.pop(-2)
|
||||
prefix = names[0]
|
||||
middle = ".".join(names[2:-2])
|
||||
suffix = ".".join(names[-2:])
|
||||
block_id = names[1]
|
||||
if middle not in middle_rename_dict:
|
||||
continue
|
||||
rename = prefix_rename_dict[prefix] + "_" + block_id + "_" + middle_rename_dict[middle] + "." + suffix_rename_dict[suffix]
|
||||
state_dict_[rename] = param
|
||||
if rename.endswith("lora_up.weight"):
|
||||
lora_alpha = alpha if alpha is not None else param.shape[-1]
|
||||
state_dict_[rename.replace("lora_up.weight", "alpha")] = torch.tensor((lora_alpha,))[0]
|
||||
return state_dict_
|
||||
|
||||
@staticmethod
|
||||
def align_to_diffsynth_format(state_dict):
|
||||
rename_dict = {
|
||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_down.weight": "blocks.blockid.norm1_a.linear.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mod_lin.lora_up.weight": "blocks.blockid.norm1_a.linear.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_down.weight": "blocks.blockid.norm1_b.linear.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mod_lin.lora_up.weight": "blocks.blockid.norm1_b.linear.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_down.weight": "blocks.blockid.attn.a_to_qkv.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_qkv.lora_up.weight": "blocks.blockid.attn.a_to_qkv.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_down.weight": "blocks.blockid.attn.b_to_qkv.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_qkv.lora_up.weight": "blocks.blockid.attn.b_to_qkv.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_proj.lora_down.weight": "blocks.blockid.attn.a_to_out.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_attn_proj.lora_up.weight": "blocks.blockid.attn.a_to_out.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_down.weight": "blocks.blockid.attn.b_to_out.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_attn_proj.lora_up.weight": "blocks.blockid.attn.b_to_out.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_0.lora_down.weight": "blocks.blockid.ff_a.0.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_0.lora_up.weight": "blocks.blockid.ff_a.0.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_2.lora_down.weight": "blocks.blockid.ff_a.2.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_img_mlp_2.lora_up.weight": "blocks.blockid.ff_a.2.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_down.weight": "blocks.blockid.ff_b.0.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_0.lora_up.weight": "blocks.blockid.ff_b.0.lora_B.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_down.weight": "blocks.blockid.ff_b.2.lora_A.default.weight",
|
||||
"lora_unet_double_blocks_blockid_txt_mlp_2.lora_up.weight": "blocks.blockid.ff_b.2.lora_B.default.weight",
|
||||
"lora_unet_single_blocks_blockid_modulation_lin.lora_down.weight": "single_blocks.blockid.norm.linear.lora_A.default.weight",
|
||||
"lora_unet_single_blocks_blockid_modulation_lin.lora_up.weight": "single_blocks.blockid.norm.linear.lora_B.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear1.lora_down.weight": "single_blocks.blockid.to_qkv_mlp.lora_A.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear1.lora_up.weight": "single_blocks.blockid.to_qkv_mlp.lora_B.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear2.lora_down.weight": "single_blocks.blockid.proj_out.lora_A.default.weight",
|
||||
"lora_unet_single_blocks_blockid_linear2.lora_up.weight": "single_blocks.blockid.proj_out.lora_B.default.weight",
|
||||
}
|
||||
def guess_block_id(name):
|
||||
names = name.split("_")
|
||||
for i in names:
|
||||
if i.isdigit():
|
||||
return i, name.replace(f"_{i}_", "_blockid_")
|
||||
return None, None
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
block_id, source_name = guess_block_id(name)
|
||||
if source_name in rename_dict:
|
||||
target_name = rename_dict[source_name]
|
||||
target_name = target_name.replace(".blockid.", f".{block_id}.")
|
||||
state_dict_[target_name] = param
|
||||
else:
|
||||
state_dict_[name] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
class WanLoRAConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def align_to_opensource_format(state_dict, **kwargs):
|
||||
state_dict = {"diffusion_model." + name.replace(".default.", "."): param for name, param in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
@staticmethod
|
||||
def align_to_diffsynth_format(state_dict, **kwargs):
|
||||
state_dict = {name.replace("diffusion_model.", "").replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
|
||||
class QwenImageLoRAConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def align_to_opensource_format(state_dict, **kwargs):
|
||||
state_dict = {name.replace(".default.", "."): param for name, param in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
@staticmethod
|
||||
def align_to_diffsynth_format(state_dict, **kwargs):
|
||||
state_dict = {name.replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_lora_loaders():
|
||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||
111
diffsynth/models/model_loader.py
Normal file
111
diffsynth/models/model_loader.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from ..core.loader import load_model, hash_model_file
|
||||
from ..core.vram import AutoWrappedModule
|
||||
from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS
|
||||
import importlib, json, torch
|
||||
|
||||
|
||||
class ModelPool:
|
||||
def __init__(self):
|
||||
self.model = []
|
||||
self.model_name = []
|
||||
self.model_path = []
|
||||
|
||||
def import_model_class(self, model_class):
|
||||
split = model_class.rfind(".")
|
||||
model_resource, model_class = model_class[:split], model_class[split+1:]
|
||||
model_class = importlib.import_module(model_resource).__getattribute__(model_class)
|
||||
return model_class
|
||||
|
||||
def need_to_enable_vram_management(self, vram_config):
|
||||
return vram_config["offload_dtype"] is not None and vram_config["offload_device"] is not None
|
||||
|
||||
def fetch_module_map(self, model_class, vram_config):
|
||||
if self.need_to_enable_vram_management(vram_config):
|
||||
if model_class in VRAM_MANAGEMENT_MODULE_MAPS:
|
||||
module_map = {self.import_model_class(source): self.import_model_class(target) for source, target in VRAM_MANAGEMENT_MODULE_MAPS[model_class].items()}
|
||||
else:
|
||||
module_map = {self.import_model_class(model_class): AutoWrappedModule}
|
||||
else:
|
||||
module_map = None
|
||||
return module_map
|
||||
|
||||
def load_model_file(self, config, path, vram_config, vram_limit=None):
|
||||
model_class = self.import_model_class(config["model_class"])
|
||||
model_config = config.get("extra_kwargs", {})
|
||||
if "state_dict_converter" in config:
|
||||
state_dict_converter = self.import_model_class(config["state_dict_converter"])
|
||||
else:
|
||||
state_dict_converter = None
|
||||
module_map = self.fetch_module_map(config["model_class"], vram_config)
|
||||
model = load_model(
|
||||
model_class, path, model_config,
|
||||
vram_config["computation_dtype"], vram_config["computation_device"],
|
||||
state_dict_converter,
|
||||
use_disk_map=True,
|
||||
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
|
||||
)
|
||||
return model
|
||||
|
||||
def default_vram_config(self):
|
||||
vram_config = {
|
||||
"offload_dtype": None,
|
||||
"offload_device": None,
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cpu",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cpu",
|
||||
}
|
||||
return vram_config
|
||||
|
||||
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False):
|
||||
print(f"Loading models from: {json.dumps(path, indent=4)}")
|
||||
if vram_config is None:
|
||||
vram_config = self.default_vram_config()
|
||||
model_hash = hash_model_file(path)
|
||||
loaded = False
|
||||
for config in MODEL_CONFIGS:
|
||||
if config["model_hash"] == model_hash:
|
||||
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit)
|
||||
if clear_parameters: self.clear_parameters(model)
|
||||
self.model.append(model)
|
||||
model_name = config["model_name"]
|
||||
self.model_name.append(model_name)
|
||||
self.model_path.append(path)
|
||||
model_info = {"model_name": model_name, "model_class": config["model_class"], "extra_kwargs": config.get("extra_kwargs")}
|
||||
print(f"Loaded model: {json.dumps(model_info, indent=4)}")
|
||||
loaded = True
|
||||
if not loaded:
|
||||
raise ValueError(f"Cannot detect the model type. File: {path}. Model hash: {model_hash}")
|
||||
|
||||
def fetch_model(self, model_name, index=None):
|
||||
fetched_models = []
|
||||
fetched_model_paths = []
|
||||
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
||||
if model_name == model_name_:
|
||||
fetched_models.append(model)
|
||||
fetched_model_paths.append(model_path)
|
||||
if len(fetched_models) == 0:
|
||||
print(f"No {model_name} models available. This is not an error.")
|
||||
model = None
|
||||
elif len(fetched_models) == 1:
|
||||
print(f"Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.")
|
||||
model = fetched_models[0]
|
||||
else:
|
||||
if index is None:
|
||||
model = fetched_models[0]
|
||||
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[0], indent=4)}.")
|
||||
elif isinstance(index, int):
|
||||
model = fetched_models[:index]
|
||||
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths[:index], indent=4)}.")
|
||||
else:
|
||||
model = fetched_models
|
||||
print(f"More than one {model_name} models are loaded: {fetched_model_paths}. Using {model_name} from {json.dumps(fetched_model_paths, indent=4)}.")
|
||||
return model
|
||||
|
||||
def clear_parameters(self, model: torch.nn.Module):
|
||||
for name, module in model.named_children():
|
||||
self.clear_parameters(module)
|
||||
for name, param in model.named_parameters(recurse=False):
|
||||
setattr(model, name, None)
|
||||
@@ -1,467 +0,0 @@
|
||||
import os, torch, json, importlib
|
||||
from typing import List
|
||||
|
||||
from .downloader import download_models, download_customized_models, Preset_model_id, Preset_model_website
|
||||
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sd_unet import SDUNet
|
||||
from .sd_vae_encoder import SDVAEEncoder
|
||||
from .sd_vae_decoder import SDVAEDecoder
|
||||
from .lora import get_lora_loaders
|
||||
|
||||
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
|
||||
from .sdxl_unet import SDXLUNet
|
||||
from .sdxl_vae_decoder import SDXLVAEDecoder
|
||||
from .sdxl_vae_encoder import SDXLVAEEncoder
|
||||
|
||||
from .sd3_text_encoder import SD3TextEncoder1, SD3TextEncoder2, SD3TextEncoder3
|
||||
from .sd3_dit import SD3DiT
|
||||
from .sd3_vae_decoder import SD3VAEDecoder
|
||||
from .sd3_vae_encoder import SD3VAEEncoder
|
||||
|
||||
from .sd_controlnet import SDControlNet
|
||||
from .sdxl_controlnet import SDXLControlNetUnion
|
||||
|
||||
from .sd_motion import SDMotionModel
|
||||
from .sdxl_motion import SDXLMotionModel
|
||||
|
||||
from .svd_image_encoder import SVDImageEncoder
|
||||
from .svd_unet import SVDUNet
|
||||
from .svd_vae_decoder import SVDVAEDecoder
|
||||
from .svd_vae_encoder import SVDVAEEncoder
|
||||
|
||||
from .sd_ipadapter import SDIpAdapter, IpAdapterCLIPImageEmbedder
|
||||
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
|
||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
from .hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
|
||||
from .hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
|
||||
|
||||
from .flux_dit import FluxDiT
|
||||
from .flux_text_encoder import FluxTextEncoder2
|
||||
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
from .flux_ipadapter import FluxIpAdapter
|
||||
|
||||
from .cog_vae import CogVAEEncoder, CogVAEDecoder
|
||||
from .cog_dit import CogDiT
|
||||
|
||||
from ..extensions.RIFE import IFNet
|
||||
from ..extensions.ESRGAN import RRDBNet
|
||||
|
||||
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
|
||||
from .utils import load_state_dict, init_weights_on_device, hash_state_dict_keys, split_state_dict_with_prefix
|
||||
|
||||
|
||||
def load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device):
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for model_name, model_class in zip(model_names, model_classes):
|
||||
print(f" model_name: {model_name} model_class: {model_class.__name__}")
|
||||
state_dict_converter = model_class.state_dict_converter()
|
||||
if model_resource == "civitai":
|
||||
state_dict_results = state_dict_converter.from_civitai(state_dict)
|
||||
elif model_resource == "diffusers":
|
||||
state_dict_results = state_dict_converter.from_diffusers(state_dict)
|
||||
if isinstance(state_dict_results, tuple):
|
||||
model_state_dict, extra_kwargs = state_dict_results
|
||||
print(f" This model is initialized with extra kwargs: {extra_kwargs}")
|
||||
else:
|
||||
model_state_dict, extra_kwargs = state_dict_results, {}
|
||||
torch_dtype = torch.float32 if extra_kwargs.get("upcast_to_float32", False) else torch_dtype
|
||||
with init_weights_on_device():
|
||||
model = model_class(**extra_kwargs)
|
||||
if hasattr(model, "eval"):
|
||||
model = model.eval()
|
||||
model.load_state_dict(model_state_dict, assign=True)
|
||||
model = model.to(dtype=torch_dtype, device=device)
|
||||
loaded_model_names.append(model_name)
|
||||
loaded_models.append(model)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
def load_model_from_huggingface_folder(file_path, model_names, model_classes, torch_dtype, device):
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for model_name, model_class in zip(model_names, model_classes):
|
||||
if torch_dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
model = model_class.from_pretrained(file_path, torch_dtype=torch_dtype).eval()
|
||||
else:
|
||||
model = model_class.from_pretrained(file_path).eval().to(dtype=torch_dtype)
|
||||
if torch_dtype == torch.float16 and hasattr(model, "half"):
|
||||
model = model.half()
|
||||
try:
|
||||
model = model.to(device=device)
|
||||
except:
|
||||
pass
|
||||
loaded_model_names.append(model_name)
|
||||
loaded_models.append(model)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
def load_single_patch_model_from_single_file(state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device):
|
||||
print(f" model_name: {model_name} model_class: {model_class.__name__} extra_kwargs: {extra_kwargs}")
|
||||
base_state_dict = base_model.state_dict()
|
||||
base_model.to("cpu")
|
||||
del base_model
|
||||
model = model_class(**extra_kwargs)
|
||||
model.load_state_dict(base_state_dict, strict=False)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model.to(dtype=torch_dtype, device=device)
|
||||
return model
|
||||
|
||||
|
||||
def load_patch_model_from_single_file(state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device):
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for model_name, model_class in zip(model_names, model_classes):
|
||||
while True:
|
||||
for model_id in range(len(model_manager.model)):
|
||||
base_model_name = model_manager.model_name[model_id]
|
||||
if base_model_name == model_name:
|
||||
base_model_path = model_manager.model_path[model_id]
|
||||
base_model = model_manager.model[model_id]
|
||||
print(f" Adding patch model to {base_model_name} ({base_model_path})")
|
||||
patched_model = load_single_patch_model_from_single_file(
|
||||
state_dict, model_name, model_class, base_model, extra_kwargs, torch_dtype, device)
|
||||
loaded_model_names.append(base_model_name)
|
||||
loaded_models.append(patched_model)
|
||||
model_manager.model.pop(model_id)
|
||||
model_manager.model_path.pop(model_id)
|
||||
model_manager.model_name.pop(model_id)
|
||||
break
|
||||
else:
|
||||
break
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorTemplate:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
return False
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
return [], []
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromSingleFile:
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
self.keys_hash_with_shape_dict = {}
|
||||
self.keys_hash_dict = {}
|
||||
for metadata in model_loader_configs:
|
||||
self.add_model_metadata(*metadata)
|
||||
|
||||
|
||||
def add_model_metadata(self, keys_hash, keys_hash_with_shape, model_names, model_classes, model_resource):
|
||||
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_names, model_classes, model_resource)
|
||||
if keys_hash is not None:
|
||||
self.keys_hash_dict[keys_hash] = (model_names, model_classes, model_resource)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if isinstance(file_path, str) and os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
return True
|
||||
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
||||
if keys_hash in self.keys_hash_dict:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
|
||||
# Load models with strict matching
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
model_names, model_classes, model_resource = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
||||
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
# Load models without strict matching
|
||||
# (the shape of parameters may be inconsistent, and the state_dict_converter will modify the model architecture)
|
||||
keys_hash = hash_state_dict_keys(state_dict, with_shape=False)
|
||||
if keys_hash in self.keys_hash_dict:
|
||||
model_names, model_classes, model_resource = self.keys_hash_dict[keys_hash]
|
||||
loaded_model_names, loaded_models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, torch_dtype, device)
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
super().__init__(model_loader_configs)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if isinstance(file_path, str) and os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
||||
for sub_state_dict in splited_state_dict:
|
||||
if super().match(file_path, sub_state_dict):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
# Split the state_dict and load from each component
|
||||
splited_state_dict = split_state_dict_with_prefix(state_dict)
|
||||
valid_state_dict = {}
|
||||
for sub_state_dict in splited_state_dict:
|
||||
if super().match(file_path, sub_state_dict):
|
||||
valid_state_dict.update(sub_state_dict)
|
||||
if super().match(file_path, valid_state_dict):
|
||||
loaded_model_names, loaded_models = super().load(file_path, valid_state_dict, device, torch_dtype)
|
||||
else:
|
||||
loaded_model_names, loaded_models = [], []
|
||||
for sub_state_dict in splited_state_dict:
|
||||
if super().match(file_path, sub_state_dict):
|
||||
loaded_model_names_, loaded_models_ = super().load(file_path, valid_state_dict, device, torch_dtype)
|
||||
loaded_model_names += loaded_model_names_
|
||||
loaded_models += loaded_models_
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromHuggingfaceFolder:
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
self.architecture_dict = {}
|
||||
for metadata in model_loader_configs:
|
||||
self.add_model_metadata(*metadata)
|
||||
|
||||
|
||||
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
|
||||
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if not isinstance(file_path, str) or os.path.isfile(file_path):
|
||||
return False
|
||||
file_list = os.listdir(file_path)
|
||||
if "config.json" not in file_list:
|
||||
return False
|
||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
if "architectures" not in config and "_class_name" not in config:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
loaded_model_names, loaded_models = [], []
|
||||
architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
|
||||
for architecture in architectures:
|
||||
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
|
||||
if redirected_architecture is not None:
|
||||
architecture = redirected_architecture
|
||||
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
|
||||
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
|
||||
loaded_model_names += loaded_model_names_
|
||||
loaded_models += loaded_models_
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelDetectorFromPatchedSingleFile:
|
||||
def __init__(self, model_loader_configs=[]):
|
||||
self.keys_hash_with_shape_dict = {}
|
||||
for metadata in model_loader_configs:
|
||||
self.add_model_metadata(*metadata)
|
||||
|
||||
|
||||
def add_model_metadata(self, keys_hash_with_shape, model_name, model_class, extra_kwargs):
|
||||
self.keys_hash_with_shape_dict[keys_hash_with_shape] = (model_name, model_class, extra_kwargs)
|
||||
|
||||
|
||||
def match(self, file_path="", state_dict={}):
|
||||
if not isinstance(file_path, str) or os.path.isdir(file_path):
|
||||
return False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, model_manager=None, **kwargs):
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
|
||||
# Load models with strict matching
|
||||
loaded_model_names, loaded_models = [], []
|
||||
keys_hash_with_shape = hash_state_dict_keys(state_dict, with_shape=True)
|
||||
if keys_hash_with_shape in self.keys_hash_with_shape_dict:
|
||||
model_names, model_classes, extra_kwargs = self.keys_hash_with_shape_dict[keys_hash_with_shape]
|
||||
loaded_model_names_, loaded_models_ = load_patch_model_from_single_file(
|
||||
state_dict, model_names, model_classes, extra_kwargs, model_manager, torch_dtype, device)
|
||||
loaded_model_names += loaded_model_names_
|
||||
loaded_models += loaded_models_
|
||||
return loaded_model_names, loaded_models
|
||||
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(
|
||||
self,
|
||||
torch_dtype=torch.float16,
|
||||
device="cuda",
|
||||
model_id_list: List[Preset_model_id] = [],
|
||||
downloading_priority: List[Preset_model_website] = ["ModelScope", "HuggingFace"],
|
||||
file_path_list: List[str] = [],
|
||||
):
|
||||
self.torch_dtype = torch_dtype
|
||||
self.device = device
|
||||
self.model = []
|
||||
self.model_path = []
|
||||
self.model_name = []
|
||||
downloaded_files = download_models(model_id_list, downloading_priority) if len(model_id_list) > 0 else []
|
||||
self.model_detector = [
|
||||
ModelDetectorFromSingleFile(model_loader_configs),
|
||||
ModelDetectorFromSplitedSingleFile(model_loader_configs),
|
||||
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
|
||||
ModelDetectorFromPatchedSingleFile(patch_model_loader_configs),
|
||||
]
|
||||
self.load_models(downloaded_files + file_path_list)
|
||||
|
||||
|
||||
def load_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], model_resource=None):
|
||||
print(f"Loading models from file: {file_path}")
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
model_names, models = load_model_from_single_file(state_dict, model_names, model_classes, model_resource, self.torch_dtype, self.device)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following models are loaded: {model_names}.")
|
||||
|
||||
|
||||
def load_model_from_huggingface_folder(self, file_path="", model_names=[], model_classes=[]):
|
||||
print(f"Loading models from folder: {file_path}")
|
||||
model_names, models = load_model_from_huggingface_folder(file_path, model_names, model_classes, self.torch_dtype, self.device)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following models are loaded: {model_names}.")
|
||||
|
||||
|
||||
def load_patch_model_from_single_file(self, file_path="", state_dict={}, model_names=[], model_classes=[], extra_kwargs={}):
|
||||
print(f"Loading patch models from file: {file_path}")
|
||||
model_names, models = load_patch_model_from_single_file(
|
||||
state_dict, model_names, model_classes, extra_kwargs, self, self.torch_dtype, self.device)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following patched models are loaded: {model_names}.")
|
||||
|
||||
|
||||
def load_lora(self, file_path="", state_dict={}, lora_alpha=1.0):
|
||||
if isinstance(file_path, list):
|
||||
for file_path_ in file_path:
|
||||
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
|
||||
else:
|
||||
print(f"Loading LoRA models from file: {file_path}")
|
||||
is_loaded = False
|
||||
if len(state_dict) == 0:
|
||||
state_dict = load_state_dict(file_path)
|
||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
||||
for lora in get_lora_loaders():
|
||||
match_results = lora.match(model, state_dict)
|
||||
if match_results is not None:
|
||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
||||
lora_prefix, model_resource = match_results
|
||||
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
||||
is_loaded = True
|
||||
break
|
||||
if not is_loaded:
|
||||
print(f" Cannot load LoRA: {file_path}")
|
||||
|
||||
|
||||
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
|
||||
print(f"Loading models from: {file_path}")
|
||||
if device is None: device = self.device
|
||||
if torch_dtype is None: torch_dtype = self.torch_dtype
|
||||
if isinstance(file_path, list):
|
||||
state_dict = {}
|
||||
for path in file_path:
|
||||
state_dict.update(load_state_dict(path))
|
||||
elif os.path.isfile(file_path):
|
||||
state_dict = load_state_dict(file_path)
|
||||
else:
|
||||
state_dict = None
|
||||
for model_detector in self.model_detector:
|
||||
if model_detector.match(file_path, state_dict):
|
||||
model_names, models = model_detector.load(
|
||||
file_path, state_dict,
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
allowed_model_names=model_names, model_manager=self
|
||||
)
|
||||
for model_name, model in zip(model_names, models):
|
||||
self.model.append(model)
|
||||
self.model_path.append(file_path)
|
||||
self.model_name.append(model_name)
|
||||
print(f" The following models are loaded: {model_names}.")
|
||||
break
|
||||
else:
|
||||
print(f" We cannot detect the model type. No models are loaded.")
|
||||
|
||||
|
||||
def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
|
||||
for file_path in file_path_list:
|
||||
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
|
||||
|
||||
|
||||
def fetch_model(self, model_name, file_path=None, require_model_path=False, index=None):
|
||||
fetched_models = []
|
||||
fetched_model_paths = []
|
||||
for model, model_path, model_name_ in zip(self.model, self.model_path, self.model_name):
|
||||
if file_path is not None and file_path != model_path:
|
||||
continue
|
||||
if model_name == model_name_:
|
||||
fetched_models.append(model)
|
||||
fetched_model_paths.append(model_path)
|
||||
if len(fetched_models) == 0:
|
||||
print(f"No {model_name} models available.")
|
||||
return None
|
||||
if len(fetched_models) == 1:
|
||||
print(f"Using {model_name} from {fetched_model_paths[0]}.")
|
||||
model = fetched_models[0]
|
||||
path = fetched_model_paths[0]
|
||||
else:
|
||||
if index is None:
|
||||
model = fetched_models[0]
|
||||
path = fetched_model_paths[0]
|
||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[0]}.")
|
||||
elif isinstance(index, int):
|
||||
model = fetched_models[:index]
|
||||
path = fetched_model_paths[:index]
|
||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths[:index]}.")
|
||||
else:
|
||||
model = fetched_models
|
||||
path = fetched_model_paths
|
||||
print(f"More than one {model_name} models are loaded in model manager: {fetched_model_paths}. Using {model_name} from {fetched_model_paths}.")
|
||||
if require_model_path:
|
||||
return model, path
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def to(self, device):
|
||||
for model in self.model:
|
||||
model.to(device)
|
||||
|
||||
@@ -1,803 +0,0 @@
|
||||
# The code is revised from DiT
|
||||
import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import math
|
||||
from safetensors.torch import load_file
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import torch.utils.checkpoint
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers import Phi3Config, Phi3Model
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Phi3Transformer(Phi3Model):
|
||||
"""
|
||||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`]
|
||||
We only modified the attention mask
|
||||
Args:
|
||||
config: Phi3Config
|
||||
"""
|
||||
def prefetch_layer(self, layer_idx: int, device: torch.device):
|
||||
"Starts prefetching the next layer cache"
|
||||
with torch.cuda.stream(self.prefetch_stream):
|
||||
# Prefetch next layer tensors to GPU
|
||||
for name, param in self.layers[layer_idx].named_parameters():
|
||||
param.data = param.data.to(device, non_blocking=True)
|
||||
|
||||
def evict_previous_layer(self, layer_idx: int):
|
||||
"Moves the previous layer cache to the CPU"
|
||||
prev_layer_idx = layer_idx - 1
|
||||
for name, param in self.layers[prev_layer_idx].named_parameters():
|
||||
param.data = param.data.to("cpu", non_blocking=True)
|
||||
|
||||
def get_offlaod_layer(self, layer_idx: int, device: torch.device):
|
||||
# init stream
|
||||
if not hasattr(self, "prefetch_stream"):
|
||||
self.prefetch_stream = torch.cuda.Stream()
|
||||
|
||||
# delete previous layer
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.evict_previous_layer(layer_idx)
|
||||
|
||||
# make sure the current layer is ready
|
||||
torch.cuda.synchronize(self.prefetch_stream)
|
||||
|
||||
# load next layer
|
||||
self.prefetch_layer((layer_idx + 1) % len(self.layers), device)
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
offload_model: Optional[bool] = False,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# kept for BC (non `Cache` `past_key_values` inputs)
|
||||
return_legacy_cache = False
|
||||
if use_cache and not isinstance(past_key_values, Cache):
|
||||
return_legacy_cache = True
|
||||
if past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
else:
|
||||
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||
logger.warning_once(
|
||||
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
||||
"will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
||||
"(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
||||
)
|
||||
|
||||
# if inputs_embeds is None:
|
||||
# inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
# if cache_position is None:
|
||||
# past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
# cache_position = torch.arange(
|
||||
# past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||
# )
|
||||
# if position_ids is None:
|
||||
# position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
if attention_mask is not None and attention_mask.dim() == 3:
|
||||
dtype = inputs_embeds.dtype
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
attention_mask = (1 - attention_mask) * min_dtype
|
||||
attention_mask = attention_mask.unsqueeze(1).to(inputs_embeds.dtype)
|
||||
else:
|
||||
raise Exception("attention_mask parameter was unavailable or invalid")
|
||||
# causal_mask = self._update_causal_mask(
|
||||
# attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
# )
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = None
|
||||
|
||||
layer_idx = -1
|
||||
for decoder_layer in self.layers:
|
||||
layer_idx += 1
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
)
|
||||
else:
|
||||
if offload_model and not self.training:
|
||||
self.get_offlaod_layer(layer_idx, device=inputs_embeds.device)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
print('************')
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if return_legacy_cache:
|
||||
next_cache = next_cache.to_legacy_cache()
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param t: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an (N, D) Tensor of positional embeddings.
|
||||
"""
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t, dtype=torch.float32):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of DiT.
|
||||
"""
|
||||
def __init__(self, hidden_size, patch_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=1):
|
||||
"""
|
||||
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
||||
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
if isinstance(grid_size, int):
|
||||
grid_size = (grid_size, grid_size)
|
||||
|
||||
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
|
||||
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.
|
||||
omega = 1. / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
class PatchEmbedMR(nn.Module):
|
||||
""" 2D Image to Patch Embedding
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_chans: int = 4,
|
||||
embed_dim: int = 768,
|
||||
bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
|
||||
return x
|
||||
|
||||
|
||||
class OmniGenOriginalModel(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
transformer_config: Phi3Config,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
pe_interpolation: float = 1.0,
|
||||
pos_embed_max_size: int = 192,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
|
||||
hidden_size = transformer_config.hidden_size
|
||||
|
||||
self.x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
|
||||
self.input_x_embedder = PatchEmbedMR(patch_size, in_channels, hidden_size, bias=True)
|
||||
|
||||
self.time_token = TimestepEmbedder(hidden_size)
|
||||
self.t_embedder = TimestepEmbedder(hidden_size)
|
||||
|
||||
self.pe_interpolation = pe_interpolation
|
||||
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, interpolation_scale=self.pe_interpolation, base_size=64)
|
||||
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=True)
|
||||
|
||||
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
self.llm = Phi3Transformer(config=transformer_config)
|
||||
self.llm.config.use_cache = False
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_name):
|
||||
if not os.path.exists(model_name):
|
||||
cache_folder = os.getenv('HF_HUB_CACHE')
|
||||
model_name = snapshot_download(repo_id=model_name,
|
||||
cache_dir=cache_folder,
|
||||
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
|
||||
config = Phi3Config.from_pretrained(model_name)
|
||||
model = cls(config)
|
||||
if os.path.exists(os.path.join(model_name, 'model.safetensors')):
|
||||
print("Loading safetensors")
|
||||
ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
|
||||
else:
|
||||
ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
|
||||
model.load_state_dict(ckpt)
|
||||
return model
|
||||
|
||||
def initialize_weights(self):
|
||||
assert not hasattr(self, "llama")
|
||||
|
||||
# Initialize transformer layers:
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
self.apply(_basic_init)
|
||||
|
||||
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
||||
w = self.x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
w = self.input_x_embedder.proj.weight.data
|
||||
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
||||
nn.init.constant_(self.x_embedder.proj.bias, 0)
|
||||
|
||||
|
||||
# Initialize timestep embedding MLP:
|
||||
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
||||
nn.init.normal_(self.time_token.mlp[0].weight, std=0.02)
|
||||
nn.init.normal_(self.time_token.mlp[2].weight, std=0.02)
|
||||
|
||||
# Zero-out output layers:
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
|
||||
nn.init.constant_(self.final_layer.linear.weight, 0)
|
||||
nn.init.constant_(self.final_layer.linear.bias, 0)
|
||||
|
||||
def unpatchify(self, x, h, w):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h//self.patch_size, w//self.patch_size, self.patch_size, self.patch_size, c))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h, w))
|
||||
return imgs
|
||||
|
||||
|
||||
def cropped_pos_embed(self, height, width):
|
||||
"""Crops positional embeddings for SD3 compatibility."""
|
||||
if self.pos_embed_max_size is None:
|
||||
raise ValueError("`pos_embed_max_size` must be set for cropping.")
|
||||
|
||||
height = height // self.patch_size
|
||||
width = width // self.patch_size
|
||||
if height > self.pos_embed_max_size:
|
||||
raise ValueError(
|
||||
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
||||
)
|
||||
if width > self.pos_embed_max_size:
|
||||
raise ValueError(
|
||||
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
|
||||
)
|
||||
|
||||
top = (self.pos_embed_max_size - height) // 2
|
||||
left = (self.pos_embed_max_size - width) // 2
|
||||
spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
|
||||
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
|
||||
# print(top, top + height, left, left + width, spatial_pos_embed.size())
|
||||
spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
|
||||
return spatial_pos_embed
|
||||
|
||||
|
||||
def patch_multiple_resolutions(self, latents, padding_latent=None, is_input_images:bool=False):
|
||||
if isinstance(latents, list):
|
||||
return_list = False
|
||||
if padding_latent is None:
|
||||
padding_latent = [None] * len(latents)
|
||||
return_list = True
|
||||
patched_latents, num_tokens, shapes = [], [], []
|
||||
for latent, padding in zip(latents, padding_latent):
|
||||
height, width = latent.shape[-2:]
|
||||
if is_input_images:
|
||||
latent = self.input_x_embedder(latent)
|
||||
else:
|
||||
latent = self.x_embedder(latent)
|
||||
pos_embed = self.cropped_pos_embed(height, width)
|
||||
latent = latent + pos_embed
|
||||
if padding is not None:
|
||||
latent = torch.cat([latent, padding], dim=-2)
|
||||
patched_latents.append(latent)
|
||||
|
||||
num_tokens.append(pos_embed.size(1))
|
||||
shapes.append([height, width])
|
||||
if not return_list:
|
||||
latents = torch.cat(patched_latents, dim=0)
|
||||
else:
|
||||
latents = patched_latents
|
||||
else:
|
||||
height, width = latents.shape[-2:]
|
||||
if is_input_images:
|
||||
latents = self.input_x_embedder(latents)
|
||||
else:
|
||||
latents = self.x_embedder(latents)
|
||||
pos_embed = self.cropped_pos_embed(height, width)
|
||||
latents = latents + pos_embed
|
||||
num_tokens = latents.size(1)
|
||||
shapes = [height, width]
|
||||
return latents, num_tokens, shapes
|
||||
|
||||
|
||||
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
|
||||
"""
|
||||
|
||||
"""
|
||||
input_is_list = isinstance(x, list)
|
||||
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
|
||||
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
|
||||
|
||||
if input_img_latents is not None:
|
||||
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
|
||||
if input_ids is not None:
|
||||
condition_embeds = self.llm.embed_tokens(input_ids).clone()
|
||||
input_img_inx = 0
|
||||
for b_inx in input_image_sizes.keys():
|
||||
for start_inx, end_inx in input_image_sizes[b_inx]:
|
||||
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
|
||||
input_img_inx += 1
|
||||
if input_img_latents is not None:
|
||||
assert input_img_inx == len(input_latents)
|
||||
|
||||
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
|
||||
else:
|
||||
input_emb = torch.cat([time_token, x], dim=1)
|
||||
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
|
||||
output, past_key_values = output.last_hidden_state, output.past_key_values
|
||||
if input_is_list:
|
||||
image_embedding = output[:, -max(num_tokens):]
|
||||
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
||||
x = self.final_layer(image_embedding, time_emb)
|
||||
latents = []
|
||||
for i in range(x.size(0)):
|
||||
latent = x[i:i+1, :num_tokens[i]]
|
||||
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
|
||||
latents.append(latent)
|
||||
else:
|
||||
image_embedding = output[:, -num_tokens:]
|
||||
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
||||
x = self.final_layer(image_embedding, time_emb)
|
||||
latents = self.unpatchify(x, shapes[0], shapes[1])
|
||||
|
||||
if return_past_key_values:
|
||||
return latents, past_key_values
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_with_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
||||
self.llm.config.use_cache = use_kv_cache
|
||||
model_out, past_key_values = self.forward(x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, past_key_values=past_key_values, return_past_key_values=True, offload_model=offload_model)
|
||||
if use_img_cfg:
|
||||
cond, uncond, img_cond = torch.split(model_out, len(model_out) // 3, dim=0)
|
||||
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
||||
model_out = [cond, cond, cond]
|
||||
else:
|
||||
cond, uncond = torch.split(model_out, len(model_out) // 2, dim=0)
|
||||
cond = uncond + cfg_scale * (cond - uncond)
|
||||
model_out = [cond, cond]
|
||||
|
||||
return torch.cat(model_out, dim=0), past_key_values
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
||||
self.llm.config.use_cache = use_kv_cache
|
||||
if past_key_values is None:
|
||||
past_key_values = [None] * len(attention_mask)
|
||||
|
||||
x = torch.split(x, len(x) // len(attention_mask), dim=0)
|
||||
timestep = timestep.to(x[0].dtype)
|
||||
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
|
||||
|
||||
model_out, pask_key_values = [], []
|
||||
for i in range(len(input_ids)):
|
||||
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
|
||||
model_out.append(temp_out)
|
||||
pask_key_values.append(temp_pask_key_values)
|
||||
|
||||
if len(model_out) == 3:
|
||||
cond, uncond, img_cond = model_out
|
||||
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
||||
model_out = [cond, cond, cond]
|
||||
elif len(model_out) == 2:
|
||||
cond, uncond = model_out
|
||||
cond = uncond + cfg_scale * (cond - uncond)
|
||||
model_out = [cond, cond]
|
||||
else:
|
||||
return model_out[0]
|
||||
|
||||
return torch.cat(model_out, dim=0), pask_key_values
|
||||
|
||||
|
||||
|
||||
class OmniGenTransformer(OmniGenOriginalModel):
|
||||
def __init__(self):
|
||||
config = {
|
||||
"_name_or_path": "Phi-3-vision-128k-instruct",
|
||||
"architectures": [
|
||||
"Phi3ForCausalLM"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 3072,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 8192,
|
||||
"max_position_embeddings": 131072,
|
||||
"model_type": "phi3",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 32,
|
||||
"num_key_value_heads": 32,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_scaling": {
|
||||
"long_factor": [
|
||||
1.0299999713897705,
|
||||
1.0499999523162842,
|
||||
1.0499999523162842,
|
||||
1.0799999237060547,
|
||||
1.2299998998641968,
|
||||
1.2299998998641968,
|
||||
1.2999999523162842,
|
||||
1.4499999284744263,
|
||||
1.5999999046325684,
|
||||
1.6499998569488525,
|
||||
1.8999998569488525,
|
||||
2.859999895095825,
|
||||
3.68999981880188,
|
||||
5.419999599456787,
|
||||
5.489999771118164,
|
||||
5.489999771118164,
|
||||
9.09000015258789,
|
||||
11.579999923706055,
|
||||
15.65999984741211,
|
||||
15.769999504089355,
|
||||
15.789999961853027,
|
||||
18.360000610351562,
|
||||
21.989999771118164,
|
||||
23.079999923706055,
|
||||
30.009998321533203,
|
||||
32.35000228881836,
|
||||
32.590003967285156,
|
||||
35.56000518798828,
|
||||
39.95000457763672,
|
||||
53.840003967285156,
|
||||
56.20000457763672,
|
||||
57.95000457763672,
|
||||
59.29000473022461,
|
||||
59.77000427246094,
|
||||
59.920005798339844,
|
||||
61.190006256103516,
|
||||
61.96000671386719,
|
||||
62.50000762939453,
|
||||
63.3700065612793,
|
||||
63.48000717163086,
|
||||
63.48000717163086,
|
||||
63.66000747680664,
|
||||
63.850006103515625,
|
||||
64.08000946044922,
|
||||
64.760009765625,
|
||||
64.80001068115234,
|
||||
64.81001281738281,
|
||||
64.81001281738281
|
||||
],
|
||||
"short_factor": [
|
||||
1.05,
|
||||
1.05,
|
||||
1.05,
|
||||
1.1,
|
||||
1.1,
|
||||
1.1,
|
||||
1.2500000000000002,
|
||||
1.2500000000000002,
|
||||
1.4000000000000004,
|
||||
1.4500000000000004,
|
||||
1.5500000000000005,
|
||||
1.8500000000000008,
|
||||
1.9000000000000008,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.000000000000001,
|
||||
2.1000000000000005,
|
||||
2.1000000000000005,
|
||||
2.2,
|
||||
2.3499999999999996,
|
||||
2.3499999999999996,
|
||||
2.3499999999999996,
|
||||
2.3499999999999996,
|
||||
2.3999999999999995,
|
||||
2.3999999999999995,
|
||||
2.6499999999999986,
|
||||
2.6999999999999984,
|
||||
2.8999999999999977,
|
||||
2.9499999999999975,
|
||||
3.049999999999997,
|
||||
3.049999999999997,
|
||||
3.049999999999997
|
||||
],
|
||||
"type": "su"
|
||||
},
|
||||
"rope_theta": 10000.0,
|
||||
"sliding_window": 131072,
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.38.1",
|
||||
"use_cache": True,
|
||||
"vocab_size": 32064,
|
||||
"_attn_implementation": "sdpa"
|
||||
}
|
||||
config = Phi3Config(**config)
|
||||
super().__init__(config)
|
||||
|
||||
|
||||
def forward(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, padding_latent=None, past_key_values=None, return_past_key_values=True, offload_model:bool=False):
|
||||
input_is_list = isinstance(x, list)
|
||||
x, num_tokens, shapes = self.patch_multiple_resolutions(x, padding_latent)
|
||||
time_token = self.time_token(timestep, dtype=x[0].dtype).unsqueeze(1)
|
||||
|
||||
if input_img_latents is not None:
|
||||
input_latents, _, _ = self.patch_multiple_resolutions(input_img_latents, is_input_images=True)
|
||||
if input_ids is not None:
|
||||
condition_embeds = self.llm.embed_tokens(input_ids).clone()
|
||||
input_img_inx = 0
|
||||
for b_inx in input_image_sizes.keys():
|
||||
for start_inx, end_inx in input_image_sizes[b_inx]:
|
||||
condition_embeds[b_inx, start_inx: end_inx] = input_latents[input_img_inx]
|
||||
input_img_inx += 1
|
||||
if input_img_latents is not None:
|
||||
assert input_img_inx == len(input_latents)
|
||||
|
||||
input_emb = torch.cat([condition_embeds, time_token, x], dim=1)
|
||||
else:
|
||||
input_emb = torch.cat([time_token, x], dim=1)
|
||||
output = self.llm(inputs_embeds=input_emb, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, offload_model=offload_model)
|
||||
output, past_key_values = output.last_hidden_state, output.past_key_values
|
||||
if input_is_list:
|
||||
image_embedding = output[:, -max(num_tokens):]
|
||||
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
||||
x = self.final_layer(image_embedding, time_emb)
|
||||
latents = []
|
||||
for i in range(x.size(0)):
|
||||
latent = x[i:i+1, :num_tokens[i]]
|
||||
latent = self.unpatchify(latent, shapes[i][0], shapes[i][1])
|
||||
latents.append(latent)
|
||||
else:
|
||||
image_embedding = output[:, -num_tokens:]
|
||||
time_emb = self.t_embedder(timestep, dtype=x.dtype)
|
||||
x = self.final_layer(image_embedding, time_emb)
|
||||
latents = self.unpatchify(x, shapes[0], shapes[1])
|
||||
|
||||
if return_past_key_values:
|
||||
return latents, past_key_values
|
||||
return latents
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_with_separate_cfg(self, x, timestep, input_ids, input_img_latents, input_image_sizes, attention_mask, position_ids, cfg_scale, use_img_cfg, img_cfg_scale, past_key_values, use_kv_cache, offload_model):
|
||||
self.llm.config.use_cache = use_kv_cache
|
||||
if past_key_values is None:
|
||||
past_key_values = [None] * len(attention_mask)
|
||||
|
||||
x = torch.split(x, len(x) // len(attention_mask), dim=0)
|
||||
timestep = timestep.to(x[0].dtype)
|
||||
timestep = torch.split(timestep, len(timestep) // len(input_ids), dim=0)
|
||||
|
||||
model_out, pask_key_values = [], []
|
||||
for i in range(len(input_ids)):
|
||||
temp_out, temp_pask_key_values = self.forward(x[i], timestep[i], input_ids[i], input_img_latents[i], input_image_sizes[i], attention_mask[i], position_ids[i], past_key_values=past_key_values[i], return_past_key_values=True, offload_model=offload_model)
|
||||
model_out.append(temp_out)
|
||||
pask_key_values.append(temp_pask_key_values)
|
||||
|
||||
if len(model_out) == 3:
|
||||
cond, uncond, img_cond = model_out
|
||||
cond = uncond + img_cfg_scale * (img_cond - uncond) + cfg_scale * (cond - img_cond)
|
||||
model_out = [cond, cond, cond]
|
||||
elif len(model_out) == 2:
|
||||
cond, uncond = model_out
|
||||
cond = uncond + cfg_scale * (cond - uncond)
|
||||
model_out = [cond, cond]
|
||||
else:
|
||||
return model_out[0]
|
||||
|
||||
return torch.cat(model_out, dim=0), pask_key_values
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return OmniGenTransformerStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class OmniGenTransformerStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
@@ -1,7 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .sd3_dit import RMSNorm
|
||||
from .utils import hash_state_dict_keys
|
||||
from .general_modules import RMSNorm
|
||||
|
||||
|
||||
class BlockWiseControlBlock(torch.nn.Module):
|
||||
@@ -55,20 +54,3 @@ class QwenImageBlockWiseControlNet(torch.nn.Module):
|
||||
|
||||
def blockwise_forward(self, img, controlnet_conditioning, block_id):
|
||||
return self.controlnet_blocks[block_id](img, controlnet_conditioning)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return QwenImageBlockWiseControlNetStateDictConverter()
|
||||
|
||||
|
||||
class QwenImageBlockWiseControlNetStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
hash_value = hash_state_dict_keys(state_dict)
|
||||
extra_kwargs = {}
|
||||
if hash_value == "a9e54e480a628f0b956a688a81c33bab":
|
||||
# inpaint controlnet
|
||||
extra_kwargs = {"additional_in_dim": 4}
|
||||
return state_dict, extra_kwargs
|
||||
|
||||
@@ -2,8 +2,7 @@ import torch, math
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, Optional, Union, List
|
||||
from einops import rearrange
|
||||
from .sd3_dit import TimestepEmbeddings, RMSNorm
|
||||
from .flux_dit import AdaLayerNorm
|
||||
from .general_modules import TimestepEmbeddings, RMSNorm, AdaLayerNorm
|
||||
|
||||
try:
|
||||
import flash_attn_interface
|
||||
@@ -532,16 +531,3 @@ class QwenImageDiT(torch.nn.Module):
|
||||
|
||||
latents = rearrange(image, "B (H W) (C P Q) -> B C (H P) (W Q)", H=height//16, W=width//16, P=2, Q=2)
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return QwenImageDiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class QwenImageDiTStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from transformers import Qwen2_5_VLModel
|
||||
import torch
|
||||
from typing import Optional, Union
|
||||
|
||||
@@ -6,7 +5,7 @@ from typing import Optional, Union
|
||||
class QwenImageTextEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
from transformers import Qwen2_5_VLConfig
|
||||
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLModel
|
||||
config = Qwen2_5_VLConfig(**{
|
||||
"architectures": [
|
||||
"Qwen2_5_VLForConditionalGeneration"
|
||||
@@ -39,7 +38,7 @@ class QwenImageTextEncoder(torch.nn.Module):
|
||||
"sliding_window": 32768,
|
||||
"text_config": {
|
||||
"architectures": [
|
||||
"Qwen2_5_VLForConditionalGeneration"
|
||||
"Qwen2_5_VLForConditionalGeneration"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
@@ -144,6 +143,7 @@ class QwenImageTextEncoder(torch.nn.Module):
|
||||
})
|
||||
self.model = Qwen2_5_VLModel(config)
|
||||
self.lm_head = torch.nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
|
||||
self.config = config
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -166,51 +166,6 @@ class QwenImageTextEncoder(torch.nn.Module):
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each image in LLM.
|
||||
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
||||
The temporal, height and width of feature shape of each video in LLM.
|
||||
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
||||
The rope index difference between sequence length and multimodal rope.
|
||||
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
|
||||
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from PIL import Image
|
||||
>>> import requests
|
||||
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||
|
||||
>>> messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": "What is shown in this image?"},
|
||||
],
|
||||
},
|
||||
]
|
||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
|
||||
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
||||
```"""
|
||||
|
||||
output_attentions = False
|
||||
output_hidden_states = True
|
||||
|
||||
@@ -233,23 +188,3 @@ class QwenImageTextEncoder(torch.nn.Module):
|
||||
**kwargs,
|
||||
)
|
||||
return outputs.hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return QwenImageTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class QwenImageTextEncoderStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for k, v in state_dict.items():
|
||||
if k.startswith("visual."):
|
||||
k = "model." + k
|
||||
elif k.startswith("model."):
|
||||
k = k.replace("model.", "model.language_model.")
|
||||
state_dict_[k] = v
|
||||
return state_dict_
|
||||
|
||||
@@ -721,16 +721,3 @@ class QwenImageVAE(torch.nn.Module):
|
||||
x = self.decoder(x)
|
||||
x = x.squeeze(2)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return QwenImageVAEStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class QwenImageVAEStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
@@ -1,567 +0,0 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from .svd_unet import TemporalTimesteps
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim, eps, elementwise_affine=True):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = torch.nn.Parameter(torch.ones((dim,)))
|
||||
else:
|
||||
self.weight = None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
if self.weight is not None:
|
||||
hidden_states = hidden_states * self.weight
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class PatchEmbed(torch.nn.Module):
|
||||
def __init__(self, patch_size=2, in_channels=16, embed_dim=1536, pos_embed_max_size=192):
|
||||
super().__init__()
|
||||
self.pos_embed_max_size = pos_embed_max_size
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.proj = torch.nn.Conv2d(in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size)
|
||||
self.pos_embed = torch.nn.Parameter(torch.zeros(1, self.pos_embed_max_size, self.pos_embed_max_size, embed_dim))
|
||||
|
||||
def cropped_pos_embed(self, height, width):
|
||||
height = height // self.patch_size
|
||||
width = width // self.patch_size
|
||||
top = (self.pos_embed_max_size - height) // 2
|
||||
left = (self.pos_embed_max_size - width) // 2
|
||||
spatial_pos_embed = self.pos_embed[:, top : top + height, left : left + width, :].flatten(1, 2)
|
||||
return spatial_pos_embed
|
||||
|
||||
def forward(self, latent):
|
||||
height, width = latent.shape[-2:]
|
||||
latent = self.proj(latent)
|
||||
latent = latent.flatten(2).transpose(1, 2)
|
||||
pos_embed = self.cropped_pos_embed(height, width)
|
||||
return latent + pos_embed
|
||||
|
||||
|
||||
class DiffusersCompatibleTimestepProj(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.linear_1 = torch.nn.Linear(dim_in, dim_out)
|
||||
self.act = torch.nn.SiLU()
|
||||
self.linear_2 = torch.nn.Linear(dim_out, dim_out)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear_1(x)
|
||||
x = self.act(x)
|
||||
x = self.linear_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbeddings(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out, computation_device=None, diffusers_compatible_format=False, scale=1, align_dtype_to_timestep=False):
|
||||
super().__init__()
|
||||
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device, scale=scale, align_dtype_to_timestep=align_dtype_to_timestep)
|
||||
if diffusers_compatible_format:
|
||||
self.timestep_embedder = DiffusersCompatibleTimestepProj(dim_in, dim_out)
|
||||
else:
|
||||
self.timestep_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
time_emb = self.time_proj(timestep).to(dtype)
|
||||
time_emb = self.timestep_embedder(time_emb)
|
||||
return time_emb
|
||||
|
||||
|
||||
|
||||
class AdaLayerNorm(torch.nn.Module):
|
||||
def __init__(self, dim, single=False, dual=False):
|
||||
super().__init__()
|
||||
self.single = single
|
||||
self.dual = dual
|
||||
self.linear = torch.nn.Linear(dim, dim * [[6, 2][single], 9][dual])
|
||||
self.norm = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, emb):
|
||||
emb = self.linear(torch.nn.functional.silu(emb))
|
||||
if self.single:
|
||||
scale, shift = emb.unsqueeze(1).chunk(2, dim=2)
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
elif self.dual:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.unsqueeze(1).chunk(9, dim=2)
|
||||
norm_x = self.norm(x)
|
||||
x = norm_x * (1 + scale_msa) + shift_msa
|
||||
norm_x2 = norm_x * (1 + scale_msa2) + shift_msa2
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_x2, gate_msa2
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.unsqueeze(1).chunk(6, dim=2)
|
||||
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
|
||||
class JointAttention(torch.nn.Module):
|
||||
def __init__(self, dim_a, dim_b, num_heads, head_dim, only_out_a=False, use_rms_norm=False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.only_out_a = only_out_a
|
||||
|
||||
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
||||
self.b_to_qkv = torch.nn.Linear(dim_b, dim_b * 3)
|
||||
|
||||
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
||||
if not only_out_a:
|
||||
self.b_to_out = torch.nn.Linear(dim_b, dim_b)
|
||||
|
||||
if use_rms_norm:
|
||||
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_q_b = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_k_b = RMSNorm(head_dim, eps=1e-6)
|
||||
else:
|
||||
self.norm_q_a = None
|
||||
self.norm_k_a = None
|
||||
self.norm_q_b = None
|
||||
self.norm_k_b = None
|
||||
|
||||
|
||||
def process_qkv(self, hidden_states, to_qkv, norm_q, norm_k):
|
||||
batch_size = hidden_states.shape[0]
|
||||
qkv = to_qkv(hidden_states)
|
||||
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
if norm_q is not None:
|
||||
q = norm_q(q)
|
||||
if norm_k is not None:
|
||||
k = norm_k(k)
|
||||
return q, k, v
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b):
|
||||
batch_size = hidden_states_a.shape[0]
|
||||
|
||||
qa, ka, va = self.process_qkv(hidden_states_a, self.a_to_qkv, self.norm_q_a, self.norm_k_a)
|
||||
qb, kb, vb = self.process_qkv(hidden_states_b, self.b_to_qkv, self.norm_q_b, self.norm_k_b)
|
||||
q = torch.concat([qa, qb], dim=2)
|
||||
k = torch.concat([ka, kb], dim=2)
|
||||
v = torch.concat([va, vb], dim=2)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
hidden_states_a, hidden_states_b = hidden_states[:, :hidden_states_a.shape[1]], hidden_states[:, hidden_states_a.shape[1]:]
|
||||
hidden_states_a = self.a_to_out(hidden_states_a)
|
||||
if self.only_out_a:
|
||||
return hidden_states_a
|
||||
else:
|
||||
hidden_states_b = self.b_to_out(hidden_states_b)
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class SingleAttention(torch.nn.Module):
|
||||
def __init__(self, dim_a, num_heads, head_dim, use_rms_norm=False):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.a_to_qkv = torch.nn.Linear(dim_a, dim_a * 3)
|
||||
self.a_to_out = torch.nn.Linear(dim_a, dim_a)
|
||||
|
||||
if use_rms_norm:
|
||||
self.norm_q_a = RMSNorm(head_dim, eps=1e-6)
|
||||
self.norm_k_a = RMSNorm(head_dim, eps=1e-6)
|
||||
else:
|
||||
self.norm_q_a = None
|
||||
self.norm_k_a = None
|
||||
|
||||
|
||||
def process_qkv(self, hidden_states, to_qkv, norm_q, norm_k):
|
||||
batch_size = hidden_states.shape[0]
|
||||
qkv = to_qkv(hidden_states)
|
||||
qkv = qkv.view(batch_size, -1, 3 * self.num_heads, self.head_dim).transpose(1, 2)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
if norm_q is not None:
|
||||
q = norm_q(q)
|
||||
if norm_k is not None:
|
||||
k = norm_k(k)
|
||||
return q, k, v
|
||||
|
||||
|
||||
def forward(self, hidden_states_a):
|
||||
batch_size = hidden_states_a.shape[0]
|
||||
q, k, v = self.process_qkv(hidden_states_a, self.a_to_qkv, self.norm_q_a, self.norm_k_a)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
hidden_states = self.a_to_out(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class DualTransformerBlock(torch.nn.Module):
|
||||
def __init__(self, dim, num_attention_heads, use_rms_norm=False):
|
||||
super().__init__()
|
||||
self.norm1_a = AdaLayerNorm(dim, dual=True)
|
||||
self.norm1_b = AdaLayerNorm(dim)
|
||||
|
||||
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
||||
self.attn2 = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
||||
|
||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_a = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_b = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb):
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a, norm_hidden_states_a_2, gate_msa_a_2 = self.norm1_a(hidden_states_a, emb=temb)
|
||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||
|
||||
# Attention
|
||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
||||
|
||||
# Part A
|
||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||
hidden_states_a = hidden_states_a + gate_msa_a_2 * self.attn2(norm_hidden_states_a_2)
|
||||
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
||||
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
||||
|
||||
# Part B
|
||||
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
||||
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
||||
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class JointTransformerBlock(torch.nn.Module):
|
||||
def __init__(self, dim, num_attention_heads, use_rms_norm=False, dual=False):
|
||||
super().__init__()
|
||||
self.norm1_a = AdaLayerNorm(dim, dual=dual)
|
||||
self.norm1_b = AdaLayerNorm(dim)
|
||||
|
||||
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
||||
if dual:
|
||||
self.attn2 = SingleAttention(dim, num_attention_heads, dim // num_attention_heads, use_rms_norm=use_rms_norm)
|
||||
|
||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_a = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
self.norm2_b = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_b = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb):
|
||||
if self.norm1_a.dual:
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a, norm_hidden_states_a_2, gate_msa_a_2 = self.norm1_a(hidden_states_a, emb=temb)
|
||||
else:
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||
norm_hidden_states_b, gate_msa_b, shift_mlp_b, scale_mlp_b, gate_mlp_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||
|
||||
# Attention
|
||||
attn_output_a, attn_output_b = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
||||
|
||||
# Part A
|
||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||
if self.norm1_a.dual:
|
||||
hidden_states_a = hidden_states_a + gate_msa_a_2 * self.attn2(norm_hidden_states_a_2)
|
||||
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
||||
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
||||
|
||||
# Part B
|
||||
hidden_states_b = hidden_states_b + gate_msa_b * attn_output_b
|
||||
norm_hidden_states_b = self.norm2_b(hidden_states_b) * (1 + scale_mlp_b) + shift_mlp_b
|
||||
hidden_states_b = hidden_states_b + gate_mlp_b * self.ff_b(norm_hidden_states_b)
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class JointTransformerFinalBlock(torch.nn.Module):
|
||||
def __init__(self, dim, num_attention_heads, use_rms_norm=False):
|
||||
super().__init__()
|
||||
self.norm1_a = AdaLayerNorm(dim)
|
||||
self.norm1_b = AdaLayerNorm(dim, single=True)
|
||||
|
||||
self.attn = JointAttention(dim, dim, num_attention_heads, dim // num_attention_heads, only_out_a=True, use_rms_norm=use_rms_norm)
|
||||
|
||||
self.norm2_a = torch.nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_a = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*4),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(dim*4, dim)
|
||||
)
|
||||
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, temb):
|
||||
norm_hidden_states_a, gate_msa_a, shift_mlp_a, scale_mlp_a, gate_mlp_a = self.norm1_a(hidden_states_a, emb=temb)
|
||||
norm_hidden_states_b = self.norm1_b(hidden_states_b, emb=temb)
|
||||
|
||||
# Attention
|
||||
attn_output_a = self.attn(norm_hidden_states_a, norm_hidden_states_b)
|
||||
|
||||
# Part A
|
||||
hidden_states_a = hidden_states_a + gate_msa_a * attn_output_a
|
||||
norm_hidden_states_a = self.norm2_a(hidden_states_a) * (1 + scale_mlp_a) + shift_mlp_a
|
||||
hidden_states_a = hidden_states_a + gate_mlp_a * self.ff_a(norm_hidden_states_a)
|
||||
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
|
||||
class SD3DiT(torch.nn.Module):
|
||||
def __init__(self, embed_dim=1536, num_layers=24, use_rms_norm=False, num_dual_blocks=0, pos_embed_max_size=192):
|
||||
super().__init__()
|
||||
self.pos_embedder = PatchEmbed(patch_size=2, in_channels=16, embed_dim=embed_dim, pos_embed_max_size=pos_embed_max_size)
|
||||
self.time_embedder = TimestepEmbeddings(256, embed_dim)
|
||||
self.pooled_text_embedder = torch.nn.Sequential(torch.nn.Linear(2048, embed_dim), torch.nn.SiLU(), torch.nn.Linear(embed_dim, embed_dim))
|
||||
self.context_embedder = torch.nn.Linear(4096, embed_dim)
|
||||
self.blocks = torch.nn.ModuleList([JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm, dual=True) for _ in range(num_dual_blocks)]
|
||||
+ [JointTransformerBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm) for _ in range(num_layers-1-num_dual_blocks)]
|
||||
+ [JointTransformerFinalBlock(embed_dim, embed_dim//64, use_rms_norm=use_rms_norm)])
|
||||
self.norm_out = AdaLayerNorm(embed_dim, single=True)
|
||||
self.proj_out = torch.nn.Linear(embed_dim, 64)
|
||||
|
||||
def tiled_forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size=128, tile_stride=64):
|
||||
# Due to the global positional embedding, we cannot implement layer-wise tiled forward.
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x, timestep, prompt_emb, pooled_prompt_emb),
|
||||
hidden_states,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states, timestep, prompt_emb, pooled_prompt_emb, tiled=False, tile_size=128, tile_stride=64, use_gradient_checkpointing=False):
|
||||
if tiled:
|
||||
return self.tiled_forward(hidden_states, timestep, prompt_emb, pooled_prompt_emb, tile_size, tile_stride)
|
||||
conditioning = self.time_embedder(timestep, hidden_states.dtype) + self.pooled_text_embedder(pooled_prompt_emb)
|
||||
prompt_emb = self.context_embedder(prompt_emb)
|
||||
|
||||
height, width = hidden_states.shape[-2:]
|
||||
hidden_states = self.pos_embedder(hidden_states)
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
|
||||
for block in self.blocks:
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, prompt_emb, conditioning,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, conditioning)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = rearrange(hidden_states, "B (H W) (P Q C) -> B C (H P) (W Q)", P=2, Q=2, H=height//2, W=width//2)
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SD3DiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class SD3DiTStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def infer_architecture(self, state_dict):
|
||||
embed_dim = state_dict["blocks.0.ff_a.0.weight"].shape[1]
|
||||
num_layers = 100
|
||||
while num_layers > 0 and f"blocks.{num_layers-1}.ff_a.0.bias" not in state_dict:
|
||||
num_layers -= 1
|
||||
use_rms_norm = "blocks.0.attn.norm_q_a.weight" in state_dict
|
||||
num_dual_blocks = 0
|
||||
while f"blocks.{num_dual_blocks}.attn2.a_to_out.bias" in state_dict:
|
||||
num_dual_blocks += 1
|
||||
pos_embed_max_size = state_dict["pos_embedder.pos_embed"].shape[1]
|
||||
return {
|
||||
"embed_dim": embed_dim,
|
||||
"num_layers": num_layers,
|
||||
"use_rms_norm": use_rms_norm,
|
||||
"num_dual_blocks": num_dual_blocks,
|
||||
"pos_embed_max_size": pos_embed_max_size
|
||||
}
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"context_embedder": "context_embedder",
|
||||
"pos_embed.pos_embed": "pos_embedder.pos_embed",
|
||||
"pos_embed.proj": "pos_embedder.proj",
|
||||
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
||||
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
||||
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
||||
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
||||
"norm_out.linear": "norm_out.linear",
|
||||
"proj_out": "proj_out",
|
||||
|
||||
"norm1.linear": "norm1_a.linear",
|
||||
"norm1_context.linear": "norm1_b.linear",
|
||||
"attn.to_q": "attn.a_to_q",
|
||||
"attn.to_k": "attn.a_to_k",
|
||||
"attn.to_v": "attn.a_to_v",
|
||||
"attn.to_out.0": "attn.a_to_out",
|
||||
"attn.add_q_proj": "attn.b_to_q",
|
||||
"attn.add_k_proj": "attn.b_to_k",
|
||||
"attn.add_v_proj": "attn.b_to_v",
|
||||
"attn.to_add_out": "attn.b_to_out",
|
||||
"ff.net.0.proj": "ff_a.0",
|
||||
"ff.net.2": "ff_a.2",
|
||||
"ff_context.net.0.proj": "ff_b.0",
|
||||
"ff_context.net.2": "ff_b.2",
|
||||
|
||||
"attn.norm_q": "attn.norm_q_a",
|
||||
"attn.norm_k": "attn.norm_k_a",
|
||||
"attn.norm_added_q": "attn.norm_q_b",
|
||||
"attn.norm_added_k": "attn.norm_k_b",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
if name == "pos_embed.pos_embed":
|
||||
param = param.reshape((1, 192, 192, param.shape[-1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif name.endswith(".weight") or name.endswith(".bias"):
|
||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||
prefix = name[:-len(suffix)]
|
||||
if prefix in rename_dict:
|
||||
state_dict_[rename_dict[prefix] + suffix] = param
|
||||
elif prefix.startswith("transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict:
|
||||
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
merged_keys = [name for name in state_dict_ if ".a_to_q." in name or ".b_to_q." in name]
|
||||
for key in merged_keys:
|
||||
param = torch.concat([
|
||||
state_dict_[key.replace("to_q", "to_q")],
|
||||
state_dict_[key.replace("to_q", "to_k")],
|
||||
state_dict_[key.replace("to_q", "to_v")],
|
||||
], dim=0)
|
||||
name = key.replace("to_q", "to_qkv")
|
||||
state_dict_.pop(key.replace("to_q", "to_q"))
|
||||
state_dict_.pop(key.replace("to_q", "to_k"))
|
||||
state_dict_.pop(key.replace("to_q", "to_v"))
|
||||
state_dict_[name] = param
|
||||
return state_dict_, self.infer_architecture(state_dict_)
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"model.diffusion_model.context_embedder.bias": "context_embedder.bias",
|
||||
"model.diffusion_model.context_embedder.weight": "context_embedder.weight",
|
||||
"model.diffusion_model.final_layer.linear.bias": "proj_out.bias",
|
||||
"model.diffusion_model.final_layer.linear.weight": "proj_out.weight",
|
||||
|
||||
"model.diffusion_model.pos_embed": "pos_embedder.pos_embed",
|
||||
"model.diffusion_model.t_embedder.mlp.0.bias": "time_embedder.timestep_embedder.0.bias",
|
||||
"model.diffusion_model.t_embedder.mlp.0.weight": "time_embedder.timestep_embedder.0.weight",
|
||||
"model.diffusion_model.t_embedder.mlp.2.bias": "time_embedder.timestep_embedder.2.bias",
|
||||
"model.diffusion_model.t_embedder.mlp.2.weight": "time_embedder.timestep_embedder.2.weight",
|
||||
"model.diffusion_model.x_embedder.proj.bias": "pos_embedder.proj.bias",
|
||||
"model.diffusion_model.x_embedder.proj.weight": "pos_embedder.proj.weight",
|
||||
"model.diffusion_model.y_embedder.mlp.0.bias": "pooled_text_embedder.0.bias",
|
||||
"model.diffusion_model.y_embedder.mlp.0.weight": "pooled_text_embedder.0.weight",
|
||||
"model.diffusion_model.y_embedder.mlp.2.bias": "pooled_text_embedder.2.bias",
|
||||
"model.diffusion_model.y_embedder.mlp.2.weight": "pooled_text_embedder.2.weight",
|
||||
|
||||
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.weight": "blocks.23.norm1_b.linear.weight",
|
||||
"model.diffusion_model.joint_blocks.23.context_block.adaLN_modulation.1.bias": "blocks.23.norm1_b.linear.bias",
|
||||
"model.diffusion_model.final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight",
|
||||
"model.diffusion_model.final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias",
|
||||
}
|
||||
for i in range(40):
|
||||
rename_dict.update({
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.adaLN_modulation.1.bias": f"blocks.{i}.norm1_b.linear.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.adaLN_modulation.1.weight": f"blocks.{i}.norm1_b.linear.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.proj.bias": f"blocks.{i}.attn.b_to_out.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.proj.weight": f"blocks.{i}.attn.b_to_out.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.qkv.bias": [f'blocks.{i}.attn.b_to_q.bias', f'blocks.{i}.attn.b_to_k.bias', f'blocks.{i}.attn.b_to_v.bias'],
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.qkv.weight": [f'blocks.{i}.attn.b_to_q.weight', f'blocks.{i}.attn.b_to_k.weight', f'blocks.{i}.attn.b_to_v.weight'],
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc1.bias": f"blocks.{i}.ff_b.0.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc1.weight": f"blocks.{i}.ff_b.0.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc2.bias": f"blocks.{i}.ff_b.2.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.mlp.fc2.weight": f"blocks.{i}.ff_b.2.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.adaLN_modulation.1.bias": f"blocks.{i}.norm1_a.linear.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.adaLN_modulation.1.weight": f"blocks.{i}.norm1_a.linear.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.proj.bias": f"blocks.{i}.attn.a_to_out.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.proj.weight": f"blocks.{i}.attn.a_to_out.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.qkv.bias": [f'blocks.{i}.attn.a_to_q.bias', f'blocks.{i}.attn.a_to_k.bias', f'blocks.{i}.attn.a_to_v.bias'],
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.qkv.weight": [f'blocks.{i}.attn.a_to_q.weight', f'blocks.{i}.attn.a_to_k.weight', f'blocks.{i}.attn.a_to_v.weight'],
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc1.bias": f"blocks.{i}.ff_a.0.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc1.weight": f"blocks.{i}.ff_a.0.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc2.bias": f"blocks.{i}.ff_a.2.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.mlp.fc2.weight": f"blocks.{i}.ff_a.2.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.ln_q.weight": f"blocks.{i}.attn.norm_q_a.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_a.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_q.weight": f"blocks.{i}.attn.norm_q_b.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.context_block.attn.ln_k.weight": f"blocks.{i}.attn.norm_k_b.weight",
|
||||
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.ln_q.weight": f"blocks.{i}.attn2.norm_q_a.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.ln_k.weight": f"blocks.{i}.attn2.norm_k_a.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.qkv.weight": f"blocks.{i}.attn2.a_to_qkv.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.qkv.bias": f"blocks.{i}.attn2.a_to_qkv.bias",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.proj.weight": f"blocks.{i}.attn2.a_to_out.weight",
|
||||
f"model.diffusion_model.joint_blocks.{i}.x_block.attn2.proj.bias": f"blocks.{i}.attn2.a_to_out.bias",
|
||||
})
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "model.diffusion_model.pos_embed":
|
||||
pos_embed_max_size = int(param.shape[1] ** 0.5 + 0.4)
|
||||
param = param.reshape((1, pos_embed_max_size, pos_embed_max_size, param.shape[-1]))
|
||||
if isinstance(rename_dict[name], str):
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
name_ = rename_dict[name][0].replace(".a_to_q.", ".a_to_qkv.").replace(".b_to_q.", ".b_to_qkv.")
|
||||
state_dict_[name_] = param
|
||||
extra_kwargs = self.infer_architecture(state_dict_)
|
||||
num_layers = extra_kwargs["num_layers"]
|
||||
for name in [
|
||||
f"blocks.{num_layers-1}.norm1_b.linear.weight", f"blocks.{num_layers-1}.norm1_b.linear.bias", "norm_out.linear.weight", "norm_out.linear.bias",
|
||||
]:
|
||||
param = state_dict_[name]
|
||||
dim = param.shape[0] // 2
|
||||
param = torch.concat([param[dim:], param[:dim]], axis=0)
|
||||
state_dict_[name] = param
|
||||
return state_dict_, self.infer_architecture(state_dict_)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,81 +0,0 @@
|
||||
import torch
|
||||
from .sd_vae_decoder import VAEAttentionBlock, SDVAEDecoderStateDictConverter
|
||||
from .sd_unet import ResnetBlock, UpSampler
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
|
||||
class SD3VAEDecoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 1.5305 # Different from SD 1.x
|
||||
self.shift_factor = 0.0609 # Different from SD 1.x
|
||||
self.conv_in = torch.nn.Conv2d(16, 512, kernel_size=3, padding=1) # Different from SD 1.x
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# UNetMidBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
UpSampler(512),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
UpSampler(512),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
UpSampler(256),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(256, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
])
|
||||
|
||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-6)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=sample.device,
|
||||
tile_dtype=sample.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# 1. pre-process
|
||||
hidden_states = sample / self.scaling_factor + self.shift_factor
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
time_emb = None
|
||||
text_emb = None
|
||||
res_stack = None
|
||||
|
||||
# 2. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 3. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDVAEDecoderStateDictConverter()
|
||||
@@ -1,95 +0,0 @@
|
||||
import torch
|
||||
from .sd_unet import ResnetBlock, DownSampler
|
||||
from .sd_vae_encoder import VAEAttentionBlock, SDVAEEncoderStateDictConverter
|
||||
from .tiler import TileWorker
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class SD3VAEEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 1.5305 # Different from SD 1.x
|
||||
self.shift_factor = 0.0609 # Different from SD 1.x
|
||||
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
DownSampler(128, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(128, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
DownSampler(256, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(256, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
DownSampler(512, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
# UNetMidBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
])
|
||||
|
||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(512, 32, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=sample.device,
|
||||
tile_dtype=sample.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# 1. pre-process
|
||||
hidden_states = self.conv_in(sample)
|
||||
time_emb = None
|
||||
text_emb = None
|
||||
res_stack = None
|
||||
|
||||
# 2. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 3. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = hidden_states[:, :16]
|
||||
hidden_states = (hidden_states - self.shift_factor) * self.scaling_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
def encode_video(self, sample, batch_size=8):
|
||||
B = sample.shape[0]
|
||||
hidden_states = []
|
||||
|
||||
for i in range(0, sample.shape[2], batch_size):
|
||||
|
||||
j = min(i + batch_size, sample.shape[2])
|
||||
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
|
||||
|
||||
hidden_states_batch = self(sample_batch)
|
||||
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
|
||||
|
||||
hidden_states.append(hidden_states_batch)
|
||||
|
||||
hidden_states = torch.concat(hidden_states, dim=2)
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDVAEEncoderStateDictConverter()
|
||||
@@ -1,589 +0,0 @@
|
||||
import torch
|
||||
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
class ControlNetConditioningLayer(torch.nn.Module):
|
||||
def __init__(self, channels = (3, 16, 32, 96, 256, 320)):
|
||||
super().__init__()
|
||||
self.blocks = torch.nn.ModuleList([])
|
||||
self.blocks.append(torch.nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1))
|
||||
self.blocks.append(torch.nn.SiLU())
|
||||
for i in range(1, len(channels) - 2):
|
||||
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i], kernel_size=3, padding=1))
|
||||
self.blocks.append(torch.nn.SiLU())
|
||||
self.blocks.append(torch.nn.Conv2d(channels[i], channels[i+1], kernel_size=3, padding=1, stride=2))
|
||||
self.blocks.append(torch.nn.SiLU())
|
||||
self.blocks.append(torch.nn.Conv2d(channels[-2], channels[-1], kernel_size=3, padding=1))
|
||||
|
||||
def forward(self, conditioning):
|
||||
for block in self.blocks:
|
||||
conditioning = block(conditioning)
|
||||
return conditioning
|
||||
|
||||
|
||||
class SDControlNet(torch.nn.Module):
|
||||
def __init__(self, global_pool=False):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(320)
|
||||
self.time_embedding = torch.nn.Sequential(
|
||||
torch.nn.Linear(320, 1280),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(1280, 1280)
|
||||
)
|
||||
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
||||
|
||||
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# CrossAttnDownBlock2D
|
||||
ResnetBlock(320, 320, 1280),
|
||||
AttentionBlock(8, 40, 320, 1, 768),
|
||||
PushBlock(),
|
||||
ResnetBlock(320, 320, 1280),
|
||||
AttentionBlock(8, 40, 320, 1, 768),
|
||||
PushBlock(),
|
||||
DownSampler(320),
|
||||
PushBlock(),
|
||||
# CrossAttnDownBlock2D
|
||||
ResnetBlock(320, 640, 1280),
|
||||
AttentionBlock(8, 80, 640, 1, 768),
|
||||
PushBlock(),
|
||||
ResnetBlock(640, 640, 1280),
|
||||
AttentionBlock(8, 80, 640, 1, 768),
|
||||
PushBlock(),
|
||||
DownSampler(640),
|
||||
PushBlock(),
|
||||
# CrossAttnDownBlock2D
|
||||
ResnetBlock(640, 1280, 1280),
|
||||
AttentionBlock(8, 160, 1280, 1, 768),
|
||||
PushBlock(),
|
||||
ResnetBlock(1280, 1280, 1280),
|
||||
AttentionBlock(8, 160, 1280, 1, 768),
|
||||
PushBlock(),
|
||||
DownSampler(1280),
|
||||
PushBlock(),
|
||||
# DownBlock2D
|
||||
ResnetBlock(1280, 1280, 1280),
|
||||
PushBlock(),
|
||||
ResnetBlock(1280, 1280, 1280),
|
||||
PushBlock(),
|
||||
# UNetMidBlock2DCrossAttn
|
||||
ResnetBlock(1280, 1280, 1280),
|
||||
AttentionBlock(8, 160, 1280, 1, 768),
|
||||
ResnetBlock(1280, 1280, 1280),
|
||||
PushBlock()
|
||||
])
|
||||
|
||||
self.controlnet_blocks = torch.nn.ModuleList([
|
||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1), bias=False),
|
||||
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
|
||||
torch.nn.Conv2d(640, 640, kernel_size=(1, 1), bias=False),
|
||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1), bias=False),
|
||||
])
|
||||
|
||||
self.global_pool = global_pool
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample, timestep, encoder_hidden_states, conditioning,
|
||||
tiled=False, tile_size=64, tile_stride=32,
|
||||
**kwargs
|
||||
):
|
||||
# 1. time
|
||||
time_emb = self.time_proj(timestep).to(sample.dtype)
|
||||
time_emb = self.time_embedding(time_emb)
|
||||
time_emb = time_emb.repeat(sample.shape[0], 1)
|
||||
|
||||
# 2. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning)
|
||||
text_emb = encoder_hidden_states
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 3. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
if tiled and not isinstance(block, PushBlock):
|
||||
_, _, inter_height, _ = hidden_states.shape
|
||||
resize_scale = inter_height / height
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: block(x, time_emb, text_emb, res_stack)[0],
|
||||
hidden_states,
|
||||
int(tile_size * resize_scale),
|
||||
int(tile_stride * resize_scale),
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 4. ControlNet blocks
|
||||
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
|
||||
|
||||
# pool
|
||||
if self.global_pool:
|
||||
controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
|
||||
|
||||
return controlnet_res_stack
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDControlNetStateDictConverter()
|
||||
|
||||
|
||||
class SDControlNetStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
# architecture
|
||||
block_types = [
|
||||
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
||||
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
||||
'ResnetBlock', 'AttentionBlock', 'PushBlock', 'ResnetBlock', 'AttentionBlock', 'PushBlock', 'DownSampler', 'PushBlock',
|
||||
'ResnetBlock', 'PushBlock', 'ResnetBlock', 'PushBlock',
|
||||
'ResnetBlock', 'AttentionBlock', 'ResnetBlock',
|
||||
'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'PopBlock', 'ResnetBlock', 'UpSampler',
|
||||
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
|
||||
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'UpSampler',
|
||||
'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock', 'PopBlock', 'ResnetBlock', 'AttentionBlock'
|
||||
]
|
||||
|
||||
# controlnet_rename_dict
|
||||
controlnet_rename_dict = {
|
||||
"controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
|
||||
"controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
|
||||
"controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
|
||||
"controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
|
||||
"controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
|
||||
"controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
|
||||
"controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
|
||||
"controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
|
||||
"controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
|
||||
"controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
|
||||
"controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
|
||||
"controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
|
||||
"controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
|
||||
"controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
|
||||
"controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
|
||||
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
|
||||
}
|
||||
|
||||
# Rename each parameter
|
||||
name_list = sorted([name for name in state_dict])
|
||||
rename_dict = {}
|
||||
block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
||||
last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
|
||||
for name in name_list:
|
||||
names = name.split(".")
|
||||
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
|
||||
pass
|
||||
elif name in controlnet_rename_dict:
|
||||
names = controlnet_rename_dict[name].split(".")
|
||||
elif names[0] == "controlnet_down_blocks":
|
||||
names[0] = "controlnet_blocks"
|
||||
elif names[0] == "controlnet_mid_block":
|
||||
names = ["controlnet_blocks", "12", names[-1]]
|
||||
elif names[0] in ["time_embedding", "add_embedding"]:
|
||||
if names[0] == "add_embedding":
|
||||
names[0] = "add_time_embedding"
|
||||
names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
|
||||
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
|
||||
if names[0] == "mid_block":
|
||||
names.insert(1, "0")
|
||||
block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
|
||||
block_type_with_id = ".".join(names[:4])
|
||||
if block_type_with_id != last_block_type_with_id[block_type]:
|
||||
block_id[block_type] += 1
|
||||
last_block_type_with_id[block_type] = block_type_with_id
|
||||
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
||||
block_id[block_type] += 1
|
||||
block_type_with_id = ".".join(names[:4])
|
||||
names = ["blocks", str(block_id[block_type])] + names[4:]
|
||||
if "ff" in names:
|
||||
ff_index = names.index("ff")
|
||||
component = ".".join(names[ff_index:ff_index+3])
|
||||
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
|
||||
names = names[:ff_index] + [component] + names[ff_index+3:]
|
||||
if "to_out" in names:
|
||||
names.pop(names.index("to_out") + 1)
|
||||
else:
|
||||
raise ValueError(f"Unknown parameters: {name}")
|
||||
rename_dict[name] = ".".join(names)
|
||||
|
||||
# Convert state_dict
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if ".proj_in." in name or ".proj_out." in name:
|
||||
param = param.squeeze()
|
||||
if rename_dict[name] in [
|
||||
"controlnet_blocks.1.bias", "controlnet_blocks.2.bias", "controlnet_blocks.3.bias", "controlnet_blocks.5.bias", "controlnet_blocks.6.bias",
|
||||
"controlnet_blocks.8.bias", "controlnet_blocks.9.bias", "controlnet_blocks.10.bias", "controlnet_blocks.11.bias", "controlnet_blocks.12.bias"
|
||||
]:
|
||||
continue
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
if "mid_block.resnets.1.time_emb_proj.weight" in state_dict:
|
||||
# For controlnets in diffusers format
|
||||
return self.from_diffusers(state_dict)
|
||||
rename_dict = {
|
||||
"control_model.time_embed.0.weight": "time_embedding.0.weight",
|
||||
"control_model.time_embed.0.bias": "time_embedding.0.bias",
|
||||
"control_model.time_embed.2.weight": "time_embedding.2.weight",
|
||||
"control_model.time_embed.2.bias": "time_embedding.2.bias",
|
||||
"control_model.input_blocks.0.0.weight": "conv_in.weight",
|
||||
"control_model.input_blocks.0.0.bias": "conv_in.bias",
|
||||
"control_model.input_blocks.1.0.in_layers.0.weight": "blocks.0.norm1.weight",
|
||||
"control_model.input_blocks.1.0.in_layers.0.bias": "blocks.0.norm1.bias",
|
||||
"control_model.input_blocks.1.0.in_layers.2.weight": "blocks.0.conv1.weight",
|
||||
"control_model.input_blocks.1.0.in_layers.2.bias": "blocks.0.conv1.bias",
|
||||
"control_model.input_blocks.1.0.emb_layers.1.weight": "blocks.0.time_emb_proj.weight",
|
||||
"control_model.input_blocks.1.0.emb_layers.1.bias": "blocks.0.time_emb_proj.bias",
|
||||
"control_model.input_blocks.1.0.out_layers.0.weight": "blocks.0.norm2.weight",
|
||||
"control_model.input_blocks.1.0.out_layers.0.bias": "blocks.0.norm2.bias",
|
||||
"control_model.input_blocks.1.0.out_layers.3.weight": "blocks.0.conv2.weight",
|
||||
"control_model.input_blocks.1.0.out_layers.3.bias": "blocks.0.conv2.bias",
|
||||
"control_model.input_blocks.1.1.norm.weight": "blocks.1.norm.weight",
|
||||
"control_model.input_blocks.1.1.norm.bias": "blocks.1.norm.bias",
|
||||
"control_model.input_blocks.1.1.proj_in.weight": "blocks.1.proj_in.weight",
|
||||
"control_model.input_blocks.1.1.proj_in.bias": "blocks.1.proj_in.bias",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_q.weight": "blocks.1.transformer_blocks.0.attn1.to_q.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_k.weight": "blocks.1.transformer_blocks.0.attn1.to_k.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_v.weight": "blocks.1.transformer_blocks.0.attn1.to_v.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.1.transformer_blocks.0.attn1.to_out.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.1.transformer_blocks.0.attn1.to_out.bias",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.1.transformer_blocks.0.act_fn.proj.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.1.transformer_blocks.0.act_fn.proj.bias",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.weight": "blocks.1.transformer_blocks.0.ff.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.ff.net.2.bias": "blocks.1.transformer_blocks.0.ff.bias",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_q.weight": "blocks.1.transformer_blocks.0.attn2.to_q.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_k.weight": "blocks.1.transformer_blocks.0.attn2.to_k.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_v.weight": "blocks.1.transformer_blocks.0.attn2.to_v.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.1.transformer_blocks.0.attn2.to_out.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.1.transformer_blocks.0.attn2.to_out.bias",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.norm1.weight": "blocks.1.transformer_blocks.0.norm1.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.norm1.bias": "blocks.1.transformer_blocks.0.norm1.bias",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.norm2.weight": "blocks.1.transformer_blocks.0.norm2.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.norm2.bias": "blocks.1.transformer_blocks.0.norm2.bias",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.norm3.weight": "blocks.1.transformer_blocks.0.norm3.weight",
|
||||
"control_model.input_blocks.1.1.transformer_blocks.0.norm3.bias": "blocks.1.transformer_blocks.0.norm3.bias",
|
||||
"control_model.input_blocks.1.1.proj_out.weight": "blocks.1.proj_out.weight",
|
||||
"control_model.input_blocks.1.1.proj_out.bias": "blocks.1.proj_out.bias",
|
||||
"control_model.input_blocks.2.0.in_layers.0.weight": "blocks.3.norm1.weight",
|
||||
"control_model.input_blocks.2.0.in_layers.0.bias": "blocks.3.norm1.bias",
|
||||
"control_model.input_blocks.2.0.in_layers.2.weight": "blocks.3.conv1.weight",
|
||||
"control_model.input_blocks.2.0.in_layers.2.bias": "blocks.3.conv1.bias",
|
||||
"control_model.input_blocks.2.0.emb_layers.1.weight": "blocks.3.time_emb_proj.weight",
|
||||
"control_model.input_blocks.2.0.emb_layers.1.bias": "blocks.3.time_emb_proj.bias",
|
||||
"control_model.input_blocks.2.0.out_layers.0.weight": "blocks.3.norm2.weight",
|
||||
"control_model.input_blocks.2.0.out_layers.0.bias": "blocks.3.norm2.bias",
|
||||
"control_model.input_blocks.2.0.out_layers.3.weight": "blocks.3.conv2.weight",
|
||||
"control_model.input_blocks.2.0.out_layers.3.bias": "blocks.3.conv2.bias",
|
||||
"control_model.input_blocks.2.1.norm.weight": "blocks.4.norm.weight",
|
||||
"control_model.input_blocks.2.1.norm.bias": "blocks.4.norm.bias",
|
||||
"control_model.input_blocks.2.1.proj_in.weight": "blocks.4.proj_in.weight",
|
||||
"control_model.input_blocks.2.1.proj_in.bias": "blocks.4.proj_in.bias",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_q.weight": "blocks.4.transformer_blocks.0.attn1.to_q.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_k.weight": "blocks.4.transformer_blocks.0.attn1.to_k.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_v.weight": "blocks.4.transformer_blocks.0.attn1.to_v.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.4.transformer_blocks.0.attn1.to_out.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.4.transformer_blocks.0.attn1.to_out.bias",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.4.transformer_blocks.0.act_fn.proj.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.4.transformer_blocks.0.act_fn.proj.bias",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.weight": "blocks.4.transformer_blocks.0.ff.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.ff.net.2.bias": "blocks.4.transformer_blocks.0.ff.bias",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_q.weight": "blocks.4.transformer_blocks.0.attn2.to_q.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight": "blocks.4.transformer_blocks.0.attn2.to_k.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_v.weight": "blocks.4.transformer_blocks.0.attn2.to_v.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.4.transformer_blocks.0.attn2.to_out.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.4.transformer_blocks.0.attn2.to_out.bias",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.norm1.weight": "blocks.4.transformer_blocks.0.norm1.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.norm1.bias": "blocks.4.transformer_blocks.0.norm1.bias",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.norm2.weight": "blocks.4.transformer_blocks.0.norm2.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.norm2.bias": "blocks.4.transformer_blocks.0.norm2.bias",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.norm3.weight": "blocks.4.transformer_blocks.0.norm3.weight",
|
||||
"control_model.input_blocks.2.1.transformer_blocks.0.norm3.bias": "blocks.4.transformer_blocks.0.norm3.bias",
|
||||
"control_model.input_blocks.2.1.proj_out.weight": "blocks.4.proj_out.weight",
|
||||
"control_model.input_blocks.2.1.proj_out.bias": "blocks.4.proj_out.bias",
|
||||
"control_model.input_blocks.3.0.op.weight": "blocks.6.conv.weight",
|
||||
"control_model.input_blocks.3.0.op.bias": "blocks.6.conv.bias",
|
||||
"control_model.input_blocks.4.0.in_layers.0.weight": "blocks.8.norm1.weight",
|
||||
"control_model.input_blocks.4.0.in_layers.0.bias": "blocks.8.norm1.bias",
|
||||
"control_model.input_blocks.4.0.in_layers.2.weight": "blocks.8.conv1.weight",
|
||||
"control_model.input_blocks.4.0.in_layers.2.bias": "blocks.8.conv1.bias",
|
||||
"control_model.input_blocks.4.0.emb_layers.1.weight": "blocks.8.time_emb_proj.weight",
|
||||
"control_model.input_blocks.4.0.emb_layers.1.bias": "blocks.8.time_emb_proj.bias",
|
||||
"control_model.input_blocks.4.0.out_layers.0.weight": "blocks.8.norm2.weight",
|
||||
"control_model.input_blocks.4.0.out_layers.0.bias": "blocks.8.norm2.bias",
|
||||
"control_model.input_blocks.4.0.out_layers.3.weight": "blocks.8.conv2.weight",
|
||||
"control_model.input_blocks.4.0.out_layers.3.bias": "blocks.8.conv2.bias",
|
||||
"control_model.input_blocks.4.0.skip_connection.weight": "blocks.8.conv_shortcut.weight",
|
||||
"control_model.input_blocks.4.0.skip_connection.bias": "blocks.8.conv_shortcut.bias",
|
||||
"control_model.input_blocks.4.1.norm.weight": "blocks.9.norm.weight",
|
||||
"control_model.input_blocks.4.1.norm.bias": "blocks.9.norm.bias",
|
||||
"control_model.input_blocks.4.1.proj_in.weight": "blocks.9.proj_in.weight",
|
||||
"control_model.input_blocks.4.1.proj_in.bias": "blocks.9.proj_in.bias",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_q.weight": "blocks.9.transformer_blocks.0.attn1.to_q.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_k.weight": "blocks.9.transformer_blocks.0.attn1.to_k.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_v.weight": "blocks.9.transformer_blocks.0.attn1.to_v.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.9.transformer_blocks.0.attn1.to_out.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.9.transformer_blocks.0.attn1.to_out.bias",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.9.transformer_blocks.0.act_fn.proj.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.9.transformer_blocks.0.act_fn.proj.bias",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.weight": "blocks.9.transformer_blocks.0.ff.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.ff.net.2.bias": "blocks.9.transformer_blocks.0.ff.bias",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_q.weight": "blocks.9.transformer_blocks.0.attn2.to_q.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight": "blocks.9.transformer_blocks.0.attn2.to_k.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_v.weight": "blocks.9.transformer_blocks.0.attn2.to_v.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.9.transformer_blocks.0.attn2.to_out.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.9.transformer_blocks.0.attn2.to_out.bias",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.norm1.weight": "blocks.9.transformer_blocks.0.norm1.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.norm1.bias": "blocks.9.transformer_blocks.0.norm1.bias",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.norm2.weight": "blocks.9.transformer_blocks.0.norm2.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.norm2.bias": "blocks.9.transformer_blocks.0.norm2.bias",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.norm3.weight": "blocks.9.transformer_blocks.0.norm3.weight",
|
||||
"control_model.input_blocks.4.1.transformer_blocks.0.norm3.bias": "blocks.9.transformer_blocks.0.norm3.bias",
|
||||
"control_model.input_blocks.4.1.proj_out.weight": "blocks.9.proj_out.weight",
|
||||
"control_model.input_blocks.4.1.proj_out.bias": "blocks.9.proj_out.bias",
|
||||
"control_model.input_blocks.5.0.in_layers.0.weight": "blocks.11.norm1.weight",
|
||||
"control_model.input_blocks.5.0.in_layers.0.bias": "blocks.11.norm1.bias",
|
||||
"control_model.input_blocks.5.0.in_layers.2.weight": "blocks.11.conv1.weight",
|
||||
"control_model.input_blocks.5.0.in_layers.2.bias": "blocks.11.conv1.bias",
|
||||
"control_model.input_blocks.5.0.emb_layers.1.weight": "blocks.11.time_emb_proj.weight",
|
||||
"control_model.input_blocks.5.0.emb_layers.1.bias": "blocks.11.time_emb_proj.bias",
|
||||
"control_model.input_blocks.5.0.out_layers.0.weight": "blocks.11.norm2.weight",
|
||||
"control_model.input_blocks.5.0.out_layers.0.bias": "blocks.11.norm2.bias",
|
||||
"control_model.input_blocks.5.0.out_layers.3.weight": "blocks.11.conv2.weight",
|
||||
"control_model.input_blocks.5.0.out_layers.3.bias": "blocks.11.conv2.bias",
|
||||
"control_model.input_blocks.5.1.norm.weight": "blocks.12.norm.weight",
|
||||
"control_model.input_blocks.5.1.norm.bias": "blocks.12.norm.bias",
|
||||
"control_model.input_blocks.5.1.proj_in.weight": "blocks.12.proj_in.weight",
|
||||
"control_model.input_blocks.5.1.proj_in.bias": "blocks.12.proj_in.bias",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_q.weight": "blocks.12.transformer_blocks.0.attn1.to_q.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_k.weight": "blocks.12.transformer_blocks.0.attn1.to_k.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_v.weight": "blocks.12.transformer_blocks.0.attn1.to_v.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.12.transformer_blocks.0.attn1.to_out.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.12.transformer_blocks.0.attn1.to_out.bias",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.12.transformer_blocks.0.act_fn.proj.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.12.transformer_blocks.0.act_fn.proj.bias",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.weight": "blocks.12.transformer_blocks.0.ff.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.ff.net.2.bias": "blocks.12.transformer_blocks.0.ff.bias",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_q.weight": "blocks.12.transformer_blocks.0.attn2.to_q.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_k.weight": "blocks.12.transformer_blocks.0.attn2.to_k.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_v.weight": "blocks.12.transformer_blocks.0.attn2.to_v.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.12.transformer_blocks.0.attn2.to_out.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.12.transformer_blocks.0.attn2.to_out.bias",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.norm1.weight": "blocks.12.transformer_blocks.0.norm1.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.norm1.bias": "blocks.12.transformer_blocks.0.norm1.bias",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.norm2.weight": "blocks.12.transformer_blocks.0.norm2.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.norm2.bias": "blocks.12.transformer_blocks.0.norm2.bias",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.norm3.weight": "blocks.12.transformer_blocks.0.norm3.weight",
|
||||
"control_model.input_blocks.5.1.transformer_blocks.0.norm3.bias": "blocks.12.transformer_blocks.0.norm3.bias",
|
||||
"control_model.input_blocks.5.1.proj_out.weight": "blocks.12.proj_out.weight",
|
||||
"control_model.input_blocks.5.1.proj_out.bias": "blocks.12.proj_out.bias",
|
||||
"control_model.input_blocks.6.0.op.weight": "blocks.14.conv.weight",
|
||||
"control_model.input_blocks.6.0.op.bias": "blocks.14.conv.bias",
|
||||
"control_model.input_blocks.7.0.in_layers.0.weight": "blocks.16.norm1.weight",
|
||||
"control_model.input_blocks.7.0.in_layers.0.bias": "blocks.16.norm1.bias",
|
||||
"control_model.input_blocks.7.0.in_layers.2.weight": "blocks.16.conv1.weight",
|
||||
"control_model.input_blocks.7.0.in_layers.2.bias": "blocks.16.conv1.bias",
|
||||
"control_model.input_blocks.7.0.emb_layers.1.weight": "blocks.16.time_emb_proj.weight",
|
||||
"control_model.input_blocks.7.0.emb_layers.1.bias": "blocks.16.time_emb_proj.bias",
|
||||
"control_model.input_blocks.7.0.out_layers.0.weight": "blocks.16.norm2.weight",
|
||||
"control_model.input_blocks.7.0.out_layers.0.bias": "blocks.16.norm2.bias",
|
||||
"control_model.input_blocks.7.0.out_layers.3.weight": "blocks.16.conv2.weight",
|
||||
"control_model.input_blocks.7.0.out_layers.3.bias": "blocks.16.conv2.bias",
|
||||
"control_model.input_blocks.7.0.skip_connection.weight": "blocks.16.conv_shortcut.weight",
|
||||
"control_model.input_blocks.7.0.skip_connection.bias": "blocks.16.conv_shortcut.bias",
|
||||
"control_model.input_blocks.7.1.norm.weight": "blocks.17.norm.weight",
|
||||
"control_model.input_blocks.7.1.norm.bias": "blocks.17.norm.bias",
|
||||
"control_model.input_blocks.7.1.proj_in.weight": "blocks.17.proj_in.weight",
|
||||
"control_model.input_blocks.7.1.proj_in.bias": "blocks.17.proj_in.bias",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_q.weight": "blocks.17.transformer_blocks.0.attn1.to_q.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_k.weight": "blocks.17.transformer_blocks.0.attn1.to_k.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_v.weight": "blocks.17.transformer_blocks.0.attn1.to_v.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.17.transformer_blocks.0.attn1.to_out.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.17.transformer_blocks.0.attn1.to_out.bias",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.17.transformer_blocks.0.act_fn.proj.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.17.transformer_blocks.0.act_fn.proj.bias",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.weight": "blocks.17.transformer_blocks.0.ff.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.ff.net.2.bias": "blocks.17.transformer_blocks.0.ff.bias",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_q.weight": "blocks.17.transformer_blocks.0.attn2.to_q.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_k.weight": "blocks.17.transformer_blocks.0.attn2.to_k.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_v.weight": "blocks.17.transformer_blocks.0.attn2.to_v.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.17.transformer_blocks.0.attn2.to_out.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.17.transformer_blocks.0.attn2.to_out.bias",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.norm1.weight": "blocks.17.transformer_blocks.0.norm1.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.norm1.bias": "blocks.17.transformer_blocks.0.norm1.bias",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.norm2.weight": "blocks.17.transformer_blocks.0.norm2.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.norm2.bias": "blocks.17.transformer_blocks.0.norm2.bias",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.norm3.weight": "blocks.17.transformer_blocks.0.norm3.weight",
|
||||
"control_model.input_blocks.7.1.transformer_blocks.0.norm3.bias": "blocks.17.transformer_blocks.0.norm3.bias",
|
||||
"control_model.input_blocks.7.1.proj_out.weight": "blocks.17.proj_out.weight",
|
||||
"control_model.input_blocks.7.1.proj_out.bias": "blocks.17.proj_out.bias",
|
||||
"control_model.input_blocks.8.0.in_layers.0.weight": "blocks.19.norm1.weight",
|
||||
"control_model.input_blocks.8.0.in_layers.0.bias": "blocks.19.norm1.bias",
|
||||
"control_model.input_blocks.8.0.in_layers.2.weight": "blocks.19.conv1.weight",
|
||||
"control_model.input_blocks.8.0.in_layers.2.bias": "blocks.19.conv1.bias",
|
||||
"control_model.input_blocks.8.0.emb_layers.1.weight": "blocks.19.time_emb_proj.weight",
|
||||
"control_model.input_blocks.8.0.emb_layers.1.bias": "blocks.19.time_emb_proj.bias",
|
||||
"control_model.input_blocks.8.0.out_layers.0.weight": "blocks.19.norm2.weight",
|
||||
"control_model.input_blocks.8.0.out_layers.0.bias": "blocks.19.norm2.bias",
|
||||
"control_model.input_blocks.8.0.out_layers.3.weight": "blocks.19.conv2.weight",
|
||||
"control_model.input_blocks.8.0.out_layers.3.bias": "blocks.19.conv2.bias",
|
||||
"control_model.input_blocks.8.1.norm.weight": "blocks.20.norm.weight",
|
||||
"control_model.input_blocks.8.1.norm.bias": "blocks.20.norm.bias",
|
||||
"control_model.input_blocks.8.1.proj_in.weight": "blocks.20.proj_in.weight",
|
||||
"control_model.input_blocks.8.1.proj_in.bias": "blocks.20.proj_in.bias",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_q.weight": "blocks.20.transformer_blocks.0.attn1.to_q.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_k.weight": "blocks.20.transformer_blocks.0.attn1.to_k.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_v.weight": "blocks.20.transformer_blocks.0.attn1.to_v.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.20.transformer_blocks.0.attn1.to_out.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.20.transformer_blocks.0.attn1.to_out.bias",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.20.transformer_blocks.0.act_fn.proj.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.20.transformer_blocks.0.act_fn.proj.bias",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.weight": "blocks.20.transformer_blocks.0.ff.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.ff.net.2.bias": "blocks.20.transformer_blocks.0.ff.bias",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_q.weight": "blocks.20.transformer_blocks.0.attn2.to_q.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_k.weight": "blocks.20.transformer_blocks.0.attn2.to_k.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_v.weight": "blocks.20.transformer_blocks.0.attn2.to_v.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.20.transformer_blocks.0.attn2.to_out.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.20.transformer_blocks.0.attn2.to_out.bias",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.norm1.weight": "blocks.20.transformer_blocks.0.norm1.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.norm1.bias": "blocks.20.transformer_blocks.0.norm1.bias",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.norm2.weight": "blocks.20.transformer_blocks.0.norm2.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.norm2.bias": "blocks.20.transformer_blocks.0.norm2.bias",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.norm3.weight": "blocks.20.transformer_blocks.0.norm3.weight",
|
||||
"control_model.input_blocks.8.1.transformer_blocks.0.norm3.bias": "blocks.20.transformer_blocks.0.norm3.bias",
|
||||
"control_model.input_blocks.8.1.proj_out.weight": "blocks.20.proj_out.weight",
|
||||
"control_model.input_blocks.8.1.proj_out.bias": "blocks.20.proj_out.bias",
|
||||
"control_model.input_blocks.9.0.op.weight": "blocks.22.conv.weight",
|
||||
"control_model.input_blocks.9.0.op.bias": "blocks.22.conv.bias",
|
||||
"control_model.input_blocks.10.0.in_layers.0.weight": "blocks.24.norm1.weight",
|
||||
"control_model.input_blocks.10.0.in_layers.0.bias": "blocks.24.norm1.bias",
|
||||
"control_model.input_blocks.10.0.in_layers.2.weight": "blocks.24.conv1.weight",
|
||||
"control_model.input_blocks.10.0.in_layers.2.bias": "blocks.24.conv1.bias",
|
||||
"control_model.input_blocks.10.0.emb_layers.1.weight": "blocks.24.time_emb_proj.weight",
|
||||
"control_model.input_blocks.10.0.emb_layers.1.bias": "blocks.24.time_emb_proj.bias",
|
||||
"control_model.input_blocks.10.0.out_layers.0.weight": "blocks.24.norm2.weight",
|
||||
"control_model.input_blocks.10.0.out_layers.0.bias": "blocks.24.norm2.bias",
|
||||
"control_model.input_blocks.10.0.out_layers.3.weight": "blocks.24.conv2.weight",
|
||||
"control_model.input_blocks.10.0.out_layers.3.bias": "blocks.24.conv2.bias",
|
||||
"control_model.input_blocks.11.0.in_layers.0.weight": "blocks.26.norm1.weight",
|
||||
"control_model.input_blocks.11.0.in_layers.0.bias": "blocks.26.norm1.bias",
|
||||
"control_model.input_blocks.11.0.in_layers.2.weight": "blocks.26.conv1.weight",
|
||||
"control_model.input_blocks.11.0.in_layers.2.bias": "blocks.26.conv1.bias",
|
||||
"control_model.input_blocks.11.0.emb_layers.1.weight": "blocks.26.time_emb_proj.weight",
|
||||
"control_model.input_blocks.11.0.emb_layers.1.bias": "blocks.26.time_emb_proj.bias",
|
||||
"control_model.input_blocks.11.0.out_layers.0.weight": "blocks.26.norm2.weight",
|
||||
"control_model.input_blocks.11.0.out_layers.0.bias": "blocks.26.norm2.bias",
|
||||
"control_model.input_blocks.11.0.out_layers.3.weight": "blocks.26.conv2.weight",
|
||||
"control_model.input_blocks.11.0.out_layers.3.bias": "blocks.26.conv2.bias",
|
||||
"control_model.zero_convs.0.0.weight": "controlnet_blocks.0.weight",
|
||||
"control_model.zero_convs.0.0.bias": "controlnet_blocks.0.bias",
|
||||
"control_model.zero_convs.1.0.weight": "controlnet_blocks.1.weight",
|
||||
"control_model.zero_convs.1.0.bias": "controlnet_blocks.0.bias",
|
||||
"control_model.zero_convs.2.0.weight": "controlnet_blocks.2.weight",
|
||||
"control_model.zero_convs.2.0.bias": "controlnet_blocks.0.bias",
|
||||
"control_model.zero_convs.3.0.weight": "controlnet_blocks.3.weight",
|
||||
"control_model.zero_convs.3.0.bias": "controlnet_blocks.0.bias",
|
||||
"control_model.zero_convs.4.0.weight": "controlnet_blocks.4.weight",
|
||||
"control_model.zero_convs.4.0.bias": "controlnet_blocks.4.bias",
|
||||
"control_model.zero_convs.5.0.weight": "controlnet_blocks.5.weight",
|
||||
"control_model.zero_convs.5.0.bias": "controlnet_blocks.4.bias",
|
||||
"control_model.zero_convs.6.0.weight": "controlnet_blocks.6.weight",
|
||||
"control_model.zero_convs.6.0.bias": "controlnet_blocks.4.bias",
|
||||
"control_model.zero_convs.7.0.weight": "controlnet_blocks.7.weight",
|
||||
"control_model.zero_convs.7.0.bias": "controlnet_blocks.7.bias",
|
||||
"control_model.zero_convs.8.0.weight": "controlnet_blocks.8.weight",
|
||||
"control_model.zero_convs.8.0.bias": "controlnet_blocks.7.bias",
|
||||
"control_model.zero_convs.9.0.weight": "controlnet_blocks.9.weight",
|
||||
"control_model.zero_convs.9.0.bias": "controlnet_blocks.7.bias",
|
||||
"control_model.zero_convs.10.0.weight": "controlnet_blocks.10.weight",
|
||||
"control_model.zero_convs.10.0.bias": "controlnet_blocks.7.bias",
|
||||
"control_model.zero_convs.11.0.weight": "controlnet_blocks.11.weight",
|
||||
"control_model.zero_convs.11.0.bias": "controlnet_blocks.7.bias",
|
||||
"control_model.input_hint_block.0.weight": "controlnet_conv_in.blocks.0.weight",
|
||||
"control_model.input_hint_block.0.bias": "controlnet_conv_in.blocks.0.bias",
|
||||
"control_model.input_hint_block.2.weight": "controlnet_conv_in.blocks.2.weight",
|
||||
"control_model.input_hint_block.2.bias": "controlnet_conv_in.blocks.2.bias",
|
||||
"control_model.input_hint_block.4.weight": "controlnet_conv_in.blocks.4.weight",
|
||||
"control_model.input_hint_block.4.bias": "controlnet_conv_in.blocks.4.bias",
|
||||
"control_model.input_hint_block.6.weight": "controlnet_conv_in.blocks.6.weight",
|
||||
"control_model.input_hint_block.6.bias": "controlnet_conv_in.blocks.6.bias",
|
||||
"control_model.input_hint_block.8.weight": "controlnet_conv_in.blocks.8.weight",
|
||||
"control_model.input_hint_block.8.bias": "controlnet_conv_in.blocks.8.bias",
|
||||
"control_model.input_hint_block.10.weight": "controlnet_conv_in.blocks.10.weight",
|
||||
"control_model.input_hint_block.10.bias": "controlnet_conv_in.blocks.10.bias",
|
||||
"control_model.input_hint_block.12.weight": "controlnet_conv_in.blocks.12.weight",
|
||||
"control_model.input_hint_block.12.bias": "controlnet_conv_in.blocks.12.bias",
|
||||
"control_model.input_hint_block.14.weight": "controlnet_conv_in.blocks.14.weight",
|
||||
"control_model.input_hint_block.14.bias": "controlnet_conv_in.blocks.14.bias",
|
||||
"control_model.middle_block.0.in_layers.0.weight": "blocks.28.norm1.weight",
|
||||
"control_model.middle_block.0.in_layers.0.bias": "blocks.28.norm1.bias",
|
||||
"control_model.middle_block.0.in_layers.2.weight": "blocks.28.conv1.weight",
|
||||
"control_model.middle_block.0.in_layers.2.bias": "blocks.28.conv1.bias",
|
||||
"control_model.middle_block.0.emb_layers.1.weight": "blocks.28.time_emb_proj.weight",
|
||||
"control_model.middle_block.0.emb_layers.1.bias": "blocks.28.time_emb_proj.bias",
|
||||
"control_model.middle_block.0.out_layers.0.weight": "blocks.28.norm2.weight",
|
||||
"control_model.middle_block.0.out_layers.0.bias": "blocks.28.norm2.bias",
|
||||
"control_model.middle_block.0.out_layers.3.weight": "blocks.28.conv2.weight",
|
||||
"control_model.middle_block.0.out_layers.3.bias": "blocks.28.conv2.bias",
|
||||
"control_model.middle_block.1.norm.weight": "blocks.29.norm.weight",
|
||||
"control_model.middle_block.1.norm.bias": "blocks.29.norm.bias",
|
||||
"control_model.middle_block.1.proj_in.weight": "blocks.29.proj_in.weight",
|
||||
"control_model.middle_block.1.proj_in.bias": "blocks.29.proj_in.bias",
|
||||
"control_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight": "blocks.29.transformer_blocks.0.attn1.to_q.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.attn1.to_k.weight": "blocks.29.transformer_blocks.0.attn1.to_k.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.attn1.to_v.weight": "blocks.29.transformer_blocks.0.attn1.to_v.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.weight": "blocks.29.transformer_blocks.0.attn1.to_out.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.attn1.to_out.0.bias": "blocks.29.transformer_blocks.0.attn1.to_out.bias",
|
||||
"control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.weight": "blocks.29.transformer_blocks.0.act_fn.proj.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.ff.net.0.proj.bias": "blocks.29.transformer_blocks.0.act_fn.proj.bias",
|
||||
"control_model.middle_block.1.transformer_blocks.0.ff.net.2.weight": "blocks.29.transformer_blocks.0.ff.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.ff.net.2.bias": "blocks.29.transformer_blocks.0.ff.bias",
|
||||
"control_model.middle_block.1.transformer_blocks.0.attn2.to_q.weight": "blocks.29.transformer_blocks.0.attn2.to_q.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.attn2.to_k.weight": "blocks.29.transformer_blocks.0.attn2.to_k.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.attn2.to_v.weight": "blocks.29.transformer_blocks.0.attn2.to_v.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.weight": "blocks.29.transformer_blocks.0.attn2.to_out.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.attn2.to_out.0.bias": "blocks.29.transformer_blocks.0.attn2.to_out.bias",
|
||||
"control_model.middle_block.1.transformer_blocks.0.norm1.weight": "blocks.29.transformer_blocks.0.norm1.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.norm1.bias": "blocks.29.transformer_blocks.0.norm1.bias",
|
||||
"control_model.middle_block.1.transformer_blocks.0.norm2.weight": "blocks.29.transformer_blocks.0.norm2.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.norm2.bias": "blocks.29.transformer_blocks.0.norm2.bias",
|
||||
"control_model.middle_block.1.transformer_blocks.0.norm3.weight": "blocks.29.transformer_blocks.0.norm3.weight",
|
||||
"control_model.middle_block.1.transformer_blocks.0.norm3.bias": "blocks.29.transformer_blocks.0.norm3.bias",
|
||||
"control_model.middle_block.1.proj_out.weight": "blocks.29.proj_out.weight",
|
||||
"control_model.middle_block.1.proj_out.bias": "blocks.29.proj_out.bias",
|
||||
"control_model.middle_block.2.in_layers.0.weight": "blocks.30.norm1.weight",
|
||||
"control_model.middle_block.2.in_layers.0.bias": "blocks.30.norm1.bias",
|
||||
"control_model.middle_block.2.in_layers.2.weight": "blocks.30.conv1.weight",
|
||||
"control_model.middle_block.2.in_layers.2.bias": "blocks.30.conv1.bias",
|
||||
"control_model.middle_block.2.emb_layers.1.weight": "blocks.30.time_emb_proj.weight",
|
||||
"control_model.middle_block.2.emb_layers.1.bias": "blocks.30.time_emb_proj.bias",
|
||||
"control_model.middle_block.2.out_layers.0.weight": "blocks.30.norm2.weight",
|
||||
"control_model.middle_block.2.out_layers.0.bias": "blocks.30.norm2.bias",
|
||||
"control_model.middle_block.2.out_layers.3.weight": "blocks.30.conv2.weight",
|
||||
"control_model.middle_block.2.out_layers.3.bias": "blocks.30.conv2.bias",
|
||||
"control_model.middle_block_out.0.weight": "controlnet_blocks.12.weight",
|
||||
"control_model.middle_block_out.0.bias": "controlnet_blocks.7.bias",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if ".proj_in." in name or ".proj_out." in name:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
@@ -1,57 +0,0 @@
|
||||
from .svd_image_encoder import SVDImageEncoder
|
||||
from .sdxl_ipadapter import IpAdapterImageProjModel, IpAdapterModule, SDXLIpAdapterStateDictConverter
|
||||
from transformers import CLIPImageProcessor
|
||||
import torch
|
||||
|
||||
|
||||
class IpAdapterCLIPImageEmbedder(SVDImageEncoder):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.image_processor = CLIPImageProcessor()
|
||||
|
||||
def forward(self, image):
|
||||
pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
|
||||
pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
|
||||
return super().forward(pixel_values)
|
||||
|
||||
|
||||
class SDIpAdapter(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
shape_list = [(768, 320)] * 2 + [(768, 640)] * 2 + [(768, 1280)] * 5 + [(768, 640)] * 3 + [(768, 320)] * 3 + [(768, 1280)] * 1
|
||||
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
|
||||
self.image_proj = IpAdapterImageProjModel(cross_attention_dim=768, clip_embeddings_dim=1024, clip_extra_context_tokens=4)
|
||||
self.set_full_adapter()
|
||||
|
||||
def set_full_adapter(self):
|
||||
block_ids = [1, 4, 9, 12, 17, 20, 40, 43, 46, 50, 53, 56, 60, 63, 66, 29]
|
||||
self.call_block_id = {(i, 0): j for j, i in enumerate(block_ids)}
|
||||
|
||||
def set_less_adapter(self):
|
||||
# IP-Adapter for SD v1.5 doesn't support this feature.
|
||||
self.set_full_adapter()
|
||||
|
||||
def forward(self, hidden_states, scale=1.0):
|
||||
hidden_states = self.image_proj(hidden_states)
|
||||
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
|
||||
ip_kv_dict = {}
|
||||
for (block_id, transformer_id) in self.call_block_id:
|
||||
ipadapter_id = self.call_block_id[(block_id, transformer_id)]
|
||||
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
|
||||
if block_id not in ip_kv_dict:
|
||||
ip_kv_dict[block_id] = {}
|
||||
ip_kv_dict[block_id][transformer_id] = {
|
||||
"ip_k": ip_k,
|
||||
"ip_v": ip_v,
|
||||
"scale": scale
|
||||
}
|
||||
return ip_kv_dict
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDIpAdapterStateDictConverter()
|
||||
|
||||
|
||||
class SDIpAdapterStateDictConverter(SDXLIpAdapterStateDictConverter):
|
||||
def __init__(self):
|
||||
pass
|
||||
@@ -1,199 +0,0 @@
|
||||
from .sd_unet import SDUNet, Attention, GEGLU
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
class TemporalTransformerBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, dim, num_attention_heads, attention_head_dim, max_position_embeddings=32):
|
||||
super().__init__()
|
||||
|
||||
# 1. Self-Attn
|
||||
self.pe1 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
|
||||
self.norm1 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
||||
self.attn1 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
||||
|
||||
# 2. Cross-Attn
|
||||
self.pe2 = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, dim))
|
||||
self.norm2 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
||||
self.attn2 = Attention(q_dim=dim, num_heads=num_attention_heads, head_dim=attention_head_dim, bias_out=True)
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3 = torch.nn.LayerNorm(dim, elementwise_affine=True)
|
||||
self.act_fn = GEGLU(dim, dim * 4)
|
||||
self.ff = torch.nn.Linear(dim * 4, dim)
|
||||
|
||||
|
||||
def forward(self, hidden_states, batch_size=1):
|
||||
|
||||
# 1. Self-Attention
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
||||
attn_output = self.attn1(norm_hidden_states + self.pe1[:, :norm_hidden_states.shape[1]])
|
||||
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 2. Cross-Attention
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
norm_hidden_states = rearrange(norm_hidden_states, "(b f) h c -> (b h) f c", b=batch_size)
|
||||
attn_output = self.attn2(norm_hidden_states + self.pe2[:, :norm_hidden_states.shape[1]])
|
||||
attn_output = rearrange(attn_output, "(b h) f c -> (b f) h c", b=batch_size)
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
ff_output = self.act_fn(norm_hidden_states)
|
||||
ff_output = self.ff(ff_output)
|
||||
hidden_states = ff_output + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class TemporalBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
self.proj_in = torch.nn.Linear(in_channels, inner_dim)
|
||||
|
||||
self.transformer_blocks = torch.nn.ModuleList([
|
||||
TemporalTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim
|
||||
)
|
||||
for d in range(num_layers)
|
||||
])
|
||||
|
||||
self.proj_out = torch.nn.Linear(inner_dim, in_channels)
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack, batch_size=1):
|
||||
batch, _, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
|
||||
class SDMotionModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.motion_modules = torch.nn.ModuleList([
|
||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 160, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||
TemporalBlock(8, 80, 640, eps=1e-6),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||
TemporalBlock(8, 40, 320, eps=1e-6),
|
||||
])
|
||||
self.call_block_id = {
|
||||
1: 0,
|
||||
4: 1,
|
||||
9: 2,
|
||||
12: 3,
|
||||
17: 4,
|
||||
20: 5,
|
||||
24: 6,
|
||||
26: 7,
|
||||
29: 8,
|
||||
32: 9,
|
||||
34: 10,
|
||||
36: 11,
|
||||
40: 12,
|
||||
43: 13,
|
||||
46: 14,
|
||||
50: 15,
|
||||
53: 16,
|
||||
56: 17,
|
||||
60: 18,
|
||||
63: 19,
|
||||
66: 20
|
||||
}
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDMotionModelStateDictConverter()
|
||||
|
||||
|
||||
class SDMotionModelStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"norm": "norm",
|
||||
"proj_in": "proj_in",
|
||||
"transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
|
||||
"transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
|
||||
"transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
|
||||
"transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
|
||||
"transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
|
||||
"transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
|
||||
"transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
|
||||
"transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
|
||||
"transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
|
||||
"transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
|
||||
"transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
|
||||
"transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
|
||||
"transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
|
||||
"transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
|
||||
"transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
|
||||
"proj_out": "proj_out",
|
||||
}
|
||||
name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
|
||||
name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
|
||||
name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
|
||||
state_dict_ = {}
|
||||
last_prefix, module_id = "", -1
|
||||
for name in name_list:
|
||||
names = name.split(".")
|
||||
prefix_index = names.index("temporal_transformer") + 1
|
||||
prefix = ".".join(names[:prefix_index])
|
||||
if prefix != last_prefix:
|
||||
last_prefix = prefix
|
||||
module_id += 1
|
||||
middle_name = ".".join(names[prefix_index:-1])
|
||||
suffix = names[-1]
|
||||
if "pos_encoder" in names:
|
||||
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
|
||||
else:
|
||||
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
|
||||
state_dict_[rename] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
@@ -1,5 +1,96 @@
|
||||
import torch
|
||||
from .attention import Attention
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def low_version_attention(query, key, value, attn_bias=None):
|
||||
scale = 1 / query.shape[-1] ** 0.5
|
||||
query = query * scale
|
||||
attn = torch.matmul(query, key.transpose(-2, -1))
|
||||
if attn_bias is not None:
|
||||
attn = attn + attn_bias
|
||||
attn = attn.softmax(-1)
|
||||
return attn @ value
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||
super().__init__()
|
||||
dim_inner = head_dim * num_heads
|
||||
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out)
|
||||
|
||||
def interact_with_ipadapter(self, hidden_states, q, ip_k, ip_v, scale=1.0):
|
||||
batch_size = q.shape[0]
|
||||
ip_k = ip_k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
ip_v = ip_v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
ip_hidden_states = torch.nn.functional.scaled_dot_product_attention(q, ip_k, ip_v)
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
return hidden_states
|
||||
|
||||
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
batch_size = encoder_hidden_states.shape[0]
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(encoder_hidden_states)
|
||||
v = self.to_v(encoder_hidden_states)
|
||||
|
||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if qkv_preprocessor is not None:
|
||||
q, k, v = qkv_preprocessor(q, k, v)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
if ipadapter_kwargs is not None:
|
||||
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(encoder_hidden_states)
|
||||
v = self.to_v(encoder_hidden_states)
|
||||
|
||||
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads)
|
||||
|
||||
if attn_mask is not None:
|
||||
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask)
|
||||
else:
|
||||
import xformers.ops as xops
|
||||
hidden_states = xops.memory_efficient_attention(q, k, v)
|
||||
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads)
|
||||
|
||||
hidden_states = hidden_states.to(q.dtype)
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class CLIPEncoderLayer(torch.nn.Module):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,336 +0,0 @@
|
||||
import torch
|
||||
from .attention import Attention
|
||||
from .sd_unet import ResnetBlock, UpSampler
|
||||
from .tiler import TileWorker
|
||||
|
||||
|
||||
class VAEAttentionBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
|
||||
self.transformer_blocks = torch.nn.ModuleList([
|
||||
Attention(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
bias_q=True,
|
||||
bias_kv=True,
|
||||
bias_out=True
|
||||
)
|
||||
for d in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
||||
batch, _, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
|
||||
class SDVAEDecoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.18215
|
||||
self.post_quant_conv = torch.nn.Conv2d(4, 4, kernel_size=1)
|
||||
self.conv_in = torch.nn.Conv2d(4, 512, kernel_size=3, padding=1)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# UNetMidBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
UpSampler(512),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
UpSampler(512),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(512, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
UpSampler(256),
|
||||
# UpDecoderBlock2D
|
||||
ResnetBlock(256, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
])
|
||||
|
||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-5)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=sample.device,
|
||||
tile_dtype=sample.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
original_dtype = sample.dtype
|
||||
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# 1. pre-process
|
||||
sample = sample / self.scaling_factor
|
||||
hidden_states = self.post_quant_conv(sample)
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
time_emb = None
|
||||
text_emb = None
|
||||
res_stack = None
|
||||
|
||||
# 2. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 3. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = hidden_states.to(original_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
class SDVAEDecoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
# architecture
|
||||
block_types = [
|
||||
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock',
|
||||
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
||||
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
||||
'ResnetBlock', 'ResnetBlock', 'ResnetBlock', 'UpSampler',
|
||||
'ResnetBlock', 'ResnetBlock', 'ResnetBlock'
|
||||
]
|
||||
|
||||
# Rename each parameter
|
||||
local_rename_dict = {
|
||||
"post_quant_conv": "post_quant_conv",
|
||||
"decoder.conv_in": "conv_in",
|
||||
"decoder.mid_block.attentions.0.group_norm": "blocks.1.norm",
|
||||
"decoder.mid_block.attentions.0.to_q": "blocks.1.transformer_blocks.0.to_q",
|
||||
"decoder.mid_block.attentions.0.to_k": "blocks.1.transformer_blocks.0.to_k",
|
||||
"decoder.mid_block.attentions.0.to_v": "blocks.1.transformer_blocks.0.to_v",
|
||||
"decoder.mid_block.attentions.0.to_out.0": "blocks.1.transformer_blocks.0.to_out",
|
||||
"decoder.mid_block.resnets.0.norm1": "blocks.0.norm1",
|
||||
"decoder.mid_block.resnets.0.conv1": "blocks.0.conv1",
|
||||
"decoder.mid_block.resnets.0.norm2": "blocks.0.norm2",
|
||||
"decoder.mid_block.resnets.0.conv2": "blocks.0.conv2",
|
||||
"decoder.mid_block.resnets.1.norm1": "blocks.2.norm1",
|
||||
"decoder.mid_block.resnets.1.conv1": "blocks.2.conv1",
|
||||
"decoder.mid_block.resnets.1.norm2": "blocks.2.norm2",
|
||||
"decoder.mid_block.resnets.1.conv2": "blocks.2.conv2",
|
||||
"decoder.conv_norm_out": "conv_norm_out",
|
||||
"decoder.conv_out": "conv_out",
|
||||
}
|
||||
name_list = sorted([name for name in state_dict])
|
||||
rename_dict = {}
|
||||
block_id = {"ResnetBlock": 2, "DownSampler": 2, "UpSampler": 2}
|
||||
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
|
||||
for name in name_list:
|
||||
names = name.split(".")
|
||||
name_prefix = ".".join(names[:-1])
|
||||
if name_prefix in local_rename_dict:
|
||||
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
|
||||
elif name.startswith("decoder.up_blocks"):
|
||||
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
|
||||
block_type_with_id = ".".join(names[:5])
|
||||
if block_type_with_id != last_block_type_with_id[block_type]:
|
||||
block_id[block_type] += 1
|
||||
last_block_type_with_id[block_type] = block_type_with_id
|
||||
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
||||
block_id[block_type] += 1
|
||||
block_type_with_id = ".".join(names[:5])
|
||||
names = ["blocks", str(block_id[block_type])] + names[5:]
|
||||
rename_dict[name] = ".".join(names)
|
||||
|
||||
# Convert state_dict
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"first_stage_model.decoder.conv_in.bias": "conv_in.bias",
|
||||
"first_stage_model.decoder.conv_in.weight": "conv_in.weight",
|
||||
"first_stage_model.decoder.conv_out.bias": "conv_out.bias",
|
||||
"first_stage_model.decoder.conv_out.weight": "conv_out.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.k.bias": "blocks.1.transformer_blocks.0.to_k.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.k.weight": "blocks.1.transformer_blocks.0.to_k.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.norm.bias": "blocks.1.norm.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.norm.weight": "blocks.1.norm.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.proj_out.bias": "blocks.1.transformer_blocks.0.to_out.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.proj_out.weight": "blocks.1.transformer_blocks.0.to_out.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.q.bias": "blocks.1.transformer_blocks.0.to_q.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.q.weight": "blocks.1.transformer_blocks.0.to_q.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.v.bias": "blocks.1.transformer_blocks.0.to_v.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.v.weight": "blocks.1.transformer_blocks.0.to_v.weight",
|
||||
"first_stage_model.decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
|
||||
"first_stage_model.decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
|
||||
"first_stage_model.decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
|
||||
"first_stage_model.decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
|
||||
"first_stage_model.decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
|
||||
"first_stage_model.decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
|
||||
"first_stage_model.decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
|
||||
"first_stage_model.decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
|
||||
"first_stage_model.decoder.mid.block_2.conv1.bias": "blocks.2.conv1.bias",
|
||||
"first_stage_model.decoder.mid.block_2.conv1.weight": "blocks.2.conv1.weight",
|
||||
"first_stage_model.decoder.mid.block_2.conv2.bias": "blocks.2.conv2.bias",
|
||||
"first_stage_model.decoder.mid.block_2.conv2.weight": "blocks.2.conv2.weight",
|
||||
"first_stage_model.decoder.mid.block_2.norm1.bias": "blocks.2.norm1.bias",
|
||||
"first_stage_model.decoder.mid.block_2.norm1.weight": "blocks.2.norm1.weight",
|
||||
"first_stage_model.decoder.mid.block_2.norm2.bias": "blocks.2.norm2.bias",
|
||||
"first_stage_model.decoder.mid.block_2.norm2.weight": "blocks.2.norm2.weight",
|
||||
"first_stage_model.decoder.norm_out.bias": "conv_norm_out.bias",
|
||||
"first_stage_model.decoder.norm_out.weight": "conv_norm_out.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.conv1.bias": "blocks.15.conv1.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.conv1.weight": "blocks.15.conv1.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.conv2.bias": "blocks.15.conv2.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.conv2.weight": "blocks.15.conv2.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "blocks.15.conv_shortcut.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "blocks.15.conv_shortcut.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.norm1.bias": "blocks.15.norm1.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.norm1.weight": "blocks.15.norm1.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.norm2.bias": "blocks.15.norm2.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.norm2.weight": "blocks.15.norm2.weight",
|
||||
"first_stage_model.decoder.up.0.block.1.conv1.bias": "blocks.16.conv1.bias",
|
||||
"first_stage_model.decoder.up.0.block.1.conv1.weight": "blocks.16.conv1.weight",
|
||||
"first_stage_model.decoder.up.0.block.1.conv2.bias": "blocks.16.conv2.bias",
|
||||
"first_stage_model.decoder.up.0.block.1.conv2.weight": "blocks.16.conv2.weight",
|
||||
"first_stage_model.decoder.up.0.block.1.norm1.bias": "blocks.16.norm1.bias",
|
||||
"first_stage_model.decoder.up.0.block.1.norm1.weight": "blocks.16.norm1.weight",
|
||||
"first_stage_model.decoder.up.0.block.1.norm2.bias": "blocks.16.norm2.bias",
|
||||
"first_stage_model.decoder.up.0.block.1.norm2.weight": "blocks.16.norm2.weight",
|
||||
"first_stage_model.decoder.up.0.block.2.conv1.bias": "blocks.17.conv1.bias",
|
||||
"first_stage_model.decoder.up.0.block.2.conv1.weight": "blocks.17.conv1.weight",
|
||||
"first_stage_model.decoder.up.0.block.2.conv2.bias": "blocks.17.conv2.bias",
|
||||
"first_stage_model.decoder.up.0.block.2.conv2.weight": "blocks.17.conv2.weight",
|
||||
"first_stage_model.decoder.up.0.block.2.norm1.bias": "blocks.17.norm1.bias",
|
||||
"first_stage_model.decoder.up.0.block.2.norm1.weight": "blocks.17.norm1.weight",
|
||||
"first_stage_model.decoder.up.0.block.2.norm2.bias": "blocks.17.norm2.bias",
|
||||
"first_stage_model.decoder.up.0.block.2.norm2.weight": "blocks.17.norm2.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.conv1.bias": "blocks.11.conv1.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.conv1.weight": "blocks.11.conv1.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.conv2.bias": "blocks.11.conv2.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.conv2.weight": "blocks.11.conv2.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "blocks.11.conv_shortcut.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "blocks.11.conv_shortcut.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.norm1.bias": "blocks.11.norm1.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.norm1.weight": "blocks.11.norm1.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.norm2.bias": "blocks.11.norm2.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.norm2.weight": "blocks.11.norm2.weight",
|
||||
"first_stage_model.decoder.up.1.block.1.conv1.bias": "blocks.12.conv1.bias",
|
||||
"first_stage_model.decoder.up.1.block.1.conv1.weight": "blocks.12.conv1.weight",
|
||||
"first_stage_model.decoder.up.1.block.1.conv2.bias": "blocks.12.conv2.bias",
|
||||
"first_stage_model.decoder.up.1.block.1.conv2.weight": "blocks.12.conv2.weight",
|
||||
"first_stage_model.decoder.up.1.block.1.norm1.bias": "blocks.12.norm1.bias",
|
||||
"first_stage_model.decoder.up.1.block.1.norm1.weight": "blocks.12.norm1.weight",
|
||||
"first_stage_model.decoder.up.1.block.1.norm2.bias": "blocks.12.norm2.bias",
|
||||
"first_stage_model.decoder.up.1.block.1.norm2.weight": "blocks.12.norm2.weight",
|
||||
"first_stage_model.decoder.up.1.block.2.conv1.bias": "blocks.13.conv1.bias",
|
||||
"first_stage_model.decoder.up.1.block.2.conv1.weight": "blocks.13.conv1.weight",
|
||||
"first_stage_model.decoder.up.1.block.2.conv2.bias": "blocks.13.conv2.bias",
|
||||
"first_stage_model.decoder.up.1.block.2.conv2.weight": "blocks.13.conv2.weight",
|
||||
"first_stage_model.decoder.up.1.block.2.norm1.bias": "blocks.13.norm1.bias",
|
||||
"first_stage_model.decoder.up.1.block.2.norm1.weight": "blocks.13.norm1.weight",
|
||||
"first_stage_model.decoder.up.1.block.2.norm2.bias": "blocks.13.norm2.bias",
|
||||
"first_stage_model.decoder.up.1.block.2.norm2.weight": "blocks.13.norm2.weight",
|
||||
"first_stage_model.decoder.up.1.upsample.conv.bias": "blocks.14.conv.bias",
|
||||
"first_stage_model.decoder.up.1.upsample.conv.weight": "blocks.14.conv.weight",
|
||||
"first_stage_model.decoder.up.2.block.0.conv1.bias": "blocks.7.conv1.bias",
|
||||
"first_stage_model.decoder.up.2.block.0.conv1.weight": "blocks.7.conv1.weight",
|
||||
"first_stage_model.decoder.up.2.block.0.conv2.bias": "blocks.7.conv2.bias",
|
||||
"first_stage_model.decoder.up.2.block.0.conv2.weight": "blocks.7.conv2.weight",
|
||||
"first_stage_model.decoder.up.2.block.0.norm1.bias": "blocks.7.norm1.bias",
|
||||
"first_stage_model.decoder.up.2.block.0.norm1.weight": "blocks.7.norm1.weight",
|
||||
"first_stage_model.decoder.up.2.block.0.norm2.bias": "blocks.7.norm2.bias",
|
||||
"first_stage_model.decoder.up.2.block.0.norm2.weight": "blocks.7.norm2.weight",
|
||||
"first_stage_model.decoder.up.2.block.1.conv1.bias": "blocks.8.conv1.bias",
|
||||
"first_stage_model.decoder.up.2.block.1.conv1.weight": "blocks.8.conv1.weight",
|
||||
"first_stage_model.decoder.up.2.block.1.conv2.bias": "blocks.8.conv2.bias",
|
||||
"first_stage_model.decoder.up.2.block.1.conv2.weight": "blocks.8.conv2.weight",
|
||||
"first_stage_model.decoder.up.2.block.1.norm1.bias": "blocks.8.norm1.bias",
|
||||
"first_stage_model.decoder.up.2.block.1.norm1.weight": "blocks.8.norm1.weight",
|
||||
"first_stage_model.decoder.up.2.block.1.norm2.bias": "blocks.8.norm2.bias",
|
||||
"first_stage_model.decoder.up.2.block.1.norm2.weight": "blocks.8.norm2.weight",
|
||||
"first_stage_model.decoder.up.2.block.2.conv1.bias": "blocks.9.conv1.bias",
|
||||
"first_stage_model.decoder.up.2.block.2.conv1.weight": "blocks.9.conv1.weight",
|
||||
"first_stage_model.decoder.up.2.block.2.conv2.bias": "blocks.9.conv2.bias",
|
||||
"first_stage_model.decoder.up.2.block.2.conv2.weight": "blocks.9.conv2.weight",
|
||||
"first_stage_model.decoder.up.2.block.2.norm1.bias": "blocks.9.norm1.bias",
|
||||
"first_stage_model.decoder.up.2.block.2.norm1.weight": "blocks.9.norm1.weight",
|
||||
"first_stage_model.decoder.up.2.block.2.norm2.bias": "blocks.9.norm2.bias",
|
||||
"first_stage_model.decoder.up.2.block.2.norm2.weight": "blocks.9.norm2.weight",
|
||||
"first_stage_model.decoder.up.2.upsample.conv.bias": "blocks.10.conv.bias",
|
||||
"first_stage_model.decoder.up.2.upsample.conv.weight": "blocks.10.conv.weight",
|
||||
"first_stage_model.decoder.up.3.block.0.conv1.bias": "blocks.3.conv1.bias",
|
||||
"first_stage_model.decoder.up.3.block.0.conv1.weight": "blocks.3.conv1.weight",
|
||||
"first_stage_model.decoder.up.3.block.0.conv2.bias": "blocks.3.conv2.bias",
|
||||
"first_stage_model.decoder.up.3.block.0.conv2.weight": "blocks.3.conv2.weight",
|
||||
"first_stage_model.decoder.up.3.block.0.norm1.bias": "blocks.3.norm1.bias",
|
||||
"first_stage_model.decoder.up.3.block.0.norm1.weight": "blocks.3.norm1.weight",
|
||||
"first_stage_model.decoder.up.3.block.0.norm2.bias": "blocks.3.norm2.bias",
|
||||
"first_stage_model.decoder.up.3.block.0.norm2.weight": "blocks.3.norm2.weight",
|
||||
"first_stage_model.decoder.up.3.block.1.conv1.bias": "blocks.4.conv1.bias",
|
||||
"first_stage_model.decoder.up.3.block.1.conv1.weight": "blocks.4.conv1.weight",
|
||||
"first_stage_model.decoder.up.3.block.1.conv2.bias": "blocks.4.conv2.bias",
|
||||
"first_stage_model.decoder.up.3.block.1.conv2.weight": "blocks.4.conv2.weight",
|
||||
"first_stage_model.decoder.up.3.block.1.norm1.bias": "blocks.4.norm1.bias",
|
||||
"first_stage_model.decoder.up.3.block.1.norm1.weight": "blocks.4.norm1.weight",
|
||||
"first_stage_model.decoder.up.3.block.1.norm2.bias": "blocks.4.norm2.bias",
|
||||
"first_stage_model.decoder.up.3.block.1.norm2.weight": "blocks.4.norm2.weight",
|
||||
"first_stage_model.decoder.up.3.block.2.conv1.bias": "blocks.5.conv1.bias",
|
||||
"first_stage_model.decoder.up.3.block.2.conv1.weight": "blocks.5.conv1.weight",
|
||||
"first_stage_model.decoder.up.3.block.2.conv2.bias": "blocks.5.conv2.bias",
|
||||
"first_stage_model.decoder.up.3.block.2.conv2.weight": "blocks.5.conv2.weight",
|
||||
"first_stage_model.decoder.up.3.block.2.norm1.bias": "blocks.5.norm1.bias",
|
||||
"first_stage_model.decoder.up.3.block.2.norm1.weight": "blocks.5.norm1.weight",
|
||||
"first_stage_model.decoder.up.3.block.2.norm2.bias": "blocks.5.norm2.bias",
|
||||
"first_stage_model.decoder.up.3.block.2.norm2.weight": "blocks.5.norm2.weight",
|
||||
"first_stage_model.decoder.up.3.upsample.conv.bias": "blocks.6.conv.bias",
|
||||
"first_stage_model.decoder.up.3.upsample.conv.weight": "blocks.6.conv.weight",
|
||||
"first_stage_model.post_quant_conv.bias": "post_quant_conv.bias",
|
||||
"first_stage_model.post_quant_conv.weight": "post_quant_conv.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if "transformer_blocks" in rename_dict[name]:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
@@ -1,282 +0,0 @@
|
||||
import torch
|
||||
from .sd_unet import ResnetBlock, DownSampler
|
||||
from .sd_vae_decoder import VAEAttentionBlock
|
||||
from .tiler import TileWorker
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class SDVAEEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.18215
|
||||
self.quant_conv = torch.nn.Conv2d(8, 8, kernel_size=1)
|
||||
self.conv_in = torch.nn.Conv2d(3, 128, kernel_size=3, padding=1)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
DownSampler(128, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(128, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
DownSampler(256, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(256, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
DownSampler(512, padding=0, extra_padding=True),
|
||||
# DownEncoderBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
# UNetMidBlock2D
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
])
|
||||
|
||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=512, num_groups=32, eps=1e-6)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(512, 8, kernel_size=3, padding=1)
|
||||
|
||||
def tiled_forward(self, sample, tile_size=64, tile_stride=32):
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: self.forward(x),
|
||||
sample,
|
||||
tile_size,
|
||||
tile_stride,
|
||||
tile_device=sample.device,
|
||||
tile_dtype=sample.dtype
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, sample, tiled=False, tile_size=64, tile_stride=32, **kwargs):
|
||||
original_dtype = sample.dtype
|
||||
sample = sample.to(dtype=next(iter(self.parameters())).dtype)
|
||||
# For VAE Decoder, we do not need to apply the tiler on each layer.
|
||||
if tiled:
|
||||
return self.tiled_forward(sample, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
# 1. pre-process
|
||||
hidden_states = self.conv_in(sample)
|
||||
time_emb = None
|
||||
text_emb = None
|
||||
res_stack = None
|
||||
|
||||
# 2. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 3. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = self.quant_conv(hidden_states)
|
||||
hidden_states = hidden_states[:, :4]
|
||||
hidden_states *= self.scaling_factor
|
||||
hidden_states = hidden_states.to(original_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def encode_video(self, sample, batch_size=8):
|
||||
B = sample.shape[0]
|
||||
hidden_states = []
|
||||
|
||||
for i in range(0, sample.shape[2], batch_size):
|
||||
|
||||
j = min(i + batch_size, sample.shape[2])
|
||||
sample_batch = rearrange(sample[:,:,i:j], "B C T H W -> (B T) C H W")
|
||||
|
||||
hidden_states_batch = self(sample_batch)
|
||||
hidden_states_batch = rearrange(hidden_states_batch, "(B T) C H W -> B C T H W", B=B)
|
||||
|
||||
hidden_states.append(hidden_states_batch)
|
||||
|
||||
hidden_states = torch.concat(hidden_states, dim=2)
|
||||
return hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SDVAEEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
# architecture
|
||||
block_types = [
|
||||
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
||||
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
||||
'ResnetBlock', 'ResnetBlock', 'DownSampler',
|
||||
'ResnetBlock', 'ResnetBlock',
|
||||
'ResnetBlock', 'VAEAttentionBlock', 'ResnetBlock'
|
||||
]
|
||||
|
||||
# Rename each parameter
|
||||
local_rename_dict = {
|
||||
"quant_conv": "quant_conv",
|
||||
"encoder.conv_in": "conv_in",
|
||||
"encoder.mid_block.attentions.0.group_norm": "blocks.12.norm",
|
||||
"encoder.mid_block.attentions.0.to_q": "blocks.12.transformer_blocks.0.to_q",
|
||||
"encoder.mid_block.attentions.0.to_k": "blocks.12.transformer_blocks.0.to_k",
|
||||
"encoder.mid_block.attentions.0.to_v": "blocks.12.transformer_blocks.0.to_v",
|
||||
"encoder.mid_block.attentions.0.to_out.0": "blocks.12.transformer_blocks.0.to_out",
|
||||
"encoder.mid_block.resnets.0.norm1": "blocks.11.norm1",
|
||||
"encoder.mid_block.resnets.0.conv1": "blocks.11.conv1",
|
||||
"encoder.mid_block.resnets.0.norm2": "blocks.11.norm2",
|
||||
"encoder.mid_block.resnets.0.conv2": "blocks.11.conv2",
|
||||
"encoder.mid_block.resnets.1.norm1": "blocks.13.norm1",
|
||||
"encoder.mid_block.resnets.1.conv1": "blocks.13.conv1",
|
||||
"encoder.mid_block.resnets.1.norm2": "blocks.13.norm2",
|
||||
"encoder.mid_block.resnets.1.conv2": "blocks.13.conv2",
|
||||
"encoder.conv_norm_out": "conv_norm_out",
|
||||
"encoder.conv_out": "conv_out",
|
||||
}
|
||||
name_list = sorted([name for name in state_dict])
|
||||
rename_dict = {}
|
||||
block_id = {"ResnetBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
||||
last_block_type_with_id = {"ResnetBlock": "", "DownSampler": "", "UpSampler": ""}
|
||||
for name in name_list:
|
||||
names = name.split(".")
|
||||
name_prefix = ".".join(names[:-1])
|
||||
if name_prefix in local_rename_dict:
|
||||
rename_dict[name] = local_rename_dict[name_prefix] + "." + names[-1]
|
||||
elif name.startswith("encoder.down_blocks"):
|
||||
block_type = {"resnets": "ResnetBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[3]]
|
||||
block_type_with_id = ".".join(names[:5])
|
||||
if block_type_with_id != last_block_type_with_id[block_type]:
|
||||
block_id[block_type] += 1
|
||||
last_block_type_with_id[block_type] = block_type_with_id
|
||||
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
||||
block_id[block_type] += 1
|
||||
block_type_with_id = ".".join(names[:5])
|
||||
names = ["blocks", str(block_id[block_type])] + names[5:]
|
||||
rename_dict[name] = ".".join(names)
|
||||
|
||||
# Convert state_dict
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"first_stage_model.encoder.conv_in.bias": "conv_in.bias",
|
||||
"first_stage_model.encoder.conv_in.weight": "conv_in.weight",
|
||||
"first_stage_model.encoder.conv_out.bias": "conv_out.bias",
|
||||
"first_stage_model.encoder.conv_out.weight": "conv_out.weight",
|
||||
"first_stage_model.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
|
||||
"first_stage_model.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
|
||||
"first_stage_model.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
|
||||
"first_stage_model.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
|
||||
"first_stage_model.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
|
||||
"first_stage_model.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
|
||||
"first_stage_model.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
|
||||
"first_stage_model.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
|
||||
"first_stage_model.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
|
||||
"first_stage_model.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
|
||||
"first_stage_model.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
|
||||
"first_stage_model.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
|
||||
"first_stage_model.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
|
||||
"first_stage_model.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
|
||||
"first_stage_model.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
|
||||
"first_stage_model.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
|
||||
"first_stage_model.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
|
||||
"first_stage_model.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
|
||||
"first_stage_model.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
|
||||
"first_stage_model.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
|
||||
"first_stage_model.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
|
||||
"first_stage_model.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
|
||||
"first_stage_model.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
|
||||
"first_stage_model.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
|
||||
"first_stage_model.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
|
||||
"first_stage_model.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
|
||||
"first_stage_model.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
|
||||
"first_stage_model.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
|
||||
"first_stage_model.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
|
||||
"first_stage_model.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
|
||||
"first_stage_model.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
|
||||
"first_stage_model.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
|
||||
"first_stage_model.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
|
||||
"first_stage_model.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
|
||||
"first_stage_model.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
|
||||
"first_stage_model.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
|
||||
"first_stage_model.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
|
||||
"first_stage_model.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
|
||||
"first_stage_model.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
|
||||
"first_stage_model.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
|
||||
"first_stage_model.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
|
||||
"first_stage_model.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
|
||||
"first_stage_model.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
|
||||
"first_stage_model.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
|
||||
"first_stage_model.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
|
||||
"first_stage_model.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
|
||||
"first_stage_model.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
|
||||
"first_stage_model.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
|
||||
"first_stage_model.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
|
||||
"first_stage_model.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
|
||||
"first_stage_model.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
|
||||
"first_stage_model.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
|
||||
"first_stage_model.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
|
||||
"first_stage_model.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
|
||||
"first_stage_model.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
|
||||
"first_stage_model.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
|
||||
"first_stage_model.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
|
||||
"first_stage_model.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
|
||||
"first_stage_model.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
|
||||
"first_stage_model.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
|
||||
"first_stage_model.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
|
||||
"first_stage_model.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
|
||||
"first_stage_model.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
|
||||
"first_stage_model.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
|
||||
"first_stage_model.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
|
||||
"first_stage_model.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
|
||||
"first_stage_model.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
|
||||
"first_stage_model.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
|
||||
"first_stage_model.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
|
||||
"first_stage_model.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
|
||||
"first_stage_model.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
|
||||
"first_stage_model.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
|
||||
"first_stage_model.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
|
||||
"first_stage_model.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
|
||||
"first_stage_model.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
|
||||
"first_stage_model.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
|
||||
"first_stage_model.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
|
||||
"first_stage_model.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
|
||||
"first_stage_model.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
|
||||
"first_stage_model.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
|
||||
"first_stage_model.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
|
||||
"first_stage_model.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
|
||||
"first_stage_model.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
|
||||
"first_stage_model.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
|
||||
"first_stage_model.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
|
||||
"first_stage_model.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
|
||||
"first_stage_model.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
|
||||
"first_stage_model.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
|
||||
"first_stage_model.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
|
||||
"first_stage_model.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
|
||||
"first_stage_model.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
|
||||
"first_stage_model.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
|
||||
"first_stage_model.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
|
||||
"first_stage_model.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
|
||||
"first_stage_model.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
|
||||
"first_stage_model.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
|
||||
"first_stage_model.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
|
||||
"first_stage_model.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
|
||||
"first_stage_model.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
|
||||
"first_stage_model.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
|
||||
"first_stage_model.encoder.norm_out.bias": "conv_norm_out.bias",
|
||||
"first_stage_model.encoder.norm_out.weight": "conv_norm_out.weight",
|
||||
"first_stage_model.quant_conv.bias": "quant_conv.bias",
|
||||
"first_stage_model.quant_conv.weight": "quant_conv.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if "transformer_blocks" in rename_dict[name]:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
@@ -1,318 +0,0 @@
|
||||
import torch
|
||||
from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, DownSampler
|
||||
from .sdxl_unet import SDXLUNet
|
||||
from .tiler import TileWorker
|
||||
from .sd_controlnet import ControlNetConditioningLayer
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
|
||||
class QuickGELU(torch.nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
|
||||
class ResidualAttentionBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
||||
super().__init__()
|
||||
|
||||
self.attn = torch.nn.MultiheadAttention(d_model, n_head)
|
||||
self.ln_1 = torch.nn.LayerNorm(d_model)
|
||||
self.mlp = torch.nn.Sequential(OrderedDict([
|
||||
("c_fc", torch.nn.Linear(d_model, d_model * 4)),
|
||||
("gelu", QuickGELU()),
|
||||
("c_proj", torch.nn.Linear(d_model * 4, d_model))
|
||||
]))
|
||||
self.ln_2 = torch.nn.LayerNorm(d_model)
|
||||
self.attn_mask = attn_mask
|
||||
|
||||
def attention(self, x: torch.Tensor):
|
||||
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
||||
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x + self.attention(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class SDXLControlNetUnion(torch.nn.Module):
|
||||
def __init__(self, global_pool=False):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(320)
|
||||
self.time_embedding = torch.nn.Sequential(
|
||||
torch.nn.Linear(320, 1280),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(1280, 1280)
|
||||
)
|
||||
self.add_time_proj = Timesteps(256)
|
||||
self.add_time_embedding = torch.nn.Sequential(
|
||||
torch.nn.Linear(2816, 1280),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(1280, 1280)
|
||||
)
|
||||
self.control_type_proj = Timesteps(256)
|
||||
self.control_type_embedding = torch.nn.Sequential(
|
||||
torch.nn.Linear(256 * 8, 1280),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(1280, 1280)
|
||||
)
|
||||
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
||||
|
||||
self.controlnet_conv_in = ControlNetConditioningLayer(channels=(3, 16, 32, 96, 256, 320))
|
||||
self.controlnet_transformer = ResidualAttentionBlock(320, 8)
|
||||
self.task_embedding = torch.nn.Parameter(torch.randn(8, 320))
|
||||
self.spatial_ch_projs = torch.nn.Linear(320, 320)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# DownBlock2D
|
||||
ResnetBlock(320, 320, 1280),
|
||||
PushBlock(),
|
||||
ResnetBlock(320, 320, 1280),
|
||||
PushBlock(),
|
||||
DownSampler(320),
|
||||
PushBlock(),
|
||||
# CrossAttnDownBlock2D
|
||||
ResnetBlock(320, 640, 1280),
|
||||
AttentionBlock(10, 64, 640, 2, 2048),
|
||||
PushBlock(),
|
||||
ResnetBlock(640, 640, 1280),
|
||||
AttentionBlock(10, 64, 640, 2, 2048),
|
||||
PushBlock(),
|
||||
DownSampler(640),
|
||||
PushBlock(),
|
||||
# CrossAttnDownBlock2D
|
||||
ResnetBlock(640, 1280, 1280),
|
||||
AttentionBlock(20, 64, 1280, 10, 2048),
|
||||
PushBlock(),
|
||||
ResnetBlock(1280, 1280, 1280),
|
||||
AttentionBlock(20, 64, 1280, 10, 2048),
|
||||
PushBlock(),
|
||||
# UNetMidBlock2DCrossAttn
|
||||
ResnetBlock(1280, 1280, 1280),
|
||||
AttentionBlock(20, 64, 1280, 10, 2048),
|
||||
ResnetBlock(1280, 1280, 1280),
|
||||
PushBlock()
|
||||
])
|
||||
|
||||
self.controlnet_blocks = torch.nn.ModuleList([
|
||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(320, 320, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(640, 640, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
||||
torch.nn.Conv2d(1280, 1280, kernel_size=(1, 1)),
|
||||
])
|
||||
|
||||
self.global_pool = global_pool
|
||||
|
||||
# 0 -- openpose
|
||||
# 1 -- depth
|
||||
# 2 -- hed/pidi/scribble/ted
|
||||
# 3 -- canny/lineart/anime_lineart/mlsd
|
||||
# 4 -- normal
|
||||
# 5 -- segment
|
||||
# 6 -- tile
|
||||
# 7 -- repaint
|
||||
self.task_id = {
|
||||
"openpose": 0,
|
||||
"depth": 1,
|
||||
"softedge": 2,
|
||||
"canny": 3,
|
||||
"lineart": 3,
|
||||
"lineart_anime": 3,
|
||||
"tile": 6,
|
||||
"inpaint": 7
|
||||
}
|
||||
|
||||
|
||||
def fuse_condition_to_input(self, hidden_states, task_id, conditioning):
|
||||
controlnet_cond = self.controlnet_conv_in(conditioning)
|
||||
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
|
||||
feat_seq = feat_seq + self.task_embedding[task_id]
|
||||
x = torch.stack([feat_seq, torch.mean(hidden_states, dim=(2, 3))], dim=1)
|
||||
x = self.controlnet_transformer(x)
|
||||
|
||||
alpha = self.spatial_ch_projs(x[:,0]).unsqueeze(-1).unsqueeze(-1)
|
||||
controlnet_cond_fuser = controlnet_cond + alpha
|
||||
|
||||
hidden_states = hidden_states + controlnet_cond_fuser
|
||||
return hidden_states
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample, timestep, encoder_hidden_states,
|
||||
conditioning, processor_id, add_time_id, add_text_embeds,
|
||||
tiled=False, tile_size=64, tile_stride=32,
|
||||
unet:SDXLUNet=None,
|
||||
**kwargs
|
||||
):
|
||||
task_id = self.task_id[processor_id]
|
||||
|
||||
# 1. time
|
||||
t_emb = self.time_proj(timestep).to(sample.dtype)
|
||||
t_emb = self.time_embedding(t_emb)
|
||||
|
||||
time_embeds = self.add_time_proj(add_time_id)
|
||||
time_embeds = time_embeds.reshape((add_text_embeds.shape[0], -1))
|
||||
add_embeds = torch.concat([add_text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(sample.dtype)
|
||||
if unet is not None and unet.is_kolors:
|
||||
add_embeds = unet.add_time_embedding(add_embeds)
|
||||
else:
|
||||
add_embeds = self.add_time_embedding(add_embeds)
|
||||
|
||||
control_type = torch.zeros((sample.shape[0], 8), dtype=sample.dtype, device=sample.device)
|
||||
control_type[:, task_id] = 1
|
||||
control_embeds = self.control_type_proj(control_type.flatten())
|
||||
control_embeds = control_embeds.reshape((sample.shape[0], -1))
|
||||
control_embeds = control_embeds.to(sample.dtype)
|
||||
control_embeds = self.control_type_embedding(control_embeds)
|
||||
time_emb = t_emb + add_embeds + control_embeds
|
||||
|
||||
# 2. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = self.conv_in(sample)
|
||||
hidden_states = self.fuse_condition_to_input(hidden_states, task_id, conditioning)
|
||||
text_emb = encoder_hidden_states
|
||||
if unet is not None and unet.is_kolors:
|
||||
text_emb = unet.text_intermediate_proj(text_emb)
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 3. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
if tiled and not isinstance(block, PushBlock):
|
||||
_, _, inter_height, _ = hidden_states.shape
|
||||
resize_scale = inter_height / height
|
||||
hidden_states = TileWorker().tiled_forward(
|
||||
lambda x: block(x, time_emb, text_emb, res_stack)[0],
|
||||
hidden_states,
|
||||
int(tile_size * resize_scale),
|
||||
int(tile_stride * resize_scale),
|
||||
tile_device=hidden_states.device,
|
||||
tile_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
hidden_states, _, _, _ = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 4. ControlNet blocks
|
||||
controlnet_res_stack = [block(res) for block, res in zip(self.controlnet_blocks, res_stack)]
|
||||
|
||||
# pool
|
||||
if self.global_pool:
|
||||
controlnet_res_stack = [res.mean(dim=(2, 3), keepdim=True) for res in controlnet_res_stack]
|
||||
|
||||
return controlnet_res_stack
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDXLControlNetUnionStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class SDXLControlNetUnionStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
# architecture
|
||||
block_types = [
|
||||
"ResnetBlock", "PushBlock", "ResnetBlock", "PushBlock", "DownSampler", "PushBlock",
|
||||
"ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock", "DownSampler", "PushBlock",
|
||||
"ResnetBlock", "AttentionBlock", "PushBlock", "ResnetBlock", "AttentionBlock", "PushBlock",
|
||||
"ResnetBlock", "AttentionBlock", "ResnetBlock", "PushBlock"
|
||||
]
|
||||
|
||||
# controlnet_rename_dict
|
||||
controlnet_rename_dict = {
|
||||
"controlnet_cond_embedding.conv_in.weight": "controlnet_conv_in.blocks.0.weight",
|
||||
"controlnet_cond_embedding.conv_in.bias": "controlnet_conv_in.blocks.0.bias",
|
||||
"controlnet_cond_embedding.blocks.0.weight": "controlnet_conv_in.blocks.2.weight",
|
||||
"controlnet_cond_embedding.blocks.0.bias": "controlnet_conv_in.blocks.2.bias",
|
||||
"controlnet_cond_embedding.blocks.1.weight": "controlnet_conv_in.blocks.4.weight",
|
||||
"controlnet_cond_embedding.blocks.1.bias": "controlnet_conv_in.blocks.4.bias",
|
||||
"controlnet_cond_embedding.blocks.2.weight": "controlnet_conv_in.blocks.6.weight",
|
||||
"controlnet_cond_embedding.blocks.2.bias": "controlnet_conv_in.blocks.6.bias",
|
||||
"controlnet_cond_embedding.blocks.3.weight": "controlnet_conv_in.blocks.8.weight",
|
||||
"controlnet_cond_embedding.blocks.3.bias": "controlnet_conv_in.blocks.8.bias",
|
||||
"controlnet_cond_embedding.blocks.4.weight": "controlnet_conv_in.blocks.10.weight",
|
||||
"controlnet_cond_embedding.blocks.4.bias": "controlnet_conv_in.blocks.10.bias",
|
||||
"controlnet_cond_embedding.blocks.5.weight": "controlnet_conv_in.blocks.12.weight",
|
||||
"controlnet_cond_embedding.blocks.5.bias": "controlnet_conv_in.blocks.12.bias",
|
||||
"controlnet_cond_embedding.conv_out.weight": "controlnet_conv_in.blocks.14.weight",
|
||||
"controlnet_cond_embedding.conv_out.bias": "controlnet_conv_in.blocks.14.bias",
|
||||
"control_add_embedding.linear_1.weight": "control_type_embedding.0.weight",
|
||||
"control_add_embedding.linear_1.bias": "control_type_embedding.0.bias",
|
||||
"control_add_embedding.linear_2.weight": "control_type_embedding.2.weight",
|
||||
"control_add_embedding.linear_2.bias": "control_type_embedding.2.bias",
|
||||
}
|
||||
|
||||
# Rename each parameter
|
||||
name_list = sorted([name for name in state_dict])
|
||||
rename_dict = {}
|
||||
block_id = {"ResnetBlock": -1, "AttentionBlock": -1, "DownSampler": -1, "UpSampler": -1}
|
||||
last_block_type_with_id = {"ResnetBlock": "", "AttentionBlock": "", "DownSampler": "", "UpSampler": ""}
|
||||
for name in name_list:
|
||||
names = name.split(".")
|
||||
if names[0] in ["conv_in", "conv_norm_out", "conv_out", "task_embedding", "spatial_ch_projs"]:
|
||||
pass
|
||||
elif name in controlnet_rename_dict:
|
||||
names = controlnet_rename_dict[name].split(".")
|
||||
elif names[0] == "controlnet_down_blocks":
|
||||
names[0] = "controlnet_blocks"
|
||||
elif names[0] == "controlnet_mid_block":
|
||||
names = ["controlnet_blocks", "9", names[-1]]
|
||||
elif names[0] in ["time_embedding", "add_embedding"]:
|
||||
if names[0] == "add_embedding":
|
||||
names[0] = "add_time_embedding"
|
||||
names[1] = {"linear_1": "0", "linear_2": "2"}[names[1]]
|
||||
elif names[0] == "control_add_embedding":
|
||||
names[0] = "control_type_embedding"
|
||||
elif names[0] == "transformer_layes":
|
||||
names[0] = "controlnet_transformer"
|
||||
names.pop(1)
|
||||
elif names[0] in ["down_blocks", "mid_block", "up_blocks"]:
|
||||
if names[0] == "mid_block":
|
||||
names.insert(1, "0")
|
||||
block_type = {"resnets": "ResnetBlock", "attentions": "AttentionBlock", "downsamplers": "DownSampler", "upsamplers": "UpSampler"}[names[2]]
|
||||
block_type_with_id = ".".join(names[:4])
|
||||
if block_type_with_id != last_block_type_with_id[block_type]:
|
||||
block_id[block_type] += 1
|
||||
last_block_type_with_id[block_type] = block_type_with_id
|
||||
while block_id[block_type] < len(block_types) and block_types[block_id[block_type]] != block_type:
|
||||
block_id[block_type] += 1
|
||||
block_type_with_id = ".".join(names[:4])
|
||||
names = ["blocks", str(block_id[block_type])] + names[4:]
|
||||
if "ff" in names:
|
||||
ff_index = names.index("ff")
|
||||
component = ".".join(names[ff_index:ff_index+3])
|
||||
component = {"ff.net.0": "act_fn", "ff.net.2": "ff"}[component]
|
||||
names = names[:ff_index] + [component] + names[ff_index+3:]
|
||||
if "to_out" in names:
|
||||
names.pop(names.index("to_out") + 1)
|
||||
else:
|
||||
print(name, state_dict[name].shape)
|
||||
# raise ValueError(f"Unknown parameters: {name}")
|
||||
rename_dict[name] = ".".join(names)
|
||||
|
||||
# Convert state_dict
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name not in rename_dict:
|
||||
continue
|
||||
if ".proj_in." in name or ".proj_out." in name:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
@@ -1,122 +0,0 @@
|
||||
from .svd_image_encoder import SVDImageEncoder
|
||||
from transformers import CLIPImageProcessor
|
||||
import torch
|
||||
|
||||
|
||||
class IpAdapterXLCLIPImageEmbedder(SVDImageEncoder):
|
||||
def __init__(self):
|
||||
super().__init__(embed_dim=1664, encoder_intermediate_size=8192, projection_dim=1280, num_encoder_layers=48, num_heads=16, head_dim=104)
|
||||
self.image_processor = CLIPImageProcessor()
|
||||
|
||||
def forward(self, image):
|
||||
pixel_values = self.image_processor(images=image, return_tensors="pt").pixel_values
|
||||
pixel_values = pixel_values.to(device=self.embeddings.class_embedding.device, dtype=self.embeddings.class_embedding.dtype)
|
||||
return super().forward(pixel_values)
|
||||
|
||||
|
||||
class IpAdapterImageProjModel(torch.nn.Module):
|
||||
def __init__(self, cross_attention_dim=2048, clip_embeddings_dim=1280, clip_extra_context_tokens=4):
|
||||
super().__init__()
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.clip_extra_context_tokens = clip_extra_context_tokens
|
||||
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
||||
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
def forward(self, image_embeds):
|
||||
clip_extra_context_tokens = self.proj(image_embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)
|
||||
clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
|
||||
return clip_extra_context_tokens
|
||||
|
||||
|
||||
class IpAdapterModule(torch.nn.Module):
|
||||
def __init__(self, input_dim, output_dim):
|
||||
super().__init__()
|
||||
self.to_k_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
||||
self.to_v_ip = torch.nn.Linear(input_dim, output_dim, bias=False)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
ip_k = self.to_k_ip(hidden_states)
|
||||
ip_v = self.to_v_ip(hidden_states)
|
||||
return ip_k, ip_v
|
||||
|
||||
|
||||
class SDXLIpAdapter(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
shape_list = [(2048, 640)] * 4 + [(2048, 1280)] * 50 + [(2048, 640)] * 6 + [(2048, 1280)] * 10
|
||||
self.ipadapter_modules = torch.nn.ModuleList([IpAdapterModule(*shape) for shape in shape_list])
|
||||
self.image_proj = IpAdapterImageProjModel()
|
||||
self.set_full_adapter()
|
||||
|
||||
def set_full_adapter(self):
|
||||
map_list = sum([
|
||||
[(7, i) for i in range(2)],
|
||||
[(10, i) for i in range(2)],
|
||||
[(15, i) for i in range(10)],
|
||||
[(18, i) for i in range(10)],
|
||||
[(25, i) for i in range(10)],
|
||||
[(28, i) for i in range(10)],
|
||||
[(31, i) for i in range(10)],
|
||||
[(35, i) for i in range(2)],
|
||||
[(38, i) for i in range(2)],
|
||||
[(41, i) for i in range(2)],
|
||||
[(21, i) for i in range(10)],
|
||||
], [])
|
||||
self.call_block_id = {i: j for j, i in enumerate(map_list)}
|
||||
|
||||
def set_less_adapter(self):
|
||||
map_list = sum([
|
||||
[(7, i) for i in range(2)],
|
||||
[(10, i) for i in range(2)],
|
||||
[(15, i) for i in range(10)],
|
||||
[(18, i) for i in range(10)],
|
||||
[(25, i) for i in range(10)],
|
||||
[(28, i) for i in range(10)],
|
||||
[(31, i) for i in range(10)],
|
||||
[(35, i) for i in range(2)],
|
||||
[(38, i) for i in range(2)],
|
||||
[(41, i) for i in range(2)],
|
||||
[(21, i) for i in range(10)],
|
||||
], [])
|
||||
self.call_block_id = {i: j for j, i in enumerate(map_list) if j>=34 and j<44}
|
||||
|
||||
def forward(self, hidden_states, scale=1.0):
|
||||
hidden_states = self.image_proj(hidden_states)
|
||||
hidden_states = hidden_states.view(1, -1, hidden_states.shape[-1])
|
||||
ip_kv_dict = {}
|
||||
for (block_id, transformer_id) in self.call_block_id:
|
||||
ipadapter_id = self.call_block_id[(block_id, transformer_id)]
|
||||
ip_k, ip_v = self.ipadapter_modules[ipadapter_id](hidden_states)
|
||||
if block_id not in ip_kv_dict:
|
||||
ip_kv_dict[block_id] = {}
|
||||
ip_kv_dict[block_id][transformer_id] = {
|
||||
"ip_k": ip_k,
|
||||
"ip_v": ip_v,
|
||||
"scale": scale
|
||||
}
|
||||
return ip_kv_dict
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDXLIpAdapterStateDictConverter()
|
||||
|
||||
|
||||
class SDXLIpAdapterStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict["ip_adapter"]:
|
||||
names = name.split(".")
|
||||
layer_id = str(int(names[0]) // 2)
|
||||
name_ = ".".join(["ipadapter_modules"] + [layer_id] + names[1:])
|
||||
state_dict_[name_] = state_dict["ip_adapter"][name]
|
||||
for name in state_dict["image_proj"]:
|
||||
name_ = "image_proj." + name
|
||||
state_dict_[name_] = state_dict["image_proj"][name]
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
from .sd_motion import TemporalBlock
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class SDXLMotionModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.motion_modules = torch.nn.ModuleList([
|
||||
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
||||
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
||||
|
||||
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
||||
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
||||
|
||||
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
||||
|
||||
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
||||
TemporalBlock(8, 1280//8, 1280, eps=1e-6),
|
||||
|
||||
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
||||
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
||||
TemporalBlock(8, 640//8, 640, eps=1e-6),
|
||||
|
||||
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
||||
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
||||
TemporalBlock(8, 320//8, 320, eps=1e-6),
|
||||
])
|
||||
self.call_block_id = {
|
||||
0: 0,
|
||||
2: 1,
|
||||
7: 2,
|
||||
10: 3,
|
||||
15: 4,
|
||||
18: 5,
|
||||
25: 6,
|
||||
28: 7,
|
||||
31: 8,
|
||||
35: 9,
|
||||
38: 10,
|
||||
41: 11,
|
||||
44: 12,
|
||||
46: 13,
|
||||
48: 14,
|
||||
}
|
||||
|
||||
def forward(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDMotionModelStateDictConverter()
|
||||
|
||||
|
||||
class SDMotionModelStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"norm": "norm",
|
||||
"proj_in": "proj_in",
|
||||
"transformer_blocks.0.attention_blocks.0.to_q": "transformer_blocks.0.attn1.to_q",
|
||||
"transformer_blocks.0.attention_blocks.0.to_k": "transformer_blocks.0.attn1.to_k",
|
||||
"transformer_blocks.0.attention_blocks.0.to_v": "transformer_blocks.0.attn1.to_v",
|
||||
"transformer_blocks.0.attention_blocks.0.to_out.0": "transformer_blocks.0.attn1.to_out",
|
||||
"transformer_blocks.0.attention_blocks.0.pos_encoder": "transformer_blocks.0.pe1",
|
||||
"transformer_blocks.0.attention_blocks.1.to_q": "transformer_blocks.0.attn2.to_q",
|
||||
"transformer_blocks.0.attention_blocks.1.to_k": "transformer_blocks.0.attn2.to_k",
|
||||
"transformer_blocks.0.attention_blocks.1.to_v": "transformer_blocks.0.attn2.to_v",
|
||||
"transformer_blocks.0.attention_blocks.1.to_out.0": "transformer_blocks.0.attn2.to_out",
|
||||
"transformer_blocks.0.attention_blocks.1.pos_encoder": "transformer_blocks.0.pe2",
|
||||
"transformer_blocks.0.norms.0": "transformer_blocks.0.norm1",
|
||||
"transformer_blocks.0.norms.1": "transformer_blocks.0.norm2",
|
||||
"transformer_blocks.0.ff.net.0.proj": "transformer_blocks.0.act_fn.proj",
|
||||
"transformer_blocks.0.ff.net.2": "transformer_blocks.0.ff",
|
||||
"transformer_blocks.0.ff_norm": "transformer_blocks.0.norm3",
|
||||
"proj_out": "proj_out",
|
||||
}
|
||||
name_list = sorted([i for i in state_dict if i.startswith("down_blocks.")])
|
||||
name_list += sorted([i for i in state_dict if i.startswith("mid_block.")])
|
||||
name_list += sorted([i for i in state_dict if i.startswith("up_blocks.")])
|
||||
state_dict_ = {}
|
||||
last_prefix, module_id = "", -1
|
||||
for name in name_list:
|
||||
names = name.split(".")
|
||||
prefix_index = names.index("temporal_transformer") + 1
|
||||
prefix = ".".join(names[:prefix_index])
|
||||
if prefix != last_prefix:
|
||||
last_prefix = prefix
|
||||
module_id += 1
|
||||
middle_name = ".".join(names[prefix_index:-1])
|
||||
suffix = names[-1]
|
||||
if "pos_encoder" in names:
|
||||
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name]])
|
||||
else:
|
||||
rename = ".".join(["motion_modules", str(module_id), rename_dict[middle_name], suffix])
|
||||
state_dict_[rename] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
@@ -1,759 +0,0 @@
|
||||
import torch
|
||||
from .sd_text_encoder import CLIPEncoderLayer
|
||||
|
||||
|
||||
class SDXLTextEncoder(torch.nn.Module):
|
||||
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=11, encoder_intermediate_size=3072):
|
||||
super().__init__()
|
||||
|
||||
# token_embedding
|
||||
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
||||
|
||||
# position_embeds (This is a fixed tensor)
|
||||
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||
|
||||
# encoders
|
||||
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
||||
|
||||
# attn_mask
|
||||
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||
|
||||
# The text encoder is different to that in Stable Diffusion 1.x.
|
||||
# It does not include final_layer_norm.
|
||||
|
||||
def attention_mask(self, length):
|
||||
mask = torch.empty(length, length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1)
|
||||
return mask
|
||||
|
||||
def forward(self, input_ids, clip_skip=1):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
break
|
||||
return embeds
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDXLTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SDXLTextEncoder2(torch.nn.Module):
|
||||
def __init__(self, embed_dim=1280, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=32, encoder_intermediate_size=5120):
|
||||
super().__init__()
|
||||
|
||||
# token_embedding
|
||||
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
||||
|
||||
# position_embeds (This is a fixed tensor)
|
||||
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||
|
||||
# encoders
|
||||
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=20, head_dim=64, use_quick_gelu=False) for _ in range(num_encoder_layers)])
|
||||
|
||||
# attn_mask
|
||||
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||
|
||||
# final_layer_norm
|
||||
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||
|
||||
# text_projection
|
||||
self.text_projection = torch.nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
|
||||
def attention_mask(self, length):
|
||||
mask = torch.empty(length, length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1)
|
||||
return mask
|
||||
|
||||
def forward(self, input_ids, clip_skip=2):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
hidden_states = embeds
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
||||
pooled_embeds = self.text_projection(pooled_embeds)
|
||||
return pooled_embeds, hidden_states
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDXLTextEncoder2StateDictConverter()
|
||||
|
||||
|
||||
class SDXLTextEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
||||
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||
"text_model.final_layer_norm.bias": "final_layer_norm.bias"
|
||||
}
|
||||
attn_rename_dict = {
|
||||
"self_attn.q_proj": "attn.to_q",
|
||||
"self_attn.k_proj": "attn.to_k",
|
||||
"self_attn.v_proj": "attn.to_v",
|
||||
"self_attn.out_proj": "attn.to_out",
|
||||
"layer_norm1": "layer_norm1",
|
||||
"layer_norm2": "layer_norm2",
|
||||
"mlp.fc1": "fc1",
|
||||
"mlp.fc2": "fc2",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif name.startswith("text_model.encoder.layers."):
|
||||
param = state_dict[name]
|
||||
names = name.split(".")
|
||||
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
||||
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight": "position_embeds",
|
||||
"conditioner.embedders.0.transformer.text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.bias": "encoders.0.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm1.weight": "encoders.0.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.bias": "encoders.0.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.layer_norm2.weight": "encoders.0.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.bias": "encoders.0.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc1.weight": "encoders.0.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.bias": "encoders.0.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.mlp.fc2.weight": "encoders.0.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.bias": "encoders.0.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.k_proj.weight": "encoders.0.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.bias": "encoders.0.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.q_proj.weight": "encoders.0.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.bias": "encoders.0.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.0.self_attn.v_proj.weight": "encoders.0.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.bias": "encoders.1.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm1.weight": "encoders.1.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.bias": "encoders.1.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.layer_norm2.weight": "encoders.1.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.bias": "encoders.1.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc1.weight": "encoders.1.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.bias": "encoders.1.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.mlp.fc2.weight": "encoders.1.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.bias": "encoders.1.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.k_proj.weight": "encoders.1.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.bias": "encoders.1.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.q_proj.weight": "encoders.1.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.bias": "encoders.1.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.1.self_attn.v_proj.weight": "encoders.1.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.bias": "encoders.10.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm1.weight": "encoders.10.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.bias": "encoders.10.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.layer_norm2.weight": "encoders.10.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.bias": "encoders.10.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc1.weight": "encoders.10.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.bias": "encoders.10.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.mlp.fc2.weight": "encoders.10.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.bias": "encoders.10.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.k_proj.weight": "encoders.10.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.bias": "encoders.10.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.q_proj.weight": "encoders.10.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.bias": "encoders.10.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.10.self_attn.v_proj.weight": "encoders.10.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.bias": "encoders.2.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm1.weight": "encoders.2.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.bias": "encoders.2.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.layer_norm2.weight": "encoders.2.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.bias": "encoders.2.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc1.weight": "encoders.2.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.bias": "encoders.2.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.mlp.fc2.weight": "encoders.2.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.bias": "encoders.2.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.k_proj.weight": "encoders.2.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.bias": "encoders.2.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.q_proj.weight": "encoders.2.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.bias": "encoders.2.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.2.self_attn.v_proj.weight": "encoders.2.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.bias": "encoders.3.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm1.weight": "encoders.3.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.bias": "encoders.3.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.layer_norm2.weight": "encoders.3.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.bias": "encoders.3.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc1.weight": "encoders.3.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.bias": "encoders.3.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.mlp.fc2.weight": "encoders.3.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.bias": "encoders.3.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.k_proj.weight": "encoders.3.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.bias": "encoders.3.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.q_proj.weight": "encoders.3.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.bias": "encoders.3.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.3.self_attn.v_proj.weight": "encoders.3.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.bias": "encoders.4.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm1.weight": "encoders.4.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.bias": "encoders.4.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.layer_norm2.weight": "encoders.4.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.bias": "encoders.4.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc1.weight": "encoders.4.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.bias": "encoders.4.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.mlp.fc2.weight": "encoders.4.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.bias": "encoders.4.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.k_proj.weight": "encoders.4.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.bias": "encoders.4.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.q_proj.weight": "encoders.4.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.bias": "encoders.4.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.4.self_attn.v_proj.weight": "encoders.4.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.bias": "encoders.5.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm1.weight": "encoders.5.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.bias": "encoders.5.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.layer_norm2.weight": "encoders.5.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.bias": "encoders.5.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc1.weight": "encoders.5.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.bias": "encoders.5.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.mlp.fc2.weight": "encoders.5.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.bias": "encoders.5.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.k_proj.weight": "encoders.5.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.bias": "encoders.5.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.q_proj.weight": "encoders.5.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.bias": "encoders.5.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.5.self_attn.v_proj.weight": "encoders.5.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.bias": "encoders.6.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm1.weight": "encoders.6.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.bias": "encoders.6.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.layer_norm2.weight": "encoders.6.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.bias": "encoders.6.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc1.weight": "encoders.6.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.bias": "encoders.6.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.mlp.fc2.weight": "encoders.6.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.bias": "encoders.6.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.k_proj.weight": "encoders.6.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.bias": "encoders.6.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.q_proj.weight": "encoders.6.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.bias": "encoders.6.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.6.self_attn.v_proj.weight": "encoders.6.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.bias": "encoders.7.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm1.weight": "encoders.7.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.bias": "encoders.7.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.layer_norm2.weight": "encoders.7.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.bias": "encoders.7.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc1.weight": "encoders.7.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.bias": "encoders.7.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.mlp.fc2.weight": "encoders.7.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.bias": "encoders.7.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.k_proj.weight": "encoders.7.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.bias": "encoders.7.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.q_proj.weight": "encoders.7.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.bias": "encoders.7.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.7.self_attn.v_proj.weight": "encoders.7.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.bias": "encoders.8.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm1.weight": "encoders.8.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.bias": "encoders.8.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.layer_norm2.weight": "encoders.8.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.bias": "encoders.8.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc1.weight": "encoders.8.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.bias": "encoders.8.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.mlp.fc2.weight": "encoders.8.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.bias": "encoders.8.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.k_proj.weight": "encoders.8.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.bias": "encoders.8.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.q_proj.weight": "encoders.8.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.bias": "encoders.8.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.8.self_attn.v_proj.weight": "encoders.8.attn.to_v.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.bias": "encoders.9.layer_norm1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm1.weight": "encoders.9.layer_norm1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.bias": "encoders.9.layer_norm2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.layer_norm2.weight": "encoders.9.layer_norm2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.bias": "encoders.9.fc1.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc1.weight": "encoders.9.fc1.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.bias": "encoders.9.fc2.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.mlp.fc2.weight": "encoders.9.fc2.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.bias": "encoders.9.attn.to_k.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.k_proj.weight": "encoders.9.attn.to_k.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.bias": "encoders.9.attn.to_q.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.q_proj.weight": "encoders.9.attn.to_q.weight",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.bias": "encoders.9.attn.to_v.bias",
|
||||
"conditioner.embedders.0.transformer.text_model.encoder.layers.9.self_attn.v_proj.weight": "encoders.9.attn.to_v.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "conditioner.embedders.0.transformer.text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
class SDXLTextEncoder2StateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"text_model.embeddings.token_embedding.weight": "token_embedding.weight",
|
||||
"text_model.embeddings.position_embedding.weight": "position_embeds",
|
||||
"text_model.final_layer_norm.weight": "final_layer_norm.weight",
|
||||
"text_model.final_layer_norm.bias": "final_layer_norm.bias",
|
||||
"text_projection.weight": "text_projection.weight"
|
||||
}
|
||||
attn_rename_dict = {
|
||||
"self_attn.q_proj": "attn.to_q",
|
||||
"self_attn.k_proj": "attn.to_k",
|
||||
"self_attn.v_proj": "attn.to_v",
|
||||
"self_attn.out_proj": "attn.to_out",
|
||||
"layer_norm1": "layer_norm1",
|
||||
"layer_norm2": "layer_norm2",
|
||||
"mlp.fc1": "fc1",
|
||||
"mlp.fc2": "fc2",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "text_model.embeddings.position_embedding.weight":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif name.startswith("text_model.encoder.layers."):
|
||||
param = state_dict[name]
|
||||
names = name.split(".")
|
||||
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
||||
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"conditioner.embedders.1.model.ln_final.bias": "final_layer_norm.bias",
|
||||
"conditioner.embedders.1.model.ln_final.weight": "final_layer_norm.weight",
|
||||
"conditioner.embedders.1.model.positional_embedding": "position_embeds",
|
||||
"conditioner.embedders.1.model.token_embedding.weight": "token_embedding.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_bias": ['encoders.0.attn.to_q.bias', 'encoders.0.attn.to_k.bias', 'encoders.0.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.attn.in_proj_weight": ['encoders.0.attn.to_q.weight', 'encoders.0.attn.to_k.weight', 'encoders.0.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.ln_1.bias": "encoders.0.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.ln_1.weight": "encoders.0.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.ln_2.bias": "encoders.0.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.ln_2.weight": "encoders.0.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.bias": "encoders.0.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_fc.weight": "encoders.0.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.bias": "encoders.0.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.0.mlp.c_proj.weight": "encoders.0.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_bias": ['encoders.1.attn.to_q.bias', 'encoders.1.attn.to_k.bias', 'encoders.1.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.attn.in_proj_weight": ['encoders.1.attn.to_q.weight', 'encoders.1.attn.to_k.weight', 'encoders.1.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.ln_1.bias": "encoders.1.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.ln_1.weight": "encoders.1.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.ln_2.bias": "encoders.1.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.ln_2.weight": "encoders.1.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.bias": "encoders.1.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_fc.weight": "encoders.1.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.bias": "encoders.1.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.1.mlp.c_proj.weight": "encoders.1.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_bias": ['encoders.10.attn.to_q.bias', 'encoders.10.attn.to_k.bias', 'encoders.10.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.attn.in_proj_weight": ['encoders.10.attn.to_q.weight', 'encoders.10.attn.to_k.weight', 'encoders.10.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.ln_1.bias": "encoders.10.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.ln_1.weight": "encoders.10.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.ln_2.bias": "encoders.10.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.ln_2.weight": "encoders.10.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.bias": "encoders.10.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_fc.weight": "encoders.10.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.bias": "encoders.10.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.10.mlp.c_proj.weight": "encoders.10.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_bias": ['encoders.11.attn.to_q.bias', 'encoders.11.attn.to_k.bias', 'encoders.11.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.attn.in_proj_weight": ['encoders.11.attn.to_q.weight', 'encoders.11.attn.to_k.weight', 'encoders.11.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.ln_1.bias": "encoders.11.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.ln_1.weight": "encoders.11.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.ln_2.bias": "encoders.11.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.ln_2.weight": "encoders.11.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.bias": "encoders.11.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_fc.weight": "encoders.11.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.bias": "encoders.11.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.11.mlp.c_proj.weight": "encoders.11.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_bias": ['encoders.12.attn.to_q.bias', 'encoders.12.attn.to_k.bias', 'encoders.12.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.attn.in_proj_weight": ['encoders.12.attn.to_q.weight', 'encoders.12.attn.to_k.weight', 'encoders.12.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.bias": "encoders.12.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.attn.out_proj.weight": "encoders.12.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.ln_1.bias": "encoders.12.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.ln_1.weight": "encoders.12.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.ln_2.bias": "encoders.12.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.ln_2.weight": "encoders.12.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.bias": "encoders.12.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_fc.weight": "encoders.12.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.bias": "encoders.12.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.12.mlp.c_proj.weight": "encoders.12.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_bias": ['encoders.13.attn.to_q.bias', 'encoders.13.attn.to_k.bias', 'encoders.13.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.attn.in_proj_weight": ['encoders.13.attn.to_q.weight', 'encoders.13.attn.to_k.weight', 'encoders.13.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.bias": "encoders.13.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.attn.out_proj.weight": "encoders.13.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.ln_1.bias": "encoders.13.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.ln_1.weight": "encoders.13.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.ln_2.bias": "encoders.13.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.ln_2.weight": "encoders.13.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.bias": "encoders.13.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_fc.weight": "encoders.13.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.bias": "encoders.13.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.13.mlp.c_proj.weight": "encoders.13.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_bias": ['encoders.14.attn.to_q.bias', 'encoders.14.attn.to_k.bias', 'encoders.14.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.attn.in_proj_weight": ['encoders.14.attn.to_q.weight', 'encoders.14.attn.to_k.weight', 'encoders.14.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.bias": "encoders.14.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.attn.out_proj.weight": "encoders.14.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.ln_1.bias": "encoders.14.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.ln_1.weight": "encoders.14.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.ln_2.bias": "encoders.14.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.ln_2.weight": "encoders.14.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.bias": "encoders.14.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_fc.weight": "encoders.14.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.bias": "encoders.14.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.14.mlp.c_proj.weight": "encoders.14.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_bias": ['encoders.15.attn.to_q.bias', 'encoders.15.attn.to_k.bias', 'encoders.15.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.attn.in_proj_weight": ['encoders.15.attn.to_q.weight', 'encoders.15.attn.to_k.weight', 'encoders.15.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.bias": "encoders.15.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.attn.out_proj.weight": "encoders.15.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.ln_1.bias": "encoders.15.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.ln_1.weight": "encoders.15.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.ln_2.bias": "encoders.15.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.ln_2.weight": "encoders.15.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.bias": "encoders.15.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_fc.weight": "encoders.15.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.bias": "encoders.15.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.15.mlp.c_proj.weight": "encoders.15.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_bias": ['encoders.16.attn.to_q.bias', 'encoders.16.attn.to_k.bias', 'encoders.16.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.attn.in_proj_weight": ['encoders.16.attn.to_q.weight', 'encoders.16.attn.to_k.weight', 'encoders.16.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.bias": "encoders.16.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.attn.out_proj.weight": "encoders.16.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.ln_1.bias": "encoders.16.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.ln_1.weight": "encoders.16.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.ln_2.bias": "encoders.16.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.ln_2.weight": "encoders.16.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.bias": "encoders.16.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_fc.weight": "encoders.16.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.bias": "encoders.16.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.16.mlp.c_proj.weight": "encoders.16.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_bias": ['encoders.17.attn.to_q.bias', 'encoders.17.attn.to_k.bias', 'encoders.17.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.attn.in_proj_weight": ['encoders.17.attn.to_q.weight', 'encoders.17.attn.to_k.weight', 'encoders.17.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.bias": "encoders.17.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.attn.out_proj.weight": "encoders.17.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.ln_1.bias": "encoders.17.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.ln_1.weight": "encoders.17.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.ln_2.bias": "encoders.17.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.ln_2.weight": "encoders.17.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.bias": "encoders.17.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_fc.weight": "encoders.17.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.bias": "encoders.17.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.17.mlp.c_proj.weight": "encoders.17.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_bias": ['encoders.18.attn.to_q.bias', 'encoders.18.attn.to_k.bias', 'encoders.18.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.attn.in_proj_weight": ['encoders.18.attn.to_q.weight', 'encoders.18.attn.to_k.weight', 'encoders.18.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.bias": "encoders.18.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.attn.out_proj.weight": "encoders.18.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.ln_1.bias": "encoders.18.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.ln_1.weight": "encoders.18.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.ln_2.bias": "encoders.18.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.ln_2.weight": "encoders.18.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.bias": "encoders.18.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_fc.weight": "encoders.18.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.bias": "encoders.18.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.18.mlp.c_proj.weight": "encoders.18.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_bias": ['encoders.19.attn.to_q.bias', 'encoders.19.attn.to_k.bias', 'encoders.19.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.attn.in_proj_weight": ['encoders.19.attn.to_q.weight', 'encoders.19.attn.to_k.weight', 'encoders.19.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.bias": "encoders.19.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.attn.out_proj.weight": "encoders.19.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.ln_1.bias": "encoders.19.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.ln_1.weight": "encoders.19.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.ln_2.bias": "encoders.19.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.ln_2.weight": "encoders.19.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.bias": "encoders.19.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_fc.weight": "encoders.19.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.bias": "encoders.19.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.19.mlp.c_proj.weight": "encoders.19.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_bias": ['encoders.2.attn.to_q.bias', 'encoders.2.attn.to_k.bias', 'encoders.2.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.attn.in_proj_weight": ['encoders.2.attn.to_q.weight', 'encoders.2.attn.to_k.weight', 'encoders.2.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.ln_1.bias": "encoders.2.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.ln_1.weight": "encoders.2.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.ln_2.bias": "encoders.2.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.ln_2.weight": "encoders.2.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.bias": "encoders.2.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_fc.weight": "encoders.2.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.bias": "encoders.2.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.2.mlp.c_proj.weight": "encoders.2.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_bias": ['encoders.20.attn.to_q.bias', 'encoders.20.attn.to_k.bias', 'encoders.20.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.attn.in_proj_weight": ['encoders.20.attn.to_q.weight', 'encoders.20.attn.to_k.weight', 'encoders.20.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.bias": "encoders.20.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.attn.out_proj.weight": "encoders.20.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.ln_1.bias": "encoders.20.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.ln_1.weight": "encoders.20.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.ln_2.bias": "encoders.20.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.ln_2.weight": "encoders.20.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.bias": "encoders.20.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_fc.weight": "encoders.20.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.bias": "encoders.20.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.20.mlp.c_proj.weight": "encoders.20.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_bias": ['encoders.21.attn.to_q.bias', 'encoders.21.attn.to_k.bias', 'encoders.21.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.attn.in_proj_weight": ['encoders.21.attn.to_q.weight', 'encoders.21.attn.to_k.weight', 'encoders.21.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.bias": "encoders.21.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.attn.out_proj.weight": "encoders.21.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.ln_1.bias": "encoders.21.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.ln_1.weight": "encoders.21.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.ln_2.bias": "encoders.21.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.ln_2.weight": "encoders.21.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.bias": "encoders.21.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_fc.weight": "encoders.21.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.bias": "encoders.21.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.21.mlp.c_proj.weight": "encoders.21.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_bias": ['encoders.22.attn.to_q.bias', 'encoders.22.attn.to_k.bias', 'encoders.22.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.attn.in_proj_weight": ['encoders.22.attn.to_q.weight', 'encoders.22.attn.to_k.weight', 'encoders.22.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.bias": "encoders.22.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.attn.out_proj.weight": "encoders.22.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.ln_1.bias": "encoders.22.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.ln_1.weight": "encoders.22.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.ln_2.bias": "encoders.22.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.ln_2.weight": "encoders.22.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.bias": "encoders.22.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_fc.weight": "encoders.22.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.bias": "encoders.22.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.22.mlp.c_proj.weight": "encoders.22.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_bias": ['encoders.23.attn.to_q.bias', 'encoders.23.attn.to_k.bias', 'encoders.23.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.attn.in_proj_weight": ['encoders.23.attn.to_q.weight', 'encoders.23.attn.to_k.weight', 'encoders.23.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.bias": "encoders.23.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.attn.out_proj.weight": "encoders.23.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.ln_1.bias": "encoders.23.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.ln_1.weight": "encoders.23.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.ln_2.bias": "encoders.23.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.ln_2.weight": "encoders.23.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.bias": "encoders.23.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_fc.weight": "encoders.23.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.bias": "encoders.23.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.23.mlp.c_proj.weight": "encoders.23.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_bias": ['encoders.24.attn.to_q.bias', 'encoders.24.attn.to_k.bias', 'encoders.24.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.attn.in_proj_weight": ['encoders.24.attn.to_q.weight', 'encoders.24.attn.to_k.weight', 'encoders.24.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.bias": "encoders.24.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.attn.out_proj.weight": "encoders.24.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.ln_1.bias": "encoders.24.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.ln_1.weight": "encoders.24.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.ln_2.bias": "encoders.24.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.ln_2.weight": "encoders.24.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.bias": "encoders.24.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_fc.weight": "encoders.24.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.bias": "encoders.24.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.24.mlp.c_proj.weight": "encoders.24.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_bias": ['encoders.25.attn.to_q.bias', 'encoders.25.attn.to_k.bias', 'encoders.25.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.attn.in_proj_weight": ['encoders.25.attn.to_q.weight', 'encoders.25.attn.to_k.weight', 'encoders.25.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.bias": "encoders.25.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.attn.out_proj.weight": "encoders.25.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.ln_1.bias": "encoders.25.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.ln_1.weight": "encoders.25.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.ln_2.bias": "encoders.25.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.ln_2.weight": "encoders.25.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.bias": "encoders.25.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_fc.weight": "encoders.25.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.bias": "encoders.25.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.25.mlp.c_proj.weight": "encoders.25.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_bias": ['encoders.26.attn.to_q.bias', 'encoders.26.attn.to_k.bias', 'encoders.26.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.attn.in_proj_weight": ['encoders.26.attn.to_q.weight', 'encoders.26.attn.to_k.weight', 'encoders.26.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.bias": "encoders.26.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.attn.out_proj.weight": "encoders.26.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.ln_1.bias": "encoders.26.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.ln_1.weight": "encoders.26.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.ln_2.bias": "encoders.26.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.ln_2.weight": "encoders.26.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.bias": "encoders.26.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_fc.weight": "encoders.26.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.bias": "encoders.26.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.26.mlp.c_proj.weight": "encoders.26.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_bias": ['encoders.27.attn.to_q.bias', 'encoders.27.attn.to_k.bias', 'encoders.27.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.attn.in_proj_weight": ['encoders.27.attn.to_q.weight', 'encoders.27.attn.to_k.weight', 'encoders.27.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.bias": "encoders.27.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.attn.out_proj.weight": "encoders.27.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.ln_1.bias": "encoders.27.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.ln_1.weight": "encoders.27.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.ln_2.bias": "encoders.27.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.ln_2.weight": "encoders.27.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.bias": "encoders.27.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_fc.weight": "encoders.27.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.bias": "encoders.27.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.27.mlp.c_proj.weight": "encoders.27.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_bias": ['encoders.28.attn.to_q.bias', 'encoders.28.attn.to_k.bias', 'encoders.28.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.attn.in_proj_weight": ['encoders.28.attn.to_q.weight', 'encoders.28.attn.to_k.weight', 'encoders.28.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.bias": "encoders.28.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.attn.out_proj.weight": "encoders.28.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.ln_1.bias": "encoders.28.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.ln_1.weight": "encoders.28.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.ln_2.bias": "encoders.28.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.ln_2.weight": "encoders.28.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.bias": "encoders.28.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_fc.weight": "encoders.28.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.bias": "encoders.28.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.28.mlp.c_proj.weight": "encoders.28.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_bias": ['encoders.29.attn.to_q.bias', 'encoders.29.attn.to_k.bias', 'encoders.29.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.attn.in_proj_weight": ['encoders.29.attn.to_q.weight', 'encoders.29.attn.to_k.weight', 'encoders.29.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.bias": "encoders.29.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.attn.out_proj.weight": "encoders.29.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.ln_1.bias": "encoders.29.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.ln_1.weight": "encoders.29.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.ln_2.bias": "encoders.29.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.ln_2.weight": "encoders.29.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.bias": "encoders.29.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_fc.weight": "encoders.29.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.bias": "encoders.29.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.29.mlp.c_proj.weight": "encoders.29.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_bias": ['encoders.3.attn.to_q.bias', 'encoders.3.attn.to_k.bias', 'encoders.3.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.attn.in_proj_weight": ['encoders.3.attn.to_q.weight', 'encoders.3.attn.to_k.weight', 'encoders.3.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.ln_1.bias": "encoders.3.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.ln_1.weight": "encoders.3.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.ln_2.bias": "encoders.3.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.ln_2.weight": "encoders.3.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.bias": "encoders.3.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_fc.weight": "encoders.3.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.bias": "encoders.3.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.3.mlp.c_proj.weight": "encoders.3.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_bias": ['encoders.30.attn.to_q.bias', 'encoders.30.attn.to_k.bias', 'encoders.30.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.attn.in_proj_weight": ['encoders.30.attn.to_q.weight', 'encoders.30.attn.to_k.weight', 'encoders.30.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.bias": "encoders.30.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.attn.out_proj.weight": "encoders.30.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.ln_1.bias": "encoders.30.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.ln_1.weight": "encoders.30.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.ln_2.bias": "encoders.30.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.ln_2.weight": "encoders.30.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.bias": "encoders.30.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_fc.weight": "encoders.30.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.bias": "encoders.30.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.30.mlp.c_proj.weight": "encoders.30.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_bias": ['encoders.31.attn.to_q.bias', 'encoders.31.attn.to_k.bias', 'encoders.31.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.attn.in_proj_weight": ['encoders.31.attn.to_q.weight', 'encoders.31.attn.to_k.weight', 'encoders.31.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.bias": "encoders.31.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.attn.out_proj.weight": "encoders.31.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.ln_1.bias": "encoders.31.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.ln_1.weight": "encoders.31.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.ln_2.bias": "encoders.31.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.ln_2.weight": "encoders.31.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.bias": "encoders.31.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_fc.weight": "encoders.31.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.bias": "encoders.31.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.31.mlp.c_proj.weight": "encoders.31.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_bias": ['encoders.4.attn.to_q.bias', 'encoders.4.attn.to_k.bias', 'encoders.4.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.attn.in_proj_weight": ['encoders.4.attn.to_q.weight', 'encoders.4.attn.to_k.weight', 'encoders.4.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.ln_1.bias": "encoders.4.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.ln_1.weight": "encoders.4.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.ln_2.bias": "encoders.4.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.ln_2.weight": "encoders.4.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.bias": "encoders.4.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_fc.weight": "encoders.4.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.bias": "encoders.4.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.4.mlp.c_proj.weight": "encoders.4.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_bias": ['encoders.5.attn.to_q.bias', 'encoders.5.attn.to_k.bias', 'encoders.5.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.attn.in_proj_weight": ['encoders.5.attn.to_q.weight', 'encoders.5.attn.to_k.weight', 'encoders.5.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.ln_1.bias": "encoders.5.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.ln_1.weight": "encoders.5.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.ln_2.bias": "encoders.5.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.ln_2.weight": "encoders.5.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.bias": "encoders.5.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_fc.weight": "encoders.5.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.bias": "encoders.5.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.5.mlp.c_proj.weight": "encoders.5.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_bias": ['encoders.6.attn.to_q.bias', 'encoders.6.attn.to_k.bias', 'encoders.6.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.attn.in_proj_weight": ['encoders.6.attn.to_q.weight', 'encoders.6.attn.to_k.weight', 'encoders.6.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.ln_1.bias": "encoders.6.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.ln_1.weight": "encoders.6.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.ln_2.bias": "encoders.6.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.ln_2.weight": "encoders.6.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.bias": "encoders.6.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_fc.weight": "encoders.6.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.bias": "encoders.6.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.6.mlp.c_proj.weight": "encoders.6.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_bias": ['encoders.7.attn.to_q.bias', 'encoders.7.attn.to_k.bias', 'encoders.7.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.attn.in_proj_weight": ['encoders.7.attn.to_q.weight', 'encoders.7.attn.to_k.weight', 'encoders.7.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.ln_1.bias": "encoders.7.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.ln_1.weight": "encoders.7.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.ln_2.bias": "encoders.7.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.ln_2.weight": "encoders.7.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.bias": "encoders.7.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_fc.weight": "encoders.7.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.bias": "encoders.7.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.7.mlp.c_proj.weight": "encoders.7.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_bias": ['encoders.8.attn.to_q.bias', 'encoders.8.attn.to_k.bias', 'encoders.8.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.attn.in_proj_weight": ['encoders.8.attn.to_q.weight', 'encoders.8.attn.to_k.weight', 'encoders.8.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.ln_1.bias": "encoders.8.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.ln_1.weight": "encoders.8.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.ln_2.bias": "encoders.8.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.ln_2.weight": "encoders.8.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.bias": "encoders.8.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_fc.weight": "encoders.8.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.bias": "encoders.8.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.8.mlp.c_proj.weight": "encoders.8.fc2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_bias": ['encoders.9.attn.to_q.bias', 'encoders.9.attn.to_k.bias', 'encoders.9.attn.to_v.bias'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.attn.in_proj_weight": ['encoders.9.attn.to_q.weight', 'encoders.9.attn.to_k.weight', 'encoders.9.attn.to_v.weight'],
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.ln_1.bias": "encoders.9.layer_norm1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.ln_1.weight": "encoders.9.layer_norm1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.ln_2.bias": "encoders.9.layer_norm2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.ln_2.weight": "encoders.9.layer_norm2.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.bias": "encoders.9.fc1.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_fc.weight": "encoders.9.fc1.weight",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias": "encoders.9.fc2.bias",
|
||||
"conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.weight": "encoders.9.fc2.weight",
|
||||
"conditioner.embedders.1.model.text_projection": "text_projection.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "conditioner.embedders.1.model.positional_embedding":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
elif name == "conditioner.embedders.1.model.text_projection":
|
||||
param = param.T
|
||||
if isinstance(rename_dict[name], str):
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
length = param.shape[0] // 3
|
||||
for i, rename in enumerate(rename_dict[name]):
|
||||
state_dict_[rename] = param[i*length: i*length+length]
|
||||
return state_dict_
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,24 +0,0 @@
|
||||
from .sd_vae_decoder import SDVAEDecoder, SDVAEDecoderStateDictConverter
|
||||
|
||||
|
||||
class SDXLVAEDecoder(SDVAEDecoder):
|
||||
def __init__(self, upcast_to_float32=True):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.13025
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDXLVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
class SDXLVAEDecoderStateDictConverter(SDVAEDecoderStateDictConverter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict = super().from_diffusers(state_dict)
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = super().from_civitai(state_dict)
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
@@ -1,24 +0,0 @@
|
||||
from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
|
||||
|
||||
|
||||
class SDXLVAEEncoder(SDVAEEncoder):
|
||||
def __init__(self, upcast_to_float32=True):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.13025
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SDXLVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SDXLVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict = super().from_diffusers(state_dict)
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = super().from_civitai(state_dict)
|
||||
return state_dict, {"upcast_to_float32": True}
|
||||
@@ -661,23 +661,3 @@ class Qwen2Connector(torch.nn.Module):
|
||||
global_out=self.global_proj_out(x_mean)
|
||||
encoder_hidden_states = self.S(x,t,mask)
|
||||
return encoder_hidden_states,global_out
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return Qwen2ConnectorStateDictConverter()
|
||||
|
||||
|
||||
class Qwen2ConnectorStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("connector."):
|
||||
name_ = name[len("connector."):]
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
@@ -1,25 +1,15 @@
|
||||
import torch
|
||||
from typing import Optional, Union
|
||||
from .qwen_image_text_encoder import QwenImageTextEncoder
|
||||
|
||||
|
||||
class Qwen25VL_7b_Embedder(torch.nn.Module):
|
||||
def __init__(self, model_path, max_length=640, dtype=torch.bfloat16, device="cuda"):
|
||||
super(Qwen25VL_7b_Embedder, self).__init__()
|
||||
class Step1xEditEmbedder(torch.nn.Module):
|
||||
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device="cuda"):
|
||||
super().__init__()
|
||||
self.max_length = max_length
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=dtype,
|
||||
).to(torch.cuda.current_device())
|
||||
|
||||
self.model.requires_grad_(False)
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_path, min_pixels=256 * 28 * 28, max_pixels=324 * 28 * 28
|
||||
)
|
||||
|
||||
Qwen25VL_7b_PREFIX = '''Given a user prompt, generate an "Enhanced prompt" that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:
|
||||
- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.
|
||||
- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.\n
|
||||
@@ -30,11 +20,56 @@ Please generate only the enhanced description for the prompt below and avoid inc
|
||||
User Prompt:'''
|
||||
|
||||
self.prefix = Qwen25VL_7b_PREFIX
|
||||
self.model = model
|
||||
self.processor = processor
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path, torch_dtype=torch.bfloat16, device="cuda"):
|
||||
return Qwen25VL_7b_Embedder(path, dtype=torch_dtype, device=device)
|
||||
def model_forward(
|
||||
self,
|
||||
model: QwenImageTextEncoder,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||
rope_deltas: Optional[torch.LongTensor] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||
**kwargs,
|
||||
):
|
||||
output_attentions = output_attentions if output_attentions is not None else model.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else model.config.output_hidden_states
|
||||
)
|
||||
|
||||
outputs = model.model(
|
||||
input_ids=input_ids,
|
||||
pixel_values=pixel_values,
|
||||
pixel_values_videos=pixel_values_videos,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=True,
|
||||
cache_position=cache_position,
|
||||
**kwargs,
|
||||
)
|
||||
return outputs.hidden_states
|
||||
|
||||
def forward(self, caption, ref_images):
|
||||
text_list = caption
|
||||
embs = torch.zeros(
|
||||
@@ -44,22 +79,12 @@ User Prompt:'''
|
||||
dtype=torch.bfloat16,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
hidden_states = torch.zeros(
|
||||
len(text_list),
|
||||
self.max_length,
|
||||
self.model.config.hidden_size,
|
||||
dtype=torch.bfloat16,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
masks = torch.zeros(
|
||||
len(text_list),
|
||||
self.max_length,
|
||||
dtype=torch.long,
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
input_ids_list = []
|
||||
attention_mask_list = []
|
||||
emb_list = []
|
||||
|
||||
def split_string(s):
|
||||
s = s.replace("“", '"').replace("”", '"').replace("'", '''"''') # use english quotes
|
||||
@@ -145,7 +170,8 @@ User Prompt:'''
|
||||
.to("cuda")
|
||||
)
|
||||
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
|
||||
outputs = self.model(
|
||||
outputs = self.model_forward(
|
||||
self.model,
|
||||
input_ids=inputs.input_ids,
|
||||
attention_mask=inputs.attention_mask,
|
||||
pixel_values=inputs.pixel_values.to("cuda"),
|
||||
@@ -153,7 +179,7 @@ User Prompt:'''
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
emb = outputs["hidden_states"][-1]
|
||||
emb = outputs[-1]
|
||||
|
||||
embs[idx, : min(self.max_length, emb.shape[1] - 217)] = emb[0, 217:][
|
||||
: self.max_length
|
||||
@@ -165,4 +191,4 @@ User Prompt:'''
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
|
||||
return embs, masks
|
||||
return embs, masks
|
||||
@@ -1,940 +0,0 @@
|
||||
# Copyright 2025 StepFun Inc. All Rights Reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
# ==============================================================================
|
||||
from typing import Dict, Optional, Tuple, Union, List
|
||||
import torch, math
|
||||
from torch import nn
|
||||
from einops import rearrange, repeat
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
elementwise_affine=True,
|
||||
eps: float = 1e-6,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""
|
||||
Initialize the RMSNorm normalization layer.
|
||||
|
||||
Args:
|
||||
dim (int): The dimension of the input tensor.
|
||||
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
||||
|
||||
Attributes:
|
||||
eps (float): A small value added to the denominator for numerical stability.
|
||||
weight (nn.Parameter): Learnable scaling parameter.
|
||||
|
||||
"""
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if elementwise_affine:
|
||||
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
||||
|
||||
def _norm(self, x):
|
||||
"""
|
||||
Apply the RMSNorm normalization to the input tensor.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The normalized tensor.
|
||||
|
||||
"""
|
||||
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass through the RMSNorm layer.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tensor after applying RMSNorm.
|
||||
|
||||
"""
|
||||
output = self._norm(x.float()).type_as(x)
|
||||
if hasattr(self, "weight"):
|
||||
output = output * self.weight
|
||||
return output
|
||||
|
||||
|
||||
ACTIVATION_FUNCTIONS = {
|
||||
"swish": nn.SiLU(),
|
||||
"silu": nn.SiLU(),
|
||||
"mish": nn.Mish(),
|
||||
"gelu": nn.GELU(),
|
||||
"relu": nn.ReLU(),
|
||||
}
|
||||
|
||||
|
||||
def get_activation(act_fn: str) -> nn.Module:
|
||||
"""Helper function to get activation function from string.
|
||||
|
||||
Args:
|
||||
act_fn (str): Name of activation function.
|
||||
|
||||
Returns:
|
||||
nn.Module: Activation function.
|
||||
"""
|
||||
|
||||
act_fn = act_fn.lower()
|
||||
if act_fn in ACTIVATION_FUNCTIONS:
|
||||
return ACTIVATION_FUNCTIONS[act_fn]
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation function: {act_fn}")
|
||||
|
||||
|
||||
def get_timestep_embedding(
|
||||
timesteps: torch.Tensor,
|
||||
embedding_dim: int,
|
||||
flip_sin_to_cos: bool = False,
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
||||
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
emb = scale * emb
|
||||
|
||||
# concat sine and cosine embeddings
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
|
||||
# flip sine and cosine embeddings
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
|
||||
# zero pad
|
||||
if embedding_dim % 2 == 1:
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
class Timesteps(nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
timesteps,
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
time_embed_dim: int,
|
||||
act_fn: str = "silu",
|
||||
out_dim: int = None,
|
||||
post_act_fn: Optional[str] = None,
|
||||
cond_proj_dim=None,
|
||||
sample_proj_bias=True
|
||||
):
|
||||
super().__init__()
|
||||
linear_cls = nn.Linear
|
||||
|
||||
self.linear_1 = linear_cls(
|
||||
in_channels,
|
||||
time_embed_dim,
|
||||
bias=sample_proj_bias,
|
||||
)
|
||||
|
||||
if cond_proj_dim is not None:
|
||||
self.cond_proj = linear_cls(
|
||||
cond_proj_dim,
|
||||
in_channels,
|
||||
bias=False,
|
||||
)
|
||||
else:
|
||||
self.cond_proj = None
|
||||
|
||||
self.act = get_activation(act_fn)
|
||||
|
||||
if out_dim is not None:
|
||||
time_embed_dim_out = out_dim
|
||||
else:
|
||||
time_embed_dim_out = time_embed_dim
|
||||
|
||||
self.linear_2 = linear_cls(
|
||||
time_embed_dim,
|
||||
time_embed_dim_out,
|
||||
bias=sample_proj_bias,
|
||||
)
|
||||
|
||||
if post_act_fn is None:
|
||||
self.post_act = None
|
||||
else:
|
||||
self.post_act = get_activation(post_act_fn)
|
||||
|
||||
def forward(self, sample, condition=None):
|
||||
if condition is not None:
|
||||
sample = sample + self.cond_proj(condition)
|
||||
sample = self.linear_1(sample)
|
||||
|
||||
if self.act is not None:
|
||||
sample = self.act(sample)
|
||||
|
||||
sample = self.linear_2(sample)
|
||||
|
||||
if self.post_act is not None:
|
||||
sample = self.post_act(sample)
|
||||
return sample
|
||||
|
||||
|
||||
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.outdim = size_emb_dim
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
self.use_additional_conditions = use_additional_conditions
|
||||
if self.use_additional_conditions:
|
||||
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
|
||||
self.nframe_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
self.fps_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
|
||||
def forward(self, timestep, resolution=None, nframe=None, fps=None):
|
||||
hidden_dtype = timestep.dtype
|
||||
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
||||
|
||||
if self.use_additional_conditions:
|
||||
batch_size = timestep.shape[0]
|
||||
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
|
||||
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
|
||||
nframe_emb = self.additional_condition_proj(nframe.flatten()).to(hidden_dtype)
|
||||
nframe_emb = self.nframe_embedder(nframe_emb).reshape(batch_size, -1)
|
||||
conditioning = timesteps_emb + resolution_emb + nframe_emb
|
||||
|
||||
if fps is not None:
|
||||
fps_emb = self.additional_condition_proj(fps.flatten()).to(hidden_dtype)
|
||||
fps_emb = self.fps_embedder(fps_emb).reshape(batch_size, -1)
|
||||
conditioning = conditioning + fps_emb
|
||||
else:
|
||||
conditioning = timesteps_emb
|
||||
|
||||
return conditioning
|
||||
|
||||
|
||||
class AdaLayerNormSingle(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm single (adaLN-single).
|
||||
|
||||
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
|
||||
"""
|
||||
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, time_step_rescale=1000):
|
||||
super().__init__()
|
||||
|
||||
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim, size_emb_dim=embedding_dim // 2, use_additional_conditions=use_additional_conditions
|
||||
)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
||||
|
||||
self.time_step_rescale = time_step_rescale ## timestep usually in [0, 1], we rescale it to [0,1000] for stability
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
embedded_timestep = self.emb(timestep*self.time_step_rescale, **added_cond_kwargs)
|
||||
|
||||
out = self.linear(self.silu(embedded_timestep))
|
||||
|
||||
return out, embedded_timestep
|
||||
|
||||
|
||||
class PixArtAlphaTextProjection(nn.Module):
|
||||
"""
|
||||
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
||||
|
||||
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, hidden_size):
|
||||
super().__init__()
|
||||
self.linear_1 = nn.Linear(
|
||||
in_features,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
)
|
||||
self.act_1 = nn.GELU(approximate="tanh")
|
||||
self.linear_2 = nn.Linear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
def forward(self, caption):
|
||||
hidden_states = self.linear_1(caption)
|
||||
hidden_states = self.act_1(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def attn_processor(self, attn_type):
|
||||
if attn_type == 'torch':
|
||||
return self.torch_attn_func
|
||||
elif attn_type == 'parallel':
|
||||
return self.parallel_attn_func
|
||||
else:
|
||||
raise Exception('Not supported attention type...')
|
||||
|
||||
def torch_attn_func(
|
||||
self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attn_mask=None,
|
||||
causal=False,
|
||||
drop_rate=0.0,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
||||
attn_mask = attn_mask.to(q.dtype)
|
||||
|
||||
if attn_mask is not None and attn_mask.ndim == 3: ## no head
|
||||
n_heads = q.shape[2]
|
||||
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
|
||||
|
||||
q, k, v = map(lambda x: rearrange(x, 'b s h d -> b h s d'), (q, k, v))
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.to(q.device)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
|
||||
)
|
||||
x = rearrange(x, 'b h s d -> b s h d')
|
||||
return x
|
||||
|
||||
|
||||
class RoPE1D:
|
||||
def __init__(self, freq=1e4, F0=1.0, scaling_factor=1.0):
|
||||
self.base = freq
|
||||
self.F0 = F0
|
||||
self.scaling_factor = scaling_factor
|
||||
self.cache = {}
|
||||
|
||||
def get_cos_sin(self, D, seq_len, device, dtype):
|
||||
if (D, seq_len, device, dtype) not in self.cache:
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
|
||||
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
|
||||
freqs = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = freqs.cos() # (Seq, Dim)
|
||||
sin = freqs.sin()
|
||||
self.cache[D, seq_len, device, dtype] = (cos, sin)
|
||||
return self.cache[D, seq_len, device, dtype]
|
||||
|
||||
@staticmethod
|
||||
def rotate_half(x):
|
||||
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def apply_rope1d(self, tokens, pos1d, cos, sin):
|
||||
assert pos1d.ndim == 2
|
||||
cos = torch.nn.functional.embedding(pos1d, cos)[:, :, None, :]
|
||||
sin = torch.nn.functional.embedding(pos1d, sin)[:, :, None, :]
|
||||
return (tokens * cos) + (self.rotate_half(tokens) * sin)
|
||||
|
||||
def __call__(self, tokens, positions):
|
||||
"""
|
||||
input:
|
||||
* tokens: batch_size x ntokens x nheads x dim
|
||||
* positions: batch_size x ntokens (t position of each token)
|
||||
output:
|
||||
* tokens after applying RoPE2D (batch_size x ntokens x nheads x dim)
|
||||
"""
|
||||
D = tokens.size(3)
|
||||
assert positions.ndim == 2 # Batch, Seq
|
||||
cos, sin = self.get_cos_sin(D, int(positions.max()) + 1, tokens.device, tokens.dtype)
|
||||
tokens = self.apply_rope1d(tokens, positions, cos, sin)
|
||||
return tokens
|
||||
|
||||
|
||||
class RoPE3D(RoPE1D):
|
||||
def __init__(self, freq=1e4, F0=1.0, scaling_factor=1.0):
|
||||
super(RoPE3D, self).__init__(freq, F0, scaling_factor)
|
||||
self.position_cache = {}
|
||||
|
||||
def get_mesh_3d(self, rope_positions, bsz):
|
||||
f, h, w = rope_positions
|
||||
|
||||
if f"{f}-{h}-{w}" not in self.position_cache:
|
||||
x = torch.arange(f, device='cpu')
|
||||
y = torch.arange(h, device='cpu')
|
||||
z = torch.arange(w, device='cpu')
|
||||
self.position_cache[f"{f}-{h}-{w}"] = torch.cartesian_prod(x, y, z).view(1, f*h*w, 3).expand(bsz, -1, 3)
|
||||
return self.position_cache[f"{f}-{h}-{w}"]
|
||||
|
||||
def __call__(self, tokens, rope_positions, ch_split, parallel=False):
|
||||
"""
|
||||
input:
|
||||
* tokens: batch_size x ntokens x nheads x dim
|
||||
* rope_positions: list of (f, h, w)
|
||||
output:
|
||||
* tokens after applying RoPE2D (batch_size x ntokens x nheads x dim)
|
||||
"""
|
||||
assert sum(ch_split) == tokens.size(-1);
|
||||
|
||||
mesh_grid = self.get_mesh_3d(rope_positions, bsz=tokens.shape[0])
|
||||
out = []
|
||||
for i, (D, x) in enumerate(zip(ch_split, torch.split(tokens, ch_split, dim=-1))):
|
||||
cos, sin = self.get_cos_sin(D, int(mesh_grid.max()) + 1, tokens.device, tokens.dtype)
|
||||
|
||||
if parallel:
|
||||
pass
|
||||
else:
|
||||
mesh = mesh_grid[:, :, i].clone()
|
||||
x = self.apply_rope1d(x, mesh.to(tokens.device), cos, sin)
|
||||
out.append(x)
|
||||
|
||||
tokens = torch.cat(out, dim=-1)
|
||||
return tokens
|
||||
|
||||
|
||||
class SelfAttention(Attention):
|
||||
def __init__(self, hidden_dim, head_dim, bias=False, with_rope=True, with_qk_norm=True, attn_type='torch'):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.n_heads = hidden_dim // head_dim
|
||||
|
||||
self.wqkv = nn.Linear(hidden_dim, hidden_dim*3, bias=bias)
|
||||
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=bias)
|
||||
|
||||
self.with_rope = with_rope
|
||||
self.with_qk_norm = with_qk_norm
|
||||
if self.with_qk_norm:
|
||||
self.q_norm = RMSNorm(head_dim, elementwise_affine=True)
|
||||
self.k_norm = RMSNorm(head_dim, elementwise_affine=True)
|
||||
|
||||
if self.with_rope:
|
||||
self.rope_3d = RoPE3D(freq=1e4, F0=1.0, scaling_factor=1.0)
|
||||
self.rope_ch_split = [64, 32, 32]
|
||||
|
||||
self.core_attention = self.attn_processor(attn_type=attn_type)
|
||||
self.parallel = attn_type=='parallel'
|
||||
|
||||
def apply_rope3d(self, x, fhw_positions, rope_ch_split, parallel=True):
|
||||
x = self.rope_3d(x, fhw_positions, rope_ch_split, parallel)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
cu_seqlens=None,
|
||||
max_seqlen=None,
|
||||
rope_positions=None,
|
||||
attn_mask=None
|
||||
):
|
||||
xqkv = self.wqkv(x)
|
||||
xqkv = xqkv.view(*x.shape[:-1], self.n_heads, 3*self.head_dim)
|
||||
|
||||
xq, xk, xv = torch.split(xqkv, [self.head_dim]*3, dim=-1) ## seq_len, n, dim
|
||||
|
||||
if self.with_qk_norm:
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
if self.with_rope:
|
||||
xq = self.apply_rope3d(xq, rope_positions, self.rope_ch_split, parallel=self.parallel)
|
||||
xk = self.apply_rope3d(xk, rope_positions, self.rope_ch_split, parallel=self.parallel)
|
||||
|
||||
output = self.core_attention(
|
||||
xq,
|
||||
xk,
|
||||
xv,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
attn_mask=attn_mask
|
||||
)
|
||||
output = rearrange(output, 'b s h d -> b s (h d)')
|
||||
output = self.wo(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class CrossAttention(Attention):
|
||||
def __init__(self, hidden_dim, head_dim, bias=False, with_qk_norm=True, attn_type='torch'):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.n_heads = hidden_dim // head_dim
|
||||
|
||||
self.wq = nn.Linear(hidden_dim, hidden_dim, bias=bias)
|
||||
self.wkv = nn.Linear(hidden_dim, hidden_dim*2, bias=bias)
|
||||
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=bias)
|
||||
|
||||
self.with_qk_norm = with_qk_norm
|
||||
if self.with_qk_norm:
|
||||
self.q_norm = RMSNorm(head_dim, elementwise_affine=True)
|
||||
self.k_norm = RMSNorm(head_dim, elementwise_affine=True)
|
||||
|
||||
self.core_attention = self.attn_processor(attn_type=attn_type)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attn_mask=None
|
||||
):
|
||||
xq = self.wq(x)
|
||||
xq = xq.view(*xq.shape[:-1], self.n_heads, self.head_dim)
|
||||
|
||||
xkv = self.wkv(encoder_hidden_states)
|
||||
xkv = xkv.view(*xkv.shape[:-1], self.n_heads, 2*self.head_dim)
|
||||
|
||||
xk, xv = torch.split(xkv, [self.head_dim]*2, dim=-1) ## seq_len, n, dim
|
||||
|
||||
if self.with_qk_norm:
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
output = self.core_attention(
|
||||
xq,
|
||||
xk,
|
||||
xv,
|
||||
attn_mask=attn_mask
|
||||
)
|
||||
|
||||
output = rearrange(output, 'b s h d -> b s (h d)')
|
||||
output = self.wo(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class GELU(nn.Module):
|
||||
r"""
|
||||
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
||||
|
||||
Parameters:
|
||||
dim_in (`int`): The number of channels in the input.
|
||||
dim_out (`int`): The number of channels in the output.
|
||||
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
|
||||
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
||||
self.approximate = approximate
|
||||
|
||||
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
||||
return torch.nn.functional.gelu(gate, approximate=self.approximate)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = self.gelu(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
inner_dim: Optional[int] = None,
|
||||
dim_out: Optional[int] = None,
|
||||
mult: int = 4,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim*mult if inner_dim is None else inner_dim
|
||||
dim_out = dim if dim_out is None else dim_out
|
||||
self.net = nn.ModuleList([
|
||||
GELU(dim, inner_dim, approximate="tanh", bias=bias),
|
||||
nn.Identity(),
|
||||
nn.Linear(inner_dim, dim_out, bias=bias)
|
||||
])
|
||||
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
||||
for module in self.net:
|
||||
hidden_states = module(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
def modulate(x, scale, shift):
|
||||
x = x * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
|
||||
def gate(x, gate):
|
||||
x = gate * x
|
||||
return x
|
||||
|
||||
|
||||
class StepVideoTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
only_cross_attention (`bool`, *optional*):
|
||||
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
||||
double_self_attention (`bool`, *optional*):
|
||||
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
||||
upcast_attention (`bool`, *optional*):
|
||||
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
||||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use learnable elementwise affine parameters for normalization.
|
||||
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
|
||||
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
|
||||
final_dropout (`bool` *optional*, defaults to False):
|
||||
Whether to apply a final dropout after the last feed-forward layer.
|
||||
attention_type (`str`, *optional*, defaults to `"default"`):
|
||||
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
||||
positional_embeddings (`str`, *optional*, defaults to `None`):
|
||||
The type of positional embeddings to apply to.
|
||||
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
||||
The maximum number of positional embeddings to apply.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
attention_head_dim: int,
|
||||
norm_eps: float = 1e-5,
|
||||
ff_inner_dim: Optional[int] = None,
|
||||
ff_bias: bool = False,
|
||||
attention_type: str = 'parallel'
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.norm1 = nn.LayerNorm(dim, eps=norm_eps)
|
||||
self.attn1 = SelfAttention(dim, attention_head_dim, bias=False, with_rope=True, with_qk_norm=True, attn_type=attention_type)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, eps=norm_eps)
|
||||
self.attn2 = CrossAttention(dim, attention_head_dim, bias=False, with_qk_norm=True, attn_type='torch')
|
||||
|
||||
self.ff = FeedForward(dim=dim, inner_dim=ff_inner_dim, dim_out=dim, bias=ff_bias)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) /dim**0.5)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
attn_mask = None,
|
||||
rope_positions: list = None,
|
||||
) -> torch.Tensor:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||
torch.clone(chunk) for chunk in (self.scale_shift_table[None].to(dtype=q.dtype, device=q.device) + timestep.reshape(-1, 6, self.dim)).chunk(6, dim=1)
|
||||
)
|
||||
|
||||
scale_shift_q = modulate(self.norm1(q), scale_msa, shift_msa)
|
||||
|
||||
attn_q = self.attn1(
|
||||
scale_shift_q,
|
||||
rope_positions=rope_positions
|
||||
)
|
||||
|
||||
q = gate(attn_q, gate_msa) + q
|
||||
|
||||
attn_q = self.attn2(
|
||||
q,
|
||||
kv,
|
||||
attn_mask
|
||||
)
|
||||
|
||||
q = attn_q + q
|
||||
|
||||
scale_shift_q = modulate(self.norm2(q), scale_mlp, shift_mlp)
|
||||
|
||||
ff_output = self.ff(scale_shift_q)
|
||||
|
||||
q = gate(ff_output, gate_mlp) + q
|
||||
|
||||
return q
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=64,
|
||||
in_channels=3,
|
||||
embed_dim=768,
|
||||
layer_norm=False,
|
||||
flatten=True,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.flatten = flatten
|
||||
self.layer_norm = layer_norm
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
||||
)
|
||||
|
||||
def forward(self, latent):
|
||||
latent = self.proj(latent).to(latent.dtype)
|
||||
if self.flatten:
|
||||
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
if self.layer_norm:
|
||||
latent = self.norm(latent)
|
||||
|
||||
return latent
|
||||
|
||||
|
||||
class StepVideoModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 48,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 64,
|
||||
out_channels: Optional[int] = 64,
|
||||
num_layers: int = 48,
|
||||
dropout: float = 0.0,
|
||||
patch_size: int = 1,
|
||||
norm_type: str = "ada_norm_single",
|
||||
norm_elementwise_affine: bool = False,
|
||||
norm_eps: float = 1e-6,
|
||||
use_additional_conditions: Optional[bool] = False,
|
||||
caption_channels: Optional[Union[int, List, Tuple]] = [6144, 1024],
|
||||
attention_type: Optional[str] = "torch",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Set some common variables used across the board.
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
|
||||
self.use_additional_conditions = use_additional_conditions
|
||||
|
||||
self.pos_embed = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=self.inner_dim,
|
||||
)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
StepVideoTransformerBlock(
|
||||
dim=self.inner_dim,
|
||||
attention_head_dim=attention_head_dim,
|
||||
attention_type=attention_type
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 3. Output blocks.
|
||||
self.norm_out = nn.LayerNorm(self.inner_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
|
||||
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels)
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.adaln_single = AdaLayerNormSingle(
|
||||
self.inner_dim, use_additional_conditions=self.use_additional_conditions
|
||||
)
|
||||
|
||||
if isinstance(caption_channels, int):
|
||||
caption_channel = caption_channels
|
||||
else:
|
||||
caption_channel, clip_channel = caption_channels
|
||||
self.clip_projection = nn.Linear(clip_channel, self.inner_dim)
|
||||
|
||||
self.caption_norm = nn.LayerNorm(caption_channel, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
|
||||
|
||||
self.caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channel, hidden_size=self.inner_dim
|
||||
)
|
||||
|
||||
self.parallel = attention_type=='parallel'
|
||||
|
||||
def patchfy(self, hidden_states):
|
||||
hidden_states = rearrange(hidden_states, 'b f c h w -> (b f) c h w')
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def prepare_attn_mask(self, encoder_attention_mask, encoder_hidden_states, q_seqlen):
|
||||
kv_seqlens = encoder_attention_mask.sum(dim=1).int()
|
||||
mask = torch.zeros([len(kv_seqlens), q_seqlen, max(kv_seqlens)], dtype=torch.bool, device=encoder_attention_mask.device)
|
||||
encoder_hidden_states = encoder_hidden_states[:,: max(kv_seqlens)]
|
||||
for i, kv_len in enumerate(kv_seqlens):
|
||||
mask[i, :, :kv_len] = 1
|
||||
return encoder_hidden_states, mask
|
||||
|
||||
|
||||
def block_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
timestep=None,
|
||||
rope_positions=None,
|
||||
attn_mask=None,
|
||||
parallel=True
|
||||
):
|
||||
for block in tqdm(self.transformer_blocks, desc="Transformer blocks"):
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
attn_mask=attn_mask,
|
||||
rope_positions=rope_positions
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states_2: Optional[torch.Tensor] = None,
|
||||
timestep: Optional[torch.LongTensor] = None,
|
||||
added_cond_kwargs: Dict[str, torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
fps: torch.Tensor=None,
|
||||
return_dict: bool = False,
|
||||
):
|
||||
assert hidden_states.ndim==5; "hidden_states's shape should be (bsz, f, ch, h ,w)"
|
||||
|
||||
bsz, frame, _, height, width = hidden_states.shape
|
||||
height, width = height // self.patch_size, width // self.patch_size
|
||||
|
||||
hidden_states = self.patchfy(hidden_states)
|
||||
len_frame = hidden_states.shape[1]
|
||||
|
||||
if self.use_additional_conditions:
|
||||
added_cond_kwargs = {
|
||||
"resolution": torch.tensor([(height, width)]*bsz, device=hidden_states.device, dtype=hidden_states.dtype),
|
||||
"nframe": torch.tensor([frame]*bsz, device=hidden_states.device, dtype=hidden_states.dtype),
|
||||
"fps": fps
|
||||
}
|
||||
else:
|
||||
added_cond_kwargs = {}
|
||||
|
||||
timestep, embedded_timestep = self.adaln_single(
|
||||
timestep, added_cond_kwargs=added_cond_kwargs
|
||||
)
|
||||
|
||||
encoder_hidden_states = self.caption_projection(self.caption_norm(encoder_hidden_states))
|
||||
|
||||
if encoder_hidden_states_2 is not None and hasattr(self, 'clip_projection'):
|
||||
clip_embedding = self.clip_projection(encoder_hidden_states_2)
|
||||
encoder_hidden_states = torch.cat([clip_embedding, encoder_hidden_states], dim=1)
|
||||
|
||||
hidden_states = rearrange(hidden_states, '(b f) l d-> b (f l) d', b=bsz, f=frame, l=len_frame).contiguous()
|
||||
encoder_hidden_states, attn_mask = self.prepare_attn_mask(encoder_attention_mask, encoder_hidden_states, q_seqlen=frame*len_frame)
|
||||
|
||||
hidden_states = self.block_forward(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
rope_positions=[frame, height, width],
|
||||
attn_mask=attn_mask,
|
||||
parallel=self.parallel
|
||||
)
|
||||
|
||||
hidden_states = rearrange(hidden_states, 'b (f l) d -> (b f) l d', b=bsz, f=frame, l=len_frame)
|
||||
|
||||
embedded_timestep = repeat(embedded_timestep, 'b d -> (b f) d', f=frame).contiguous()
|
||||
|
||||
shift, scale = (self.scale_shift_table[None].to(dtype=embedded_timestep.dtype, device=embedded_timestep.device) + embedded_timestep[:, None]).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
# Modulation
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# unpatchify
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
||||
)
|
||||
|
||||
hidden_states = rearrange(hidden_states, 'n h w p q c -> n c h p w q')
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
||||
)
|
||||
|
||||
output = rearrange(output, '(b f) c h w -> b f c h w', f=frame)
|
||||
|
||||
if return_dict:
|
||||
return {'x': output}
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return StepVideoDiTStateDictConverter()
|
||||
|
||||
|
||||
class StepVideoDiTStateDictConverter:
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
|
||||
|
||||
@@ -1,553 +0,0 @@
|
||||
# Copyright 2025 StepFun Inc. All Rights Reserved.
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
# ==============================================================================
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .stepvideo_dit import RMSNorm
|
||||
from safetensors.torch import load_file
|
||||
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
||||
from einops import rearrange
|
||||
import json
|
||||
from typing import List
|
||||
from functools import wraps
|
||||
import warnings
|
||||
|
||||
|
||||
|
||||
class EmptyInitOnDevice(torch.overrides.TorchFunctionMode):
|
||||
def __init__(self, device=None):
|
||||
self.device = device
|
||||
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
if getattr(func, '__module__', None) == 'torch.nn.init':
|
||||
if 'tensor' in kwargs:
|
||||
return kwargs['tensor']
|
||||
else:
|
||||
return args[0]
|
||||
if self.device is not None and func in torch.utils._device._device_constructors() and kwargs.get('device') is None:
|
||||
kwargs['device'] = self.device
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def with_empty_init(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with EmptyInitOnDevice('cpu'):
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
|
||||
class LLaMaEmbedding(nn.Module):
|
||||
"""Language model embeddings.
|
||||
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
vocab_size: vocabulary size
|
||||
max_sequence_length: maximum size of sequence. This
|
||||
is used for positional embedding
|
||||
embedding_dropout_prob: dropout probability for embeddings
|
||||
init_method: weight initialization method
|
||||
num_tokentypes: size of the token-type embeddings. 0 value
|
||||
will ignore this embedding
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
cfg,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = cfg.hidden_size
|
||||
self.params_dtype = cfg.params_dtype
|
||||
self.fp32_residual_connection = cfg.fp32_residual_connection
|
||||
self.embedding_weights_in_fp32 = cfg.embedding_weights_in_fp32
|
||||
self.word_embeddings = torch.nn.Embedding(
|
||||
cfg.padded_vocab_size, self.hidden_size,
|
||||
)
|
||||
self.embedding_dropout = torch.nn.Dropout(cfg.hidden_dropout)
|
||||
|
||||
def forward(self, input_ids):
|
||||
# Embeddings.
|
||||
if self.embedding_weights_in_fp32:
|
||||
self.word_embeddings = self.word_embeddings.to(torch.float32)
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
if self.embedding_weights_in_fp32:
|
||||
embeddings = embeddings.to(self.params_dtype)
|
||||
self.word_embeddings = self.word_embeddings.to(self.params_dtype)
|
||||
|
||||
# Data format change to avoid explicit transposes : [b s h] --> [s b h].
|
||||
embeddings = embeddings.transpose(0, 1).contiguous()
|
||||
|
||||
# If the input flag for fp32 residual connection is set, convert for float.
|
||||
if self.fp32_residual_connection:
|
||||
embeddings = embeddings.float()
|
||||
|
||||
# Dropout.
|
||||
embeddings = self.embedding_dropout(embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
|
||||
class StepChatTokenizer:
|
||||
"""Step Chat Tokenizer"""
|
||||
|
||||
def __init__(
|
||||
self, model_file, name="StepChatTokenizer",
|
||||
bot_token="<|BOT|>", # Begin of Turn
|
||||
eot_token="<|EOT|>", # End of Turn
|
||||
call_start_token="<|CALL_START|>", # Call Start
|
||||
call_end_token="<|CALL_END|>", # Call End
|
||||
think_start_token="<|THINK_START|>", # Think Start
|
||||
think_end_token="<|THINK_END|>", # Think End
|
||||
mask_start_token="<|MASK_1e69f|>", # Mask start
|
||||
mask_end_token="<|UNMASK_1e69f|>", # Mask end
|
||||
):
|
||||
import sentencepiece
|
||||
|
||||
self._tokenizer = sentencepiece.SentencePieceProcessor(model_file=model_file)
|
||||
|
||||
self._vocab = {}
|
||||
self._inv_vocab = {}
|
||||
|
||||
self._special_tokens = {}
|
||||
self._inv_special_tokens = {}
|
||||
|
||||
self._t5_tokens = []
|
||||
|
||||
for idx in range(self._tokenizer.get_piece_size()):
|
||||
text = self._tokenizer.id_to_piece(idx)
|
||||
self._inv_vocab[idx] = text
|
||||
self._vocab[text] = idx
|
||||
|
||||
if self._tokenizer.is_control(idx) or self._tokenizer.is_unknown(idx):
|
||||
self._special_tokens[text] = idx
|
||||
self._inv_special_tokens[idx] = text
|
||||
|
||||
self._unk_id = self._tokenizer.unk_id()
|
||||
self._bos_id = self._tokenizer.bos_id()
|
||||
self._eos_id = self._tokenizer.eos_id()
|
||||
|
||||
for token in [
|
||||
bot_token, eot_token, call_start_token, call_end_token,
|
||||
think_start_token, think_end_token
|
||||
]:
|
||||
assert token in self._vocab, f"Token '{token}' not found in tokenizer"
|
||||
assert token in self._special_tokens, f"Token '{token}' is not a special token"
|
||||
|
||||
for token in [mask_start_token, mask_end_token]:
|
||||
assert token in self._vocab, f"Token '{token}' not found in tokenizer"
|
||||
|
||||
self._bot_id = self._tokenizer.piece_to_id(bot_token)
|
||||
self._eot_id = self._tokenizer.piece_to_id(eot_token)
|
||||
self._call_start_id = self._tokenizer.piece_to_id(call_start_token)
|
||||
self._call_end_id = self._tokenizer.piece_to_id(call_end_token)
|
||||
self._think_start_id = self._tokenizer.piece_to_id(think_start_token)
|
||||
self._think_end_id = self._tokenizer.piece_to_id(think_end_token)
|
||||
self._mask_start_id = self._tokenizer.piece_to_id(mask_start_token)
|
||||
self._mask_end_id = self._tokenizer.piece_to_id(mask_end_token)
|
||||
|
||||
self._underline_id = self._tokenizer.piece_to_id("\u2581")
|
||||
|
||||
@property
|
||||
def vocab(self):
|
||||
return self._vocab
|
||||
|
||||
@property
|
||||
def inv_vocab(self):
|
||||
return self._inv_vocab
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self._tokenizer.vocab_size()
|
||||
|
||||
def tokenize(self, text: str) -> List[int]:
|
||||
return self._tokenizer.encode_as_ids(text)
|
||||
|
||||
def detokenize(self, token_ids: List[int]) -> str:
|
||||
return self._tokenizer.decode_ids(token_ids)
|
||||
|
||||
|
||||
class Tokens:
|
||||
def __init__(self, input_ids, cu_input_ids, attention_mask, cu_seqlens, max_seq_len) -> None:
|
||||
self.input_ids = input_ids
|
||||
self.attention_mask = attention_mask
|
||||
self.cu_input_ids = cu_input_ids
|
||||
self.cu_seqlens = cu_seqlens
|
||||
self.max_seq_len = max_seq_len
|
||||
def to(self, device):
|
||||
self.input_ids = self.input_ids.to(device)
|
||||
self.attention_mask = self.attention_mask.to(device)
|
||||
self.cu_input_ids = self.cu_input_ids.to(device)
|
||||
self.cu_seqlens = self.cu_seqlens.to(device)
|
||||
return self
|
||||
|
||||
class Wrapped_StepChatTokenizer(StepChatTokenizer):
|
||||
def __call__(self, text, max_length=320, padding="max_length", truncation=True, return_tensors="pt"):
|
||||
# [bos, ..., eos, pad, pad, ..., pad]
|
||||
self.BOS = 1
|
||||
self.EOS = 2
|
||||
self.PAD = 2
|
||||
out_tokens = []
|
||||
attn_mask = []
|
||||
if len(text) == 0:
|
||||
part_tokens = [self.BOS] + [self.EOS]
|
||||
valid_size = len(part_tokens)
|
||||
if len(part_tokens) < max_length:
|
||||
part_tokens += [self.PAD] * (max_length - valid_size)
|
||||
out_tokens.append(part_tokens)
|
||||
attn_mask.append([1]*valid_size+[0]*(max_length-valid_size))
|
||||
else:
|
||||
for part in text:
|
||||
part_tokens = self.tokenize(part)
|
||||
part_tokens = part_tokens[:(max_length - 2)] # leave 2 space for bos and eos
|
||||
part_tokens = [self.BOS] + part_tokens + [self.EOS]
|
||||
valid_size = len(part_tokens)
|
||||
if len(part_tokens) < max_length:
|
||||
part_tokens += [self.PAD] * (max_length - valid_size)
|
||||
out_tokens.append(part_tokens)
|
||||
attn_mask.append([1]*valid_size+[0]*(max_length-valid_size))
|
||||
|
||||
out_tokens = torch.tensor(out_tokens, dtype=torch.long)
|
||||
attn_mask = torch.tensor(attn_mask, dtype=torch.long)
|
||||
|
||||
# padding y based on tp size
|
||||
padded_len = 0
|
||||
padded_flag = True if padded_len > 0 else False
|
||||
if padded_flag:
|
||||
pad_tokens = torch.tensor([[self.PAD] * max_length], device=out_tokens.device)
|
||||
pad_attn_mask = torch.tensor([[1]*padded_len+[0]*(max_length-padded_len)], device=attn_mask.device)
|
||||
out_tokens = torch.cat([out_tokens, pad_tokens], dim=0)
|
||||
attn_mask = torch.cat([attn_mask, pad_attn_mask], dim=0)
|
||||
|
||||
# cu_seqlens
|
||||
cu_out_tokens = out_tokens.masked_select(attn_mask != 0).unsqueeze(0)
|
||||
seqlen = attn_mask.sum(dim=1).tolist()
|
||||
cu_seqlens = torch.cumsum(torch.tensor([0]+seqlen), 0).to(device=out_tokens.device,dtype=torch.int32)
|
||||
max_seq_len = max(seqlen)
|
||||
return Tokens(out_tokens, cu_out_tokens, attn_mask, cu_seqlens, max_seq_len)
|
||||
|
||||
|
||||
|
||||
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=True,
|
||||
return_attn_probs=False, tp_group_rank=0, tp_group_size=1):
|
||||
softmax_scale = q.size(-1) ** (-0.5) if softmax_scale is None else softmax_scale
|
||||
if hasattr(torch.ops.Optimus, "fwd"):
|
||||
results = torch.ops.Optimus.fwd(q, k, v, None, dropout_p, softmax_scale, causal, return_attn_probs, None, tp_group_rank, tp_group_size)[0]
|
||||
else:
|
||||
warnings.warn("Cannot load `torch.ops.Optimus.fwd`. Using `torch.nn.functional.scaled_dot_product_attention` instead.")
|
||||
results = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True, scale=softmax_scale).transpose(1, 2)
|
||||
return results
|
||||
|
||||
|
||||
class FlashSelfAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
attention_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dropout_p = attention_dropout
|
||||
|
||||
|
||||
def forward(self, q, k, v, cu_seqlens=None, max_seq_len=None):
|
||||
if cu_seqlens is None:
|
||||
output = flash_attn_func(q, k, v, dropout_p=self.dropout_p)
|
||||
else:
|
||||
raise ValueError('cu_seqlens is not supported!')
|
||||
|
||||
return output
|
||||
|
||||
|
||||
|
||||
def safediv(n, d):
|
||||
q, r = divmod(n, d)
|
||||
assert r == 0
|
||||
return q
|
||||
|
||||
|
||||
class MultiQueryAttention(nn.Module):
|
||||
def __init__(self, cfg, layer_id=None):
|
||||
super().__init__()
|
||||
|
||||
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
|
||||
self.max_seq_len = cfg.seq_length
|
||||
self.use_flash_attention = cfg.use_flash_attn
|
||||
assert self.use_flash_attention, 'FlashAttention is required!'
|
||||
|
||||
self.n_groups = cfg.num_attention_groups
|
||||
self.tp_size = 1
|
||||
self.n_local_heads = cfg.num_attention_heads
|
||||
self.n_local_groups = self.n_groups
|
||||
|
||||
self.wqkv = nn.Linear(
|
||||
cfg.hidden_size,
|
||||
cfg.hidden_size + self.head_dim * 2 * self.n_groups,
|
||||
bias=False,
|
||||
)
|
||||
self.wo = nn.Linear(
|
||||
cfg.hidden_size,
|
||||
cfg.hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
assert self.use_flash_attention, 'non-Flash attention not supported yet.'
|
||||
self.core_attention = FlashSelfAttention(attention_dropout=cfg.attention_dropout)
|
||||
|
||||
self.layer_id = layer_id
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
cu_seqlens: Optional[torch.Tensor],
|
||||
max_seq_len: Optional[torch.Tensor],
|
||||
):
|
||||
seqlen, bsz, dim = x.shape
|
||||
xqkv = self.wqkv(x)
|
||||
|
||||
xq, xkv = torch.split(
|
||||
xqkv,
|
||||
(dim // self.tp_size,
|
||||
self.head_dim*2*self.n_groups // self.tp_size
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# gather on 1st dimension
|
||||
xq = xq.view(seqlen, bsz, self.n_local_heads, self.head_dim)
|
||||
xkv = xkv.view(seqlen, bsz, self.n_local_groups, 2 * self.head_dim)
|
||||
xk, xv = xkv.chunk(2, -1)
|
||||
|
||||
# rotary embedding + flash attn
|
||||
xq = rearrange(xq, "s b h d -> b s h d")
|
||||
xk = rearrange(xk, "s b h d -> b s h d")
|
||||
xv = rearrange(xv, "s b h d -> b s h d")
|
||||
|
||||
q_per_kv = self.n_local_heads // self.n_local_groups
|
||||
if q_per_kv > 1:
|
||||
b, s, h, d = xk.size()
|
||||
if h == 1:
|
||||
xk = xk.expand(b, s, q_per_kv, d)
|
||||
xv = xv.expand(b, s, q_per_kv, d)
|
||||
else:
|
||||
''' To cover the cases where h > 1, we have
|
||||
the following implementation, which is equivalent to:
|
||||
xk = xk.repeat_interleave(q_per_kv, dim=-2)
|
||||
xv = xv.repeat_interleave(q_per_kv, dim=-2)
|
||||
but can avoid calling aten::item() that involves cpu.
|
||||
'''
|
||||
idx = torch.arange(q_per_kv * h, device=xk.device).reshape(q_per_kv, -1).permute(1, 0).flatten()
|
||||
xk = torch.index_select(xk.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous()
|
||||
xv = torch.index_select(xv.repeat(1, 1, q_per_kv, 1), 2, idx).contiguous()
|
||||
|
||||
if self.use_flash_attention:
|
||||
output = self.core_attention(xq, xk, xv,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seq_len=max_seq_len)
|
||||
# reduce-scatter only support first dimension now
|
||||
output = rearrange(output, "b s h d -> s b (h d)").contiguous()
|
||||
else:
|
||||
xq, xk, xv = [
|
||||
rearrange(x, "b s ... -> s b ...").contiguous()
|
||||
for x in (xq, xk, xv)
|
||||
]
|
||||
output = self.core_attention(xq, xk, xv, mask)
|
||||
output = self.wo(output)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cfg,
|
||||
dim: int,
|
||||
hidden_dim: int,
|
||||
layer_id: int,
|
||||
multiple_of: int=256,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
||||
def swiglu(x):
|
||||
x = torch.chunk(x, 2, dim=-1)
|
||||
return F.silu(x[0]) * x[1]
|
||||
self.swiglu = swiglu
|
||||
|
||||
self.w1 = nn.Linear(
|
||||
dim,
|
||||
2 * hidden_dim,
|
||||
bias=False,
|
||||
)
|
||||
self.w2 = nn.Linear(
|
||||
hidden_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.swiglu(self.w1(x))
|
||||
output = self.w2(x)
|
||||
return output
|
||||
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self, cfg, layer_id: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.n_heads = cfg.num_attention_heads
|
||||
self.dim = cfg.hidden_size
|
||||
self.head_dim = cfg.hidden_size // cfg.num_attention_heads
|
||||
self.attention = MultiQueryAttention(
|
||||
cfg,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
|
||||
self.feed_forward = FeedForward(
|
||||
cfg,
|
||||
dim=cfg.hidden_size,
|
||||
hidden_dim=cfg.ffn_hidden_size,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
self.layer_id = layer_id
|
||||
self.attention_norm = RMSNorm(
|
||||
cfg.hidden_size,
|
||||
eps=cfg.layernorm_epsilon,
|
||||
)
|
||||
self.ffn_norm = RMSNorm(
|
||||
cfg.hidden_size,
|
||||
eps=cfg.layernorm_epsilon,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
cu_seqlens: Optional[torch.Tensor],
|
||||
max_seq_len: Optional[torch.Tensor],
|
||||
):
|
||||
residual = self.attention.forward(
|
||||
self.attention_norm(x), mask,
|
||||
cu_seqlens, max_seq_len
|
||||
)
|
||||
h = x + residual
|
||||
ffn_res = self.feed_forward.forward(self.ffn_norm(h))
|
||||
out = h + ffn_res
|
||||
return out
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
max_seq_size=8192,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_layers = config.num_layers
|
||||
self.layers = self._build_layers(config)
|
||||
|
||||
def _build_layers(self, config):
|
||||
layers = torch.nn.ModuleList()
|
||||
for layer_id in range(self.num_layers):
|
||||
layers.append(
|
||||
TransformerBlock(
|
||||
config,
|
||||
layer_id=layer_id + 1 ,
|
||||
)
|
||||
)
|
||||
return layers
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
cu_seqlens=None,
|
||||
max_seq_len=None,
|
||||
):
|
||||
|
||||
if max_seq_len is not None and not isinstance(max_seq_len, torch.Tensor):
|
||||
max_seq_len = torch.tensor(max_seq_len, dtype=torch.int32, device="cpu")
|
||||
|
||||
for lid, layer in enumerate(self.layers):
|
||||
hidden_states = layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
cu_seqlens,
|
||||
max_seq_len,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Step1Model(PreTrainedModel):
|
||||
config_class=PretrainedConfig
|
||||
@with_empty_init
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.tok_embeddings = LLaMaEmbedding(config)
|
||||
self.transformer = Transformer(config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
|
||||
hidden_states = self.tok_embeddings(input_ids)
|
||||
|
||||
hidden_states = self.transformer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
class STEP1TextEncoder(torch.nn.Module):
|
||||
def __init__(self, model_dir, max_length=320):
|
||||
super(STEP1TextEncoder, self).__init__()
|
||||
self.max_length = max_length
|
||||
self.text_tokenizer = Wrapped_StepChatTokenizer(os.path.join(model_dir, 'step1_chat_tokenizer.model'))
|
||||
text_encoder = Step1Model.from_pretrained(model_dir)
|
||||
self.text_encoder = text_encoder.eval().to(torch.bfloat16)
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(path, torch_dtype=torch.bfloat16):
|
||||
model = STEP1TextEncoder(path).to(torch_dtype)
|
||||
return model
|
||||
|
||||
@torch.no_grad
|
||||
def forward(self, prompts, with_mask=True, max_length=None, device="cuda"):
|
||||
self.device = device
|
||||
with torch.no_grad(), torch.amp.autocast(dtype=torch.bfloat16, device_type=device):
|
||||
if type(prompts) is str:
|
||||
prompts = [prompts]
|
||||
|
||||
txt_tokens = self.text_tokenizer(
|
||||
prompts, max_length=max_length or self.max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
y = self.text_encoder(
|
||||
txt_tokens.input_ids.to(self.device),
|
||||
attention_mask=txt_tokens.attention_mask.to(self.device) if with_mask else None
|
||||
)
|
||||
y_mask = txt_tokens.attention_mask
|
||||
return y.transpose(0,1), y_mask
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,505 +0,0 @@
|
||||
import torch
|
||||
from .sd_text_encoder import CLIPEncoderLayer
|
||||
|
||||
|
||||
class CLIPVisionEmbeddings(torch.nn.Module):
|
||||
def __init__(self, embed_dim=1280, image_size=224, patch_size=14, num_channels=3):
|
||||
super().__init__()
|
||||
|
||||
# class_embeds (This is a fixed tensor)
|
||||
self.class_embedding = torch.nn.Parameter(torch.randn(1, 1, embed_dim))
|
||||
|
||||
# position_embeds
|
||||
self.patch_embedding = torch.nn.Conv2d(in_channels=num_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, bias=False)
|
||||
|
||||
# position_embeds (This is a fixed tensor)
|
||||
self.position_embeds = torch.nn.Parameter(torch.zeros(1, (image_size // patch_size) ** 2 + 1, embed_dim))
|
||||
|
||||
def forward(self, pixel_values):
|
||||
batch_size = pixel_values.shape[0]
|
||||
patch_embeds = self.patch_embedding(pixel_values)
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
class_embeds = self.class_embedding.repeat(batch_size, 1, 1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + self.position_embeds
|
||||
return embeddings
|
||||
|
||||
|
||||
class SVDImageEncoder(torch.nn.Module):
|
||||
def __init__(self, embed_dim=1280, layer_norm_eps=1e-5, num_encoder_layers=32, encoder_intermediate_size=5120, projection_dim=1024, num_heads=16, head_dim=80):
|
||||
super().__init__()
|
||||
self.embeddings = CLIPVisionEmbeddings(embed_dim=embed_dim)
|
||||
self.pre_layernorm = torch.nn.LayerNorm(embed_dim, eps=layer_norm_eps)
|
||||
self.encoders = torch.nn.ModuleList([
|
||||
CLIPEncoderLayer(embed_dim, encoder_intermediate_size, num_heads=num_heads, head_dim=head_dim, use_quick_gelu=False)
|
||||
for _ in range(num_encoder_layers)])
|
||||
self.post_layernorm = torch.nn.LayerNorm(embed_dim, eps=layer_norm_eps)
|
||||
self.visual_projection = torch.nn.Linear(embed_dim, projection_dim, bias=False)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
embeds = self.embeddings(pixel_values)
|
||||
embeds = self.pre_layernorm(embeds)
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds)
|
||||
embeds = self.post_layernorm(embeds[:, 0, :])
|
||||
embeds = self.visual_projection(embeds)
|
||||
return embeds
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SVDImageEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SVDImageEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"vision_model.embeddings.patch_embedding.weight": "embeddings.patch_embedding.weight",
|
||||
"vision_model.embeddings.class_embedding": "embeddings.class_embedding",
|
||||
"vision_model.embeddings.position_embedding.weight": "embeddings.position_embeds",
|
||||
"vision_model.pre_layrnorm.weight": "pre_layernorm.weight",
|
||||
"vision_model.pre_layrnorm.bias": "pre_layernorm.bias",
|
||||
"vision_model.post_layernorm.weight": "post_layernorm.weight",
|
||||
"vision_model.post_layernorm.bias": "post_layernorm.bias",
|
||||
"visual_projection.weight": "visual_projection.weight"
|
||||
}
|
||||
attn_rename_dict = {
|
||||
"self_attn.q_proj": "attn.to_q",
|
||||
"self_attn.k_proj": "attn.to_k",
|
||||
"self_attn.v_proj": "attn.to_v",
|
||||
"self_attn.out_proj": "attn.to_out",
|
||||
"layer_norm1": "layer_norm1",
|
||||
"layer_norm2": "layer_norm2",
|
||||
"mlp.fc1": "fc1",
|
||||
"mlp.fc2": "fc2",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "vision_model.embeddings.class_embedding":
|
||||
param = state_dict[name].view(1, 1, -1)
|
||||
elif name == "vision_model.embeddings.position_embedding.weight":
|
||||
param = state_dict[name].unsqueeze(0)
|
||||
state_dict_[rename_dict[name]] = param
|
||||
elif name.startswith("vision_model.encoder.layers."):
|
||||
param = state_dict[name]
|
||||
names = name.split(".")
|
||||
layer_id, layer_type, tail = names[3], ".".join(names[4:-1]), names[-1]
|
||||
name_ = ".".join(["encoders", layer_id, attn_rename_dict[layer_type], tail])
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"conditioner.embedders.0.open_clip.model.visual.class_embedding": "embeddings.class_embedding",
|
||||
"conditioner.embedders.0.open_clip.model.visual.conv1.weight": "embeddings.patch_embedding.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.ln_post.bias": "post_layernorm.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.ln_post.weight": "post_layernorm.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.ln_pre.bias": "pre_layernorm.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.ln_pre.weight": "pre_layernorm.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.positional_embedding": "embeddings.position_embeds",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.in_proj_bias": ['encoders.0.attn.to_q.bias', 'encoders.0.attn.to_k.bias', 'encoders.0.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.in_proj_weight": ['encoders.0.attn.to_q.weight', 'encoders.0.attn.to_k.weight', 'encoders.0.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.out_proj.bias": "encoders.0.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.attn.out_proj.weight": "encoders.0.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_1.bias": "encoders.0.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_1.weight": "encoders.0.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_2.bias": "encoders.0.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.ln_2.weight": "encoders.0.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_fc.bias": "encoders.0.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_fc.weight": "encoders.0.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_proj.bias": "encoders.0.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.0.mlp.c_proj.weight": "encoders.0.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.in_proj_bias": ['encoders.1.attn.to_q.bias', 'encoders.1.attn.to_k.bias', 'encoders.1.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.in_proj_weight": ['encoders.1.attn.to_q.weight', 'encoders.1.attn.to_k.weight', 'encoders.1.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.out_proj.bias": "encoders.1.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.attn.out_proj.weight": "encoders.1.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_1.bias": "encoders.1.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_1.weight": "encoders.1.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_2.bias": "encoders.1.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.ln_2.weight": "encoders.1.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_fc.bias": "encoders.1.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_fc.weight": "encoders.1.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_proj.bias": "encoders.1.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.1.mlp.c_proj.weight": "encoders.1.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.in_proj_bias": ['encoders.10.attn.to_q.bias', 'encoders.10.attn.to_k.bias', 'encoders.10.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.in_proj_weight": ['encoders.10.attn.to_q.weight', 'encoders.10.attn.to_k.weight', 'encoders.10.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.out_proj.bias": "encoders.10.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.attn.out_proj.weight": "encoders.10.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_1.bias": "encoders.10.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_1.weight": "encoders.10.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_2.bias": "encoders.10.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.ln_2.weight": "encoders.10.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_fc.bias": "encoders.10.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_fc.weight": "encoders.10.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_proj.bias": "encoders.10.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.10.mlp.c_proj.weight": "encoders.10.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.in_proj_bias": ['encoders.11.attn.to_q.bias', 'encoders.11.attn.to_k.bias', 'encoders.11.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.in_proj_weight": ['encoders.11.attn.to_q.weight', 'encoders.11.attn.to_k.weight', 'encoders.11.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.out_proj.bias": "encoders.11.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.attn.out_proj.weight": "encoders.11.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_1.bias": "encoders.11.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_1.weight": "encoders.11.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_2.bias": "encoders.11.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.ln_2.weight": "encoders.11.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_fc.bias": "encoders.11.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_fc.weight": "encoders.11.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_proj.bias": "encoders.11.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.11.mlp.c_proj.weight": "encoders.11.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.in_proj_bias": ['encoders.12.attn.to_q.bias', 'encoders.12.attn.to_k.bias', 'encoders.12.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.in_proj_weight": ['encoders.12.attn.to_q.weight', 'encoders.12.attn.to_k.weight', 'encoders.12.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.out_proj.bias": "encoders.12.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.attn.out_proj.weight": "encoders.12.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_1.bias": "encoders.12.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_1.weight": "encoders.12.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_2.bias": "encoders.12.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.ln_2.weight": "encoders.12.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_fc.bias": "encoders.12.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_fc.weight": "encoders.12.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_proj.bias": "encoders.12.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.12.mlp.c_proj.weight": "encoders.12.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.in_proj_bias": ['encoders.13.attn.to_q.bias', 'encoders.13.attn.to_k.bias', 'encoders.13.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.in_proj_weight": ['encoders.13.attn.to_q.weight', 'encoders.13.attn.to_k.weight', 'encoders.13.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.out_proj.bias": "encoders.13.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.attn.out_proj.weight": "encoders.13.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_1.bias": "encoders.13.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_1.weight": "encoders.13.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_2.bias": "encoders.13.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.ln_2.weight": "encoders.13.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_fc.bias": "encoders.13.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_fc.weight": "encoders.13.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_proj.bias": "encoders.13.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.13.mlp.c_proj.weight": "encoders.13.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.in_proj_bias": ['encoders.14.attn.to_q.bias', 'encoders.14.attn.to_k.bias', 'encoders.14.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.in_proj_weight": ['encoders.14.attn.to_q.weight', 'encoders.14.attn.to_k.weight', 'encoders.14.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.out_proj.bias": "encoders.14.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.attn.out_proj.weight": "encoders.14.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_1.bias": "encoders.14.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_1.weight": "encoders.14.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_2.bias": "encoders.14.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.ln_2.weight": "encoders.14.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_fc.bias": "encoders.14.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_fc.weight": "encoders.14.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_proj.bias": "encoders.14.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.14.mlp.c_proj.weight": "encoders.14.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.in_proj_bias": ['encoders.15.attn.to_q.bias', 'encoders.15.attn.to_k.bias', 'encoders.15.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.in_proj_weight": ['encoders.15.attn.to_q.weight', 'encoders.15.attn.to_k.weight', 'encoders.15.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.out_proj.bias": "encoders.15.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.attn.out_proj.weight": "encoders.15.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_1.bias": "encoders.15.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_1.weight": "encoders.15.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_2.bias": "encoders.15.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.ln_2.weight": "encoders.15.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_fc.bias": "encoders.15.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_fc.weight": "encoders.15.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_proj.bias": "encoders.15.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.15.mlp.c_proj.weight": "encoders.15.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.in_proj_bias": ['encoders.16.attn.to_q.bias', 'encoders.16.attn.to_k.bias', 'encoders.16.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.in_proj_weight": ['encoders.16.attn.to_q.weight', 'encoders.16.attn.to_k.weight', 'encoders.16.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.out_proj.bias": "encoders.16.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.attn.out_proj.weight": "encoders.16.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_1.bias": "encoders.16.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_1.weight": "encoders.16.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_2.bias": "encoders.16.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.ln_2.weight": "encoders.16.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_fc.bias": "encoders.16.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_fc.weight": "encoders.16.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_proj.bias": "encoders.16.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.16.mlp.c_proj.weight": "encoders.16.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.in_proj_bias": ['encoders.17.attn.to_q.bias', 'encoders.17.attn.to_k.bias', 'encoders.17.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.in_proj_weight": ['encoders.17.attn.to_q.weight', 'encoders.17.attn.to_k.weight', 'encoders.17.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.out_proj.bias": "encoders.17.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.attn.out_proj.weight": "encoders.17.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_1.bias": "encoders.17.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_1.weight": "encoders.17.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_2.bias": "encoders.17.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.ln_2.weight": "encoders.17.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_fc.bias": "encoders.17.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_fc.weight": "encoders.17.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_proj.bias": "encoders.17.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.17.mlp.c_proj.weight": "encoders.17.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.in_proj_bias": ['encoders.18.attn.to_q.bias', 'encoders.18.attn.to_k.bias', 'encoders.18.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.in_proj_weight": ['encoders.18.attn.to_q.weight', 'encoders.18.attn.to_k.weight', 'encoders.18.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.out_proj.bias": "encoders.18.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.attn.out_proj.weight": "encoders.18.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_1.bias": "encoders.18.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_1.weight": "encoders.18.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_2.bias": "encoders.18.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.ln_2.weight": "encoders.18.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_fc.bias": "encoders.18.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_fc.weight": "encoders.18.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_proj.bias": "encoders.18.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.18.mlp.c_proj.weight": "encoders.18.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.in_proj_bias": ['encoders.19.attn.to_q.bias', 'encoders.19.attn.to_k.bias', 'encoders.19.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.in_proj_weight": ['encoders.19.attn.to_q.weight', 'encoders.19.attn.to_k.weight', 'encoders.19.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.out_proj.bias": "encoders.19.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.attn.out_proj.weight": "encoders.19.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_1.bias": "encoders.19.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_1.weight": "encoders.19.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_2.bias": "encoders.19.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.ln_2.weight": "encoders.19.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_fc.bias": "encoders.19.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_fc.weight": "encoders.19.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_proj.bias": "encoders.19.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.19.mlp.c_proj.weight": "encoders.19.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.in_proj_bias": ['encoders.2.attn.to_q.bias', 'encoders.2.attn.to_k.bias', 'encoders.2.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.in_proj_weight": ['encoders.2.attn.to_q.weight', 'encoders.2.attn.to_k.weight', 'encoders.2.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.out_proj.bias": "encoders.2.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.attn.out_proj.weight": "encoders.2.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_1.bias": "encoders.2.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_1.weight": "encoders.2.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_2.bias": "encoders.2.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.ln_2.weight": "encoders.2.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_fc.bias": "encoders.2.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_fc.weight": "encoders.2.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_proj.bias": "encoders.2.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.2.mlp.c_proj.weight": "encoders.2.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.in_proj_bias": ['encoders.20.attn.to_q.bias', 'encoders.20.attn.to_k.bias', 'encoders.20.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.in_proj_weight": ['encoders.20.attn.to_q.weight', 'encoders.20.attn.to_k.weight', 'encoders.20.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.out_proj.bias": "encoders.20.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.attn.out_proj.weight": "encoders.20.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_1.bias": "encoders.20.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_1.weight": "encoders.20.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_2.bias": "encoders.20.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.ln_2.weight": "encoders.20.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_fc.bias": "encoders.20.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_fc.weight": "encoders.20.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_proj.bias": "encoders.20.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.20.mlp.c_proj.weight": "encoders.20.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.in_proj_bias": ['encoders.21.attn.to_q.bias', 'encoders.21.attn.to_k.bias', 'encoders.21.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.in_proj_weight": ['encoders.21.attn.to_q.weight', 'encoders.21.attn.to_k.weight', 'encoders.21.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.out_proj.bias": "encoders.21.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.attn.out_proj.weight": "encoders.21.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_1.bias": "encoders.21.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_1.weight": "encoders.21.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_2.bias": "encoders.21.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.ln_2.weight": "encoders.21.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_fc.bias": "encoders.21.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_fc.weight": "encoders.21.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_proj.bias": "encoders.21.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.21.mlp.c_proj.weight": "encoders.21.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.in_proj_bias": ['encoders.22.attn.to_q.bias', 'encoders.22.attn.to_k.bias', 'encoders.22.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.in_proj_weight": ['encoders.22.attn.to_q.weight', 'encoders.22.attn.to_k.weight', 'encoders.22.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.out_proj.bias": "encoders.22.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.attn.out_proj.weight": "encoders.22.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_1.bias": "encoders.22.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_1.weight": "encoders.22.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_2.bias": "encoders.22.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.ln_2.weight": "encoders.22.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_fc.bias": "encoders.22.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_fc.weight": "encoders.22.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_proj.bias": "encoders.22.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.22.mlp.c_proj.weight": "encoders.22.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.in_proj_bias": ['encoders.23.attn.to_q.bias', 'encoders.23.attn.to_k.bias', 'encoders.23.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.in_proj_weight": ['encoders.23.attn.to_q.weight', 'encoders.23.attn.to_k.weight', 'encoders.23.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.out_proj.bias": "encoders.23.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.attn.out_proj.weight": "encoders.23.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_1.bias": "encoders.23.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_1.weight": "encoders.23.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_2.bias": "encoders.23.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.ln_2.weight": "encoders.23.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_fc.bias": "encoders.23.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_fc.weight": "encoders.23.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_proj.bias": "encoders.23.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.23.mlp.c_proj.weight": "encoders.23.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.in_proj_bias": ['encoders.24.attn.to_q.bias', 'encoders.24.attn.to_k.bias', 'encoders.24.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.in_proj_weight": ['encoders.24.attn.to_q.weight', 'encoders.24.attn.to_k.weight', 'encoders.24.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.out_proj.bias": "encoders.24.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.attn.out_proj.weight": "encoders.24.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_1.bias": "encoders.24.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_1.weight": "encoders.24.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_2.bias": "encoders.24.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.ln_2.weight": "encoders.24.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_fc.bias": "encoders.24.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_fc.weight": "encoders.24.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_proj.bias": "encoders.24.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.24.mlp.c_proj.weight": "encoders.24.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.in_proj_bias": ['encoders.25.attn.to_q.bias', 'encoders.25.attn.to_k.bias', 'encoders.25.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.in_proj_weight": ['encoders.25.attn.to_q.weight', 'encoders.25.attn.to_k.weight', 'encoders.25.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.out_proj.bias": "encoders.25.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.attn.out_proj.weight": "encoders.25.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_1.bias": "encoders.25.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_1.weight": "encoders.25.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_2.bias": "encoders.25.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.ln_2.weight": "encoders.25.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_fc.bias": "encoders.25.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_fc.weight": "encoders.25.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_proj.bias": "encoders.25.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.25.mlp.c_proj.weight": "encoders.25.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.in_proj_bias": ['encoders.26.attn.to_q.bias', 'encoders.26.attn.to_k.bias', 'encoders.26.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.in_proj_weight": ['encoders.26.attn.to_q.weight', 'encoders.26.attn.to_k.weight', 'encoders.26.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.out_proj.bias": "encoders.26.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.attn.out_proj.weight": "encoders.26.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_1.bias": "encoders.26.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_1.weight": "encoders.26.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_2.bias": "encoders.26.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.ln_2.weight": "encoders.26.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_fc.bias": "encoders.26.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_fc.weight": "encoders.26.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_proj.bias": "encoders.26.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.26.mlp.c_proj.weight": "encoders.26.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.in_proj_bias": ['encoders.27.attn.to_q.bias', 'encoders.27.attn.to_k.bias', 'encoders.27.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.in_proj_weight": ['encoders.27.attn.to_q.weight', 'encoders.27.attn.to_k.weight', 'encoders.27.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.out_proj.bias": "encoders.27.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.attn.out_proj.weight": "encoders.27.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_1.bias": "encoders.27.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_1.weight": "encoders.27.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_2.bias": "encoders.27.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.ln_2.weight": "encoders.27.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_fc.bias": "encoders.27.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_fc.weight": "encoders.27.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_proj.bias": "encoders.27.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.27.mlp.c_proj.weight": "encoders.27.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.in_proj_bias": ['encoders.28.attn.to_q.bias', 'encoders.28.attn.to_k.bias', 'encoders.28.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.in_proj_weight": ['encoders.28.attn.to_q.weight', 'encoders.28.attn.to_k.weight', 'encoders.28.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.out_proj.bias": "encoders.28.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.attn.out_proj.weight": "encoders.28.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_1.bias": "encoders.28.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_1.weight": "encoders.28.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_2.bias": "encoders.28.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.ln_2.weight": "encoders.28.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_fc.bias": "encoders.28.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_fc.weight": "encoders.28.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_proj.bias": "encoders.28.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.28.mlp.c_proj.weight": "encoders.28.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.in_proj_bias": ['encoders.29.attn.to_q.bias', 'encoders.29.attn.to_k.bias', 'encoders.29.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.in_proj_weight": ['encoders.29.attn.to_q.weight', 'encoders.29.attn.to_k.weight', 'encoders.29.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.out_proj.bias": "encoders.29.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.attn.out_proj.weight": "encoders.29.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_1.bias": "encoders.29.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_1.weight": "encoders.29.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_2.bias": "encoders.29.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.ln_2.weight": "encoders.29.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_fc.bias": "encoders.29.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_fc.weight": "encoders.29.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_proj.bias": "encoders.29.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.29.mlp.c_proj.weight": "encoders.29.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.in_proj_bias": ['encoders.3.attn.to_q.bias', 'encoders.3.attn.to_k.bias', 'encoders.3.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.in_proj_weight": ['encoders.3.attn.to_q.weight', 'encoders.3.attn.to_k.weight', 'encoders.3.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.out_proj.bias": "encoders.3.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.attn.out_proj.weight": "encoders.3.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_1.bias": "encoders.3.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_1.weight": "encoders.3.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_2.bias": "encoders.3.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.ln_2.weight": "encoders.3.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_fc.bias": "encoders.3.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_fc.weight": "encoders.3.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_proj.bias": "encoders.3.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.3.mlp.c_proj.weight": "encoders.3.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.in_proj_bias": ['encoders.30.attn.to_q.bias', 'encoders.30.attn.to_k.bias', 'encoders.30.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.in_proj_weight": ['encoders.30.attn.to_q.weight', 'encoders.30.attn.to_k.weight', 'encoders.30.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.out_proj.bias": "encoders.30.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.attn.out_proj.weight": "encoders.30.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_1.bias": "encoders.30.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_1.weight": "encoders.30.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_2.bias": "encoders.30.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.ln_2.weight": "encoders.30.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_fc.bias": "encoders.30.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_fc.weight": "encoders.30.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_proj.bias": "encoders.30.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.30.mlp.c_proj.weight": "encoders.30.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.in_proj_bias": ['encoders.31.attn.to_q.bias', 'encoders.31.attn.to_k.bias', 'encoders.31.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.in_proj_weight": ['encoders.31.attn.to_q.weight', 'encoders.31.attn.to_k.weight', 'encoders.31.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.out_proj.bias": "encoders.31.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.attn.out_proj.weight": "encoders.31.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_1.bias": "encoders.31.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_1.weight": "encoders.31.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_2.bias": "encoders.31.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.ln_2.weight": "encoders.31.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_fc.bias": "encoders.31.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_fc.weight": "encoders.31.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_proj.bias": "encoders.31.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.31.mlp.c_proj.weight": "encoders.31.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.in_proj_bias": ['encoders.4.attn.to_q.bias', 'encoders.4.attn.to_k.bias', 'encoders.4.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.in_proj_weight": ['encoders.4.attn.to_q.weight', 'encoders.4.attn.to_k.weight', 'encoders.4.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.out_proj.bias": "encoders.4.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.attn.out_proj.weight": "encoders.4.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_1.bias": "encoders.4.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_1.weight": "encoders.4.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_2.bias": "encoders.4.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.ln_2.weight": "encoders.4.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_fc.bias": "encoders.4.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_fc.weight": "encoders.4.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_proj.bias": "encoders.4.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.4.mlp.c_proj.weight": "encoders.4.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.in_proj_bias": ['encoders.5.attn.to_q.bias', 'encoders.5.attn.to_k.bias', 'encoders.5.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.in_proj_weight": ['encoders.5.attn.to_q.weight', 'encoders.5.attn.to_k.weight', 'encoders.5.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.out_proj.bias": "encoders.5.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.attn.out_proj.weight": "encoders.5.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_1.bias": "encoders.5.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_1.weight": "encoders.5.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_2.bias": "encoders.5.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.ln_2.weight": "encoders.5.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_fc.bias": "encoders.5.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_fc.weight": "encoders.5.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_proj.bias": "encoders.5.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.5.mlp.c_proj.weight": "encoders.5.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.in_proj_bias": ['encoders.6.attn.to_q.bias', 'encoders.6.attn.to_k.bias', 'encoders.6.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.in_proj_weight": ['encoders.6.attn.to_q.weight', 'encoders.6.attn.to_k.weight', 'encoders.6.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.out_proj.bias": "encoders.6.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.attn.out_proj.weight": "encoders.6.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_1.bias": "encoders.6.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_1.weight": "encoders.6.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_2.bias": "encoders.6.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.ln_2.weight": "encoders.6.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_fc.bias": "encoders.6.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_fc.weight": "encoders.6.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_proj.bias": "encoders.6.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.6.mlp.c_proj.weight": "encoders.6.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.in_proj_bias": ['encoders.7.attn.to_q.bias', 'encoders.7.attn.to_k.bias', 'encoders.7.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.in_proj_weight": ['encoders.7.attn.to_q.weight', 'encoders.7.attn.to_k.weight', 'encoders.7.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.out_proj.bias": "encoders.7.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.attn.out_proj.weight": "encoders.7.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_1.bias": "encoders.7.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_1.weight": "encoders.7.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_2.bias": "encoders.7.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.ln_2.weight": "encoders.7.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_fc.bias": "encoders.7.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_fc.weight": "encoders.7.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_proj.bias": "encoders.7.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.7.mlp.c_proj.weight": "encoders.7.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.in_proj_bias": ['encoders.8.attn.to_q.bias', 'encoders.8.attn.to_k.bias', 'encoders.8.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.in_proj_weight": ['encoders.8.attn.to_q.weight', 'encoders.8.attn.to_k.weight', 'encoders.8.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.out_proj.bias": "encoders.8.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.attn.out_proj.weight": "encoders.8.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_1.bias": "encoders.8.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_1.weight": "encoders.8.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_2.bias": "encoders.8.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.ln_2.weight": "encoders.8.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_fc.bias": "encoders.8.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_fc.weight": "encoders.8.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_proj.bias": "encoders.8.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.8.mlp.c_proj.weight": "encoders.8.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.in_proj_bias": ['encoders.9.attn.to_q.bias', 'encoders.9.attn.to_k.bias', 'encoders.9.attn.to_v.bias'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.in_proj_weight": ['encoders.9.attn.to_q.weight', 'encoders.9.attn.to_k.weight', 'encoders.9.attn.to_v.weight'],
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.out_proj.bias": "encoders.9.attn.to_out.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.attn.out_proj.weight": "encoders.9.attn.to_out.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_1.bias": "encoders.9.layer_norm1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_1.weight": "encoders.9.layer_norm1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_2.bias": "encoders.9.layer_norm2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.ln_2.weight": "encoders.9.layer_norm2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_fc.bias": "encoders.9.fc1.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_fc.weight": "encoders.9.fc1.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_proj.bias": "encoders.9.fc2.bias",
|
||||
"conditioner.embedders.0.open_clip.model.visual.transformer.resblocks.9.mlp.c_proj.weight": "encoders.9.fc2.weight",
|
||||
"conditioner.embedders.0.open_clip.model.visual.proj": "visual_projection.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if name == "conditioner.embedders.0.open_clip.model.visual.class_embedding":
|
||||
param = param.reshape((1, 1, param.shape[0]))
|
||||
elif name == "conditioner.embedders.0.open_clip.model.visual.positional_embedding":
|
||||
param = param.reshape((1, param.shape[0], param.shape[1]))
|
||||
elif name == "conditioner.embedders.0.open_clip.model.visual.proj":
|
||||
param = param.T
|
||||
if isinstance(rename_dict[name], str):
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
length = param.shape[0] // 3
|
||||
for i, rename in enumerate(rename_dict[name]):
|
||||
state_dict_[rename] = param[i*length: i*length+length]
|
||||
return state_dict_
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,578 +0,0 @@
|
||||
import torch
|
||||
from .attention import Attention
|
||||
from .sd_unet import ResnetBlock, UpSampler
|
||||
from .tiler import TileWorker
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
class VAEAttentionBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, num_attention_heads, attention_head_dim, in_channels, num_layers=1, norm_num_groups=32, eps=1e-5):
|
||||
super().__init__()
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
|
||||
self.transformer_blocks = torch.nn.ModuleList([
|
||||
Attention(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
bias_q=True,
|
||||
bias_kv=True,
|
||||
bias_out=True
|
||||
)
|
||||
for d in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack):
|
||||
batch, _, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
|
||||
class TemporalResnetBlock(torch.nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, groups=32, eps=1e-5):
|
||||
super().__init__()
|
||||
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
self.conv1 = torch.nn.Conv3d(in_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0))
|
||||
self.norm2 = torch.nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
||||
self.conv2 = torch.nn.Conv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=(1, 0, 0))
|
||||
self.nonlinearity = torch.nn.SiLU()
|
||||
self.mix_factor = torch.nn.Parameter(torch.Tensor([0.5]))
|
||||
|
||||
def forward(self, hidden_states, time_emb, text_emb, res_stack, **kwargs):
|
||||
x_spatial = hidden_states
|
||||
x = rearrange(hidden_states, "T C H W -> 1 C T H W")
|
||||
x = self.norm1(x)
|
||||
x = self.nonlinearity(x)
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = self.nonlinearity(x)
|
||||
x = self.conv2(x)
|
||||
x_temporal = hidden_states + x[0].permute(1, 0, 2, 3)
|
||||
alpha = torch.sigmoid(self.mix_factor)
|
||||
hidden_states = alpha * x_temporal + (1 - alpha) * x_spatial
|
||||
return hidden_states, time_emb, text_emb, res_stack
|
||||
|
||||
|
||||
class SVDVAEDecoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.18215
|
||||
self.conv_in = torch.nn.Conv2d(4, 512, kernel_size=3, padding=1)
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# UNetMidBlock
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
TemporalResnetBlock(512, 512, eps=1e-6),
|
||||
VAEAttentionBlock(1, 512, 512, 1, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
TemporalResnetBlock(512, 512, eps=1e-6),
|
||||
# UpDecoderBlock
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
TemporalResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
TemporalResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
TemporalResnetBlock(512, 512, eps=1e-6),
|
||||
UpSampler(512),
|
||||
# UpDecoderBlock
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
TemporalResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
TemporalResnetBlock(512, 512, eps=1e-6),
|
||||
ResnetBlock(512, 512, eps=1e-6),
|
||||
TemporalResnetBlock(512, 512, eps=1e-6),
|
||||
UpSampler(512),
|
||||
# UpDecoderBlock
|
||||
ResnetBlock(512, 256, eps=1e-6),
|
||||
TemporalResnetBlock(256, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
TemporalResnetBlock(256, 256, eps=1e-6),
|
||||
ResnetBlock(256, 256, eps=1e-6),
|
||||
TemporalResnetBlock(256, 256, eps=1e-6),
|
||||
UpSampler(256),
|
||||
# UpDecoderBlock
|
||||
ResnetBlock(256, 128, eps=1e-6),
|
||||
TemporalResnetBlock(128, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
TemporalResnetBlock(128, 128, eps=1e-6),
|
||||
ResnetBlock(128, 128, eps=1e-6),
|
||||
TemporalResnetBlock(128, 128, eps=1e-6),
|
||||
])
|
||||
|
||||
self.conv_norm_out = torch.nn.GroupNorm(num_channels=128, num_groups=32, eps=1e-5)
|
||||
self.conv_act = torch.nn.SiLU()
|
||||
self.conv_out = torch.nn.Conv2d(128, 3, kernel_size=3, padding=1)
|
||||
self.time_conv_out = torch.nn.Conv3d(3, 3, kernel_size=(3, 1, 1), padding=(1, 0, 0))
|
||||
|
||||
|
||||
def forward(self, sample):
|
||||
# 1. pre-process
|
||||
hidden_states = rearrange(sample, "C T H W -> T C H W")
|
||||
hidden_states = hidden_states / self.scaling_factor
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
time_emb, text_emb, res_stack = None, None, None
|
||||
|
||||
# 2. blocks
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(hidden_states, time_emb, text_emb, res_stack)
|
||||
|
||||
# 3. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = rearrange(hidden_states, "T C H W -> C T H W")
|
||||
hidden_states = self.time_conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def build_mask(self, data, is_bound):
|
||||
_, T, H, W = data.shape
|
||||
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
|
||||
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
|
||||
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
|
||||
border_width = (T + H + W) // 6
|
||||
pad = torch.ones_like(t) * border_width
|
||||
mask = torch.stack([
|
||||
pad if is_bound[0] else t + 1,
|
||||
pad if is_bound[1] else T - t,
|
||||
pad if is_bound[2] else h + 1,
|
||||
pad if is_bound[3] else H - h,
|
||||
pad if is_bound[4] else w + 1,
|
||||
pad if is_bound[5] else W - w
|
||||
]).min(dim=0).values
|
||||
mask = mask.clip(1, border_width)
|
||||
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
|
||||
mask = rearrange(mask, "T H W -> 1 T H W")
|
||||
return mask
|
||||
|
||||
|
||||
def decode_video(
|
||||
self, sample,
|
||||
batch_time=8, batch_height=128, batch_width=128,
|
||||
stride_time=4, stride_height=32, stride_width=32,
|
||||
progress_bar=lambda x:x
|
||||
):
|
||||
sample = sample.permute(1, 0, 2, 3)
|
||||
data_device = sample.device
|
||||
computation_device = self.conv_in.weight.device
|
||||
torch_dtype = sample.dtype
|
||||
_, T, H, W = sample.shape
|
||||
|
||||
weight = torch.zeros((1, T, H*8, W*8), dtype=torch_dtype, device=data_device)
|
||||
values = torch.zeros((3, T, H*8, W*8), dtype=torch_dtype, device=data_device)
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for t in range(0, T, stride_time):
|
||||
for h in range(0, H, stride_height):
|
||||
for w in range(0, W, stride_width):
|
||||
if (t-stride_time >= 0 and t-stride_time+batch_time >= T)\
|
||||
or (h-stride_height >= 0 and h-stride_height+batch_height >= H)\
|
||||
or (w-stride_width >= 0 and w-stride_width+batch_width >= W):
|
||||
continue
|
||||
tasks.append((t, t+batch_time, h, h+batch_height, w, w+batch_width))
|
||||
|
||||
# Run
|
||||
for tl, tr, hl, hr, wl, wr in progress_bar(tasks):
|
||||
sample_batch = sample[:, tl:tr, hl:hr, wl:wr].to(computation_device)
|
||||
sample_batch = self.forward(sample_batch).to(data_device)
|
||||
mask = self.build_mask(sample_batch, is_bound=(tl==0, tr>=T, hl==0, hr>=H, wl==0, wr>=W))
|
||||
values[:, tl:tr, hl*8:hr*8, wl*8:wr*8] += sample_batch * mask
|
||||
weight[:, tl:tr, hl*8:hr*8, wl*8:wr*8] += mask
|
||||
values /= weight
|
||||
return values
|
||||
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SVDVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
class SVDVAEDecoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
static_rename_dict = {
|
||||
"decoder.conv_in": "conv_in",
|
||||
"decoder.mid_block.attentions.0.group_norm": "blocks.2.norm",
|
||||
"decoder.mid_block.attentions.0.to_q": "blocks.2.transformer_blocks.0.to_q",
|
||||
"decoder.mid_block.attentions.0.to_k": "blocks.2.transformer_blocks.0.to_k",
|
||||
"decoder.mid_block.attentions.0.to_v": "blocks.2.transformer_blocks.0.to_v",
|
||||
"decoder.mid_block.attentions.0.to_out.0": "blocks.2.transformer_blocks.0.to_out",
|
||||
"decoder.up_blocks.0.upsamplers.0.conv": "blocks.11.conv",
|
||||
"decoder.up_blocks.1.upsamplers.0.conv": "blocks.18.conv",
|
||||
"decoder.up_blocks.2.upsamplers.0.conv": "blocks.25.conv",
|
||||
"decoder.conv_norm_out": "conv_norm_out",
|
||||
"decoder.conv_out": "conv_out",
|
||||
"decoder.time_conv_out": "time_conv_out"
|
||||
}
|
||||
prefix_rename_dict = {
|
||||
"decoder.mid_block.resnets.0.spatial_res_block": "blocks.0",
|
||||
"decoder.mid_block.resnets.0.temporal_res_block": "blocks.1",
|
||||
"decoder.mid_block.resnets.0.time_mixer": "blocks.1",
|
||||
"decoder.mid_block.resnets.1.spatial_res_block": "blocks.3",
|
||||
"decoder.mid_block.resnets.1.temporal_res_block": "blocks.4",
|
||||
"decoder.mid_block.resnets.1.time_mixer": "blocks.4",
|
||||
|
||||
"decoder.up_blocks.0.resnets.0.spatial_res_block": "blocks.5",
|
||||
"decoder.up_blocks.0.resnets.0.temporal_res_block": "blocks.6",
|
||||
"decoder.up_blocks.0.resnets.0.time_mixer": "blocks.6",
|
||||
"decoder.up_blocks.0.resnets.1.spatial_res_block": "blocks.7",
|
||||
"decoder.up_blocks.0.resnets.1.temporal_res_block": "blocks.8",
|
||||
"decoder.up_blocks.0.resnets.1.time_mixer": "blocks.8",
|
||||
"decoder.up_blocks.0.resnets.2.spatial_res_block": "blocks.9",
|
||||
"decoder.up_blocks.0.resnets.2.temporal_res_block": "blocks.10",
|
||||
"decoder.up_blocks.0.resnets.2.time_mixer": "blocks.10",
|
||||
|
||||
"decoder.up_blocks.1.resnets.0.spatial_res_block": "blocks.12",
|
||||
"decoder.up_blocks.1.resnets.0.temporal_res_block": "blocks.13",
|
||||
"decoder.up_blocks.1.resnets.0.time_mixer": "blocks.13",
|
||||
"decoder.up_blocks.1.resnets.1.spatial_res_block": "blocks.14",
|
||||
"decoder.up_blocks.1.resnets.1.temporal_res_block": "blocks.15",
|
||||
"decoder.up_blocks.1.resnets.1.time_mixer": "blocks.15",
|
||||
"decoder.up_blocks.1.resnets.2.spatial_res_block": "blocks.16",
|
||||
"decoder.up_blocks.1.resnets.2.temporal_res_block": "blocks.17",
|
||||
"decoder.up_blocks.1.resnets.2.time_mixer": "blocks.17",
|
||||
|
||||
"decoder.up_blocks.2.resnets.0.spatial_res_block": "blocks.19",
|
||||
"decoder.up_blocks.2.resnets.0.temporal_res_block": "blocks.20",
|
||||
"decoder.up_blocks.2.resnets.0.time_mixer": "blocks.20",
|
||||
"decoder.up_blocks.2.resnets.1.spatial_res_block": "blocks.21",
|
||||
"decoder.up_blocks.2.resnets.1.temporal_res_block": "blocks.22",
|
||||
"decoder.up_blocks.2.resnets.1.time_mixer": "blocks.22",
|
||||
"decoder.up_blocks.2.resnets.2.spatial_res_block": "blocks.23",
|
||||
"decoder.up_blocks.2.resnets.2.temporal_res_block": "blocks.24",
|
||||
"decoder.up_blocks.2.resnets.2.time_mixer": "blocks.24",
|
||||
|
||||
"decoder.up_blocks.3.resnets.0.spatial_res_block": "blocks.26",
|
||||
"decoder.up_blocks.3.resnets.0.temporal_res_block": "blocks.27",
|
||||
"decoder.up_blocks.3.resnets.0.time_mixer": "blocks.27",
|
||||
"decoder.up_blocks.3.resnets.1.spatial_res_block": "blocks.28",
|
||||
"decoder.up_blocks.3.resnets.1.temporal_res_block": "blocks.29",
|
||||
"decoder.up_blocks.3.resnets.1.time_mixer": "blocks.29",
|
||||
"decoder.up_blocks.3.resnets.2.spatial_res_block": "blocks.30",
|
||||
"decoder.up_blocks.3.resnets.2.temporal_res_block": "blocks.31",
|
||||
"decoder.up_blocks.3.resnets.2.time_mixer": "blocks.31",
|
||||
}
|
||||
suffix_rename_dict = {
|
||||
"norm1.weight": "norm1.weight",
|
||||
"conv1.weight": "conv1.weight",
|
||||
"norm2.weight": "norm2.weight",
|
||||
"conv2.weight": "conv2.weight",
|
||||
"conv_shortcut.weight": "conv_shortcut.weight",
|
||||
"norm1.bias": "norm1.bias",
|
||||
"conv1.bias": "conv1.bias",
|
||||
"norm2.bias": "norm2.bias",
|
||||
"conv2.bias": "conv2.bias",
|
||||
"conv_shortcut.bias": "conv_shortcut.bias",
|
||||
"mix_factor": "mix_factor",
|
||||
}
|
||||
|
||||
state_dict_ = {}
|
||||
for name in static_rename_dict:
|
||||
state_dict_[static_rename_dict[name] + ".weight"] = state_dict[name + ".weight"]
|
||||
state_dict_[static_rename_dict[name] + ".bias"] = state_dict[name + ".bias"]
|
||||
for prefix_name in prefix_rename_dict:
|
||||
for suffix_name in suffix_rename_dict:
|
||||
name = prefix_name + "." + suffix_name
|
||||
name_ = prefix_rename_dict[prefix_name] + "." + suffix_rename_dict[suffix_name]
|
||||
if name in state_dict:
|
||||
state_dict_[name_] = state_dict[name]
|
||||
|
||||
return state_dict_
|
||||
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"first_stage_model.decoder.conv_in.bias": "conv_in.bias",
|
||||
"first_stage_model.decoder.conv_in.weight": "conv_in.weight",
|
||||
"first_stage_model.decoder.conv_out.bias": "conv_out.bias",
|
||||
"first_stage_model.decoder.conv_out.time_mix_conv.bias": "time_conv_out.bias",
|
||||
"first_stage_model.decoder.conv_out.time_mix_conv.weight": "time_conv_out.weight",
|
||||
"first_stage_model.decoder.conv_out.weight": "conv_out.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.k.bias": "blocks.2.transformer_blocks.0.to_k.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.k.weight": "blocks.2.transformer_blocks.0.to_k.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.norm.bias": "blocks.2.norm.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.norm.weight": "blocks.2.norm.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.proj_out.bias": "blocks.2.transformer_blocks.0.to_out.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.proj_out.weight": "blocks.2.transformer_blocks.0.to_out.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.q.bias": "blocks.2.transformer_blocks.0.to_q.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.q.weight": "blocks.2.transformer_blocks.0.to_q.weight",
|
||||
"first_stage_model.decoder.mid.attn_1.v.bias": "blocks.2.transformer_blocks.0.to_v.bias",
|
||||
"first_stage_model.decoder.mid.attn_1.v.weight": "blocks.2.transformer_blocks.0.to_v.weight",
|
||||
"first_stage_model.decoder.mid.block_1.conv1.bias": "blocks.0.conv1.bias",
|
||||
"first_stage_model.decoder.mid.block_1.conv1.weight": "blocks.0.conv1.weight",
|
||||
"first_stage_model.decoder.mid.block_1.conv2.bias": "blocks.0.conv2.bias",
|
||||
"first_stage_model.decoder.mid.block_1.conv2.weight": "blocks.0.conv2.weight",
|
||||
"first_stage_model.decoder.mid.block_1.mix_factor": "blocks.1.mix_factor",
|
||||
"first_stage_model.decoder.mid.block_1.norm1.bias": "blocks.0.norm1.bias",
|
||||
"first_stage_model.decoder.mid.block_1.norm1.weight": "blocks.0.norm1.weight",
|
||||
"first_stage_model.decoder.mid.block_1.norm2.bias": "blocks.0.norm2.bias",
|
||||
"first_stage_model.decoder.mid.block_1.norm2.weight": "blocks.0.norm2.weight",
|
||||
"first_stage_model.decoder.mid.block_1.time_stack.in_layers.0.bias": "blocks.1.norm1.bias",
|
||||
"first_stage_model.decoder.mid.block_1.time_stack.in_layers.0.weight": "blocks.1.norm1.weight",
|
||||
"first_stage_model.decoder.mid.block_1.time_stack.in_layers.2.bias": "blocks.1.conv1.bias",
|
||||
"first_stage_model.decoder.mid.block_1.time_stack.in_layers.2.weight": "blocks.1.conv1.weight",
|
||||
"first_stage_model.decoder.mid.block_1.time_stack.out_layers.0.bias": "blocks.1.norm2.bias",
|
||||
"first_stage_model.decoder.mid.block_1.time_stack.out_layers.0.weight": "blocks.1.norm2.weight",
|
||||
"first_stage_model.decoder.mid.block_1.time_stack.out_layers.3.bias": "blocks.1.conv2.bias",
|
||||
"first_stage_model.decoder.mid.block_1.time_stack.out_layers.3.weight": "blocks.1.conv2.weight",
|
||||
"first_stage_model.decoder.mid.block_2.conv1.bias": "blocks.3.conv1.bias",
|
||||
"first_stage_model.decoder.mid.block_2.conv1.weight": "blocks.3.conv1.weight",
|
||||
"first_stage_model.decoder.mid.block_2.conv2.bias": "blocks.3.conv2.bias",
|
||||
"first_stage_model.decoder.mid.block_2.conv2.weight": "blocks.3.conv2.weight",
|
||||
"first_stage_model.decoder.mid.block_2.mix_factor": "blocks.4.mix_factor",
|
||||
"first_stage_model.decoder.mid.block_2.norm1.bias": "blocks.3.norm1.bias",
|
||||
"first_stage_model.decoder.mid.block_2.norm1.weight": "blocks.3.norm1.weight",
|
||||
"first_stage_model.decoder.mid.block_2.norm2.bias": "blocks.3.norm2.bias",
|
||||
"first_stage_model.decoder.mid.block_2.norm2.weight": "blocks.3.norm2.weight",
|
||||
"first_stage_model.decoder.mid.block_2.time_stack.in_layers.0.bias": "blocks.4.norm1.bias",
|
||||
"first_stage_model.decoder.mid.block_2.time_stack.in_layers.0.weight": "blocks.4.norm1.weight",
|
||||
"first_stage_model.decoder.mid.block_2.time_stack.in_layers.2.bias": "blocks.4.conv1.bias",
|
||||
"first_stage_model.decoder.mid.block_2.time_stack.in_layers.2.weight": "blocks.4.conv1.weight",
|
||||
"first_stage_model.decoder.mid.block_2.time_stack.out_layers.0.bias": "blocks.4.norm2.bias",
|
||||
"first_stage_model.decoder.mid.block_2.time_stack.out_layers.0.weight": "blocks.4.norm2.weight",
|
||||
"first_stage_model.decoder.mid.block_2.time_stack.out_layers.3.bias": "blocks.4.conv2.bias",
|
||||
"first_stage_model.decoder.mid.block_2.time_stack.out_layers.3.weight": "blocks.4.conv2.weight",
|
||||
"first_stage_model.decoder.norm_out.bias": "conv_norm_out.bias",
|
||||
"first_stage_model.decoder.norm_out.weight": "conv_norm_out.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.conv1.bias": "blocks.26.conv1.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.conv1.weight": "blocks.26.conv1.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.conv2.bias": "blocks.26.conv2.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.conv2.weight": "blocks.26.conv2.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.mix_factor": "blocks.27.mix_factor",
|
||||
"first_stage_model.decoder.up.0.block.0.nin_shortcut.bias": "blocks.26.conv_shortcut.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.nin_shortcut.weight": "blocks.26.conv_shortcut.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.norm1.bias": "blocks.26.norm1.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.norm1.weight": "blocks.26.norm1.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.norm2.bias": "blocks.26.norm2.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.norm2.weight": "blocks.26.norm2.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.time_stack.in_layers.0.bias": "blocks.27.norm1.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.time_stack.in_layers.0.weight": "blocks.27.norm1.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.time_stack.in_layers.2.bias": "blocks.27.conv1.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.time_stack.in_layers.2.weight": "blocks.27.conv1.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.time_stack.out_layers.0.bias": "blocks.27.norm2.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.time_stack.out_layers.0.weight": "blocks.27.norm2.weight",
|
||||
"first_stage_model.decoder.up.0.block.0.time_stack.out_layers.3.bias": "blocks.27.conv2.bias",
|
||||
"first_stage_model.decoder.up.0.block.0.time_stack.out_layers.3.weight": "blocks.27.conv2.weight",
|
||||
"first_stage_model.decoder.up.0.block.1.conv1.bias": "blocks.28.conv1.bias",
|
||||
"first_stage_model.decoder.up.0.block.1.conv1.weight": "blocks.28.conv1.weight",
|
||||
"first_stage_model.decoder.up.0.block.1.conv2.bias": "blocks.28.conv2.bias",
|
||||
"first_stage_model.decoder.up.0.block.1.conv2.weight": "blocks.28.conv2.weight",
|
||||
"first_stage_model.decoder.up.0.block.1.mix_factor": "blocks.29.mix_factor",
|
||||
"first_stage_model.decoder.up.0.block.1.norm1.bias": "blocks.28.norm1.bias",
|
||||
"first_stage_model.decoder.up.0.block.1.norm1.weight": "blocks.28.norm1.weight",
|
||||
"first_stage_model.decoder.up.0.block.1.norm2.bias": "blocks.28.norm2.bias",
|
||||
"first_stage_model.decoder.up.0.block.1.norm2.weight": "blocks.28.norm2.weight",
|
||||
"first_stage_model.decoder.up.0.block.1.time_stack.in_layers.0.bias": "blocks.29.norm1.bias",
|
||||
"first_stage_model.decoder.up.0.block.1.time_stack.in_layers.0.weight": "blocks.29.norm1.weight",
|
||||
"first_stage_model.decoder.up.0.block.1.time_stack.in_layers.2.bias": "blocks.29.conv1.bias",
|
||||
"first_stage_model.decoder.up.0.block.1.time_stack.in_layers.2.weight": "blocks.29.conv1.weight",
|
||||
"first_stage_model.decoder.up.0.block.1.time_stack.out_layers.0.bias": "blocks.29.norm2.bias",
|
||||
"first_stage_model.decoder.up.0.block.1.time_stack.out_layers.0.weight": "blocks.29.norm2.weight",
|
||||
"first_stage_model.decoder.up.0.block.1.time_stack.out_layers.3.bias": "blocks.29.conv2.bias",
|
||||
"first_stage_model.decoder.up.0.block.1.time_stack.out_layers.3.weight": "blocks.29.conv2.weight",
|
||||
"first_stage_model.decoder.up.0.block.2.conv1.bias": "blocks.30.conv1.bias",
|
||||
"first_stage_model.decoder.up.0.block.2.conv1.weight": "blocks.30.conv1.weight",
|
||||
"first_stage_model.decoder.up.0.block.2.conv2.bias": "blocks.30.conv2.bias",
|
||||
"first_stage_model.decoder.up.0.block.2.conv2.weight": "blocks.30.conv2.weight",
|
||||
"first_stage_model.decoder.up.0.block.2.mix_factor": "blocks.31.mix_factor",
|
||||
"first_stage_model.decoder.up.0.block.2.norm1.bias": "blocks.30.norm1.bias",
|
||||
"first_stage_model.decoder.up.0.block.2.norm1.weight": "blocks.30.norm1.weight",
|
||||
"first_stage_model.decoder.up.0.block.2.norm2.bias": "blocks.30.norm2.bias",
|
||||
"first_stage_model.decoder.up.0.block.2.norm2.weight": "blocks.30.norm2.weight",
|
||||
"first_stage_model.decoder.up.0.block.2.time_stack.in_layers.0.bias": "blocks.31.norm1.bias",
|
||||
"first_stage_model.decoder.up.0.block.2.time_stack.in_layers.0.weight": "blocks.31.norm1.weight",
|
||||
"first_stage_model.decoder.up.0.block.2.time_stack.in_layers.2.bias": "blocks.31.conv1.bias",
|
||||
"first_stage_model.decoder.up.0.block.2.time_stack.in_layers.2.weight": "blocks.31.conv1.weight",
|
||||
"first_stage_model.decoder.up.0.block.2.time_stack.out_layers.0.bias": "blocks.31.norm2.bias",
|
||||
"first_stage_model.decoder.up.0.block.2.time_stack.out_layers.0.weight": "blocks.31.norm2.weight",
|
||||
"first_stage_model.decoder.up.0.block.2.time_stack.out_layers.3.bias": "blocks.31.conv2.bias",
|
||||
"first_stage_model.decoder.up.0.block.2.time_stack.out_layers.3.weight": "blocks.31.conv2.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.conv1.bias": "blocks.19.conv1.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.conv1.weight": "blocks.19.conv1.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.conv2.bias": "blocks.19.conv2.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.conv2.weight": "blocks.19.conv2.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.mix_factor": "blocks.20.mix_factor",
|
||||
"first_stage_model.decoder.up.1.block.0.nin_shortcut.bias": "blocks.19.conv_shortcut.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.nin_shortcut.weight": "blocks.19.conv_shortcut.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.norm1.bias": "blocks.19.norm1.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.norm1.weight": "blocks.19.norm1.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.norm2.bias": "blocks.19.norm2.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.norm2.weight": "blocks.19.norm2.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.time_stack.in_layers.0.bias": "blocks.20.norm1.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.time_stack.in_layers.0.weight": "blocks.20.norm1.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.time_stack.in_layers.2.bias": "blocks.20.conv1.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.time_stack.in_layers.2.weight": "blocks.20.conv1.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.time_stack.out_layers.0.bias": "blocks.20.norm2.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.time_stack.out_layers.0.weight": "blocks.20.norm2.weight",
|
||||
"first_stage_model.decoder.up.1.block.0.time_stack.out_layers.3.bias": "blocks.20.conv2.bias",
|
||||
"first_stage_model.decoder.up.1.block.0.time_stack.out_layers.3.weight": "blocks.20.conv2.weight",
|
||||
"first_stage_model.decoder.up.1.block.1.conv1.bias": "blocks.21.conv1.bias",
|
||||
"first_stage_model.decoder.up.1.block.1.conv1.weight": "blocks.21.conv1.weight",
|
||||
"first_stage_model.decoder.up.1.block.1.conv2.bias": "blocks.21.conv2.bias",
|
||||
"first_stage_model.decoder.up.1.block.1.conv2.weight": "blocks.21.conv2.weight",
|
||||
"first_stage_model.decoder.up.1.block.1.mix_factor": "blocks.22.mix_factor",
|
||||
"first_stage_model.decoder.up.1.block.1.norm1.bias": "blocks.21.norm1.bias",
|
||||
"first_stage_model.decoder.up.1.block.1.norm1.weight": "blocks.21.norm1.weight",
|
||||
"first_stage_model.decoder.up.1.block.1.norm2.bias": "blocks.21.norm2.bias",
|
||||
"first_stage_model.decoder.up.1.block.1.norm2.weight": "blocks.21.norm2.weight",
|
||||
"first_stage_model.decoder.up.1.block.1.time_stack.in_layers.0.bias": "blocks.22.norm1.bias",
|
||||
"first_stage_model.decoder.up.1.block.1.time_stack.in_layers.0.weight": "blocks.22.norm1.weight",
|
||||
"first_stage_model.decoder.up.1.block.1.time_stack.in_layers.2.bias": "blocks.22.conv1.bias",
|
||||
"first_stage_model.decoder.up.1.block.1.time_stack.in_layers.2.weight": "blocks.22.conv1.weight",
|
||||
"first_stage_model.decoder.up.1.block.1.time_stack.out_layers.0.bias": "blocks.22.norm2.bias",
|
||||
"first_stage_model.decoder.up.1.block.1.time_stack.out_layers.0.weight": "blocks.22.norm2.weight",
|
||||
"first_stage_model.decoder.up.1.block.1.time_stack.out_layers.3.bias": "blocks.22.conv2.bias",
|
||||
"first_stage_model.decoder.up.1.block.1.time_stack.out_layers.3.weight": "blocks.22.conv2.weight",
|
||||
"first_stage_model.decoder.up.1.block.2.conv1.bias": "blocks.23.conv1.bias",
|
||||
"first_stage_model.decoder.up.1.block.2.conv1.weight": "blocks.23.conv1.weight",
|
||||
"first_stage_model.decoder.up.1.block.2.conv2.bias": "blocks.23.conv2.bias",
|
||||
"first_stage_model.decoder.up.1.block.2.conv2.weight": "blocks.23.conv2.weight",
|
||||
"first_stage_model.decoder.up.1.block.2.mix_factor": "blocks.24.mix_factor",
|
||||
"first_stage_model.decoder.up.1.block.2.norm1.bias": "blocks.23.norm1.bias",
|
||||
"first_stage_model.decoder.up.1.block.2.norm1.weight": "blocks.23.norm1.weight",
|
||||
"first_stage_model.decoder.up.1.block.2.norm2.bias": "blocks.23.norm2.bias",
|
||||
"first_stage_model.decoder.up.1.block.2.norm2.weight": "blocks.23.norm2.weight",
|
||||
"first_stage_model.decoder.up.1.block.2.time_stack.in_layers.0.bias": "blocks.24.norm1.bias",
|
||||
"first_stage_model.decoder.up.1.block.2.time_stack.in_layers.0.weight": "blocks.24.norm1.weight",
|
||||
"first_stage_model.decoder.up.1.block.2.time_stack.in_layers.2.bias": "blocks.24.conv1.bias",
|
||||
"first_stage_model.decoder.up.1.block.2.time_stack.in_layers.2.weight": "blocks.24.conv1.weight",
|
||||
"first_stage_model.decoder.up.1.block.2.time_stack.out_layers.0.bias": "blocks.24.norm2.bias",
|
||||
"first_stage_model.decoder.up.1.block.2.time_stack.out_layers.0.weight": "blocks.24.norm2.weight",
|
||||
"first_stage_model.decoder.up.1.block.2.time_stack.out_layers.3.bias": "blocks.24.conv2.bias",
|
||||
"first_stage_model.decoder.up.1.block.2.time_stack.out_layers.3.weight": "blocks.24.conv2.weight",
|
||||
"first_stage_model.decoder.up.1.upsample.conv.bias": "blocks.25.conv.bias",
|
||||
"first_stage_model.decoder.up.1.upsample.conv.weight": "blocks.25.conv.weight",
|
||||
"first_stage_model.decoder.up.2.block.0.conv1.bias": "blocks.12.conv1.bias",
|
||||
"first_stage_model.decoder.up.2.block.0.conv1.weight": "blocks.12.conv1.weight",
|
||||
"first_stage_model.decoder.up.2.block.0.conv2.bias": "blocks.12.conv2.bias",
|
||||
"first_stage_model.decoder.up.2.block.0.conv2.weight": "blocks.12.conv2.weight",
|
||||
"first_stage_model.decoder.up.2.block.0.mix_factor": "blocks.13.mix_factor",
|
||||
"first_stage_model.decoder.up.2.block.0.norm1.bias": "blocks.12.norm1.bias",
|
||||
"first_stage_model.decoder.up.2.block.0.norm1.weight": "blocks.12.norm1.weight",
|
||||
"first_stage_model.decoder.up.2.block.0.norm2.bias": "blocks.12.norm2.bias",
|
||||
"first_stage_model.decoder.up.2.block.0.norm2.weight": "blocks.12.norm2.weight",
|
||||
"first_stage_model.decoder.up.2.block.0.time_stack.in_layers.0.bias": "blocks.13.norm1.bias",
|
||||
"first_stage_model.decoder.up.2.block.0.time_stack.in_layers.0.weight": "blocks.13.norm1.weight",
|
||||
"first_stage_model.decoder.up.2.block.0.time_stack.in_layers.2.bias": "blocks.13.conv1.bias",
|
||||
"first_stage_model.decoder.up.2.block.0.time_stack.in_layers.2.weight": "blocks.13.conv1.weight",
|
||||
"first_stage_model.decoder.up.2.block.0.time_stack.out_layers.0.bias": "blocks.13.norm2.bias",
|
||||
"first_stage_model.decoder.up.2.block.0.time_stack.out_layers.0.weight": "blocks.13.norm2.weight",
|
||||
"first_stage_model.decoder.up.2.block.0.time_stack.out_layers.3.bias": "blocks.13.conv2.bias",
|
||||
"first_stage_model.decoder.up.2.block.0.time_stack.out_layers.3.weight": "blocks.13.conv2.weight",
|
||||
"first_stage_model.decoder.up.2.block.1.conv1.bias": "blocks.14.conv1.bias",
|
||||
"first_stage_model.decoder.up.2.block.1.conv1.weight": "blocks.14.conv1.weight",
|
||||
"first_stage_model.decoder.up.2.block.1.conv2.bias": "blocks.14.conv2.bias",
|
||||
"first_stage_model.decoder.up.2.block.1.conv2.weight": "blocks.14.conv2.weight",
|
||||
"first_stage_model.decoder.up.2.block.1.mix_factor": "blocks.15.mix_factor",
|
||||
"first_stage_model.decoder.up.2.block.1.norm1.bias": "blocks.14.norm1.bias",
|
||||
"first_stage_model.decoder.up.2.block.1.norm1.weight": "blocks.14.norm1.weight",
|
||||
"first_stage_model.decoder.up.2.block.1.norm2.bias": "blocks.14.norm2.bias",
|
||||
"first_stage_model.decoder.up.2.block.1.norm2.weight": "blocks.14.norm2.weight",
|
||||
"first_stage_model.decoder.up.2.block.1.time_stack.in_layers.0.bias": "blocks.15.norm1.bias",
|
||||
"first_stage_model.decoder.up.2.block.1.time_stack.in_layers.0.weight": "blocks.15.norm1.weight",
|
||||
"first_stage_model.decoder.up.2.block.1.time_stack.in_layers.2.bias": "blocks.15.conv1.bias",
|
||||
"first_stage_model.decoder.up.2.block.1.time_stack.in_layers.2.weight": "blocks.15.conv1.weight",
|
||||
"first_stage_model.decoder.up.2.block.1.time_stack.out_layers.0.bias": "blocks.15.norm2.bias",
|
||||
"first_stage_model.decoder.up.2.block.1.time_stack.out_layers.0.weight": "blocks.15.norm2.weight",
|
||||
"first_stage_model.decoder.up.2.block.1.time_stack.out_layers.3.bias": "blocks.15.conv2.bias",
|
||||
"first_stage_model.decoder.up.2.block.1.time_stack.out_layers.3.weight": "blocks.15.conv2.weight",
|
||||
"first_stage_model.decoder.up.2.block.2.conv1.bias": "blocks.16.conv1.bias",
|
||||
"first_stage_model.decoder.up.2.block.2.conv1.weight": "blocks.16.conv1.weight",
|
||||
"first_stage_model.decoder.up.2.block.2.conv2.bias": "blocks.16.conv2.bias",
|
||||
"first_stage_model.decoder.up.2.block.2.conv2.weight": "blocks.16.conv2.weight",
|
||||
"first_stage_model.decoder.up.2.block.2.mix_factor": "blocks.17.mix_factor",
|
||||
"first_stage_model.decoder.up.2.block.2.norm1.bias": "blocks.16.norm1.bias",
|
||||
"first_stage_model.decoder.up.2.block.2.norm1.weight": "blocks.16.norm1.weight",
|
||||
"first_stage_model.decoder.up.2.block.2.norm2.bias": "blocks.16.norm2.bias",
|
||||
"first_stage_model.decoder.up.2.block.2.norm2.weight": "blocks.16.norm2.weight",
|
||||
"first_stage_model.decoder.up.2.block.2.time_stack.in_layers.0.bias": "blocks.17.norm1.bias",
|
||||
"first_stage_model.decoder.up.2.block.2.time_stack.in_layers.0.weight": "blocks.17.norm1.weight",
|
||||
"first_stage_model.decoder.up.2.block.2.time_stack.in_layers.2.bias": "blocks.17.conv1.bias",
|
||||
"first_stage_model.decoder.up.2.block.2.time_stack.in_layers.2.weight": "blocks.17.conv1.weight",
|
||||
"first_stage_model.decoder.up.2.block.2.time_stack.out_layers.0.bias": "blocks.17.norm2.bias",
|
||||
"first_stage_model.decoder.up.2.block.2.time_stack.out_layers.0.weight": "blocks.17.norm2.weight",
|
||||
"first_stage_model.decoder.up.2.block.2.time_stack.out_layers.3.bias": "blocks.17.conv2.bias",
|
||||
"first_stage_model.decoder.up.2.block.2.time_stack.out_layers.3.weight": "blocks.17.conv2.weight",
|
||||
"first_stage_model.decoder.up.2.upsample.conv.bias": "blocks.18.conv.bias",
|
||||
"first_stage_model.decoder.up.2.upsample.conv.weight": "blocks.18.conv.weight",
|
||||
"first_stage_model.decoder.up.3.block.0.conv1.bias": "blocks.5.conv1.bias",
|
||||
"first_stage_model.decoder.up.3.block.0.conv1.weight": "blocks.5.conv1.weight",
|
||||
"first_stage_model.decoder.up.3.block.0.conv2.bias": "blocks.5.conv2.bias",
|
||||
"first_stage_model.decoder.up.3.block.0.conv2.weight": "blocks.5.conv2.weight",
|
||||
"first_stage_model.decoder.up.3.block.0.mix_factor": "blocks.6.mix_factor",
|
||||
"first_stage_model.decoder.up.3.block.0.norm1.bias": "blocks.5.norm1.bias",
|
||||
"first_stage_model.decoder.up.3.block.0.norm1.weight": "blocks.5.norm1.weight",
|
||||
"first_stage_model.decoder.up.3.block.0.norm2.bias": "blocks.5.norm2.bias",
|
||||
"first_stage_model.decoder.up.3.block.0.norm2.weight": "blocks.5.norm2.weight",
|
||||
"first_stage_model.decoder.up.3.block.0.time_stack.in_layers.0.bias": "blocks.6.norm1.bias",
|
||||
"first_stage_model.decoder.up.3.block.0.time_stack.in_layers.0.weight": "blocks.6.norm1.weight",
|
||||
"first_stage_model.decoder.up.3.block.0.time_stack.in_layers.2.bias": "blocks.6.conv1.bias",
|
||||
"first_stage_model.decoder.up.3.block.0.time_stack.in_layers.2.weight": "blocks.6.conv1.weight",
|
||||
"first_stage_model.decoder.up.3.block.0.time_stack.out_layers.0.bias": "blocks.6.norm2.bias",
|
||||
"first_stage_model.decoder.up.3.block.0.time_stack.out_layers.0.weight": "blocks.6.norm2.weight",
|
||||
"first_stage_model.decoder.up.3.block.0.time_stack.out_layers.3.bias": "blocks.6.conv2.bias",
|
||||
"first_stage_model.decoder.up.3.block.0.time_stack.out_layers.3.weight": "blocks.6.conv2.weight",
|
||||
"first_stage_model.decoder.up.3.block.1.conv1.bias": "blocks.7.conv1.bias",
|
||||
"first_stage_model.decoder.up.3.block.1.conv1.weight": "blocks.7.conv1.weight",
|
||||
"first_stage_model.decoder.up.3.block.1.conv2.bias": "blocks.7.conv2.bias",
|
||||
"first_stage_model.decoder.up.3.block.1.conv2.weight": "blocks.7.conv2.weight",
|
||||
"first_stage_model.decoder.up.3.block.1.mix_factor": "blocks.8.mix_factor",
|
||||
"first_stage_model.decoder.up.3.block.1.norm1.bias": "blocks.7.norm1.bias",
|
||||
"first_stage_model.decoder.up.3.block.1.norm1.weight": "blocks.7.norm1.weight",
|
||||
"first_stage_model.decoder.up.3.block.1.norm2.bias": "blocks.7.norm2.bias",
|
||||
"first_stage_model.decoder.up.3.block.1.norm2.weight": "blocks.7.norm2.weight",
|
||||
"first_stage_model.decoder.up.3.block.1.time_stack.in_layers.0.bias": "blocks.8.norm1.bias",
|
||||
"first_stage_model.decoder.up.3.block.1.time_stack.in_layers.0.weight": "blocks.8.norm1.weight",
|
||||
"first_stage_model.decoder.up.3.block.1.time_stack.in_layers.2.bias": "blocks.8.conv1.bias",
|
||||
"first_stage_model.decoder.up.3.block.1.time_stack.in_layers.2.weight": "blocks.8.conv1.weight",
|
||||
"first_stage_model.decoder.up.3.block.1.time_stack.out_layers.0.bias": "blocks.8.norm2.bias",
|
||||
"first_stage_model.decoder.up.3.block.1.time_stack.out_layers.0.weight": "blocks.8.norm2.weight",
|
||||
"first_stage_model.decoder.up.3.block.1.time_stack.out_layers.3.bias": "blocks.8.conv2.bias",
|
||||
"first_stage_model.decoder.up.3.block.1.time_stack.out_layers.3.weight": "blocks.8.conv2.weight",
|
||||
"first_stage_model.decoder.up.3.block.2.conv1.bias": "blocks.9.conv1.bias",
|
||||
"first_stage_model.decoder.up.3.block.2.conv1.weight": "blocks.9.conv1.weight",
|
||||
"first_stage_model.decoder.up.3.block.2.conv2.bias": "blocks.9.conv2.bias",
|
||||
"first_stage_model.decoder.up.3.block.2.conv2.weight": "blocks.9.conv2.weight",
|
||||
"first_stage_model.decoder.up.3.block.2.mix_factor": "blocks.10.mix_factor",
|
||||
"first_stage_model.decoder.up.3.block.2.norm1.bias": "blocks.9.norm1.bias",
|
||||
"first_stage_model.decoder.up.3.block.2.norm1.weight": "blocks.9.norm1.weight",
|
||||
"first_stage_model.decoder.up.3.block.2.norm2.bias": "blocks.9.norm2.bias",
|
||||
"first_stage_model.decoder.up.3.block.2.norm2.weight": "blocks.9.norm2.weight",
|
||||
"first_stage_model.decoder.up.3.block.2.time_stack.in_layers.0.bias": "blocks.10.norm1.bias",
|
||||
"first_stage_model.decoder.up.3.block.2.time_stack.in_layers.0.weight": "blocks.10.norm1.weight",
|
||||
"first_stage_model.decoder.up.3.block.2.time_stack.in_layers.2.bias": "blocks.10.conv1.bias",
|
||||
"first_stage_model.decoder.up.3.block.2.time_stack.in_layers.2.weight": "blocks.10.conv1.weight",
|
||||
"first_stage_model.decoder.up.3.block.2.time_stack.out_layers.0.bias": "blocks.10.norm2.bias",
|
||||
"first_stage_model.decoder.up.3.block.2.time_stack.out_layers.0.weight": "blocks.10.norm2.weight",
|
||||
"first_stage_model.decoder.up.3.block.2.time_stack.out_layers.3.bias": "blocks.10.conv2.bias",
|
||||
"first_stage_model.decoder.up.3.block.2.time_stack.out_layers.3.weight": "blocks.10.conv2.weight",
|
||||
"first_stage_model.decoder.up.3.upsample.conv.bias": "blocks.11.conv.bias",
|
||||
"first_stage_model.decoder.up.3.upsample.conv.weight": "blocks.11.conv.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if "blocks.2.transformer_blocks.0" in rename_dict[name]:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
@@ -1,139 +0,0 @@
|
||||
from .sd_vae_encoder import SDVAEEncoderStateDictConverter, SDVAEEncoder
|
||||
|
||||
|
||||
class SVDVAEEncoder(SDVAEEncoder):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.scaling_factor = 0.13025
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SVDVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
class SVDVAEEncoderStateDictConverter(SDVAEEncoderStateDictConverter):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return super().from_diffusers(state_dict)
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
rename_dict = {
|
||||
"conditioner.embedders.3.encoder.encoder.conv_in.bias": "conv_in.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.conv_in.weight": "conv_in.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.conv_out.bias": "conv_out.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.conv_out.weight": "conv_out.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.0.block.0.conv1.bias": "blocks.0.conv1.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.0.block.0.conv1.weight": "blocks.0.conv1.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.0.block.0.conv2.bias": "blocks.0.conv2.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.0.block.0.conv2.weight": "blocks.0.conv2.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.0.block.0.norm1.bias": "blocks.0.norm1.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.0.block.0.norm1.weight": "blocks.0.norm1.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.0.block.0.norm2.bias": "blocks.0.norm2.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.0.block.0.norm2.weight": "blocks.0.norm2.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.0.block.1.conv1.bias": "blocks.1.conv1.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.0.block.1.conv1.weight": "blocks.1.conv1.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.0.block.1.conv2.bias": "blocks.1.conv2.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.0.block.1.conv2.weight": "blocks.1.conv2.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.0.block.1.norm1.bias": "blocks.1.norm1.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.0.block.1.norm1.weight": "blocks.1.norm1.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.0.block.1.norm2.bias": "blocks.1.norm2.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.0.block.1.norm2.weight": "blocks.1.norm2.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.0.downsample.conv.bias": "blocks.2.conv.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.0.downsample.conv.weight": "blocks.2.conv.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.1.block.0.conv1.bias": "blocks.3.conv1.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.1.block.0.conv1.weight": "blocks.3.conv1.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.1.block.0.conv2.bias": "blocks.3.conv2.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.1.block.0.conv2.weight": "blocks.3.conv2.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.1.block.0.nin_shortcut.bias": "blocks.3.conv_shortcut.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.1.block.0.nin_shortcut.weight": "blocks.3.conv_shortcut.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.1.block.0.norm1.bias": "blocks.3.norm1.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.1.block.0.norm1.weight": "blocks.3.norm1.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.1.block.0.norm2.bias": "blocks.3.norm2.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.1.block.0.norm2.weight": "blocks.3.norm2.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.1.block.1.conv1.bias": "blocks.4.conv1.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.1.block.1.conv1.weight": "blocks.4.conv1.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.1.block.1.conv2.bias": "blocks.4.conv2.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.1.block.1.conv2.weight": "blocks.4.conv2.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.1.block.1.norm1.bias": "blocks.4.norm1.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.1.block.1.norm1.weight": "blocks.4.norm1.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.1.block.1.norm2.bias": "blocks.4.norm2.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.1.block.1.norm2.weight": "blocks.4.norm2.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.1.downsample.conv.bias": "blocks.5.conv.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.1.downsample.conv.weight": "blocks.5.conv.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.2.block.0.conv1.bias": "blocks.6.conv1.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.2.block.0.conv1.weight": "blocks.6.conv1.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.2.block.0.conv2.bias": "blocks.6.conv2.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.2.block.0.conv2.weight": "blocks.6.conv2.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.2.block.0.nin_shortcut.bias": "blocks.6.conv_shortcut.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.2.block.0.nin_shortcut.weight": "blocks.6.conv_shortcut.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.2.block.0.norm1.bias": "blocks.6.norm1.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.2.block.0.norm1.weight": "blocks.6.norm1.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.2.block.0.norm2.bias": "blocks.6.norm2.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.2.block.0.norm2.weight": "blocks.6.norm2.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.2.block.1.conv1.bias": "blocks.7.conv1.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.2.block.1.conv1.weight": "blocks.7.conv1.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.2.block.1.conv2.bias": "blocks.7.conv2.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.2.block.1.conv2.weight": "blocks.7.conv2.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.2.block.1.norm1.bias": "blocks.7.norm1.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.2.block.1.norm1.weight": "blocks.7.norm1.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.2.block.1.norm2.bias": "blocks.7.norm2.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.2.block.1.norm2.weight": "blocks.7.norm2.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.2.downsample.conv.bias": "blocks.8.conv.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.2.downsample.conv.weight": "blocks.8.conv.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.3.block.0.conv1.bias": "blocks.9.conv1.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.3.block.0.conv1.weight": "blocks.9.conv1.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.3.block.0.conv2.bias": "blocks.9.conv2.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.3.block.0.conv2.weight": "blocks.9.conv2.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.3.block.0.norm1.bias": "blocks.9.norm1.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.3.block.0.norm1.weight": "blocks.9.norm1.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.3.block.0.norm2.bias": "blocks.9.norm2.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.3.block.0.norm2.weight": "blocks.9.norm2.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.3.block.1.conv1.bias": "blocks.10.conv1.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.3.block.1.conv1.weight": "blocks.10.conv1.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.3.block.1.conv2.bias": "blocks.10.conv2.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.3.block.1.conv2.weight": "blocks.10.conv2.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.3.block.1.norm1.bias": "blocks.10.norm1.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.3.block.1.norm1.weight": "blocks.10.norm1.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.down.3.block.1.norm2.bias": "blocks.10.norm2.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.down.3.block.1.norm2.weight": "blocks.10.norm2.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.attn_1.k.bias": "blocks.12.transformer_blocks.0.to_k.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.attn_1.k.weight": "blocks.12.transformer_blocks.0.to_k.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.attn_1.norm.bias": "blocks.12.norm.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.attn_1.norm.weight": "blocks.12.norm.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.attn_1.proj_out.bias": "blocks.12.transformer_blocks.0.to_out.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.attn_1.proj_out.weight": "blocks.12.transformer_blocks.0.to_out.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.attn_1.q.bias": "blocks.12.transformer_blocks.0.to_q.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.attn_1.q.weight": "blocks.12.transformer_blocks.0.to_q.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.attn_1.v.bias": "blocks.12.transformer_blocks.0.to_v.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.attn_1.v.weight": "blocks.12.transformer_blocks.0.to_v.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.block_1.conv1.bias": "blocks.11.conv1.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.block_1.conv1.weight": "blocks.11.conv1.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.block_1.conv2.bias": "blocks.11.conv2.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.block_1.conv2.weight": "blocks.11.conv2.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.block_1.norm1.bias": "blocks.11.norm1.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.block_1.norm1.weight": "blocks.11.norm1.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.block_1.norm2.bias": "blocks.11.norm2.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.block_1.norm2.weight": "blocks.11.norm2.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.block_2.conv1.bias": "blocks.13.conv1.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.block_2.conv1.weight": "blocks.13.conv1.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.block_2.conv2.bias": "blocks.13.conv2.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.block_2.conv2.weight": "blocks.13.conv2.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.block_2.norm1.bias": "blocks.13.norm1.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.block_2.norm1.weight": "blocks.13.norm1.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.block_2.norm2.bias": "blocks.13.norm2.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.mid.block_2.norm2.weight": "blocks.13.norm2.weight",
|
||||
"conditioner.embedders.3.encoder.encoder.norm_out.bias": "conv_norm_out.bias",
|
||||
"conditioner.embedders.3.encoder.encoder.norm_out.weight": "conv_norm_out.weight",
|
||||
"conditioner.embedders.3.encoder.quant_conv.bias": "quant_conv.bias",
|
||||
"conditioner.embedders.3.encoder.quant_conv.weight": "quant_conv.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name in rename_dict:
|
||||
param = state_dict[name]
|
||||
if "transformer_blocks" in rename_dict[name]:
|
||||
param = param.squeeze()
|
||||
state_dict_[rename_dict[name]] = param
|
||||
return state_dict_
|
||||
@@ -1,234 +0,0 @@
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
class TileWorker:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def mask(self, height, width, border_width):
|
||||
# Create a mask with shape (height, width).
|
||||
# The centre area is filled with 1, and the border line is filled with values in range (0, 1].
|
||||
x = torch.arange(height).repeat(width, 1).T
|
||||
y = torch.arange(width).repeat(height, 1)
|
||||
mask = torch.stack([x + 1, height - x, y + 1, width - y]).min(dim=0).values
|
||||
mask = (mask / border_width).clip(0, 1)
|
||||
return mask
|
||||
|
||||
|
||||
def tile(self, model_input, tile_size, tile_stride, tile_device, tile_dtype):
|
||||
# Convert a tensor (b, c, h, w) to (b, c, tile_size, tile_size, tile_num)
|
||||
batch_size, channel, _, _ = model_input.shape
|
||||
model_input = model_input.to(device=tile_device, dtype=tile_dtype)
|
||||
unfold_operator = torch.nn.Unfold(
|
||||
kernel_size=(tile_size, tile_size),
|
||||
stride=(tile_stride, tile_stride)
|
||||
)
|
||||
model_input = unfold_operator(model_input)
|
||||
model_input = model_input.view((batch_size, channel, tile_size, tile_size, -1))
|
||||
|
||||
return model_input
|
||||
|
||||
|
||||
def tiled_inference(self, forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype):
|
||||
# Call y=forward_fn(x) for each tile
|
||||
tile_num = model_input.shape[-1]
|
||||
model_output_stack = []
|
||||
|
||||
for tile_id in range(0, tile_num, tile_batch_size):
|
||||
|
||||
# process input
|
||||
tile_id_ = min(tile_id + tile_batch_size, tile_num)
|
||||
x = model_input[:, :, :, :, tile_id: tile_id_]
|
||||
x = x.to(device=inference_device, dtype=inference_dtype)
|
||||
x = rearrange(x, "b c h w n -> (n b) c h w")
|
||||
|
||||
# process output
|
||||
y = forward_fn(x)
|
||||
y = rearrange(y, "(n b) c h w -> b c h w n", n=tile_id_-tile_id)
|
||||
y = y.to(device=tile_device, dtype=tile_dtype)
|
||||
model_output_stack.append(y)
|
||||
|
||||
model_output = torch.concat(model_output_stack, dim=-1)
|
||||
return model_output
|
||||
|
||||
|
||||
def io_scale(self, model_output, tile_size):
|
||||
# Determine the size modification happened in forward_fn
|
||||
# We only consider the same scale on height and width.
|
||||
io_scale = model_output.shape[2] / tile_size
|
||||
return io_scale
|
||||
|
||||
|
||||
def untile(self, model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype):
|
||||
# The reversed function of tile
|
||||
mask = self.mask(tile_size, tile_size, border_width)
|
||||
mask = mask.to(device=tile_device, dtype=tile_dtype)
|
||||
mask = rearrange(mask, "h w -> 1 1 h w 1")
|
||||
model_output = model_output * mask
|
||||
|
||||
fold_operator = torch.nn.Fold(
|
||||
output_size=(height, width),
|
||||
kernel_size=(tile_size, tile_size),
|
||||
stride=(tile_stride, tile_stride)
|
||||
)
|
||||
mask = repeat(mask[0, 0, :, :, 0], "h w -> 1 (h w) n", n=model_output.shape[-1])
|
||||
model_output = rearrange(model_output, "b c h w n -> b (c h w) n")
|
||||
model_output = fold_operator(model_output) / fold_operator(mask)
|
||||
|
||||
return model_output
|
||||
|
||||
|
||||
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_batch_size=1, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
|
||||
# Prepare
|
||||
inference_device, inference_dtype = model_input.device, model_input.dtype
|
||||
height, width = model_input.shape[2], model_input.shape[3]
|
||||
border_width = int(tile_stride*0.5) if border_width is None else border_width
|
||||
|
||||
# tile
|
||||
model_input = self.tile(model_input, tile_size, tile_stride, tile_device, tile_dtype)
|
||||
|
||||
# inference
|
||||
model_output = self.tiled_inference(forward_fn, model_input, tile_batch_size, inference_device, inference_dtype, tile_device, tile_dtype)
|
||||
|
||||
# resize
|
||||
io_scale = self.io_scale(model_output, tile_size)
|
||||
height, width = int(height*io_scale), int(width*io_scale)
|
||||
tile_size, tile_stride = int(tile_size*io_scale), int(tile_stride*io_scale)
|
||||
border_width = int(border_width*io_scale)
|
||||
|
||||
# untile
|
||||
model_output = self.untile(model_output, height, width, tile_size, tile_stride, border_width, tile_device, tile_dtype)
|
||||
|
||||
# Done!
|
||||
model_output = model_output.to(device=inference_device, dtype=inference_dtype)
|
||||
return model_output
|
||||
|
||||
|
||||
|
||||
class FastTileWorker:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def build_mask(self, data, is_bound):
|
||||
_, _, H, W = data.shape
|
||||
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
|
||||
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
|
||||
border_width = (H + W) // 4
|
||||
pad = torch.ones_like(h) * border_width
|
||||
mask = torch.stack([
|
||||
pad if is_bound[0] else h + 1,
|
||||
pad if is_bound[1] else H - h,
|
||||
pad if is_bound[2] else w + 1,
|
||||
pad if is_bound[3] else W - w
|
||||
]).min(dim=0).values
|
||||
mask = mask.clip(1, border_width)
|
||||
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
|
||||
mask = rearrange(mask, "H W -> 1 H W")
|
||||
return mask
|
||||
|
||||
|
||||
def tiled_forward(self, forward_fn, model_input, tile_size, tile_stride, tile_device="cpu", tile_dtype=torch.float32, border_width=None):
|
||||
# Prepare
|
||||
B, C, H, W = model_input.shape
|
||||
border_width = int(tile_stride*0.5) if border_width is None else border_width
|
||||
weight = torch.zeros((1, 1, H, W), dtype=tile_dtype, device=tile_device)
|
||||
values = torch.zeros((B, C, H, W), dtype=tile_dtype, device=tile_device)
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for h in range(0, H, tile_stride):
|
||||
for w in range(0, W, tile_stride):
|
||||
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
|
||||
continue
|
||||
h_, w_ = h + tile_size, w + tile_size
|
||||
if h_ > H: h, h_ = H - tile_size, H
|
||||
if w_ > W: w, w_ = W - tile_size, W
|
||||
tasks.append((h, h_, w, w_))
|
||||
|
||||
# Run
|
||||
for hl, hr, wl, wr in tasks:
|
||||
# Forward
|
||||
hidden_states_batch = forward_fn(hl, hr, wl, wr).to(dtype=tile_dtype, device=tile_device)
|
||||
|
||||
mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
|
||||
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
|
||||
weight[:, :, hl:hr, wl:wr] += mask
|
||||
values /= weight
|
||||
return values
|
||||
|
||||
|
||||
|
||||
class TileWorker2Dto3D:
|
||||
"""
|
||||
Process 3D tensors, but only enable TileWorker on 2D.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def build_mask(self, T, H, W, dtype, device, is_bound, border_width):
|
||||
t = repeat(torch.arange(T), "T -> T H W", T=T, H=H, W=W)
|
||||
h = repeat(torch.arange(H), "H -> T H W", T=T, H=H, W=W)
|
||||
w = repeat(torch.arange(W), "W -> T H W", T=T, H=H, W=W)
|
||||
border_width = (H + W) // 4 if border_width is None else border_width
|
||||
pad = torch.ones_like(h) * border_width
|
||||
mask = torch.stack([
|
||||
pad if is_bound[0] else t + 1,
|
||||
pad if is_bound[1] else T - t,
|
||||
pad if is_bound[2] else h + 1,
|
||||
pad if is_bound[3] else H - h,
|
||||
pad if is_bound[4] else w + 1,
|
||||
pad if is_bound[5] else W - w
|
||||
]).min(dim=0).values
|
||||
mask = mask.clip(1, border_width)
|
||||
mask = (mask / border_width).to(dtype=dtype, device=device)
|
||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
||||
return mask
|
||||
|
||||
|
||||
def tiled_forward(
|
||||
self,
|
||||
forward_fn,
|
||||
model_input,
|
||||
tile_size, tile_stride,
|
||||
tile_device="cpu", tile_dtype=torch.float32,
|
||||
computation_device="cuda", computation_dtype=torch.float32,
|
||||
border_width=None, scales=[1, 1, 1, 1],
|
||||
progress_bar=lambda x:x
|
||||
):
|
||||
B, C, T, H, W = model_input.shape
|
||||
scale_C, scale_T, scale_H, scale_W = scales
|
||||
tile_size_H, tile_size_W = tile_size
|
||||
tile_stride_H, tile_stride_W = tile_stride
|
||||
|
||||
value = torch.zeros((B, int(C*scale_C), int(T*scale_T), int(H*scale_H), int(W*scale_W)), dtype=tile_dtype, device=tile_device)
|
||||
weight = torch.zeros((1, 1, int(T*scale_T), int(H*scale_H), int(W*scale_W)), dtype=tile_dtype, device=tile_device)
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for h in range(0, H, tile_stride_H):
|
||||
for w in range(0, W, tile_stride_W):
|
||||
if (h-tile_stride_H >= 0 and h-tile_stride_H+tile_size_H >= H) or (w-tile_stride_W >= 0 and w-tile_stride_W+tile_size_W >= W):
|
||||
continue
|
||||
h_, w_ = h + tile_size_H, w + tile_size_W
|
||||
if h_ > H: h, h_ = max(H - tile_size_H, 0), H
|
||||
if w_ > W: w, w_ = max(W - tile_size_W, 0), W
|
||||
tasks.append((h, h_, w, w_))
|
||||
|
||||
# Run
|
||||
for hl, hr, wl, wr in progress_bar(tasks):
|
||||
mask = self.build_mask(
|
||||
int(T*scale_T), int((hr-hl)*scale_H), int((wr-wl)*scale_W),
|
||||
tile_dtype, tile_device,
|
||||
is_bound=(True, True, hl==0, hr>=H, wl==0, wr>=W),
|
||||
border_width=border_width
|
||||
)
|
||||
grid_input = model_input[:, :, :, hl:hr, wl:wr].to(dtype=computation_dtype, device=computation_device)
|
||||
grid_output = forward_fn(grid_input).to(dtype=tile_dtype, device=tile_device)
|
||||
value[:, :, :, int(hl*scale_H):int(hr*scale_H), int(wl*scale_W):int(wr*scale_W)] += grid_output * mask
|
||||
weight[:, :, :, int(hl*scale_H):int(hr*scale_H), int(wl*scale_W):int(wr*scale_W)] += mask
|
||||
value = value / weight
|
||||
return value
|
||||
@@ -1,182 +0,0 @@
|
||||
import torch, os
|
||||
from safetensors import safe_open
|
||||
from contextlib import contextmanager
|
||||
import hashlib
|
||||
|
||||
@contextmanager
|
||||
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False):
|
||||
|
||||
old_register_parameter = torch.nn.Module.register_parameter
|
||||
if include_buffers:
|
||||
old_register_buffer = torch.nn.Module.register_buffer
|
||||
|
||||
def register_empty_parameter(module, name, param):
|
||||
old_register_parameter(module, name, param)
|
||||
if param is not None:
|
||||
param_cls = type(module._parameters[name])
|
||||
kwargs = module._parameters[name].__dict__
|
||||
kwargs["requires_grad"] = param.requires_grad
|
||||
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
|
||||
|
||||
def register_empty_buffer(module, name, buffer, persistent=True):
|
||||
old_register_buffer(module, name, buffer, persistent=persistent)
|
||||
if buffer is not None:
|
||||
module._buffers[name] = module._buffers[name].to(device)
|
||||
|
||||
def patch_tensor_constructor(fn):
|
||||
def wrapper(*args, **kwargs):
|
||||
kwargs["device"] = device
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
if include_buffers:
|
||||
tensor_constructors_to_patch = {
|
||||
torch_function_name: getattr(torch, torch_function_name)
|
||||
for torch_function_name in ["empty", "zeros", "ones", "full"]
|
||||
}
|
||||
else:
|
||||
tensor_constructors_to_patch = {}
|
||||
|
||||
try:
|
||||
torch.nn.Module.register_parameter = register_empty_parameter
|
||||
if include_buffers:
|
||||
torch.nn.Module.register_buffer = register_empty_buffer
|
||||
for torch_function_name in tensor_constructors_to_patch.keys():
|
||||
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
|
||||
yield
|
||||
finally:
|
||||
torch.nn.Module.register_parameter = old_register_parameter
|
||||
if include_buffers:
|
||||
torch.nn.Module.register_buffer = old_register_buffer
|
||||
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
|
||||
setattr(torch, torch_function_name, old_torch_function)
|
||||
|
||||
def load_state_dict_from_folder(file_path, torch_dtype=None):
|
||||
state_dict = {}
|
||||
for file_name in os.listdir(file_path):
|
||||
if "." in file_name and file_name.split(".")[-1] in [
|
||||
"safetensors", "bin", "ckpt", "pth", "pt"
|
||||
]:
|
||||
state_dict.update(load_state_dict(os.path.join(file_path, file_name), torch_dtype=torch_dtype))
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||
else:
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||
state_dict = {}
|
||||
with safe_open(file_path, framework="pt", device=str(device)) as f:
|
||||
for k in f.keys():
|
||||
state_dict[k] = f.get_tensor(k)
|
||||
if torch_dtype is not None:
|
||||
state_dict[k] = state_dict[k].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
|
||||
state_dict = torch.load(file_path, map_location=device, weights_only=True)
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
state_dict[i] = state_dict[i].to(torch_dtype)
|
||||
return state_dict
|
||||
|
||||
|
||||
def search_for_embeddings(state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
if isinstance(state_dict[k], torch.Tensor):
|
||||
embeddings.append(state_dict[k])
|
||||
elif isinstance(state_dict[k], dict):
|
||||
embeddings += search_for_embeddings(state_dict[k])
|
||||
return embeddings
|
||||
|
||||
|
||||
def search_parameter(param, state_dict):
|
||||
for name, param_ in state_dict.items():
|
||||
if param.numel() == param_.numel():
|
||||
if param.shape == param_.shape:
|
||||
if torch.dist(param, param_) < 1e-3:
|
||||
return name
|
||||
else:
|
||||
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
|
||||
matched_keys = set()
|
||||
with torch.no_grad():
|
||||
for name in source_state_dict:
|
||||
rename = search_parameter(source_state_dict[name], target_state_dict)
|
||||
if rename is not None:
|
||||
print(f'"{name}": "{rename}",')
|
||||
matched_keys.add(rename)
|
||||
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
|
||||
length = source_state_dict[name].shape[0] // 3
|
||||
rename = []
|
||||
for i in range(3):
|
||||
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
|
||||
if None not in rename:
|
||||
print(f'"{name}": {rename},')
|
||||
for rename_ in rename:
|
||||
matched_keys.add(rename_)
|
||||
for name in target_state_dict:
|
||||
if name not in matched_keys:
|
||||
print("Cannot find", name, target_state_dict[name].shape)
|
||||
|
||||
|
||||
def search_for_files(folder, extensions):
|
||||
files = []
|
||||
if os.path.isdir(folder):
|
||||
for file in sorted(os.listdir(folder)):
|
||||
files += search_for_files(os.path.join(folder, file), extensions)
|
||||
elif os.path.isfile(folder):
|
||||
for extension in extensions:
|
||||
if folder.endswith(extension):
|
||||
files.append(folder)
|
||||
break
|
||||
return files
|
||||
|
||||
|
||||
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
|
||||
keys = []
|
||||
for key, value in state_dict.items():
|
||||
if isinstance(key, str):
|
||||
if isinstance(value, torch.Tensor):
|
||||
if with_shape:
|
||||
shape = "_".join(map(str, list(value.shape)))
|
||||
keys.append(key + ":" + shape)
|
||||
keys.append(key)
|
||||
elif isinstance(value, dict):
|
||||
keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape))
|
||||
keys.sort()
|
||||
keys_str = ",".join(keys)
|
||||
return keys_str
|
||||
|
||||
|
||||
def split_state_dict_with_prefix(state_dict):
|
||||
keys = sorted([key for key in state_dict if isinstance(key, str)])
|
||||
prefix_dict = {}
|
||||
for key in keys:
|
||||
prefix = key if "." not in key else key.split(".")[0]
|
||||
if prefix not in prefix_dict:
|
||||
prefix_dict[prefix] = []
|
||||
prefix_dict[prefix].append(key)
|
||||
state_dicts = []
|
||||
for prefix, keys in prefix_dict.items():
|
||||
sub_state_dict = {key: state_dict[key] for key in keys}
|
||||
state_dicts.append(sub_state_dict)
|
||||
return state_dicts
|
||||
|
||||
|
||||
def hash_state_dict_keys(state_dict, with_shape=True):
|
||||
keys_str = convert_state_dict_keys_to_single_str(state_dict, with_shape=with_shape)
|
||||
keys_str = keys_str.encode(encoding="UTF-8")
|
||||
return hashlib.md5(keys_str).hexdigest()
|
||||
@@ -375,7 +375,7 @@ class Blur(nn.Module):
|
||||
if upsample_factor > 1:
|
||||
kernel = kernel * (upsample_factor ** 2)
|
||||
|
||||
self.register_buffer('kernel', kernel)
|
||||
self.kernel = torch.nn.Parameter(kernel)
|
||||
|
||||
self.pad = pad
|
||||
|
||||
@@ -648,23 +648,3 @@ class WanAnimateAdapter(torch.nn.Module):
|
||||
residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args)
|
||||
x = residual_out + x
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanAnimateAdapterStateDictConverter()
|
||||
|
||||
|
||||
class WanAnimateAdapterStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("pose_patch_embedding.") or name.startswith("face_adapter") or name.startswith("face_encoder") or name.startswith("motion_encoder"):
|
||||
state_dict_[name] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import torch.nn.functional as F
|
||||
import math
|
||||
from typing import Tuple, Optional
|
||||
from einops import rearrange
|
||||
from .utils import hash_state_dict_keys
|
||||
from .wan_video_camera_controller import SimpleAdapter
|
||||
try:
|
||||
import flash_attn_interface
|
||||
@@ -405,369 +404,3 @@ class WanModel(torch.nn.Module):
|
||||
x = self.head(x, t)
|
||||
x = self.unpatchify(x, (f, h, w))
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanModelStateDictConverter()
|
||||
|
||||
|
||||
class WanModelStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
rename_dict = {
|
||||
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
|
||||
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
|
||||
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
|
||||
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
|
||||
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
|
||||
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
|
||||
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
|
||||
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
|
||||
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
|
||||
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
|
||||
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
|
||||
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
|
||||
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
|
||||
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
|
||||
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
|
||||
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
|
||||
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
|
||||
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
||||
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
||||
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
||||
"blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias",
|
||||
"blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight",
|
||||
"blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias",
|
||||
"blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight",
|
||||
"blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight",
|
||||
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
||||
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
||||
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
||||
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
|
||||
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
|
||||
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
|
||||
"blocks.0.scale_shift_table": "blocks.0.modulation",
|
||||
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
|
||||
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
|
||||
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
|
||||
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
|
||||
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
|
||||
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
|
||||
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
|
||||
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
||||
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
||||
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias",
|
||||
"condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight",
|
||||
"condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias",
|
||||
"condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight",
|
||||
"condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias",
|
||||
"condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight",
|
||||
"patch_embedding.bias": "patch_embedding.bias",
|
||||
"patch_embedding.weight": "patch_embedding.weight",
|
||||
"scale_shift_table": "head.modulation",
|
||||
"proj_out.bias": "head.head.bias",
|
||||
"proj_out.weight": "head.head.weight",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
|
||||
if name_ in rename_dict:
|
||||
name_ = rename_dict[name_]
|
||||
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
|
||||
state_dict_[name_] = param
|
||||
if hash_state_dict_keys(state_dict_) == "cb104773c6c2cb6df4f9529ad5c60d0b":
|
||||
config = {
|
||||
"model_type": "t2v",
|
||||
"patch_size": (1, 2, 2),
|
||||
"text_len": 512,
|
||||
"in_dim": 16,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"window_size": (-1, -1),
|
||||
"qk_norm": True,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-6,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict_) == "6bfcfb3b342cb286ce886889d519a77e":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict_, config
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = {name: param for name, param in state_dict.items() if not name.startswith("vace")}
|
||||
state_dict = {name: param for name, param in state_dict.items() if name.split(".")[0] not in ["pose_patch_embedding", "face_adapter", "face_encoder", "motion_encoder"]}
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("model."):
|
||||
name = name[len("model."):]
|
||||
state_dict_[name] = param
|
||||
state_dict = state_dict_
|
||||
if hash_state_dict_keys(state_dict) == "9269f8db9040a9d860eaca435be61814":
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 16,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "aafcfd9672c3a2456dc46e1cb6e52c70":
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 16,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "6bfcfb3b342cb286ce886889d519a77e":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "6d6ccde6845b95ad9114ab993d917893":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "349723183fc063b2bfc10bb2835cf677":
|
||||
# 1.3B PAI control
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "efa44cddf936c70abd0ea28b6cbe946c":
|
||||
# 14B PAI control
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "3ef3b1f8e1dab83d5b71fd7b617f859f":
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_image_pos_emb": True
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "70ddad9d3a133785da5ea371aae09504":
|
||||
# 1.3B PAI control v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": True
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "26bde73488a92e64cc20b0a7485b9e5b":
|
||||
# 14B PAI control v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": True
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "ac6a5aa74f4a0aab6f64eb9a72f19901":
|
||||
# 1.3B PAI control-camera v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 32,
|
||||
"dim": 1536,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 12,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": False,
|
||||
"add_control_adapter": True,
|
||||
"in_dim_control_adapter": 24,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "b61c605c2adbd23124d152ed28e049ae":
|
||||
# 14B PAI control-camera v1.1
|
||||
config = {
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 32,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": False,
|
||||
"add_control_adapter": True,
|
||||
"in_dim_control_adapter": 24,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "1f5ab7703c6fc803fdded85ff040c316":
|
||||
# Wan-AI/Wan2.2-TI2V-5B
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 48,
|
||||
"dim": 3072,
|
||||
"ffn_dim": 14336,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 48,
|
||||
"num_heads": 24,
|
||||
"num_layers": 30,
|
||||
"eps": 1e-6,
|
||||
"seperated_timestep": True,
|
||||
"require_clip_embedding": False,
|
||||
"require_vae_embedding": False,
|
||||
"fuse_vae_embedding_in_latents": True,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "5b013604280dd715f8457c6ed6d6a626":
|
||||
# Wan-AI/Wan2.2-I2V-A14B
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"require_clip_embedding": False,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "2267d489f0ceb9f21836532952852ee5":
|
||||
# Wan2.2-Fun-A14B-Control
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 52,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": True,
|
||||
"require_clip_embedding": False,
|
||||
}
|
||||
elif hash_state_dict_keys(state_dict) == "47dbeab5e560db3180adf51dc0232fb1":
|
||||
# Wan2.2-Fun-A14B-Control-Camera
|
||||
config = {
|
||||
"has_image_input": False,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"out_dim": 16,
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"eps": 1e-6,
|
||||
"has_ref_conv": False,
|
||||
"add_control_adapter": True,
|
||||
"in_dim_control_adapter": 24,
|
||||
"require_clip_embedding": False,
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict, config
|
||||
|
||||
@@ -3,7 +3,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
from .utils import hash_state_dict_keys
|
||||
from .wan_video_dit import rearrange, precompute_freqs_cis_3d, DiTBlock, Head, CrossAttention, modulate, sinusoidal_embedding_1d
|
||||
|
||||
|
||||
@@ -593,33 +592,3 @@ class WanS2VModel(torch.nn.Module):
|
||||
# make compatible with wan video
|
||||
x = torch.cat([origin_ref_latents, x], dim=2)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanS2VModelStateDictConverter()
|
||||
|
||||
|
||||
class WanS2VModelStateDictConverter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
config = {}
|
||||
if hash_state_dict_keys(state_dict) == "966cffdcc52f9c46c391768b27637614":
|
||||
config = {
|
||||
"dim": 5120,
|
||||
"in_dim": 16,
|
||||
"ffn_dim": 13824,
|
||||
"out_dim": 16,
|
||||
"text_dim": 4096,
|
||||
"freq_dim": 256,
|
||||
"eps": 1e-06,
|
||||
"patch_size": (1, 2, 2),
|
||||
"num_heads": 40,
|
||||
"num_layers": 40,
|
||||
"cond_dim": 16,
|
||||
"audio_dim": 1024,
|
||||
"num_audio_token": 4,
|
||||
}
|
||||
return state_dict, config
|
||||
|
||||
@@ -874,29 +874,5 @@ class WanImageEncoder(torch.nn.Module):
|
||||
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
|
||||
|
||||
# forward
|
||||
dtype = next(iter(self.model.visual.parameters())).dtype
|
||||
videos = videos.to(dtype)
|
||||
out = self.model.visual(videos, use_31_block=True)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanImageEncoderStateDictConverter()
|
||||
|
||||
|
||||
class WanImageEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
if name.startswith("textual."):
|
||||
continue
|
||||
name = "model." + name
|
||||
state_dict_[name] = param
|
||||
return state_dict_
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import torch
|
||||
from .wan_video_dit import DiTBlock, SelfAttention, rope_apply, flash_attention, modulate, MLP
|
||||
from .utils import hash_state_dict_keys
|
||||
import einops
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -168,114 +167,3 @@ class MotWanModel(torch.nn.Module):
|
||||
block = self.blocks[self.mot_layers_mapping[block_id]]
|
||||
x, x_mot = block(wan_block, x, context, t_mod, freqs, x_mot, context_mot, t_mod_mot, freqs_mot)
|
||||
return x, x_mot
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return MotWanModelDictConverter()
|
||||
|
||||
|
||||
class MotWanModelDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
|
||||
rename_dict = {
|
||||
"blocks.0.attn1.norm_k.weight": "blocks.0.self_attn.norm_k.weight",
|
||||
"blocks.0.attn1.norm_q.weight": "blocks.0.self_attn.norm_q.weight",
|
||||
"blocks.0.attn1.to_k.bias": "blocks.0.self_attn.k.bias",
|
||||
"blocks.0.attn1.to_k.weight": "blocks.0.self_attn.k.weight",
|
||||
"blocks.0.attn1.to_out.0.bias": "blocks.0.self_attn.o.bias",
|
||||
"blocks.0.attn1.to_out.0.weight": "blocks.0.self_attn.o.weight",
|
||||
"blocks.0.attn1.to_q.bias": "blocks.0.self_attn.q.bias",
|
||||
"blocks.0.attn1.to_q.weight": "blocks.0.self_attn.q.weight",
|
||||
"blocks.0.attn1.to_v.bias": "blocks.0.self_attn.v.bias",
|
||||
"blocks.0.attn1.to_v.weight": "blocks.0.self_attn.v.weight",
|
||||
"blocks.0.attn2.norm_k.weight": "blocks.0.cross_attn.norm_k.weight",
|
||||
"blocks.0.attn2.norm_q.weight": "blocks.0.cross_attn.norm_q.weight",
|
||||
"blocks.0.attn2.to_k.bias": "blocks.0.cross_attn.k.bias",
|
||||
"blocks.0.attn2.to_k.weight": "blocks.0.cross_attn.k.weight",
|
||||
"blocks.0.attn2.to_out.0.bias": "blocks.0.cross_attn.o.bias",
|
||||
"blocks.0.attn2.to_out.0.weight": "blocks.0.cross_attn.o.weight",
|
||||
"blocks.0.attn2.to_q.bias": "blocks.0.cross_attn.q.bias",
|
||||
"blocks.0.attn2.to_q.weight": "blocks.0.cross_attn.q.weight",
|
||||
"blocks.0.attn2.to_v.bias": "blocks.0.cross_attn.v.bias",
|
||||
"blocks.0.attn2.to_v.weight": "blocks.0.cross_attn.v.weight",
|
||||
"blocks.0.attn2.add_k_proj.bias":"blocks.0.cross_attn.k_img.bias",
|
||||
"blocks.0.attn2.add_k_proj.weight":"blocks.0.cross_attn.k_img.weight",
|
||||
"blocks.0.attn2.add_v_proj.bias":"blocks.0.cross_attn.v_img.bias",
|
||||
"blocks.0.attn2.add_v_proj.weight":"blocks.0.cross_attn.v_img.weight",
|
||||
"blocks.0.attn2.norm_added_k.weight":"blocks.0.cross_attn.norm_k_img.weight",
|
||||
"blocks.0.ffn.net.0.proj.bias": "blocks.0.ffn.0.bias",
|
||||
"blocks.0.ffn.net.0.proj.weight": "blocks.0.ffn.0.weight",
|
||||
"blocks.0.ffn.net.2.bias": "blocks.0.ffn.2.bias",
|
||||
"blocks.0.ffn.net.2.weight": "blocks.0.ffn.2.weight",
|
||||
"blocks.0.norm2.bias": "blocks.0.norm3.bias",
|
||||
"blocks.0.norm2.weight": "blocks.0.norm3.weight",
|
||||
"blocks.0.scale_shift_table": "blocks.0.modulation",
|
||||
"condition_embedder.text_embedder.linear_1.bias": "text_embedding.0.bias",
|
||||
"condition_embedder.text_embedder.linear_1.weight": "text_embedding.0.weight",
|
||||
"condition_embedder.text_embedder.linear_2.bias": "text_embedding.2.bias",
|
||||
"condition_embedder.text_embedder.linear_2.weight": "text_embedding.2.weight",
|
||||
"condition_embedder.time_embedder.linear_1.bias": "time_embedding.0.bias",
|
||||
"condition_embedder.time_embedder.linear_1.weight": "time_embedding.0.weight",
|
||||
"condition_embedder.time_embedder.linear_2.bias": "time_embedding.2.bias",
|
||||
"condition_embedder.time_embedder.linear_2.weight": "time_embedding.2.weight",
|
||||
"condition_embedder.time_proj.bias": "time_projection.1.bias",
|
||||
"condition_embedder.time_proj.weight": "time_projection.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.bias":"img_emb.proj.1.bias",
|
||||
"condition_embedder.image_embedder.ff.net.0.proj.weight":"img_emb.proj.1.weight",
|
||||
"condition_embedder.image_embedder.ff.net.2.bias":"img_emb.proj.3.bias",
|
||||
"condition_embedder.image_embedder.ff.net.2.weight":"img_emb.proj.3.weight",
|
||||
"condition_embedder.image_embedder.norm1.bias":"img_emb.proj.0.bias",
|
||||
"condition_embedder.image_embedder.norm1.weight":"img_emb.proj.0.weight",
|
||||
"condition_embedder.image_embedder.norm2.bias":"img_emb.proj.4.bias",
|
||||
"condition_embedder.image_embedder.norm2.weight":"img_emb.proj.4.weight",
|
||||
"patch_embedding.bias": "patch_embedding.bias",
|
||||
"patch_embedding.weight": "patch_embedding.weight",
|
||||
"scale_shift_table": "head.modulation",
|
||||
"proj_out.bias": "head.head.bias",
|
||||
"proj_out.weight": "head.head.weight",
|
||||
}
|
||||
state_dict = {name: param for name, param in state_dict.items() if '_mot_ref' in name}
|
||||
if hash_state_dict_keys(state_dict) == '19debbdb7f4d5ba93b4ddb1cbe5788c7':
|
||||
mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)
|
||||
else:
|
||||
mot_layers = (0, 4, 8, 12, 16, 20, 24, 28, 32, 36)
|
||||
mot_layers_mapping = {i:n for n, i in enumerate(mot_layers)}
|
||||
|
||||
state_dict_ = {}
|
||||
|
||||
for name, param in state_dict.items():
|
||||
name = name.replace("_mot_ref", "")
|
||||
if name in rename_dict:
|
||||
state_dict_[rename_dict[name]] = param
|
||||
else:
|
||||
if name.split(".")[1].isdigit():
|
||||
block_id = int(name.split(".")[1])
|
||||
name = name.replace(str(block_id), str(mot_layers_mapping[block_id]))
|
||||
name_ = ".".join(name.split(".")[:1] + ["0"] + name.split(".")[2:])
|
||||
if name_ in rename_dict:
|
||||
name_ = rename_dict[name_]
|
||||
name_ = ".".join(name_.split(".")[:1] + [name.split(".")[1]] + name_.split(".")[2:])
|
||||
state_dict_[name_] = param
|
||||
|
||||
if hash_state_dict_keys(state_dict_) == '6507c8213a3c476df5958b01dcf302d0': # vap 14B
|
||||
config = {
|
||||
"mot_layers":(0, 4, 8, 12, 16, 20, 24, 28, 32, 36),
|
||||
"has_image_input": True,
|
||||
"patch_size": [1, 2, 2],
|
||||
"in_dim": 36,
|
||||
"dim": 5120,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"text_dim": 4096,
|
||||
"num_heads": 40,
|
||||
"eps": 1e-6
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict_, config
|
||||
|
||||
|
||||
|
||||
@@ -25,20 +25,3 @@ class WanMotionControllerModel(torch.nn.Module):
|
||||
state_dict = self.linear[-1].state_dict()
|
||||
state_dict = {i: state_dict[i] * 0 for i in state_dict}
|
||||
self.linear[-1].load_state_dict(state_dict)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanMotionControllerModelDictConverter()
|
||||
|
||||
|
||||
|
||||
class WanMotionControllerModelDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
import ftfy
|
||||
import html
|
||||
import string
|
||||
import regex as re
|
||||
|
||||
def fp16_clamp(x):
|
||||
if x.dtype == torch.float16 and torch.isinf(x).any():
|
||||
@@ -252,18 +255,76 @@ class WanTextEncoder(torch.nn.Module):
|
||||
x = self.norm(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
class WanTextEncoderStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
return state_dict
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def canonicalize(text, keep_punctuation_exact_string=None):
|
||||
text = text.replace('_', ' ')
|
||||
if keep_punctuation_exact_string:
|
||||
text = keep_punctuation_exact_string.join(
|
||||
part.translate(str.maketrans('', '', string.punctuation))
|
||||
for part in text.split(keep_punctuation_exact_string))
|
||||
else:
|
||||
text = text.translate(str.maketrans('', '', string.punctuation))
|
||||
text = text.lower()
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
return text.strip()
|
||||
|
||||
|
||||
class HuggingfaceTokenizer:
|
||||
|
||||
def __init__(self, name, seq_len=None, clean=None, **kwargs):
|
||||
assert clean in (None, 'whitespace', 'lower', 'canonicalize')
|
||||
self.name = name
|
||||
self.seq_len = seq_len
|
||||
self.clean = clean
|
||||
|
||||
# init tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
|
||||
self.vocab_size = self.tokenizer.vocab_size
|
||||
|
||||
def __call__(self, sequence, **kwargs):
|
||||
return_mask = kwargs.pop('return_mask', False)
|
||||
|
||||
# arguments
|
||||
_kwargs = {'return_tensors': 'pt'}
|
||||
if self.seq_len is not None:
|
||||
_kwargs.update({
|
||||
'padding': 'max_length',
|
||||
'truncation': True,
|
||||
'max_length': self.seq_len
|
||||
})
|
||||
_kwargs.update(**kwargs)
|
||||
|
||||
# tokenization
|
||||
if isinstance(sequence, str):
|
||||
sequence = [sequence]
|
||||
if self.clean:
|
||||
sequence = [self._clean(u) for u in sequence]
|
||||
ids = self.tokenizer(sequence, **_kwargs)
|
||||
|
||||
# output
|
||||
if return_mask:
|
||||
return ids.input_ids, ids.attention_mask
|
||||
else:
|
||||
return ids.input_ids
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return state_dict
|
||||
def _clean(self, text):
|
||||
if self.clean == 'whitespace':
|
||||
text = whitespace_clean(basic_clean(text))
|
||||
elif self.clean == 'lower':
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
elif self.clean == 'canonicalize':
|
||||
text = canonicalize(basic_clean(text))
|
||||
return text
|
||||
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
from .wan_video_dit import DiTBlock
|
||||
from .utils import hash_state_dict_keys
|
||||
|
||||
|
||||
class VaceWanAttentionBlock(DiTBlock):
|
||||
def __init__(self, has_image_input, dim, num_heads, ffn_dim, eps=1e-6, block_id=0):
|
||||
@@ -85,29 +85,3 @@ class VaceWanModel(torch.nn.Module):
|
||||
c = block(c, x, context, t_mod, freqs)
|
||||
hints = torch.unbind(c)[:-1]
|
||||
return hints
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return VaceWanModelDictConverter()
|
||||
|
||||
|
||||
class VaceWanModelDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("vace")}
|
||||
if hash_state_dict_keys(state_dict_) == '3b2726384e4f64837bdf216eea3f310d': # vace 14B
|
||||
config = {
|
||||
"vace_layers": (0, 5, 10, 15, 20, 25, 30, 35),
|
||||
"vace_in_dim": 96,
|
||||
"patch_size": (1, 2, 2),
|
||||
"has_image_input": False,
|
||||
"dim": 5120,
|
||||
"num_heads": 40,
|
||||
"ffn_dim": 13824,
|
||||
"eps": 1e-06,
|
||||
}
|
||||
else:
|
||||
config = {}
|
||||
return state_dict_, config
|
||||
|
||||
@@ -189,16 +189,3 @@ class WanS2VAudioEncoder(torch.nn.Module):
|
||||
audio_embed_bucket = audio_embed_bucket.unsqueeze(0).permute(0, 2, 3, 1).to(device, dtype)
|
||||
audio_embeds = [audio_embed_bucket[..., i * batch_frames:(i + 1) * batch_frames] for i in range(min_batch_num)]
|
||||
return audio_embeds
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return WanS2VAudioEncoderStateDictConverter()
|
||||
|
||||
|
||||
class WanS2VAudioEncoderStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
state_dict = {'model.' + k: v for k, v in state_dict.items()}
|
||||
return state_dict
|
||||
|
||||
621
diffsynth/models/z_image_dit.py
Normal file
621
diffsynth/models/z_image_dit.py
Normal file
@@ -0,0 +1,621 @@
|
||||
import math
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from torch.nn import RMSNorm
|
||||
from ..core.attention import attention_forward
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
|
||||
ADALN_EMBED_DIM = 256
|
||||
SEQ_MULTI_OF = 32
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
if mid_size is None:
|
||||
mid_size = out_size
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(
|
||||
frequency_embedding_size,
|
||||
mid_size,
|
||||
bias=True,
|
||||
),
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
mid_size,
|
||||
out_size,
|
||||
bias=True,
|
||||
),
|
||||
)
|
||||
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
t_emb = self.mlp(t_freq.to(torch.bfloat16))
|
||||
return t_emb
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def _forward_silu_gating(self, x1, x3):
|
||||
return F.silu(x1) * x3
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False):
|
||||
super().__init__()
|
||||
dim_inner = head_dim * num_heads
|
||||
kv_dim = kv_dim if kv_dim is not None else q_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q)
|
||||
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv)
|
||||
self.to_out = torch.nn.ModuleList([torch.nn.Linear(dim_inner, q_dim, bias=bias_out)])
|
||||
|
||||
self.norm_q = RMSNorm(head_dim, eps=1e-5)
|
||||
self.norm_k = RMSNorm(head_dim, eps=1e-5)
|
||||
|
||||
def forward(self, hidden_states, freqs_cis):
|
||||
query = self.to_q(hidden_states)
|
||||
key = self.to_k(hidden_states)
|
||||
value = self.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(-1, (self.num_heads, -1))
|
||||
key = key.unflatten(-1, (self.num_heads, -1))
|
||||
value = value.unflatten(-1, (self.num_heads, -1))
|
||||
|
||||
# Apply Norms
|
||||
if self.norm_q is not None:
|
||||
query = self.norm_q(query)
|
||||
if self.norm_k is not None:
|
||||
key = self.norm_k(key)
|
||||
|
||||
# Apply RoPE
|
||||
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x_in) # todo
|
||||
|
||||
if freqs_cis is not None:
|
||||
query = apply_rotary_emb(query, freqs_cis)
|
||||
key = apply_rotary_emb(key, freqs_cis)
|
||||
|
||||
# Cast to correct dtype
|
||||
dtype = query.dtype
|
||||
query, key = query.to(dtype), key.to(dtype)
|
||||
|
||||
# Compute joint attention
|
||||
hidden_states = attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
q_pattern="b s n d", k_pattern="b s n d", v_pattern="b s n d", out_pattern="b s n d",
|
||||
)
|
||||
|
||||
# Reshape back
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
output = self.to_out[0](hidden_states)
|
||||
if len(self.to_out) > 1: # dropout
|
||||
output = self.to_out[1](output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ZImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
# Refactored to use diffusers Attention with custom processor
|
||||
# Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
|
||||
self.attention = Attention(
|
||||
q_dim=dim,
|
||||
num_heads=n_heads,
|
||||
head_dim=dim // n_heads,
|
||||
)
|
||||
|
||||
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.modulation = modulation
|
||||
if modulation:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.modulation:
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(x) * scale_msa,
|
||||
freqs_cis=freqs_cis,
|
||||
)
|
||||
x = x + gate_msa * self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
x = x + gate_mlp * self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
self.ffn_norm1(x) * scale_mlp,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(x),
|
||||
freqs_cis=freqs_cis,
|
||||
)
|
||||
x = x + self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
x = x + self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
self.ffn_norm1(x),
|
||||
)
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
scale = 1.0 + self.adaLN_modulation(c)
|
||||
x = self.norm_final(x) * scale.unsqueeze(1)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class RopeEmbedder:
|
||||
def __init__(
|
||||
self,
|
||||
theta: float = 256.0,
|
||||
axes_dims: List[int] = (16, 56, 56),
|
||||
axes_lens: List[int] = (64, 128, 128),
|
||||
):
|
||||
self.theta = theta
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
|
||||
self.freqs_cis = None
|
||||
|
||||
@staticmethod
|
||||
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
|
||||
with torch.device("cpu"):
|
||||
freqs_cis = []
|
||||
for i, (d, e) in enumerate(zip(dim, end)):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
|
||||
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
|
||||
freqs = torch.outer(timestep, freqs).float()
|
||||
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
|
||||
freqs_cis.append(freqs_cis_i)
|
||||
|
||||
return freqs_cis
|
||||
|
||||
def __call__(self, ids: torch.Tensor):
|
||||
assert ids.ndim == 2
|
||||
assert ids.shape[-1] == len(self.axes_dims)
|
||||
device = ids.device
|
||||
|
||||
if self.freqs_cis is None:
|
||||
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
||||
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
||||
|
||||
result = []
|
||||
for i in range(len(self.axes_dims)):
|
||||
index = ids[:, i]
|
||||
result.append(self.freqs_cis[i][index])
|
||||
return torch.cat(result, dim=-1)
|
||||
|
||||
|
||||
class ZImageDiT(nn.Module):
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["ZImageTransformerBlock"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
all_patch_size=(2,),
|
||||
all_f_patch_size=(1,),
|
||||
in_channels=16,
|
||||
dim=3840,
|
||||
n_layers=30,
|
||||
n_refiner_layers=2,
|
||||
n_heads=30,
|
||||
n_kv_heads=30,
|
||||
norm_eps=1e-5,
|
||||
qk_norm=True,
|
||||
cap_feat_dim=2560,
|
||||
rope_theta=256.0,
|
||||
t_scale=1000.0,
|
||||
axes_dims=[32, 48, 48],
|
||||
axes_lens=[1024, 512, 512],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.all_patch_size = all_patch_size
|
||||
self.all_f_patch_size = all_f_patch_size
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
|
||||
self.rope_theta = rope_theta
|
||||
self.t_scale = t_scale
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
assert len(all_patch_size) == len(all_f_patch_size)
|
||||
|
||||
all_x_embedder = {}
|
||||
all_final_layer = {}
|
||||
for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
|
||||
x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
|
||||
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||
|
||||
final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)
|
||||
all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer
|
||||
|
||||
self.all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||
self.all_final_layer = nn.ModuleDict(all_final_layer)
|
||||
self.noise_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageTransformerBlock(
|
||||
1000 + layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.context_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=False,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
|
||||
self.cap_embedder = nn.Sequential(
|
||||
RMSNorm(cap_feat_dim, eps=norm_eps),
|
||||
nn.Linear(cap_feat_dim, dim, bias=True),
|
||||
)
|
||||
|
||||
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm)
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
head_dim = dim // n_heads
|
||||
assert head_dim == sum(axes_dims)
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
|
||||
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
|
||||
|
||||
def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
bsz = len(x)
|
||||
assert len(size) == bsz
|
||||
for i in range(bsz):
|
||||
F, H, W = size[i]
|
||||
ori_len = (F // pF) * (H // pH) * (W // pW)
|
||||
# "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
|
||||
x[i] = (
|
||||
x[i][:ori_len]
|
||||
.view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
|
||||
.permute(6, 0, 3, 1, 4, 2, 5)
|
||||
.reshape(self.out_channels, F, H, W)
|
||||
)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def create_coordinate_grid(size, start=None, device=None):
|
||||
if start is None:
|
||||
start = (0 for _ in size)
|
||||
|
||||
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
|
||||
grids = torch.meshgrid(axes, indexing="ij")
|
||||
return torch.stack(grids, dim=-1)
|
||||
|
||||
def patchify_and_embed(
|
||||
self,
|
||||
all_image: List[torch.Tensor],
|
||||
all_cap_feats: List[torch.Tensor],
|
||||
patch_size: int,
|
||||
f_patch_size: int,
|
||||
):
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
device = all_image[0].device
|
||||
|
||||
all_image_out = []
|
||||
all_image_size = []
|
||||
all_image_pos_ids = []
|
||||
all_image_pad_mask = []
|
||||
all_cap_pos_ids = []
|
||||
all_cap_pad_mask = []
|
||||
all_cap_feats_out = []
|
||||
|
||||
for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
|
||||
### Process Caption
|
||||
cap_ori_len = len(cap_feat)
|
||||
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
|
||||
# padded position ids
|
||||
cap_padded_pos_ids = self.create_coordinate_grid(
|
||||
size=(cap_ori_len + cap_padding_len, 1, 1),
|
||||
start=(1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
all_cap_pos_ids.append(cap_padded_pos_ids)
|
||||
# pad mask
|
||||
all_cap_pad_mask.append(
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
# padded feature
|
||||
cap_padded_feat = torch.cat(
|
||||
[cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],
|
||||
dim=0,
|
||||
)
|
||||
all_cap_feats_out.append(cap_padded_feat)
|
||||
|
||||
### Process Image
|
||||
C, F, H, W = image.size()
|
||||
all_image_size.append((F, H, W))
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
|
||||
image_ori_len = len(image)
|
||||
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
||||
|
||||
image_ori_pos_ids = self.create_coordinate_grid(
|
||||
size=(F_tokens, H_tokens, W_tokens),
|
||||
start=(cap_ori_len + cap_padding_len + 1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
image_padding_pos_ids = (
|
||||
self.create_coordinate_grid(
|
||||
size=(1, 1, 1),
|
||||
start=(0, 0, 0),
|
||||
device=device,
|
||||
)
|
||||
.flatten(0, 2)
|
||||
.repeat(image_padding_len, 1)
|
||||
)
|
||||
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
|
||||
all_image_pos_ids.append(image_padded_pos_ids)
|
||||
# pad mask
|
||||
all_image_pad_mask.append(
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
# padded feature
|
||||
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
|
||||
all_image_out.append(image_padded_feat)
|
||||
|
||||
return (
|
||||
all_image_out,
|
||||
all_cap_feats_out,
|
||||
all_image_size,
|
||||
all_image_pos_ids,
|
||||
all_cap_pos_ids,
|
||||
all_image_pad_mask,
|
||||
all_cap_pad_mask,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: List[torch.Tensor],
|
||||
t,
|
||||
cap_feats: List[torch.Tensor],
|
||||
patch_size=2,
|
||||
f_patch_size=1,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
assert patch_size in self.all_patch_size
|
||||
assert f_patch_size in self.all_f_patch_size
|
||||
|
||||
bsz = len(x)
|
||||
device = x[0].device
|
||||
t = t * self.t_scale
|
||||
t = self.t_embedder(t)
|
||||
|
||||
adaln_input = t
|
||||
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
x_inner_pad_mask,
|
||||
cap_inner_pad_mask,
|
||||
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
||||
|
||||
# x embed & refine
|
||||
x_item_seqlens = [len(_) for _ in x]
|
||||
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
|
||||
x_max_item_seqlen = max(x_item_seqlens)
|
||||
|
||||
x = torch.cat(x, dim=0)
|
||||
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
|
||||
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token.to(dtype=x.dtype, device=x.device)
|
||||
x = list(x.split(x_item_seqlens, dim=0))
|
||||
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
|
||||
|
||||
x = pad_sequence(x, batch_first=True, padding_value=0.0)
|
||||
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(x_item_seqlens):
|
||||
x_attn_mask[i, :seq_len] = 1
|
||||
|
||||
for layer in self.noise_refiner:
|
||||
x = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=x,
|
||||
attn_mask=x_attn_mask,
|
||||
freqs_cis=x_freqs_cis,
|
||||
adaln_input=adaln_input,
|
||||
)
|
||||
|
||||
# cap embed & refine
|
||||
cap_item_seqlens = [len(_) for _ in cap_feats]
|
||||
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
|
||||
cap_max_item_seqlen = max(cap_item_seqlens)
|
||||
|
||||
cap_feats = torch.cat(cap_feats, dim=0)
|
||||
cap_feats = self.cap_embedder(cap_feats)
|
||||
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token.to(dtype=x.dtype, device=x.device)
|
||||
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
|
||||
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
|
||||
|
||||
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
|
||||
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(cap_item_seqlens):
|
||||
cap_attn_mask[i, :seq_len] = 1
|
||||
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=cap_feats,
|
||||
attn_mask=cap_attn_mask,
|
||||
freqs_cis=cap_freqs_cis,
|
||||
)
|
||||
|
||||
# unified
|
||||
unified = []
|
||||
unified_freqs_cis = []
|
||||
for i in range(bsz):
|
||||
x_len = x_item_seqlens[i]
|
||||
cap_len = cap_item_seqlens[i]
|
||||
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
|
||||
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
|
||||
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
|
||||
assert unified_item_seqlens == [len(_) for _ in unified]
|
||||
unified_max_item_seqlen = max(unified_item_seqlens)
|
||||
|
||||
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_item_seqlens):
|
||||
unified_attn_mask[i, :seq_len] = 1
|
||||
|
||||
for layer in self.layers:
|
||||
unified = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
x=unified,
|
||||
attn_mask=unified_attn_mask,
|
||||
freqs_cis=unified_freqs_cis,
|
||||
adaln_input=adaln_input,
|
||||
)
|
||||
|
||||
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
|
||||
unified = list(unified.unbind(dim=0))
|
||||
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
|
||||
|
||||
return x, {}
|
||||
41
diffsynth/models/z_image_text_encoder.py
Normal file
41
diffsynth/models/z_image_text_encoder.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from transformers import Qwen3Model, Qwen3Config
|
||||
import torch
|
||||
|
||||
|
||||
class ZImageTextEncoder(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
config = Qwen3Config(**{
|
||||
"architectures": [
|
||||
"Qwen3ForCausalLM"
|
||||
],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 2560,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 9728,
|
||||
"max_position_embeddings": 40960,
|
||||
"max_window_layers": 36,
|
||||
"model_type": "qwen3",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 36,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": None,
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": None,
|
||||
"tie_word_embeddings": True,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.51.0",
|
||||
"use_cache": True,
|
||||
"use_sliding_window": False,
|
||||
"vocab_size": 151936
|
||||
})
|
||||
self.model = Qwen3Model(config)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.model(*args, **kwargs)
|
||||
Reference in New Issue
Block a user