support HunyuanDiT

This commit is contained in:
Artiprocher
2024-06-05 13:00:39 +08:00
parent 83461d400c
commit 78f53a0754
23 changed files with 69988 additions and 185 deletions

View File

@@ -24,6 +24,9 @@ from .svd_vae_encoder import SVDVAEEncoder
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterCLIPImageEmbedder
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from .hunyuan_dit import HunyuanDiT
class ModelManager:
def __init__(self, torch_dtype=torch.float16, device="cuda"):
@@ -83,6 +86,22 @@ class ModelManager:
param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
return param_name in state_dict
def is_hunyuan_dit_clip_text_encoder(self, state_dict):
param_name = "bert.encoder.layer.23.attention.output.dense.weight"
return param_name in state_dict
def is_hunyuan_dit_t5_text_encoder(self, state_dict):
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
return param_name in state_dict
def is_hunyuan_dit(self, state_dict):
param_name = "final_layer.adaLN_modulation.1.weight"
return param_name in state_dict
def is_diffusers_vae(self, state_dict):
param_name = "quant_conv.weight"
return param_name in state_dict
def load_stable_video_diffusion(self, state_dict, components=None, file_path=""):
component_dict = {
"image_encoder": SVDImageEncoder,
@@ -223,6 +242,45 @@ class ModelManager:
self.model[component] = model
self.model_path[component] = file_path
def load_hunyuan_dit_clip_text_encoder(self, state_dict, file_path=""):
component = "hunyuan_dit_clip_text_encoder"
model = HunyuanDiTCLIPTextEncoder()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_hunyuan_dit_t5_text_encoder(self, state_dict, file_path=""):
component = "hunyuan_dit_t5_text_encoder"
model = HunyuanDiTT5TextEncoder()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_hunyuan_dit(self, state_dict, file_path=""):
component = "hunyuan_dit"
model = HunyuanDiT()
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_diffusers_vae(self, state_dict, file_path=""):
# TODO: detect SD and SDXL
component = "vae_encoder"
model = SDXLVAEEncoder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
component = "vae_decoder"
model = SDXLVAEDecoder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def search_for_embeddings(self, state_dict):
embeddings = []
for k in state_dict:
@@ -276,6 +334,14 @@ class ModelManager:
self.load_ipadapter_xl(state_dict, file_path=file_path)
elif self.is_ipadapter_xl_image_encoder(state_dict):
self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path)
elif self.is_hunyuan_dit_clip_text_encoder(state_dict):
self.load_hunyuan_dit_clip_text_encoder(state_dict, file_path=file_path)
elif self.is_hunyuan_dit_t5_text_encoder(state_dict):
self.load_hunyuan_dit_t5_text_encoder(state_dict, file_path=file_path)
elif self.is_hunyuan_dit(state_dict):
self.load_hunyuan_dit(state_dict, file_path=file_path)
elif self.is_diffusers_vae(state_dict):
self.load_diffusers_vae(state_dict, file_path=file_path)
def load_models(self, file_path_list, lora_alphas=[]):
for file_path in file_path_list:

View File

@@ -34,7 +34,7 @@ class Attention(torch.nn.Module):
hidden_states = hidden_states + scale * ip_hidden_states
return hidden_states
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None):
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
@@ -48,6 +48,9 @@ class Attention(torch.nn.Module):
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
if qkv_preprocessor is not None:
q, k, v = qkv_preprocessor(q, k, v)
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
if ipadapter_kwargs is not None:
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
@@ -82,5 +85,5 @@ class Attention(torch.nn.Module):
return hidden_states
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None):
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs)
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)

View File

@@ -0,0 +1,451 @@
from .attention import Attention
from .tiler import TileWorker
from einops import repeat, rearrange
import math
import torch
class HunyuanDiTRotaryEmbedding(torch.nn.Module):
def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=True):
super().__init__()
self.q_norm = torch.nn.LayerNorm((q_norm_shape,), elementwise_affine=True, eps=1e-06)
self.k_norm = torch.nn.LayerNorm((k_norm_shape,), elementwise_affine=True, eps=1e-06)
self.rotary_emb_on_k = rotary_emb_on_k
self.k_cache, self.v_cache = [], []
def reshape_for_broadcast(self, freqs_cis, x):
ndim = x.ndim
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
def rotate_half(self, x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(self, xq, xk, freqs_cis):
xk_out = None
cos, sin = self.reshape_for_broadcast(freqs_cis, xq)
cos, sin = cos.to(xq.device), sin.to(xq.device)
xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
if xk is not None:
xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
return xq_out, xk_out
def forward(self, q, k, v, freqs_cis_img, to_cache=False):
# norm
q = self.q_norm(q)
k = self.k_norm(k)
# RoPE
if self.rotary_emb_on_k:
q, k = self.apply_rotary_emb(q, k, freqs_cis_img)
else:
q, _ = self.apply_rotary_emb(q, None, freqs_cis_img)
if to_cache:
self.k_cache.append(k)
self.v_cache.append(v)
elif len(self.k_cache) > 0 and len(self.v_cache) > 0:
k = torch.concat([k] + self.k_cache, dim=2)
v = torch.concat([v] + self.v_cache, dim=2)
self.k_cache, self.v_cache = [], []
return q, k, v
class FP32_Layernorm(torch.nn.LayerNorm):
def forward(self, inputs):
origin_dtype = inputs.dtype
return torch.nn.functional.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).to(origin_dtype)
class FP32_SiLU(torch.nn.SiLU):
def forward(self, inputs):
origin_dtype = inputs.dtype
return torch.nn.functional.silu(inputs.float(), inplace=False).to(origin_dtype)
class HunyuanDiTFinalLayer(torch.nn.Module):
def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_size=2, out_channels=8):
super().__init__()
self.norm_final = torch.nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = torch.nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = torch.nn.Sequential(
FP32_SiLU(),
torch.nn.Linear(condition_dim, 2 * final_hidden_size, bias=True)
)
def modulate(self, x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def forward(self, hidden_states, condition_emb):
shift, scale = self.adaLN_modulation(condition_emb).chunk(2, dim=1)
hidden_states = self.modulate(self.norm_final(hidden_states), shift, scale)
hidden_states = self.linear(hidden_states)
return hidden_states
class HunyuanDiTBlock(torch.nn.Module):
def __init__(
self,
hidden_dim=1408,
condition_dim=1408,
num_heads=16,
mlp_ratio=4.3637,
text_dim=1024,
skip_connection=False
):
super().__init__()
self.norm1 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
self.rota1 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads)
self.attn1 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
self.norm2 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
self.rota2 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads, rotary_emb_on_k=False)
self.attn2 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, kv_dim=text_dim, bias_q=True, bias_kv=True, bias_out=True)
self.norm3 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
self.modulation = torch.nn.Sequential(FP32_SiLU(), torch.nn.Linear(condition_dim, hidden_dim, bias=True))
self.mlp = torch.nn.Sequential(
torch.nn.Linear(hidden_dim, int(hidden_dim*mlp_ratio), bias=True),
torch.nn.GELU(approximate="tanh"),
torch.nn.Linear(int(hidden_dim*mlp_ratio), hidden_dim, bias=True)
)
if skip_connection:
self.skip_norm = FP32_Layernorm((hidden_dim * 2,), eps=1e-6, elementwise_affine=True)
self.skip_linear = torch.nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
else:
self.skip_norm, self.skip_linear = None, None
def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img, residual=None, to_cache=False):
# Long Skip Connection
if self.skip_norm is not None and self.skip_linear is not None:
hidden_states = torch.cat([hidden_states, residual], dim=-1)
hidden_states = self.skip_norm(hidden_states)
hidden_states = self.skip_linear(hidden_states)
# Self-Attention
shift_msa = self.modulation(condition_emb).unsqueeze(dim=1)
attn_input = self.norm1(hidden_states) + shift_msa
hidden_states = hidden_states + self.attn1(attn_input, qkv_preprocessor=lambda q, k, v: self.rota1(q, k, v, freq_cis_img, to_cache=to_cache))
# Cross-Attention
attn_input = self.norm3(hidden_states)
hidden_states = hidden_states + self.attn2(attn_input, text_emb, qkv_preprocessor=lambda q, k, v: self.rota2(q, k, v, freq_cis_img))
# FFN Layer
mlp_input = self.norm2(hidden_states)
hidden_states = hidden_states + self.mlp(mlp_input)
return hidden_states
class AttentionPool(torch.nn.Module):
def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None):
super().__init__()
self.positional_embedding = torch.nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
self.c_proj = torch.nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.permute(1, 0, 2) # NLC -> LNC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
x, _ = torch.nn.functional.multi_head_attention_forward(
query=x[:1], key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x.squeeze(0)
class PatchEmbed(torch.nn.Module):
def __init__(
self,
patch_size=(2, 2),
in_chans=4,
embed_dim=1408,
bias=True,
):
super().__init__()
self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
def forward(self, x):
x = self.proj(x)
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
return x
def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
else:
embedding = repeat(t, "b -> b d", d=dim)
return embedding
class TimestepEmbedder(torch.nn.Module):
def __init__(self, hidden_size=1408, frequency_embedding_size=256):
super().__init__()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(frequency_embedding_size, hidden_size, bias=True),
torch.nn.SiLU(),
torch.nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
def forward(self, t):
t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
t_emb = self.mlp(t_freq)
return t_emb
class HunyuanDiT(torch.nn.Module):
def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4, out_channels=8, hidden_dim=1408, text_dim=1024, t5_dim=2048, text_length=77, t5_length=256):
super().__init__()
# Embedders
self.text_emb_padding = torch.nn.Parameter(torch.randn(text_length + t5_length, text_dim, dtype=torch.float32))
self.t5_embedder = torch.nn.Sequential(
torch.nn.Linear(t5_dim, t5_dim * 4, bias=True),
FP32_SiLU(),
torch.nn.Linear(t5_dim * 4, text_dim, bias=True),
)
self.t5_pooler = AttentionPool(t5_length, t5_dim, num_heads=8, output_dim=1024)
self.style_embedder = torch.nn.Parameter(torch.randn(hidden_dim))
self.patch_embedder = PatchEmbed(in_chans=in_channels)
self.timestep_embedder = TimestepEmbedder()
self.extra_embedder = torch.nn.Sequential(
torch.nn.Linear(256 * 6 + 1024 + hidden_dim, hidden_dim * 4),
FP32_SiLU(),
torch.nn.Linear(hidden_dim * 4, hidden_dim),
)
# Transformer blocks
self.num_layers_down = num_layers_down
self.num_layers_up = num_layers_up
self.blocks = torch.nn.ModuleList(
[HunyuanDiTBlock(skip_connection=False) for _ in range(num_layers_down)] + \
[HunyuanDiTBlock(skip_connection=True) for _ in range(num_layers_up)]
)
# Output layers
self.final_layer = HunyuanDiTFinalLayer()
self.out_channels = out_channels
def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5):
text_emb_mask = text_emb_mask.bool()
text_emb_mask_t5 = text_emb_mask_t5.bool()
text_emb_t5 = self.t5_embedder(text_emb_t5)
text_emb = torch.cat([text_emb, text_emb_t5], dim=1)
text_emb_mask = torch.cat([text_emb_mask, text_emb_mask_t5], dim=-1)
text_emb = torch.where(text_emb_mask.unsqueeze(2), text_emb, self.text_emb_padding.to(text_emb))
return text_emb
def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, batch_size):
# Text embedding
pooled_text_emb_t5 = self.t5_pooler(text_emb_t5)
# Timestep embedding
timestep_emb = self.timestep_embedder(timestep)
# Size embedding
size_emb = timestep_embedding(size_emb.view(-1), 256).to(dtype)
size_emb = size_emb.view(-1, 6 * 256)
# Style embedding
style_emb = repeat(self.style_embedder, "D -> B D", B=batch_size)
# Concatenate all extra vectors
extra_emb = torch.cat([pooled_text_emb_t5, size_emb, style_emb], dim=1)
condition_emb = timestep_emb + self.extra_embedder(extra_emb)
return condition_emb
def unpatchify(self, x, h, w):
return rearrange(x, "B (H W) (P Q C) -> B C (H P) (W Q)", H=h, W=w, P=2, Q=2)
def build_mask(self, data, is_bound):
_, _, H, W = data.shape
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
border_width = (H + W) // 4
pad = torch.ones_like(h) * border_width
mask = torch.stack([
pad if is_bound[0] else h + 1,
pad if is_bound[1] else H - h,
pad if is_bound[2] else w + 1,
pad if is_bound[3] else W - w
]).min(dim=0).values
mask = mask.clip(1, border_width)
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
mask = rearrange(mask, "H W -> 1 H W")
return mask
def tiled_block_forward(self, block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, torch_dtype, data_device, computation_device, tile_size, tile_stride):
B, C, H, W = hidden_states.shape
weight = torch.zeros((1, 1, H, W), dtype=torch_dtype, device=data_device)
values = torch.zeros((B, C, H, W), dtype=torch_dtype, device=data_device)
# Split tasks
tasks = []
for h in range(0, H, tile_stride):
for w in range(0, W, tile_stride):
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
continue
h_, w_ = h + tile_size, w + tile_size
if h_ > H: h, h_ = H - tile_size, H
if w_ > W: w, w_ = W - tile_size, W
tasks.append((h, h_, w, w_))
# Run
for hl, hr, wl, wr in tasks:
hidden_states_batch = hidden_states[:, :, hl:hr, wl:wr].to(computation_device)
hidden_states_batch = rearrange(hidden_states_batch, "B C H W -> B (H W) C")
if residual is not None:
residual_batch = residual[:, :, hl:hr, wl:wr].to(computation_device)
residual_batch = rearrange(residual_batch, "B C H W -> B (H W) C")
else:
residual_batch = None
# Forward
hidden_states_batch = block(hidden_states_batch, condition_emb, text_emb, freq_cis_img, residual_batch).to(data_device)
hidden_states_batch = rearrange(hidden_states_batch, "B (H W) C -> B C H W", H=hr-hl)
mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
weight[:, :, hl:hr, wl:wr] += mask
values /= weight
return values
def forward(
self, hidden_states, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5, timestep, size_emb, freq_cis_img,
tiled=False, tile_size=64, tile_stride=32,
to_cache=False,
use_gradient_checkpointing=False,
):
# Embeddings
text_emb = self.prepare_text_emb(text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5)
condition_emb = self.prepare_extra_emb(text_emb_t5, timestep, size_emb, hidden_states.dtype, hidden_states.shape[0])
# Input
height, width = hidden_states.shape[-2], hidden_states.shape[-1]
hidden_states = self.patch_embedder(hidden_states)
# Blocks
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if tiled:
hidden_states = rearrange(hidden_states, "B (H W) C -> B C H W", H=height//2)
residuals = []
for block_id, block in enumerate(self.blocks):
residual = residuals.pop() if block_id >= self.num_layers_down else None
hidden_states = self.tiled_block_forward(
block, hidden_states, condition_emb, text_emb, freq_cis_img, residual,
torch_dtype=hidden_states.dtype, data_device=hidden_states.device, computation_device=hidden_states.device,
tile_size=tile_size, tile_stride=tile_stride
)
if block_id < self.num_layers_down - 2:
residuals.append(hidden_states)
hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C")
else:
residuals = []
for block_id, block in enumerate(self.blocks):
residual = residuals.pop() if block_id >= self.num_layers_down else None
if self.training and use_gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, condition_emb, text_emb, freq_cis_img, residual,
use_reentrant=False,
)
else:
hidden_states = block(hidden_states, condition_emb, text_emb, freq_cis_img, residual, to_cache=to_cache)
if block_id < self.num_layers_down - 2:
residuals.append(hidden_states)
# Output
hidden_states = self.final_layer(hidden_states, condition_emb)
hidden_states = self.unpatchify(hidden_states, height//2, width//2)
hidden_states, _ = hidden_states.chunk(2, dim=1)
return hidden_states
def state_dict_converter(self):
return HunyuanDiTStateDictConverter()
class HunyuanDiTStateDictConverter():
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {}
for name, param in state_dict.items():
name_ = name
name_ = name_.replace(".default_modulation.", ".modulation.")
name_ = name_.replace(".mlp.fc1.", ".mlp.0.")
name_ = name_.replace(".mlp.fc2.", ".mlp.2.")
name_ = name_.replace(".attn1.q_norm.", ".rota1.q_norm.")
name_ = name_.replace(".attn2.q_norm.", ".rota2.q_norm.")
name_ = name_.replace(".attn1.k_norm.", ".rota1.k_norm.")
name_ = name_.replace(".attn2.k_norm.", ".rota2.k_norm.")
name_ = name_.replace(".q_proj.", ".to_q.")
name_ = name_.replace(".out_proj.", ".to_out.")
name_ = name_.replace("text_embedding_padding", "text_emb_padding")
name_ = name_.replace("mlp_t5.0.", "t5_embedder.0.")
name_ = name_.replace("mlp_t5.2.", "t5_embedder.2.")
name_ = name_.replace("pooler.", "t5_pooler.")
name_ = name_.replace("x_embedder.", "patch_embedder.")
name_ = name_.replace("t_embedder.", "timestep_embedder.")
name_ = name_.replace("t5_pooler.to_q.", "t5_pooler.q_proj.")
name_ = name_.replace("style_embedder.weight", "style_embedder")
if ".kv_proj." in name_:
param_k = param[:param.shape[0]//2]
param_v = param[param.shape[0]//2:]
state_dict_[name_.replace(".kv_proj.", ".to_k.")] = param_k
state_dict_[name_.replace(".kv_proj.", ".to_v.")] = param_v
elif ".Wqkv." in name_:
param_q = param[:param.shape[0]//3]
param_k = param[param.shape[0]//3:param.shape[0]//3*2]
param_v = param[param.shape[0]//3*2:]
state_dict_[name_.replace(".Wqkv.", ".to_q.")] = param_q
state_dict_[name_.replace(".Wqkv.", ".to_k.")] = param_k
state_dict_[name_.replace(".Wqkv.", ".to_v.")] = param_v
elif "style_embedder" in name_:
state_dict_[name_] = param.squeeze()
else:
state_dict_[name_] = param
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -0,0 +1,161 @@
from transformers import BertModel, BertConfig, T5EncoderModel, T5Config
import torch
class HunyuanDiTCLIPTextEncoder(BertModel):
def __init__(self):
config = BertConfig(
_name_or_path = "",
architectures = ["BertModel"],
attention_probs_dropout_prob = 0.1,
bos_token_id = 0,
classifier_dropout = None,
directionality = "bidi",
eos_token_id = 2,
hidden_act = "gelu",
hidden_dropout_prob = 0.1,
hidden_size = 1024,
initializer_range = 0.02,
intermediate_size = 4096,
layer_norm_eps = 1e-12,
max_position_embeddings = 512,
model_type = "bert",
num_attention_heads = 16,
num_hidden_layers = 24,
output_past = True,
pad_token_id = 0,
pooler_fc_size = 768,
pooler_num_attention_heads = 12,
pooler_num_fc_layers = 3,
pooler_size_per_head = 128,
pooler_type = "first_token_transform",
position_embedding_type = "absolute",
torch_dtype = "float32",
transformers_version = "4.37.2",
type_vocab_size = 2,
use_cache = True,
vocab_size = 47020
)
super().__init__(config, add_pooling_layer=False)
self.eval()
def forward(self, input_ids, attention_mask, clip_skip=1):
input_shape = input_ids.size()
batch_size, seq_length = input_shape
device = input_ids.device
past_key_values_length = 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=None,
token_type_ids=None,
inputs_embeds=None,
past_key_values_length=0,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=False,
output_attentions=False,
output_hidden_states=True,
return_dict=True,
)
all_hidden_states = encoder_outputs.hidden_states
prompt_emb = all_hidden_states[-clip_skip]
if clip_skip > 1:
mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std()
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
return prompt_emb
def state_dict_converter(self):
return HunyuanDiTCLIPTextEncoderStateDictConverter()
class HunyuanDiTT5TextEncoder(T5EncoderModel):
def __init__(self):
config = T5Config(
_name_or_path = "../HunyuanDiT/t2i/mt5",
architectures = ["MT5ForConditionalGeneration"],
classifier_dropout = 0.0,
d_ff = 5120,
d_kv = 64,
d_model = 2048,
decoder_start_token_id = 0,
dense_act_fn = "gelu_new",
dropout_rate = 0.1,
eos_token_id = 1,
feed_forward_proj = "gated-gelu",
initializer_factor = 1.0,
is_encoder_decoder = True,
is_gated_act = True,
layer_norm_epsilon = 1e-06,
model_type = "t5",
num_decoder_layers = 24,
num_heads = 32,
num_layers = 24,
output_past = True,
pad_token_id = 0,
relative_attention_max_distance = 128,
relative_attention_num_buckets = 32,
tie_word_embeddings = False,
tokenizer_class = "T5Tokenizer",
transformers_version = "4.37.2",
use_cache = True,
vocab_size = 250112
)
super().__init__(config)
self.eval()
def forward(self, input_ids, attention_mask, clip_skip=1):
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)
prompt_emb = outputs.hidden_states[-clip_skip]
if clip_skip > 1:
mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std()
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
return prompt_emb
def state_dict_converter(self):
return HunyuanDiTT5TextEncoderStateDictConverter()
class HunyuanDiTCLIPTextEncoderStateDictConverter():
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")}
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)
class HunyuanDiTT5TextEncoderStateDictConverter():
def __init__(self):
pass
def from_diffusers(self, state_dict):
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")}
state_dict_["shared.weight"] = state_dict["shared.weight"]
return state_dict_
def from_civitai(self, state_dict):
return self.from_diffusers(state_dict)

View File

@@ -2,4 +2,5 @@ from .stable_diffusion import SDImagePipeline
from .stable_diffusion_xl import SDXLImagePipeline
from .stable_diffusion_video import SDVideoPipeline, SDVideoPipelineRunner
from .stable_diffusion_xl_video import SDXLVideoPipeline
from .stable_video_diffusion import SVDVideoPipeline
from .stable_video_diffusion import SVDVideoPipeline
from .hunyuan_dit import HunyuanDiTImagePipeline

View File

@@ -0,0 +1,298 @@
from ..models.hunyuan_dit import HunyuanDiT
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
from ..models import ModelManager
from ..prompts import HunyuanDiTPrompter
from ..schedulers import EnhancedDDIMScheduler
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
class ImageSizeManager:
def __init__(self):
pass
def _to_tuple(self, x):
if isinstance(x, int):
return x, x
else:
return x
def get_fill_resize_and_crop(self, src, tgt):
th, tw = self._to_tuple(tgt)
h, w = self._to_tuple(src)
tr = th / tw # base 分辨率
r = h / w # 目标分辨率
# resize
if r > tr:
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h)) # 根据base分辨率将目标分辨率resize下来
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
def get_meshgrid(self, start, *args):
if len(args) == 0:
# start is grid_size
num = self._to_tuple(start)
start = (0, 0)
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = self._to_tuple(start)
stop = self._to_tuple(args[0])
num = (stop[0] - start[0], stop[1] - start[1])
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = self._to_tuple(start) # 左上角 eg: 12,0
stop = self._to_tuple(args[0]) # 右下角 eg: 20,32
num = self._to_tuple(args[1]) # 目标大小 eg: 32,124
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份
grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0) # [2, W, H]
return grid
def get_2d_rotary_pos_embed(self, embed_dim, start, *args, use_real=True):
grid = self.get_meshgrid(start, *args) # [2, H, w]
grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致
pos_embed = self.get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
return pos_embed
def get_2d_rotary_pos_embed_from_grid(self, embed_dim, grid, use_real=False):
assert embed_dim % 4 == 0
# use half of dimensions to encode grid_h
emb_h = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
emb_w = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
if use_real:
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
return cos, sin
else:
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
return emb
def get_1d_rotary_pos_embed(self, dim: int, pos, theta: float = 10000.0, use_real=False):
if isinstance(pos, int):
pos = np.arange(pos)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
def calc_rope(self, height, width):
patch_size = 2
head_size = 88
th = height // 8 // patch_size
tw = width // 8 // patch_size
base_size = 512 // 8 // patch_size
start, stop = self.get_fill_resize_and_crop((th, tw), base_size)
sub_args = [start, stop, (th, tw)]
rope = self.get_2d_rotary_pos_embed(head_size, *sub_args)
return rope
class HunyuanDiTImagePipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__()
self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
self.prompter = HunyuanDiTPrompter()
self.device = device
self.torch_dtype = torch_dtype
self.image_size_manager = ImageSizeManager()
# models
self.text_encoder: HunyuanDiTCLIPTextEncoder = None
self.text_encoder_t5: HunyuanDiTT5TextEncoder = None
self.dit: HunyuanDiT = None
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
def fetch_main_models(self, model_manager: ModelManager):
self.text_encoder = model_manager.hunyuan_dit_clip_text_encoder
self.text_encoder_t5 = model_manager.hunyuan_dit_t5_text_encoder
self.dit = model_manager.hunyuan_dit
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
@staticmethod
def from_model_manager(model_manager: ModelManager):
pipe = HunyuanDiTImagePipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_main_models(model_manager)
pipe.fetch_prompter(model_manager)
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
def prepare_extra_input(self, height=1024, width=1024, tiled=False, tile_size=64, tile_stride=32, batch_size=1):
if tiled:
height, width = tile_size * 16, tile_size * 16
image_meta_size = torch.as_tensor([width, height, width, height, 0, 0]).to(device=self.device)
freqs_cis_img = self.image_size_manager.calc_rope(height, width)
image_meta_size = torch.stack([image_meta_size] * batch_size)
return {
"size_emb": image_meta_size,
"freq_cis_img": (freqs_cis_img[0].to(dtype=self.torch_dtype, device=self.device), freqs_cis_img[1].to(dtype=self.torch_dtype, device=self.device)),
"tiled": tiled,
"tile_size": tile_size,
"tile_stride": tile_stride
}
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=1,
clip_skip_2=1,
input_image=None,
reference_images=[],
reference_strengths=[0.4],
denoising_strength=1.0,
height=1024,
width=1024,
num_inference_steps=20,
tiled=False,
tile_size=64,
tile_stride=32,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
if input_image is not None:
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = noise.clone()
# Prepare reference latents
reference_latents = []
for reference_image in reference_images:
reference_image = self.preprocess_image(reference_image).to(device=self.device, dtype=self.torch_dtype)
reference_latents.append(self.vae_encoder(reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype))
# Encode prompts
prompt_emb_posi, attention_mask_posi, prompt_emb_t5_posi, attention_mask_t5_posi = self.prompter.encode_prompt(
self.text_encoder,
self.text_encoder_t5,
prompt,
clip_skip=clip_skip,
clip_skip_2=clip_skip_2,
positive=True,
device=self.device
)
if cfg_scale != 1.0:
prompt_emb_nega, attention_mask_nega, prompt_emb_t5_nega, attention_mask_t5_nega = self.prompter.encode_prompt(
self.text_encoder,
self.text_encoder_t5,
negative_prompt,
clip_skip=clip_skip,
clip_skip_2=clip_skip_2,
positive=False,
device=self.device
)
# Prepare positional id
extra_input = self.prepare_extra_input(height, width, tiled, tile_size)
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
# In-context reference
for reference_latents_, reference_strength in zip(reference_latents, reference_strengths):
if progress_id < num_inference_steps * reference_strength:
noisy_reference_latents = self.scheduler.add_noise(reference_latents_, noise, self.scheduler.timesteps[progress_id])
self.dit(
noisy_reference_latents,
prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi,
timestep,
**extra_input,
to_cache=True
)
# Positive side
noise_pred_posi = self.dit(
latents,
prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi,
timestep,
**extra_input,
)
if cfg_scale != 1.0:
# Negative side
noise_pred_nega = self.dit(
latents,
prompt_emb_nega, prompt_emb_t5_nega, attention_mask_nega, attention_mask_t5_nega,
timestep,
**extra_input
)
# Classifier-free guidance
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return image

View File

@@ -1,176 +1,3 @@
from transformers import CLIPTokenizer, AutoTokenizer
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2, ModelManager
import torch, os
def tokenize_long_prompt(tokenizer, prompt):
# Get model_max_length from self.tokenizer
length = tokenizer.model_max_length
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
tokenizer.model_max_length = 99999999
# Tokenize it!
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
# Determine the real length.
max_length = (input_ids.shape[1] + length - 1) // length * length
# Restore tokenizer.model_max_length
tokenizer.model_max_length = length
# Tokenize it again with fixed length.
input_ids = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True
).input_ids
# Reshape input_ids to fit the text encoder.
num_sentence = input_ids.shape[1] // length
input_ids = input_ids.reshape((num_sentence, length))
return input_ids
class BeautifulPrompt:
def __init__(self, tokenizer_path="configs/beautiful_prompt/tokenizer", model=None):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = model
self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
def __call__(self, raw_prompt):
model_input = self.template.format(raw_prompt=raw_prompt)
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
outputs = self.model.generate(
input_ids,
max_new_tokens=384,
do_sample=True,
temperature=0.9,
top_k=50,
top_p=0.95,
repetition_penalty=1.1,
num_return_sequences=1
)
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
outputs[:, input_ids.size(1):],
skip_special_tokens=True
)[0].strip()
return prompt
class Translator:
def __init__(self, tokenizer_path="configs/translator/tokenizer", model=None):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = model
def __call__(self, prompt):
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
output_ids = self.model.generate(input_ids)
prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
return prompt
class Prompter:
def __init__(self):
self.tokenizer: CLIPTokenizer = None
self.keyword_dict = {}
self.translator: Translator = None
self.beautiful_prompt: BeautifulPrompt = None
def load_textual_inversion(self, textual_inversion_dict):
self.keyword_dict = {}
additional_tokens = []
for keyword in textual_inversion_dict:
tokens, _ = textual_inversion_dict[keyword]
additional_tokens += tokens
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
self.tokenizer.add_tokens(additional_tokens)
def load_beautiful_prompt(self, model, model_path):
model_folder = os.path.dirname(model_path)
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
if model_folder.endswith("v2"):
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
but make sure there is a correlation between the input and output.\n\
### Input: {raw_prompt}\n### Output:"""
def load_translator(self, model, model_path):
model_folder = os.path.dirname(model_path)
self.translator = Translator(tokenizer_path=model_folder, model=model)
def load_from_model_manager(self, model_manager: ModelManager):
self.load_textual_inversion(model_manager.textual_inversion_dict)
if "translator" in model_manager.model:
self.load_translator(model_manager.model["translator"], model_manager.model_path["translator"])
if "beautiful_prompt" in model_manager.model:
self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
def process_prompt(self, prompt, positive=True):
for keyword in self.keyword_dict:
if keyword in prompt:
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
if positive and self.translator is not None:
prompt = self.translator(prompt)
print(f"Your prompt is translated: \"{prompt}\"")
if positive and self.beautiful_prompt is not None:
prompt = self.beautiful_prompt(prompt)
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
return prompt
class SDPrompter(Prompter):
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
prompt = self.process_prompt(prompt, positive=positive)
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return prompt_emb
class SDXLPrompter(Prompter):
def __init__(
self,
tokenizer_path="configs/stable_diffusion/tokenizer",
tokenizer_2_path="configs/stable_diffusion_xl/tokenizer_2"
):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
def encode_prompt(
self,
text_encoder: SDXLTextEncoder,
text_encoder_2: SDXLTextEncoder2,
prompt,
clip_skip=1,
clip_skip_2=2,
positive=True,
device="cuda"
):
prompt = self.process_prompt(prompt, positive=positive)
# 1
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip)
# 2
input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
add_text_embeds, prompt_emb_2 = text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
# Merge
prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
# For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
add_text_embeds = add_text_embeds[0:1]
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return add_text_embeds, prompt_emb
from .sd_prompter import SDPrompter
from .sdxl_prompter import SDXLPrompter
from .hunyuan_dit_prompter import HunyuanDiTPrompter

View File

@@ -0,0 +1,56 @@
from .utils import Prompter
from transformers import BertModel, T5EncoderModel, BertTokenizer, AutoTokenizer
import warnings
class HunyuanDiTPrompter(Prompter):
def __init__(
self,
tokenizer_path="configs/hunyuan_dit/tokenizer",
tokenizer_t5_path="configs/hunyuan_dit/tokenizer_t5"
):
super().__init__()
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.tokenizer_t5 = AutoTokenizer.from_pretrained(tokenizer_t5_path)
def encode_prompt_using_signle_model(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = text_encoder(
text_input_ids.to(device),
attention_mask=attention_mask,
clip_skip=clip_skip
)
return prompt_embeds, attention_mask
def encode_prompt(
self,
text_encoder: BertModel,
text_encoder_t5: T5EncoderModel,
prompt,
clip_skip=1,
clip_skip_2=1,
positive=True,
device="cuda"
):
prompt = self.process_prompt(prompt, positive=positive)
# CLIP
prompt_emb, attention_mask = self.encode_prompt_using_signle_model(prompt, text_encoder, self.tokenizer, self.tokenizer.model_max_length, clip_skip, device)
# T5
prompt_emb_t5, attention_mask_t5 = self.encode_prompt_using_signle_model(prompt, text_encoder_t5, self.tokenizer_t5, self.tokenizer_t5.model_max_length, clip_skip_2, device)
return prompt_emb, attention_mask, prompt_emb_t5, attention_mask_t5

View File

@@ -0,0 +1,17 @@
from .utils import Prompter, tokenize_long_prompt
from transformers import CLIPTokenizer
from ..models import SDTextEncoder
class SDPrompter(Prompter):
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
prompt = self.process_prompt(prompt, positive=positive)
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return prompt_emb

View File

@@ -0,0 +1,43 @@
from .utils import Prompter, tokenize_long_prompt
from transformers import CLIPTokenizer
from ..models import SDXLTextEncoder, SDXLTextEncoder2
import torch
class SDXLPrompter(Prompter):
def __init__(
self,
tokenizer_path="configs/stable_diffusion/tokenizer",
tokenizer_2_path="configs/stable_diffusion_xl/tokenizer_2"
):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
def encode_prompt(
self,
text_encoder: SDXLTextEncoder,
text_encoder_2: SDXLTextEncoder2,
prompt,
clip_skip=1,
clip_skip_2=2,
positive=True,
device="cuda"
):
prompt = self.process_prompt(prompt, positive=positive)
# 1
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip)
# 2
input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
add_text_embeds, prompt_emb_2 = text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
# Merge
prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
# For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
add_text_embeds = add_text_embeds[0:1]
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
return add_text_embeds, prompt_emb

123
diffsynth/prompts/utils.py Normal file
View File

@@ -0,0 +1,123 @@
from transformers import CLIPTokenizer, AutoTokenizer
from ..models import ModelManager
import os
def tokenize_long_prompt(tokenizer, prompt):
# Get model_max_length from self.tokenizer
length = tokenizer.model_max_length
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
tokenizer.model_max_length = 99999999
# Tokenize it!
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
# Determine the real length.
max_length = (input_ids.shape[1] + length - 1) // length * length
# Restore tokenizer.model_max_length
tokenizer.model_max_length = length
# Tokenize it again with fixed length.
input_ids = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True
).input_ids
# Reshape input_ids to fit the text encoder.
num_sentence = input_ids.shape[1] // length
input_ids = input_ids.reshape((num_sentence, length))
return input_ids
class BeautifulPrompt:
def __init__(self, tokenizer_path="configs/beautiful_prompt/tokenizer", model=None):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = model
self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
def __call__(self, raw_prompt):
model_input = self.template.format(raw_prompt=raw_prompt)
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
outputs = self.model.generate(
input_ids,
max_new_tokens=384,
do_sample=True,
temperature=0.9,
top_k=50,
top_p=0.95,
repetition_penalty=1.1,
num_return_sequences=1
)
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
outputs[:, input_ids.size(1):],
skip_special_tokens=True
)[0].strip()
return prompt
class Translator:
def __init__(self, tokenizer_path="configs/translator/tokenizer", model=None):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
self.model = model
def __call__(self, prompt):
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
output_ids = self.model.generate(input_ids)
prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
return prompt
class Prompter:
def __init__(self):
self.tokenizer: CLIPTokenizer = None
self.keyword_dict = {}
self.translator: Translator = None
self.beautiful_prompt: BeautifulPrompt = None
def load_textual_inversion(self, textual_inversion_dict):
self.keyword_dict = {}
additional_tokens = []
for keyword in textual_inversion_dict:
tokens, _ = textual_inversion_dict[keyword]
additional_tokens += tokens
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
self.tokenizer.add_tokens(additional_tokens)
def load_beautiful_prompt(self, model, model_path):
model_folder = os.path.dirname(model_path)
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
if model_folder.endswith("v2"):
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
but make sure there is a correlation between the input and output.\n\
### Input: {raw_prompt}\n### Output:"""
def load_translator(self, model, model_path):
model_folder = os.path.dirname(model_path)
self.translator = Translator(tokenizer_path=model_folder, model=model)
def load_from_model_manager(self, model_manager: ModelManager):
self.load_textual_inversion(model_manager.textual_inversion_dict)
if "translator" in model_manager.model:
self.load_translator(model_manager.model["translator"], model_manager.model_path["translator"])
if "beautiful_prompt" in model_manager.model:
self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
def process_prompt(self, prompt, positive=True):
for keyword in self.keyword_dict:
if keyword in prompt:
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
if positive and self.translator is not None:
prompt = self.translator(prompt)
print(f"Your prompt is translated: \"{prompt}\"")
if positive and self.beautiful_prompt is not None:
prompt = self.beautiful_prompt(prompt)
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
return prompt

View File

@@ -3,7 +3,7 @@ import torch, math
class EnhancedDDIMScheduler():
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"):
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon"):
self.num_train_timesteps = num_train_timesteps
if beta_schedule == "scaled_linear":
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
@@ -13,6 +13,7 @@ class EnhancedDDIMScheduler():
raise NotImplementedError(f"{beta_schedule} is not implemented")
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0).tolist()
self.set_timesteps(10)
self.prediction_type = prediction_type
def set_timesteps(self, num_inference_steps, denoising_strength=1.0):
@@ -28,9 +29,16 @@ class EnhancedDDIMScheduler():
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
prev_sample = sample * weight_x + model_output * weight_e
if self.prediction_type == "epsilon":
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
prev_sample = sample * weight_x + model_output * weight_e
elif self.prediction_type == "v_prediction":
weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev))
weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev))
prev_sample = sample * weight_x + model_output * weight_e
else:
raise NotImplementedError(f"{self.prediction_type} is not implemented")
return prev_sample
@@ -57,4 +65,9 @@ class EnhancedDDIMScheduler():
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep])
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def training_target(self, sample, noise, timestep):
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep])
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep])
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return target