mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 23:58:12 +00:00
support HunyuanDiT
This commit is contained in:
7
configs/hunyuan_dit/tokenizer/special_tokens_map.json
Normal file
7
configs/hunyuan_dit/tokenizer/special_tokens_map.json
Normal file
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"cls_token": "[CLS]",
|
||||
"mask_token": "[MASK]",
|
||||
"pad_token": "[PAD]",
|
||||
"sep_token": "[SEP]",
|
||||
"unk_token": "[UNK]"
|
||||
}
|
||||
16
configs/hunyuan_dit/tokenizer/tokenizer_config.json
Normal file
16
configs/hunyuan_dit/tokenizer/tokenizer_config.json
Normal file
@@ -0,0 +1,16 @@
|
||||
{
|
||||
"cls_token": "[CLS]",
|
||||
"do_basic_tokenize": true,
|
||||
"do_lower_case": true,
|
||||
"mask_token": "[MASK]",
|
||||
"name_or_path": "hfl/chinese-roberta-wwm-ext",
|
||||
"never_split": null,
|
||||
"pad_token": "[PAD]",
|
||||
"sep_token": "[SEP]",
|
||||
"special_tokens_map_file": "/home/chenweifeng/.cache/huggingface/hub/models--hfl--chinese-roberta-wwm-ext/snapshots/5c58d0b8ec1d9014354d691c538661bf00bfdb44/special_tokens_map.json",
|
||||
"strip_accents": null,
|
||||
"tokenize_chinese_chars": true,
|
||||
"tokenizer_class": "BertTokenizer",
|
||||
"unk_token": "[UNK]",
|
||||
"model_max_length": 77
|
||||
}
|
||||
47020
configs/hunyuan_dit/tokenizer/vocab.txt
Normal file
47020
configs/hunyuan_dit/tokenizer/vocab.txt
Normal file
File diff suppressed because it is too large
Load Diff
21128
configs/hunyuan_dit/tokenizer/vocab_org.txt
Normal file
21128
configs/hunyuan_dit/tokenizer/vocab_org.txt
Normal file
File diff suppressed because it is too large
Load Diff
28
configs/hunyuan_dit/tokenizer_t5/config.json
Normal file
28
configs/hunyuan_dit/tokenizer_t5/config.json
Normal file
@@ -0,0 +1,28 @@
|
||||
{
|
||||
"_name_or_path": "/home/patrick/t5/mt5-xl",
|
||||
"architectures": [
|
||||
"MT5ForConditionalGeneration"
|
||||
],
|
||||
"d_ff": 5120,
|
||||
"d_kv": 64,
|
||||
"d_model": 2048,
|
||||
"decoder_start_token_id": 0,
|
||||
"dropout_rate": 0.1,
|
||||
"eos_token_id": 1,
|
||||
"feed_forward_proj": "gated-gelu",
|
||||
"initializer_factor": 1.0,
|
||||
"is_encoder_decoder": true,
|
||||
"layer_norm_epsilon": 1e-06,
|
||||
"model_type": "mt5",
|
||||
"num_decoder_layers": 24,
|
||||
"num_heads": 32,
|
||||
"num_layers": 24,
|
||||
"output_past": true,
|
||||
"pad_token_id": 0,
|
||||
"relative_attention_num_buckets": 32,
|
||||
"tie_word_embeddings": false,
|
||||
"tokenizer_class": "T5Tokenizer",
|
||||
"transformers_version": "4.10.0.dev0",
|
||||
"use_cache": true,
|
||||
"vocab_size": 250112
|
||||
}
|
||||
1
configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json
Normal file
1
configs/hunyuan_dit/tokenizer_t5/special_tokens_map.json
Normal file
@@ -0,0 +1 @@
|
||||
{"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>"}
|
||||
BIN
configs/hunyuan_dit/tokenizer_t5/spiece.model
Normal file
BIN
configs/hunyuan_dit/tokenizer_t5/spiece.model
Normal file
Binary file not shown.
1
configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json
Normal file
1
configs/hunyuan_dit/tokenizer_t5/tokenizer_config.json
Normal file
@@ -0,0 +1 @@
|
||||
{"eos_token": "</s>", "unk_token": "<unk>", "pad_token": "<pad>", "extra_ids": 0, "additional_special_tokens": null, "special_tokens_map_file": "", "tokenizer_file": null, "name_or_path": "google/mt5-small", "model_max_length": 256, "legacy": true}
|
||||
@@ -24,6 +24,9 @@ from .svd_vae_encoder import SVDVAEEncoder
|
||||
|
||||
from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterCLIPImageEmbedder
|
||||
|
||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
|
||||
|
||||
class ModelManager:
|
||||
def __init__(self, torch_dtype=torch.float16, device="cuda"):
|
||||
@@ -83,6 +86,22 @@ class ModelManager:
|
||||
param_name = "vision_model.encoder.layers.47.self_attn.v_proj.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_hunyuan_dit_clip_text_encoder(self, state_dict):
|
||||
param_name = "bert.encoder.layer.23.attention.output.dense.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_hunyuan_dit_t5_text_encoder(self, state_dict):
|
||||
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_hunyuan_dit(self, state_dict):
|
||||
param_name = "final_layer.adaLN_modulation.1.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_diffusers_vae(self, state_dict):
|
||||
param_name = "quant_conv.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def load_stable_video_diffusion(self, state_dict, components=None, file_path=""):
|
||||
component_dict = {
|
||||
"image_encoder": SVDImageEncoder,
|
||||
@@ -223,6 +242,45 @@ class ModelManager:
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_hunyuan_dit_clip_text_encoder(self, state_dict, file_path=""):
|
||||
component = "hunyuan_dit_clip_text_encoder"
|
||||
model = HunyuanDiTCLIPTextEncoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_hunyuan_dit_t5_text_encoder(self, state_dict, file_path=""):
|
||||
component = "hunyuan_dit_t5_text_encoder"
|
||||
model = HunyuanDiTT5TextEncoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_hunyuan_dit(self, state_dict, file_path=""):
|
||||
component = "hunyuan_dit"
|
||||
model = HunyuanDiT()
|
||||
model.load_state_dict(model.state_dict_converter().from_civitai(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_diffusers_vae(self, state_dict, file_path=""):
|
||||
# TODO: detect SD and SDXL
|
||||
component = "vae_encoder"
|
||||
model = SDXLVAEEncoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
component = "vae_decoder"
|
||||
model = SDXLVAEDecoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def search_for_embeddings(self, state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
@@ -276,6 +334,14 @@ class ModelManager:
|
||||
self.load_ipadapter_xl(state_dict, file_path=file_path)
|
||||
elif self.is_ipadapter_xl_image_encoder(state_dict):
|
||||
self.load_ipadapter_xl_image_encoder(state_dict, file_path=file_path)
|
||||
elif self.is_hunyuan_dit_clip_text_encoder(state_dict):
|
||||
self.load_hunyuan_dit_clip_text_encoder(state_dict, file_path=file_path)
|
||||
elif self.is_hunyuan_dit_t5_text_encoder(state_dict):
|
||||
self.load_hunyuan_dit_t5_text_encoder(state_dict, file_path=file_path)
|
||||
elif self.is_hunyuan_dit(state_dict):
|
||||
self.load_hunyuan_dit(state_dict, file_path=file_path)
|
||||
elif self.is_diffusers_vae(state_dict):
|
||||
self.load_diffusers_vae(state_dict, file_path=file_path)
|
||||
|
||||
def load_models(self, file_path_list, lora_alphas=[]):
|
||||
for file_path in file_path_list:
|
||||
|
||||
@@ -34,7 +34,7 @@ class Attention(torch.nn.Module):
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
return hidden_states
|
||||
|
||||
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None):
|
||||
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
@@ -48,6 +48,9 @@ class Attention(torch.nn.Module):
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if qkv_preprocessor is not None:
|
||||
q, k, v = qkv_preprocessor(q, k, v)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
if ipadapter_kwargs is not None:
|
||||
hidden_states = self.interact_with_ipadapter(hidden_states, q, **ipadapter_kwargs)
|
||||
@@ -82,5 +85,5 @@ class Attention(torch.nn.Module):
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None):
|
||||
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs)
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None, ipadapter_kwargs=None, qkv_preprocessor=None):
|
||||
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask, ipadapter_kwargs=ipadapter_kwargs, qkv_preprocessor=qkv_preprocessor)
|
||||
451
diffsynth/models/hunyuan_dit.py
Normal file
451
diffsynth/models/hunyuan_dit.py
Normal file
@@ -0,0 +1,451 @@
|
||||
from .attention import Attention
|
||||
from .tiler import TileWorker
|
||||
from einops import repeat, rearrange
|
||||
import math
|
||||
import torch
|
||||
|
||||
|
||||
class HunyuanDiTRotaryEmbedding(torch.nn.Module):
|
||||
|
||||
def __init__(self, q_norm_shape=88, k_norm_shape=88, rotary_emb_on_k=True):
|
||||
super().__init__()
|
||||
self.q_norm = torch.nn.LayerNorm((q_norm_shape,), elementwise_affine=True, eps=1e-06)
|
||||
self.k_norm = torch.nn.LayerNorm((k_norm_shape,), elementwise_affine=True, eps=1e-06)
|
||||
self.rotary_emb_on_k = rotary_emb_on_k
|
||||
self.k_cache, self.v_cache = [], []
|
||||
|
||||
def reshape_for_broadcast(self, freqs_cis, x):
|
||||
ndim = x.ndim
|
||||
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
||||
|
||||
def rotate_half(self, x):
|
||||
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
def apply_rotary_emb(self, xq, xk, freqs_cis):
|
||||
xk_out = None
|
||||
cos, sin = self.reshape_for_broadcast(freqs_cis, xq)
|
||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
||||
xq_out = (xq.float() * cos + self.rotate_half(xq.float()) * sin).type_as(xq)
|
||||
if xk is not None:
|
||||
xk_out = (xk.float() * cos + self.rotate_half(xk.float()) * sin).type_as(xk)
|
||||
return xq_out, xk_out
|
||||
|
||||
def forward(self, q, k, v, freqs_cis_img, to_cache=False):
|
||||
# norm
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# RoPE
|
||||
if self.rotary_emb_on_k:
|
||||
q, k = self.apply_rotary_emb(q, k, freqs_cis_img)
|
||||
else:
|
||||
q, _ = self.apply_rotary_emb(q, None, freqs_cis_img)
|
||||
|
||||
if to_cache:
|
||||
self.k_cache.append(k)
|
||||
self.v_cache.append(v)
|
||||
elif len(self.k_cache) > 0 and len(self.v_cache) > 0:
|
||||
k = torch.concat([k] + self.k_cache, dim=2)
|
||||
v = torch.concat([v] + self.v_cache, dim=2)
|
||||
self.k_cache, self.v_cache = [], []
|
||||
return q, k, v
|
||||
|
||||
|
||||
class FP32_Layernorm(torch.nn.LayerNorm):
|
||||
def forward(self, inputs):
|
||||
origin_dtype = inputs.dtype
|
||||
return torch.nn.functional.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).to(origin_dtype)
|
||||
|
||||
|
||||
class FP32_SiLU(torch.nn.SiLU):
|
||||
def forward(self, inputs):
|
||||
origin_dtype = inputs.dtype
|
||||
return torch.nn.functional.silu(inputs.float(), inplace=False).to(origin_dtype)
|
||||
|
||||
|
||||
class HunyuanDiTFinalLayer(torch.nn.Module):
|
||||
def __init__(self, final_hidden_size=1408, condition_dim=1408, patch_size=2, out_channels=8):
|
||||
super().__init__()
|
||||
self.norm_final = torch.nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = torch.nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = torch.nn.Sequential(
|
||||
FP32_SiLU(),
|
||||
torch.nn.Linear(condition_dim, 2 * final_hidden_size, bias=True)
|
||||
)
|
||||
|
||||
def modulate(self, x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
def forward(self, hidden_states, condition_emb):
|
||||
shift, scale = self.adaLN_modulation(condition_emb).chunk(2, dim=1)
|
||||
hidden_states = self.modulate(self.norm_final(hidden_states), shift, scale)
|
||||
hidden_states = self.linear(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanDiTBlock(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim=1408,
|
||||
condition_dim=1408,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.3637,
|
||||
text_dim=1024,
|
||||
skip_connection=False
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
||||
self.rota1 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads)
|
||||
self.attn1 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, bias_q=True, bias_kv=True, bias_out=True)
|
||||
self.norm2 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
||||
self.rota2 = HunyuanDiTRotaryEmbedding(hidden_dim//num_heads, hidden_dim//num_heads, rotary_emb_on_k=False)
|
||||
self.attn2 = Attention(hidden_dim, num_heads, hidden_dim//num_heads, kv_dim=text_dim, bias_q=True, bias_kv=True, bias_out=True)
|
||||
self.norm3 = FP32_Layernorm((hidden_dim,), eps=1e-6, elementwise_affine=True)
|
||||
self.modulation = torch.nn.Sequential(FP32_SiLU(), torch.nn.Linear(condition_dim, hidden_dim, bias=True))
|
||||
self.mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden_dim, int(hidden_dim*mlp_ratio), bias=True),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(int(hidden_dim*mlp_ratio), hidden_dim, bias=True)
|
||||
)
|
||||
if skip_connection:
|
||||
self.skip_norm = FP32_Layernorm((hidden_dim * 2,), eps=1e-6, elementwise_affine=True)
|
||||
self.skip_linear = torch.nn.Linear(hidden_dim * 2, hidden_dim, bias=True)
|
||||
else:
|
||||
self.skip_norm, self.skip_linear = None, None
|
||||
|
||||
def forward(self, hidden_states, condition_emb, text_emb, freq_cis_img, residual=None, to_cache=False):
|
||||
# Long Skip Connection
|
||||
if self.skip_norm is not None and self.skip_linear is not None:
|
||||
hidden_states = torch.cat([hidden_states, residual], dim=-1)
|
||||
hidden_states = self.skip_norm(hidden_states)
|
||||
hidden_states = self.skip_linear(hidden_states)
|
||||
|
||||
# Self-Attention
|
||||
shift_msa = self.modulation(condition_emb).unsqueeze(dim=1)
|
||||
attn_input = self.norm1(hidden_states) + shift_msa
|
||||
hidden_states = hidden_states + self.attn1(attn_input, qkv_preprocessor=lambda q, k, v: self.rota1(q, k, v, freq_cis_img, to_cache=to_cache))
|
||||
|
||||
# Cross-Attention
|
||||
attn_input = self.norm3(hidden_states)
|
||||
hidden_states = hidden_states + self.attn2(attn_input, text_emb, qkv_preprocessor=lambda q, k, v: self.rota2(q, k, v, freq_cis_img))
|
||||
|
||||
# FFN Layer
|
||||
mlp_input = self.norm2(hidden_states)
|
||||
hidden_states = hidden_states + self.mlp(mlp_input)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttentionPool(torch.nn.Module):
|
||||
def __init__(self, spacial_dim, embed_dim, num_heads, output_dim = None):
|
||||
super().__init__()
|
||||
self.positional_embedding = torch.nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5)
|
||||
self.k_proj = torch.nn.Linear(embed_dim, embed_dim)
|
||||
self.q_proj = torch.nn.Linear(embed_dim, embed_dim)
|
||||
self.v_proj = torch.nn.Linear(embed_dim, embed_dim)
|
||||
self.c_proj = torch.nn.Linear(embed_dim, output_dim or embed_dim)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(1, 0, 2) # NLC -> LNC
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
|
||||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC
|
||||
x, _ = torch.nn.functional.multi_head_attention_forward(
|
||||
query=x[:1], key=x, value=x,
|
||||
embed_dim_to_check=x.shape[-1],
|
||||
num_heads=self.num_heads,
|
||||
q_proj_weight=self.q_proj.weight,
|
||||
k_proj_weight=self.k_proj.weight,
|
||||
v_proj_weight=self.v_proj.weight,
|
||||
in_proj_weight=None,
|
||||
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||
bias_k=None,
|
||||
bias_v=None,
|
||||
add_zero_attn=False,
|
||||
dropout_p=0,
|
||||
out_proj_weight=self.c_proj.weight,
|
||||
out_proj_bias=self.c_proj.bias,
|
||||
use_separate_proj_weight=True,
|
||||
training=self.training,
|
||||
need_weights=False
|
||||
)
|
||||
return x.squeeze(0)
|
||||
|
||||
|
||||
class PatchEmbed(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size=(2, 2),
|
||||
in_chans=4,
|
||||
embed_dim=1408,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.proj = torch.nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
return x
|
||||
|
||||
|
||||
def timestep_embedding(t, dim, max_period=10000, repeat_only=False):
|
||||
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||
/ half
|
||||
).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
||||
)
|
||||
else:
|
||||
embedding = repeat(t, "b -> b d", d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
class TimestepEmbedder(torch.nn.Module):
|
||||
def __init__(self, hidden_size=1408, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
self.mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size, bias=True),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
class HunyuanDiT(torch.nn.Module):
|
||||
def __init__(self, num_layers_down=21, num_layers_up=19, in_channels=4, out_channels=8, hidden_dim=1408, text_dim=1024, t5_dim=2048, text_length=77, t5_length=256):
|
||||
super().__init__()
|
||||
|
||||
# Embedders
|
||||
self.text_emb_padding = torch.nn.Parameter(torch.randn(text_length + t5_length, text_dim, dtype=torch.float32))
|
||||
self.t5_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(t5_dim, t5_dim * 4, bias=True),
|
||||
FP32_SiLU(),
|
||||
torch.nn.Linear(t5_dim * 4, text_dim, bias=True),
|
||||
)
|
||||
self.t5_pooler = AttentionPool(t5_length, t5_dim, num_heads=8, output_dim=1024)
|
||||
self.style_embedder = torch.nn.Parameter(torch.randn(hidden_dim))
|
||||
self.patch_embedder = PatchEmbed(in_chans=in_channels)
|
||||
self.timestep_embedder = TimestepEmbedder()
|
||||
self.extra_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(256 * 6 + 1024 + hidden_dim, hidden_dim * 4),
|
||||
FP32_SiLU(),
|
||||
torch.nn.Linear(hidden_dim * 4, hidden_dim),
|
||||
)
|
||||
|
||||
# Transformer blocks
|
||||
self.num_layers_down = num_layers_down
|
||||
self.num_layers_up = num_layers_up
|
||||
self.blocks = torch.nn.ModuleList(
|
||||
[HunyuanDiTBlock(skip_connection=False) for _ in range(num_layers_down)] + \
|
||||
[HunyuanDiTBlock(skip_connection=True) for _ in range(num_layers_up)]
|
||||
)
|
||||
|
||||
# Output layers
|
||||
self.final_layer = HunyuanDiTFinalLayer()
|
||||
self.out_channels = out_channels
|
||||
|
||||
def prepare_text_emb(self, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5):
|
||||
text_emb_mask = text_emb_mask.bool()
|
||||
text_emb_mask_t5 = text_emb_mask_t5.bool()
|
||||
text_emb_t5 = self.t5_embedder(text_emb_t5)
|
||||
text_emb = torch.cat([text_emb, text_emb_t5], dim=1)
|
||||
text_emb_mask = torch.cat([text_emb_mask, text_emb_mask_t5], dim=-1)
|
||||
text_emb = torch.where(text_emb_mask.unsqueeze(2), text_emb, self.text_emb_padding.to(text_emb))
|
||||
return text_emb
|
||||
|
||||
def prepare_extra_emb(self, text_emb_t5, timestep, size_emb, dtype, batch_size):
|
||||
# Text embedding
|
||||
pooled_text_emb_t5 = self.t5_pooler(text_emb_t5)
|
||||
|
||||
# Timestep embedding
|
||||
timestep_emb = self.timestep_embedder(timestep)
|
||||
|
||||
# Size embedding
|
||||
size_emb = timestep_embedding(size_emb.view(-1), 256).to(dtype)
|
||||
size_emb = size_emb.view(-1, 6 * 256)
|
||||
|
||||
# Style embedding
|
||||
style_emb = repeat(self.style_embedder, "D -> B D", B=batch_size)
|
||||
|
||||
# Concatenate all extra vectors
|
||||
extra_emb = torch.cat([pooled_text_emb_t5, size_emb, style_emb], dim=1)
|
||||
condition_emb = timestep_emb + self.extra_embedder(extra_emb)
|
||||
|
||||
return condition_emb
|
||||
|
||||
def unpatchify(self, x, h, w):
|
||||
return rearrange(x, "B (H W) (P Q C) -> B C (H P) (W Q)", H=h, W=w, P=2, Q=2)
|
||||
|
||||
def build_mask(self, data, is_bound):
|
||||
_, _, H, W = data.shape
|
||||
h = repeat(torch.arange(H), "H -> H W", H=H, W=W)
|
||||
w = repeat(torch.arange(W), "W -> H W", H=H, W=W)
|
||||
border_width = (H + W) // 4
|
||||
pad = torch.ones_like(h) * border_width
|
||||
mask = torch.stack([
|
||||
pad if is_bound[0] else h + 1,
|
||||
pad if is_bound[1] else H - h,
|
||||
pad if is_bound[2] else w + 1,
|
||||
pad if is_bound[3] else W - w
|
||||
]).min(dim=0).values
|
||||
mask = mask.clip(1, border_width)
|
||||
mask = (mask / border_width).to(dtype=data.dtype, device=data.device)
|
||||
mask = rearrange(mask, "H W -> 1 H W")
|
||||
return mask
|
||||
|
||||
def tiled_block_forward(self, block, hidden_states, condition_emb, text_emb, freq_cis_img, residual, torch_dtype, data_device, computation_device, tile_size, tile_stride):
|
||||
B, C, H, W = hidden_states.shape
|
||||
|
||||
weight = torch.zeros((1, 1, H, W), dtype=torch_dtype, device=data_device)
|
||||
values = torch.zeros((B, C, H, W), dtype=torch_dtype, device=data_device)
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for h in range(0, H, tile_stride):
|
||||
for w in range(0, W, tile_stride):
|
||||
if (h-tile_stride >= 0 and h-tile_stride+tile_size >= H) or (w-tile_stride >= 0 and w-tile_stride+tile_size >= W):
|
||||
continue
|
||||
h_, w_ = h + tile_size, w + tile_size
|
||||
if h_ > H: h, h_ = H - tile_size, H
|
||||
if w_ > W: w, w_ = W - tile_size, W
|
||||
tasks.append((h, h_, w, w_))
|
||||
|
||||
# Run
|
||||
for hl, hr, wl, wr in tasks:
|
||||
hidden_states_batch = hidden_states[:, :, hl:hr, wl:wr].to(computation_device)
|
||||
hidden_states_batch = rearrange(hidden_states_batch, "B C H W -> B (H W) C")
|
||||
if residual is not None:
|
||||
residual_batch = residual[:, :, hl:hr, wl:wr].to(computation_device)
|
||||
residual_batch = rearrange(residual_batch, "B C H W -> B (H W) C")
|
||||
else:
|
||||
residual_batch = None
|
||||
|
||||
# Forward
|
||||
hidden_states_batch = block(hidden_states_batch, condition_emb, text_emb, freq_cis_img, residual_batch).to(data_device)
|
||||
hidden_states_batch = rearrange(hidden_states_batch, "B (H W) C -> B C H W", H=hr-hl)
|
||||
|
||||
mask = self.build_mask(hidden_states_batch, is_bound=(hl==0, hr>=H, wl==0, wr>=W))
|
||||
values[:, :, hl:hr, wl:wr] += hidden_states_batch * mask
|
||||
weight[:, :, hl:hr, wl:wr] += mask
|
||||
values /= weight
|
||||
return values
|
||||
|
||||
def forward(
|
||||
self, hidden_states, text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5, timestep, size_emb, freq_cis_img,
|
||||
tiled=False, tile_size=64, tile_stride=32,
|
||||
to_cache=False,
|
||||
use_gradient_checkpointing=False,
|
||||
):
|
||||
# Embeddings
|
||||
text_emb = self.prepare_text_emb(text_emb, text_emb_t5, text_emb_mask, text_emb_mask_t5)
|
||||
condition_emb = self.prepare_extra_emb(text_emb_t5, timestep, size_emb, hidden_states.dtype, hidden_states.shape[0])
|
||||
|
||||
# Input
|
||||
height, width = hidden_states.shape[-2], hidden_states.shape[-1]
|
||||
hidden_states = self.patch_embedder(hidden_states)
|
||||
|
||||
# Blocks
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
if tiled:
|
||||
hidden_states = rearrange(hidden_states, "B (H W) C -> B C H W", H=height//2)
|
||||
residuals = []
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
residual = residuals.pop() if block_id >= self.num_layers_down else None
|
||||
hidden_states = self.tiled_block_forward(
|
||||
block, hidden_states, condition_emb, text_emb, freq_cis_img, residual,
|
||||
torch_dtype=hidden_states.dtype, data_device=hidden_states.device, computation_device=hidden_states.device,
|
||||
tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
if block_id < self.num_layers_down - 2:
|
||||
residuals.append(hidden_states)
|
||||
hidden_states = rearrange(hidden_states, "B C H W -> B (H W) C")
|
||||
else:
|
||||
residuals = []
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
residual = residuals.pop() if block_id >= self.num_layers_down else None
|
||||
if self.training and use_gradient_checkpointing:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, condition_emb, text_emb, freq_cis_img, residual,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(hidden_states, condition_emb, text_emb, freq_cis_img, residual, to_cache=to_cache)
|
||||
if block_id < self.num_layers_down - 2:
|
||||
residuals.append(hidden_states)
|
||||
|
||||
# Output
|
||||
hidden_states = self.final_layer(hidden_states, condition_emb)
|
||||
hidden_states = self.unpatchify(hidden_states, height//2, width//2)
|
||||
hidden_states, _ = hidden_states.chunk(2, dim=1)
|
||||
return hidden_states
|
||||
|
||||
def state_dict_converter(self):
|
||||
return HunyuanDiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class HunyuanDiTStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
name_ = name
|
||||
name_ = name_.replace(".default_modulation.", ".modulation.")
|
||||
name_ = name_.replace(".mlp.fc1.", ".mlp.0.")
|
||||
name_ = name_.replace(".mlp.fc2.", ".mlp.2.")
|
||||
name_ = name_.replace(".attn1.q_norm.", ".rota1.q_norm.")
|
||||
name_ = name_.replace(".attn2.q_norm.", ".rota2.q_norm.")
|
||||
name_ = name_.replace(".attn1.k_norm.", ".rota1.k_norm.")
|
||||
name_ = name_.replace(".attn2.k_norm.", ".rota2.k_norm.")
|
||||
name_ = name_.replace(".q_proj.", ".to_q.")
|
||||
name_ = name_.replace(".out_proj.", ".to_out.")
|
||||
name_ = name_.replace("text_embedding_padding", "text_emb_padding")
|
||||
name_ = name_.replace("mlp_t5.0.", "t5_embedder.0.")
|
||||
name_ = name_.replace("mlp_t5.2.", "t5_embedder.2.")
|
||||
name_ = name_.replace("pooler.", "t5_pooler.")
|
||||
name_ = name_.replace("x_embedder.", "patch_embedder.")
|
||||
name_ = name_.replace("t_embedder.", "timestep_embedder.")
|
||||
name_ = name_.replace("t5_pooler.to_q.", "t5_pooler.q_proj.")
|
||||
name_ = name_.replace("style_embedder.weight", "style_embedder")
|
||||
if ".kv_proj." in name_:
|
||||
param_k = param[:param.shape[0]//2]
|
||||
param_v = param[param.shape[0]//2:]
|
||||
state_dict_[name_.replace(".kv_proj.", ".to_k.")] = param_k
|
||||
state_dict_[name_.replace(".kv_proj.", ".to_v.")] = param_v
|
||||
elif ".Wqkv." in name_:
|
||||
param_q = param[:param.shape[0]//3]
|
||||
param_k = param[param.shape[0]//3:param.shape[0]//3*2]
|
||||
param_v = param[param.shape[0]//3*2:]
|
||||
state_dict_[name_.replace(".Wqkv.", ".to_q.")] = param_q
|
||||
state_dict_[name_.replace(".Wqkv.", ".to_k.")] = param_k
|
||||
state_dict_[name_.replace(".Wqkv.", ".to_v.")] = param_v
|
||||
elif "style_embedder" in name_:
|
||||
state_dict_[name_] = param.squeeze()
|
||||
else:
|
||||
state_dict_[name_] = param
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
161
diffsynth/models/hunyuan_dit_text_encoder.py
Normal file
161
diffsynth/models/hunyuan_dit_text_encoder.py
Normal file
@@ -0,0 +1,161 @@
|
||||
from transformers import BertModel, BertConfig, T5EncoderModel, T5Config
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
class HunyuanDiTCLIPTextEncoder(BertModel):
|
||||
def __init__(self):
|
||||
config = BertConfig(
|
||||
_name_or_path = "",
|
||||
architectures = ["BertModel"],
|
||||
attention_probs_dropout_prob = 0.1,
|
||||
bos_token_id = 0,
|
||||
classifier_dropout = None,
|
||||
directionality = "bidi",
|
||||
eos_token_id = 2,
|
||||
hidden_act = "gelu",
|
||||
hidden_dropout_prob = 0.1,
|
||||
hidden_size = 1024,
|
||||
initializer_range = 0.02,
|
||||
intermediate_size = 4096,
|
||||
layer_norm_eps = 1e-12,
|
||||
max_position_embeddings = 512,
|
||||
model_type = "bert",
|
||||
num_attention_heads = 16,
|
||||
num_hidden_layers = 24,
|
||||
output_past = True,
|
||||
pad_token_id = 0,
|
||||
pooler_fc_size = 768,
|
||||
pooler_num_attention_heads = 12,
|
||||
pooler_num_fc_layers = 3,
|
||||
pooler_size_per_head = 128,
|
||||
pooler_type = "first_token_transform",
|
||||
position_embedding_type = "absolute",
|
||||
torch_dtype = "float32",
|
||||
transformers_version = "4.37.2",
|
||||
type_vocab_size = 2,
|
||||
use_cache = True,
|
||||
vocab_size = 47020
|
||||
)
|
||||
super().__init__(config, add_pooling_layer=False)
|
||||
self.eval()
|
||||
|
||||
def forward(self, input_ids, attention_mask, clip_skip=1):
|
||||
input_shape = input_ids.size()
|
||||
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device
|
||||
|
||||
past_key_values_length = 0
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
||||
embedding_output = self.embeddings(
|
||||
input_ids=input_ids,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
past_key_values_length=0,
|
||||
)
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
past_key_values=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
all_hidden_states = encoder_outputs.hidden_states
|
||||
prompt_emb = all_hidden_states[-clip_skip]
|
||||
if clip_skip > 1:
|
||||
mean, std = all_hidden_states[-1].mean(), all_hidden_states[-1].std()
|
||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
||||
return prompt_emb
|
||||
|
||||
def state_dict_converter(self):
|
||||
return HunyuanDiTCLIPTextEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class HunyuanDiTT5TextEncoder(T5EncoderModel):
|
||||
def __init__(self):
|
||||
config = T5Config(
|
||||
_name_or_path = "../HunyuanDiT/t2i/mt5",
|
||||
architectures = ["MT5ForConditionalGeneration"],
|
||||
classifier_dropout = 0.0,
|
||||
d_ff = 5120,
|
||||
d_kv = 64,
|
||||
d_model = 2048,
|
||||
decoder_start_token_id = 0,
|
||||
dense_act_fn = "gelu_new",
|
||||
dropout_rate = 0.1,
|
||||
eos_token_id = 1,
|
||||
feed_forward_proj = "gated-gelu",
|
||||
initializer_factor = 1.0,
|
||||
is_encoder_decoder = True,
|
||||
is_gated_act = True,
|
||||
layer_norm_epsilon = 1e-06,
|
||||
model_type = "t5",
|
||||
num_decoder_layers = 24,
|
||||
num_heads = 32,
|
||||
num_layers = 24,
|
||||
output_past = True,
|
||||
pad_token_id = 0,
|
||||
relative_attention_max_distance = 128,
|
||||
relative_attention_num_buckets = 32,
|
||||
tie_word_embeddings = False,
|
||||
tokenizer_class = "T5Tokenizer",
|
||||
transformers_version = "4.37.2",
|
||||
use_cache = True,
|
||||
vocab_size = 250112
|
||||
)
|
||||
super().__init__(config)
|
||||
self.eval()
|
||||
|
||||
def forward(self, input_ids, attention_mask, clip_skip=1):
|
||||
outputs = super().forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
prompt_emb = outputs.hidden_states[-clip_skip]
|
||||
if clip_skip > 1:
|
||||
mean, std = outputs.hidden_states[-1].mean(), outputs.hidden_states[-1].std()
|
||||
prompt_emb = (prompt_emb - prompt_emb.mean()) / prompt_emb.std() * std + mean
|
||||
return prompt_emb
|
||||
|
||||
def state_dict_converter(self):
|
||||
return HunyuanDiTT5TextEncoderStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class HunyuanDiTCLIPTextEncoderStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {name[5:]: param for name, param in state_dict.items() if name.startswith("bert.")}
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
|
||||
|
||||
class HunyuanDiTT5TextEncoderStateDictConverter():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {name: param for name, param in state_dict.items() if name.startswith("encoder.")}
|
||||
state_dict_["shared.weight"] = state_dict["shared.weight"]
|
||||
return state_dict_
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
return self.from_diffusers(state_dict)
|
||||
@@ -2,4 +2,5 @@ from .stable_diffusion import SDImagePipeline
|
||||
from .stable_diffusion_xl import SDXLImagePipeline
|
||||
from .stable_diffusion_video import SDVideoPipeline, SDVideoPipelineRunner
|
||||
from .stable_diffusion_xl_video import SDXLVideoPipeline
|
||||
from .stable_video_diffusion import SVDVideoPipeline
|
||||
from .stable_video_diffusion import SVDVideoPipeline
|
||||
from .hunyuan_dit import HunyuanDiTImagePipeline
|
||||
|
||||
298
diffsynth/pipelines/hunyuan_dit.py
Normal file
298
diffsynth/pipelines/hunyuan_dit.py
Normal file
@@ -0,0 +1,298 @@
|
||||
from ..models.hunyuan_dit import HunyuanDiT
|
||||
from ..models.hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from ..models.sdxl_vae_encoder import SDXLVAEEncoder
|
||||
from ..models.sdxl_vae_decoder import SDXLVAEDecoder
|
||||
from ..models import ModelManager
|
||||
from ..prompts import HunyuanDiTPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
||||
class ImageSizeManager:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def _to_tuple(self, x):
|
||||
if isinstance(x, int):
|
||||
return x, x
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def get_fill_resize_and_crop(self, src, tgt):
|
||||
th, tw = self._to_tuple(tgt)
|
||||
h, w = self._to_tuple(src)
|
||||
|
||||
tr = th / tw # base 分辨率
|
||||
r = h / w # 目标分辨率
|
||||
|
||||
# resize
|
||||
if r > tr:
|
||||
resize_height = th
|
||||
resize_width = int(round(th / h * w))
|
||||
else:
|
||||
resize_width = tw
|
||||
resize_height = int(round(tw / w * h)) # 根据base分辨率,将目标分辨率resize下来
|
||||
|
||||
crop_top = int(round((th - resize_height) / 2.0))
|
||||
crop_left = int(round((tw - resize_width) / 2.0))
|
||||
|
||||
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
||||
|
||||
|
||||
def get_meshgrid(self, start, *args):
|
||||
if len(args) == 0:
|
||||
# start is grid_size
|
||||
num = self._to_tuple(start)
|
||||
start = (0, 0)
|
||||
stop = num
|
||||
elif len(args) == 1:
|
||||
# start is start, args[0] is stop, step is 1
|
||||
start = self._to_tuple(start)
|
||||
stop = self._to_tuple(args[0])
|
||||
num = (stop[0] - start[0], stop[1] - start[1])
|
||||
elif len(args) == 2:
|
||||
# start is start, args[0] is stop, args[1] is num
|
||||
start = self._to_tuple(start) # 左上角 eg: 12,0
|
||||
stop = self._to_tuple(args[0]) # 右下角 eg: 20,32
|
||||
num = self._to_tuple(args[1]) # 目标大小 eg: 32,124
|
||||
else:
|
||||
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
||||
|
||||
grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份
|
||||
grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0) # [2, W, H]
|
||||
return grid
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed(self, embed_dim, start, *args, use_real=True):
|
||||
grid = self.get_meshgrid(start, *args) # [2, H, w]
|
||||
grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致
|
||||
pos_embed = self.get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_rotary_pos_embed_from_grid(self, embed_dim, grid, use_real=False):
|
||||
assert embed_dim % 4 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
|
||||
emb_w = self.get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
|
||||
|
||||
if use_real:
|
||||
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
|
||||
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
|
||||
return cos, sin
|
||||
else:
|
||||
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(self, dim: int, pos, theta: float = 10000.0, use_real=False):
|
||||
if isinstance(pos, int):
|
||||
pos = np.arange(pos)
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
|
||||
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
|
||||
if use_real:
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def calc_rope(self, height, width):
|
||||
patch_size = 2
|
||||
head_size = 88
|
||||
th = height // 8 // patch_size
|
||||
tw = width // 8 // patch_size
|
||||
base_size = 512 // 8 // patch_size
|
||||
start, stop = self.get_fill_resize_and_crop((th, tw), base_size)
|
||||
sub_args = [start, stop, (th, tw)]
|
||||
rope = self.get_2d_rotary_pos_embed(head_size, *sub_args)
|
||||
return rope
|
||||
|
||||
|
||||
|
||||
class HunyuanDiTImagePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.scheduler = EnhancedDDIMScheduler(prediction_type="v_prediction", beta_start=0.00085, beta_end=0.03)
|
||||
self.prompter = HunyuanDiTPrompter()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
self.image_size_manager = ImageSizeManager()
|
||||
# models
|
||||
self.text_encoder: HunyuanDiTCLIPTextEncoder = None
|
||||
self.text_encoder_t5: HunyuanDiTT5TextEncoder = None
|
||||
self.dit: HunyuanDiT = None
|
||||
self.vae_decoder: SDXLVAEDecoder = None
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
|
||||
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.text_encoder = model_manager.hunyuan_dit_clip_text_encoder
|
||||
self.text_encoder_t5 = model_manager.hunyuan_dit_t5_text_encoder
|
||||
self.dit = model_manager.hunyuan_dit
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
|
||||
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager):
|
||||
pipe = HunyuanDiTImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
return pipe
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
def prepare_extra_input(self, height=1024, width=1024, tiled=False, tile_size=64, tile_stride=32, batch_size=1):
|
||||
if tiled:
|
||||
height, width = tile_size * 16, tile_size * 16
|
||||
image_meta_size = torch.as_tensor([width, height, width, height, 0, 0]).to(device=self.device)
|
||||
freqs_cis_img = self.image_size_manager.calc_rope(height, width)
|
||||
image_meta_size = torch.stack([image_meta_size] * batch_size)
|
||||
return {
|
||||
"size_emb": image_meta_size,
|
||||
"freq_cis_img": (freqs_cis_img[0].to(dtype=self.torch_dtype, device=self.device), freqs_cis_img[1].to(dtype=self.torch_dtype, device=self.device)),
|
||||
"tiled": tiled,
|
||||
"tile_size": tile_size,
|
||||
"tile_stride": tile_stride
|
||||
}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=1,
|
||||
clip_skip_2=1,
|
||||
input_image=None,
|
||||
reference_images=[],
|
||||
reference_strengths=[0.4],
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=20,
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = noise.clone()
|
||||
|
||||
# Prepare reference latents
|
||||
reference_latents = []
|
||||
for reference_image in reference_images:
|
||||
reference_image = self.preprocess_image(reference_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
reference_latents.append(self.vae_encoder(reference_image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype))
|
||||
|
||||
# Encode prompts
|
||||
prompt_emb_posi, attention_mask_posi, prompt_emb_t5_posi, attention_mask_t5_posi = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_t5,
|
||||
prompt,
|
||||
clip_skip=clip_skip,
|
||||
clip_skip_2=clip_skip_2,
|
||||
positive=True,
|
||||
device=self.device
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega, attention_mask_nega, prompt_emb_t5_nega, attention_mask_t5_nega = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
self.text_encoder_t5,
|
||||
negative_prompt,
|
||||
clip_skip=clip_skip,
|
||||
clip_skip_2=clip_skip_2,
|
||||
positive=False,
|
||||
device=self.device
|
||||
)
|
||||
|
||||
# Prepare positional id
|
||||
extra_input = self.prepare_extra_input(height, width, tiled, tile_size)
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.tensor([timestep]).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# In-context reference
|
||||
for reference_latents_, reference_strength in zip(reference_latents, reference_strengths):
|
||||
if progress_id < num_inference_steps * reference_strength:
|
||||
noisy_reference_latents = self.scheduler.add_noise(reference_latents_, noise, self.scheduler.timesteps[progress_id])
|
||||
self.dit(
|
||||
noisy_reference_latents,
|
||||
prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi,
|
||||
timestep,
|
||||
**extra_input,
|
||||
to_cache=True
|
||||
)
|
||||
# Positive side
|
||||
noise_pred_posi = self.dit(
|
||||
latents,
|
||||
prompt_emb_posi, prompt_emb_t5_posi, attention_mask_posi, attention_mask_t5_posi,
|
||||
timestep,
|
||||
**extra_input,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
# Negative side
|
||||
noise_pred_nega = self.dit(
|
||||
latents,
|
||||
prompt_emb_nega, prompt_emb_t5_nega, attention_mask_nega, attention_mask_t5_nega,
|
||||
timestep,
|
||||
**extra_input
|
||||
)
|
||||
# Classifier-free guidance
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
return image
|
||||
@@ -1,176 +1,3 @@
|
||||
from transformers import CLIPTokenizer, AutoTokenizer
|
||||
from ..models import SDTextEncoder, SDXLTextEncoder, SDXLTextEncoder2, ModelManager
|
||||
import torch, os
|
||||
|
||||
|
||||
def tokenize_long_prompt(tokenizer, prompt):
|
||||
# Get model_max_length from self.tokenizer
|
||||
length = tokenizer.model_max_length
|
||||
|
||||
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
||||
tokenizer.model_max_length = 99999999
|
||||
|
||||
# Tokenize it!
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
# Determine the real length.
|
||||
max_length = (input_ids.shape[1] + length - 1) // length * length
|
||||
|
||||
# Restore tokenizer.model_max_length
|
||||
tokenizer.model_max_length = length
|
||||
|
||||
# Tokenize it again with fixed length.
|
||||
input_ids = tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True
|
||||
).input_ids
|
||||
|
||||
# Reshape input_ids to fit the text encoder.
|
||||
num_sentence = input_ids.shape[1] // length
|
||||
input_ids = input_ids.reshape((num_sentence, length))
|
||||
|
||||
return input_ids
|
||||
|
||||
|
||||
class BeautifulPrompt:
|
||||
def __init__(self, tokenizer_path="configs/beautiful_prompt/tokenizer", model=None):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
|
||||
|
||||
def __call__(self, raw_prompt):
|
||||
model_input = self.template.format(raw_prompt=raw_prompt)
|
||||
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
|
||||
outputs = self.model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=384,
|
||||
do_sample=True,
|
||||
temperature=0.9,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
repetition_penalty=1.1,
|
||||
num_return_sequences=1
|
||||
)
|
||||
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
|
||||
outputs[:, input_ids.size(1):],
|
||||
skip_special_tokens=True
|
||||
)[0].strip()
|
||||
return prompt
|
||||
|
||||
|
||||
class Translator:
|
||||
def __init__(self, tokenizer_path="configs/translator/tokenizer", model=None):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
|
||||
def __call__(self, prompt):
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
|
||||
output_ids = self.model.generate(input_ids)
|
||||
prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
||||
return prompt
|
||||
|
||||
|
||||
class Prompter:
|
||||
def __init__(self):
|
||||
self.tokenizer: CLIPTokenizer = None
|
||||
self.keyword_dict = {}
|
||||
self.translator: Translator = None
|
||||
self.beautiful_prompt: BeautifulPrompt = None
|
||||
|
||||
def load_textual_inversion(self, textual_inversion_dict):
|
||||
self.keyword_dict = {}
|
||||
additional_tokens = []
|
||||
for keyword in textual_inversion_dict:
|
||||
tokens, _ = textual_inversion_dict[keyword]
|
||||
additional_tokens += tokens
|
||||
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
|
||||
self.tokenizer.add_tokens(additional_tokens)
|
||||
|
||||
def load_beautiful_prompt(self, model, model_path):
|
||||
model_folder = os.path.dirname(model_path)
|
||||
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
|
||||
if model_folder.endswith("v2"):
|
||||
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
|
||||
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
|
||||
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
|
||||
but make sure there is a correlation between the input and output.\n\
|
||||
### Input: {raw_prompt}\n### Output:"""
|
||||
|
||||
def load_translator(self, model, model_path):
|
||||
model_folder = os.path.dirname(model_path)
|
||||
self.translator = Translator(tokenizer_path=model_folder, model=model)
|
||||
|
||||
def load_from_model_manager(self, model_manager: ModelManager):
|
||||
self.load_textual_inversion(model_manager.textual_inversion_dict)
|
||||
if "translator" in model_manager.model:
|
||||
self.load_translator(model_manager.model["translator"], model_manager.model_path["translator"])
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
|
||||
def process_prompt(self, prompt, positive=True):
|
||||
for keyword in self.keyword_dict:
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
||||
if positive and self.translator is not None:
|
||||
prompt = self.translator(prompt)
|
||||
print(f"Your prompt is translated: \"{prompt}\"")
|
||||
if positive and self.beautiful_prompt is not None:
|
||||
prompt = self.beautiful_prompt(prompt)
|
||||
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
|
||||
return prompt
|
||||
|
||||
|
||||
class SDPrompter(Prompter):
|
||||
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
|
||||
return prompt_emb
|
||||
|
||||
|
||||
class SDXLPrompter(Prompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path="configs/stable_diffusion/tokenizer",
|
||||
tokenizer_2_path="configs/stable_diffusion_xl/tokenizer_2"
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
text_encoder: SDXLTextEncoder,
|
||||
text_encoder_2: SDXLTextEncoder2,
|
||||
prompt,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
# 1
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
|
||||
# 2
|
||||
input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
|
||||
add_text_embeds, prompt_emb_2 = text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
|
||||
|
||||
# Merge
|
||||
prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
|
||||
|
||||
# For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
|
||||
add_text_embeds = add_text_embeds[0:1]
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
return add_text_embeds, prompt_emb
|
||||
from .sd_prompter import SDPrompter
|
||||
from .sdxl_prompter import SDXLPrompter
|
||||
from .hunyuan_dit_prompter import HunyuanDiTPrompter
|
||||
|
||||
56
diffsynth/prompts/hunyuan_dit_prompter.py
Normal file
56
diffsynth/prompts/hunyuan_dit_prompter.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from .utils import Prompter
|
||||
from transformers import BertModel, T5EncoderModel, BertTokenizer, AutoTokenizer
|
||||
import warnings
|
||||
|
||||
|
||||
class HunyuanDiTPrompter(Prompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path="configs/hunyuan_dit/tokenizer",
|
||||
tokenizer_t5_path="configs/hunyuan_dit/tokenizer_t5"
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
self.tokenizer_t5 = AutoTokenizer.from_pretrained(tokenizer_t5_path)
|
||||
|
||||
|
||||
def encode_prompt_using_signle_model(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
clip_skip=clip_skip
|
||||
)
|
||||
return prompt_embeds, attention_mask
|
||||
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
text_encoder: BertModel,
|
||||
text_encoder_t5: T5EncoderModel,
|
||||
prompt,
|
||||
clip_skip=1,
|
||||
clip_skip_2=1,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
# CLIP
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_signle_model(prompt, text_encoder, self.tokenizer, self.tokenizer.model_max_length, clip_skip, device)
|
||||
|
||||
# T5
|
||||
prompt_emb_t5, attention_mask_t5 = self.encode_prompt_using_signle_model(prompt, text_encoder_t5, self.tokenizer_t5, self.tokenizer_t5.model_max_length, clip_skip_2, device)
|
||||
|
||||
return prompt_emb, attention_mask, prompt_emb_t5, attention_mask_t5
|
||||
17
diffsynth/prompts/sd_prompter.py
Normal file
17
diffsynth/prompts/sd_prompter.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .utils import Prompter, tokenize_long_prompt
|
||||
from transformers import CLIPTokenizer
|
||||
from ..models import SDTextEncoder
|
||||
|
||||
|
||||
class SDPrompter(Prompter):
|
||||
def __init__(self, tokenizer_path="configs/stable_diffusion/tokenizer"):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
def encode_prompt(self, text_encoder: SDTextEncoder, prompt, clip_skip=1, device="cuda", positive=True):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
|
||||
return prompt_emb
|
||||
43
diffsynth/prompts/sdxl_prompter.py
Normal file
43
diffsynth/prompts/sdxl_prompter.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from .utils import Prompter, tokenize_long_prompt
|
||||
from transformers import CLIPTokenizer
|
||||
from ..models import SDXLTextEncoder, SDXLTextEncoder2
|
||||
import torch
|
||||
|
||||
|
||||
class SDXLPrompter(Prompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path="configs/stable_diffusion/tokenizer",
|
||||
tokenizer_2_path="configs/stable_diffusion_xl/tokenizer_2"
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
|
||||
self.tokenizer_2 = CLIPTokenizer.from_pretrained(tokenizer_2_path)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
text_encoder: SDXLTextEncoder,
|
||||
text_encoder_2: SDXLTextEncoder2,
|
||||
prompt,
|
||||
clip_skip=1,
|
||||
clip_skip_2=2,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
# 1
|
||||
input_ids = tokenize_long_prompt(self.tokenizer, prompt).to(device)
|
||||
prompt_emb_1 = text_encoder(input_ids, clip_skip=clip_skip)
|
||||
|
||||
# 2
|
||||
input_ids_2 = tokenize_long_prompt(self.tokenizer_2, prompt).to(device)
|
||||
add_text_embeds, prompt_emb_2 = text_encoder_2(input_ids_2, clip_skip=clip_skip_2)
|
||||
|
||||
# Merge
|
||||
prompt_emb = torch.concatenate([prompt_emb_1, prompt_emb_2], dim=-1)
|
||||
|
||||
# For very long prompt, we only use the first 77 tokens to compute `add_text_embeds`.
|
||||
add_text_embeds = add_text_embeds[0:1]
|
||||
prompt_emb = prompt_emb.reshape((1, prompt_emb.shape[0]*prompt_emb.shape[1], -1))
|
||||
return add_text_embeds, prompt_emb
|
||||
123
diffsynth/prompts/utils.py
Normal file
123
diffsynth/prompts/utils.py
Normal file
@@ -0,0 +1,123 @@
|
||||
from transformers import CLIPTokenizer, AutoTokenizer
|
||||
from ..models import ModelManager
|
||||
import os
|
||||
|
||||
|
||||
def tokenize_long_prompt(tokenizer, prompt):
|
||||
# Get model_max_length from self.tokenizer
|
||||
length = tokenizer.model_max_length
|
||||
|
||||
# To avoid the warning. set self.tokenizer.model_max_length to +oo.
|
||||
tokenizer.model_max_length = 99999999
|
||||
|
||||
# Tokenize it!
|
||||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
|
||||
|
||||
# Determine the real length.
|
||||
max_length = (input_ids.shape[1] + length - 1) // length * length
|
||||
|
||||
# Restore tokenizer.model_max_length
|
||||
tokenizer.model_max_length = length
|
||||
|
||||
# Tokenize it again with fixed length.
|
||||
input_ids = tokenizer(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True
|
||||
).input_ids
|
||||
|
||||
# Reshape input_ids to fit the text encoder.
|
||||
num_sentence = input_ids.shape[1] // length
|
||||
input_ids = input_ids.reshape((num_sentence, length))
|
||||
|
||||
return input_ids
|
||||
|
||||
|
||||
class BeautifulPrompt:
|
||||
def __init__(self, tokenizer_path="configs/beautiful_prompt/tokenizer", model=None):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
self.template = 'Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:'
|
||||
|
||||
def __call__(self, raw_prompt):
|
||||
model_input = self.template.format(raw_prompt=raw_prompt)
|
||||
input_ids = self.tokenizer.encode(model_input, return_tensors='pt').to(self.model.device)
|
||||
outputs = self.model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=384,
|
||||
do_sample=True,
|
||||
temperature=0.9,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
repetition_penalty=1.1,
|
||||
num_return_sequences=1
|
||||
)
|
||||
prompt = raw_prompt + ", " + self.tokenizer.batch_decode(
|
||||
outputs[:, input_ids.size(1):],
|
||||
skip_special_tokens=True
|
||||
)[0].strip()
|
||||
return prompt
|
||||
|
||||
|
||||
class Translator:
|
||||
def __init__(self, tokenizer_path="configs/translator/tokenizer", model=None):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
self.model = model
|
||||
|
||||
def __call__(self, prompt):
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.model.device)
|
||||
output_ids = self.model.generate(input_ids)
|
||||
prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
||||
return prompt
|
||||
|
||||
|
||||
class Prompter:
|
||||
def __init__(self):
|
||||
self.tokenizer: CLIPTokenizer = None
|
||||
self.keyword_dict = {}
|
||||
self.translator: Translator = None
|
||||
self.beautiful_prompt: BeautifulPrompt = None
|
||||
|
||||
def load_textual_inversion(self, textual_inversion_dict):
|
||||
self.keyword_dict = {}
|
||||
additional_tokens = []
|
||||
for keyword in textual_inversion_dict:
|
||||
tokens, _ = textual_inversion_dict[keyword]
|
||||
additional_tokens += tokens
|
||||
self.keyword_dict[keyword] = " " + " ".join(tokens) + " "
|
||||
self.tokenizer.add_tokens(additional_tokens)
|
||||
|
||||
def load_beautiful_prompt(self, model, model_path):
|
||||
model_folder = os.path.dirname(model_path)
|
||||
self.beautiful_prompt = BeautifulPrompt(tokenizer_path=model_folder, model=model)
|
||||
if model_folder.endswith("v2"):
|
||||
self.beautiful_prompt.template = """Converts a simple image description into a prompt. \
|
||||
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
|
||||
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
|
||||
but make sure there is a correlation between the input and output.\n\
|
||||
### Input: {raw_prompt}\n### Output:"""
|
||||
|
||||
def load_translator(self, model, model_path):
|
||||
model_folder = os.path.dirname(model_path)
|
||||
self.translator = Translator(tokenizer_path=model_folder, model=model)
|
||||
|
||||
def load_from_model_manager(self, model_manager: ModelManager):
|
||||
self.load_textual_inversion(model_manager.textual_inversion_dict)
|
||||
if "translator" in model_manager.model:
|
||||
self.load_translator(model_manager.model["translator"], model_manager.model_path["translator"])
|
||||
if "beautiful_prompt" in model_manager.model:
|
||||
self.load_beautiful_prompt(model_manager.model["beautiful_prompt"], model_manager.model_path["beautiful_prompt"])
|
||||
|
||||
def process_prompt(self, prompt, positive=True):
|
||||
for keyword in self.keyword_dict:
|
||||
if keyword in prompt:
|
||||
prompt = prompt.replace(keyword, self.keyword_dict[keyword])
|
||||
if positive and self.translator is not None:
|
||||
prompt = self.translator(prompt)
|
||||
print(f"Your prompt is translated: \"{prompt}\"")
|
||||
if positive and self.beautiful_prompt is not None:
|
||||
prompt = self.beautiful_prompt(prompt)
|
||||
print(f"Your prompt is refined by BeautifulPrompt: \"{prompt}\"")
|
||||
return prompt
|
||||
@@ -3,7 +3,7 @@ import torch, math
|
||||
|
||||
class EnhancedDDIMScheduler():
|
||||
|
||||
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"):
|
||||
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon"):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
if beta_schedule == "scaled_linear":
|
||||
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
|
||||
@@ -13,6 +13,7 @@ class EnhancedDDIMScheduler():
|
||||
raise NotImplementedError(f"{beta_schedule} is not implemented")
|
||||
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0).tolist()
|
||||
self.set_timesteps(10)
|
||||
self.prediction_type = prediction_type
|
||||
|
||||
|
||||
def set_timesteps(self, num_inference_steps, denoising_strength=1.0):
|
||||
@@ -28,9 +29,16 @@ class EnhancedDDIMScheduler():
|
||||
|
||||
|
||||
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
|
||||
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
|
||||
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
|
||||
prev_sample = sample * weight_x + model_output * weight_e
|
||||
if self.prediction_type == "epsilon":
|
||||
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
|
||||
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
|
||||
prev_sample = sample * weight_x + model_output * weight_e
|
||||
elif self.prediction_type == "v_prediction":
|
||||
weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev))
|
||||
weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev))
|
||||
prev_sample = sample * weight_x + model_output * weight_e
|
||||
else:
|
||||
raise NotImplementedError(f"{self.prediction_type} is not implemented")
|
||||
return prev_sample
|
||||
|
||||
|
||||
@@ -57,4 +65,9 @@ class EnhancedDDIMScheduler():
|
||||
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep])
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
|
||||
def training_target(self, sample, noise, timestep):
|
||||
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep])
|
||||
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep])
|
||||
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return target
|
||||
|
||||
245
examples/hunyuan_dit/README.md
Normal file
245
examples/hunyuan_dit/README.md
Normal file
@@ -0,0 +1,245 @@
|
||||
# Hunyuan DiT
|
||||
|
||||
Hunyuan DiT is an image generation model based on DiT. We provide training and inference support for Hunyuan DiT.
|
||||
|
||||
## Inference
|
||||
|
||||
### Text-to-image with highres-fix
|
||||
|
||||
The original resolution of Hunyuan DiT is 1024x1024. If you want to use larger resolutions, please use highres-fix.
|
||||
|
||||
```python
|
||||
from diffsynth import ModelManager, HunyuanDiTImagePipeline
|
||||
import torch
|
||||
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
||||
model_manager.load_models([
|
||||
"models/HunyuanDiT/t2i/clip_text_encoder/pytorch_model.bin",
|
||||
"models/HunyuanDiT/t2i/mt5/pytorch_model.bin",
|
||||
"models/HunyuanDiT/t2i/model/pytorch_model_ema.pt",
|
||||
"models/HunyuanDiT/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"
|
||||
])
|
||||
pipe = HunyuanDiTImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
# Enjoy!
|
||||
torch.manual_seed(0)
|
||||
image = pipe(
|
||||
prompt="少女手捧鲜花,坐在公园的长椅上,夕阳的余晖洒在少女的脸庞,整个画面充满诗意的美感",
|
||||
negative_prompt="错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,",
|
||||
num_inference_steps=50, height=1024, width=1024,
|
||||
)
|
||||
image.save("image_1024.png")
|
||||
|
||||
# Highres fix
|
||||
image = pipe(
|
||||
prompt="少女手捧鲜花,坐在公园的长椅上,夕阳的余晖洒在少女的脸庞,整个画面充满诗意的美感",
|
||||
negative_prompt="错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,",
|
||||
input_image=image.resize((2048, 2048)),
|
||||
num_inference_steps=50, height=2048, width=2048,
|
||||
cfg_scale=3.0, denoising_strength=0.5, tiled=True,
|
||||
)
|
||||
image.save("image_2048.png")
|
||||
```
|
||||
|
||||
Prompt: 少女手捧鲜花,坐在公园的长椅上,夕阳的余晖洒在少女的脸庞,整个画面充满诗意的美感
|
||||
|
||||
|1024x1024|2048x2048 (highres-fix)|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
### In-context reference (experimental)
|
||||
|
||||
This feature is similar to the "reference-only" mode in ControlNets. By extending the self-attention layer, the content in the reference image can be retained in the new image. Any number of reference images are supported, and the influence from each reference image can be controled by independent `reference_strengths` parameters.
|
||||
|
||||
```python
|
||||
from diffsynth import ModelManager, HunyuanDiTImagePipeline
|
||||
import torch
|
||||
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
||||
model_manager.load_models([
|
||||
"models/HunyuanDiT/t2i/clip_text_encoder/pytorch_model.bin",
|
||||
"models/HunyuanDiT/t2i/mt5/pytorch_model.bin",
|
||||
"models/HunyuanDiT/t2i/model/pytorch_model_ema.pt",
|
||||
"models/HunyuanDiT/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"
|
||||
])
|
||||
pipe = HunyuanDiTImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
# Generate an image as reference
|
||||
torch.manual_seed(0)
|
||||
reference_image = pipe(
|
||||
prompt="梵高,星空,油画,明亮",
|
||||
negative_prompt="",
|
||||
num_inference_steps=50, height=1024, width=1024,
|
||||
)
|
||||
reference_image.save("image_reference.png")
|
||||
|
||||
# Generate a new image with reference
|
||||
image = pipe(
|
||||
prompt="层峦叠嶂的山脉,郁郁葱葱的森林,皎洁明亮的月光,夜色下的自然美景",
|
||||
negative_prompt="",
|
||||
reference_images=[reference_image], reference_strengths=[0.4],
|
||||
num_inference_steps=50, height=1024, width=1024,
|
||||
)
|
||||
image.save("image_with_reference.png")
|
||||
```
|
||||
|
||||
Prompt: 层峦叠嶂的山脉,郁郁葱葱的森林,皎洁明亮的月光,夜色下的自然美景
|
||||
|
||||
|Reference image|Generated new image|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
## Train
|
||||
|
||||
### Install training dependency
|
||||
|
||||
```
|
||||
pip install peft lightning pandas torchvision
|
||||
```
|
||||
|
||||
### Prepare your dataset
|
||||
|
||||
We provide an example dataset [here](https://modelscope.cn/datasets/buptwq/lora-stable-diffusion-finetune/files). You need to manage the training images as follows:
|
||||
|
||||
```
|
||||
data/dog/
|
||||
└── train
|
||||
├── 00.jpg
|
||||
├── 01.jpg
|
||||
├── 02.jpg
|
||||
├── 03.jpg
|
||||
├── 04.jpg
|
||||
└── metadata.csv
|
||||
```
|
||||
|
||||
`metadata.csv`:
|
||||
|
||||
```
|
||||
file_name,text
|
||||
00.jpg,一只小狗
|
||||
01.jpg,一只小狗
|
||||
02.jpg,一只小狗
|
||||
03.jpg,一只小狗
|
||||
04.jpg,一只小狗
|
||||
```
|
||||
|
||||
### Train a LoRA model
|
||||
|
||||
We provide a training script `train_hunyuan_dit_lora.py`. Before you run this training script, please copy it to the root directory of this project.
|
||||
|
||||
If GPU memory >= 24GB, we recommmand to use the following settings.
|
||||
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES="0" python train_hunyuan_dit_lora.py \
|
||||
--pretrained_path models/HunyuanDiT/t2i \
|
||||
--dataset_path data/dog \
|
||||
--output_path ./models \
|
||||
--max_epochs 1 \
|
||||
--center_crop
|
||||
```
|
||||
|
||||
If 12GB <= GPU memory <= 24GB, we recommand to enable gradient checkpointing.
|
||||
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES="0" python train_hunyuan_dit_lora.py \
|
||||
--pretrained_path models/HunyuanDiT/t2i \
|
||||
--dataset_path data/dog \
|
||||
--output_path ./models \
|
||||
--max_epochs 1 \
|
||||
--center_crop \
|
||||
--use_gradient_checkpointing
|
||||
```
|
||||
|
||||
Optional arguments:
|
||||
```
|
||||
-h, --help show this help message and exit
|
||||
--pretrained_path PRETRAINED_PATH
|
||||
Path to pretrained model. For example, `./HunyuanDiT/t2i`.
|
||||
--dataset_path DATASET_PATH
|
||||
The path of the Dataset.
|
||||
--output_path OUTPUT_PATH
|
||||
Path to save the model.
|
||||
--steps_per_epoch STEPS_PER_EPOCH
|
||||
Number of steps per epoch.
|
||||
--height HEIGHT Image height.
|
||||
--width WIDTH Image width.
|
||||
--center_crop Whether to center crop the input images to the resolution. If not set, the images will be randomly cropped. The images will be resized to the resolution first before cropping.
|
||||
--random_flip Whether to randomly flip images horizontally
|
||||
--batch_size BATCH_SIZE
|
||||
Batch size (per device) for the training dataloader.
|
||||
--dataloader_num_workers DATALOADER_NUM_WORKERS
|
||||
Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
|
||||
--precision {32,16,16-mixed}
|
||||
Training precision
|
||||
--learning_rate LEARNING_RATE
|
||||
Learning rate.
|
||||
--lora_rank LORA_RANK
|
||||
The dimension of the LoRA update matrices.
|
||||
--lora_alpha LORA_ALPHA
|
||||
The weight of the LoRA update matrices.
|
||||
--use_gradient_checkpointing
|
||||
Whether to use gradient checkpointing.
|
||||
--accumulate_grad_batches ACCUMULATE_GRAD_BATCHES
|
||||
The number of batches in gradient accumulation.
|
||||
--training_strategy {auto,deepspeed_stage_1,deepspeed_stage_2,deepspeed_stage_3}
|
||||
Training strategy
|
||||
--max_epochs MAX_EPOCHS
|
||||
Number of epochs.
|
||||
```
|
||||
|
||||
### Inference with your own LoRA model
|
||||
|
||||
After training, you can use your own LoRA model to generate new images. Here are some examples.
|
||||
|
||||
```python
|
||||
from diffsynth import ModelManager, HunyuanDiTImagePipeline
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
import torch
|
||||
|
||||
|
||||
def load_lora(dit, lora_rank, lora_alpha, lora_path):
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out"],
|
||||
)
|
||||
dit = inject_adapter_in_model(lora_config, dit)
|
||||
state_dict = torch.load(lora_path, map_location="cpu")
|
||||
dit.load_state_dict(state_dict, strict=False)
|
||||
return dit
|
||||
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda")
|
||||
model_manager.load_models([
|
||||
"models/HunyuanDiT/t2i/clip_text_encoder/pytorch_model.bin",
|
||||
"models/HunyuanDiT/t2i/mt5/pytorch_model.bin",
|
||||
"models/HunyuanDiT/t2i/model/pytorch_model_ema.pt",
|
||||
"models/HunyuanDiT/t2i/sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"
|
||||
])
|
||||
pipe = HunyuanDiTImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
# Generate an image with lora
|
||||
pipe.dit = load_lora(
|
||||
pipe.dit, lora_rank=4, lora_alpha=4.0,
|
||||
lora_path="path/to/your/lora/model/lightning_logs/version_x/checkpoints/epoch=x-step=xxx.ckpt"
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
image = pipe(
|
||||
prompt="一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉",
|
||||
negative_prompt="",
|
||||
num_inference_steps=50, height=1024, width=1024,
|
||||
)
|
||||
image.save("image_with_lora.png")
|
||||
```
|
||||
|
||||
Prompt: 一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉
|
||||
|
||||
|Without LoRA|With LoRA|
|
||||
|-|-|
|
||||
|||
|
||||
298
examples/hunyuan_dit/train_hunyuan_dit_lora.py
Normal file
298
examples/hunyuan_dit/train_hunyuan_dit_lora.py
Normal file
@@ -0,0 +1,298 @@
|
||||
from diffsynth import ModelManager, HunyuanDiTImagePipeline
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import lightning as pl
|
||||
import pandas as pd
|
||||
import torch, os, argparse
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "True"
|
||||
|
||||
|
||||
|
||||
class TextImageDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
|
||||
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
self.text = metadata["text"].to_list()
|
||||
self.image_processor = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
|
||||
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data_id = torch.randint(0, len(self.path), (1,))[0]
|
||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||
text = self.text[data_id]
|
||||
image = Image.open(self.path[data_id]).convert("RGB")
|
||||
image = self.image_processor(image)
|
||||
return {"text": text, "image": image}
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.steps_per_epoch
|
||||
|
||||
|
||||
|
||||
class LightningModel(pl.LightningModule):
|
||||
def __init__(self, torch_dtype=torch.float16, learning_rate=1e-4, pretrained_weights=[], lora_rank=4, lora_alpha=4, use_gradient_checkpointing=True):
|
||||
super().__init__()
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
|
||||
model_manager.load_models(pretrained_weights)
|
||||
self.pipe = HunyuanDiTImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
# Freeze parameters
|
||||
self.pipe.text_encoder.requires_grad_(False)
|
||||
self.pipe.text_encoder_t5.requires_grad_(False)
|
||||
self.pipe.dit.requires_grad_(False)
|
||||
self.pipe.vae_decoder.requires_grad_(False)
|
||||
self.pipe.vae_encoder.requires_grad_(False)
|
||||
self.pipe.text_encoder.eval()
|
||||
self.pipe.text_encoder_t5.eval()
|
||||
self.pipe.dit.train()
|
||||
self.pipe.vae_decoder.eval()
|
||||
self.pipe.vae_encoder.eval()
|
||||
|
||||
# Add LoRA to DiT
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out"],
|
||||
)
|
||||
self.pipe.dit = inject_adapter_in_model(lora_config, self.pipe.dit)
|
||||
for param in self.pipe.dit.parameters():
|
||||
# Upcast LoRA parameters into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
# Set other parameters
|
||||
self.learning_rate = learning_rate
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# Data
|
||||
text, image = batch["text"], batch["image"]
|
||||
|
||||
# Prepare input parameters
|
||||
self.pipe.device = self.device
|
||||
prompt_emb, attention_mask, prompt_emb_t5, attention_mask_t5 = self.pipe.prompter.encode_prompt(
|
||||
self.pipe.text_encoder, self.pipe.text_encoder_t5, text, positive=True, device=self.device
|
||||
)
|
||||
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
|
||||
noise = torch.randn_like(latents)
|
||||
timestep = torch.randint(0, 1000, (1,), device=self.device)
|
||||
extra_input = self.pipe.prepare_extra_input(image.shape[-2], image.shape[-1], batch_size=latents.shape[0])
|
||||
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||
|
||||
# Compute loss
|
||||
noise_pred = self.pipe.dit(
|
||||
noisy_latents,
|
||||
prompt_emb, prompt_emb_t5, attention_mask, attention_mask_t5,
|
||||
timestep,
|
||||
**extra_input,
|
||||
use_gradient_checkpointing=self.use_gradient_checkpointing
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred, training_target)
|
||||
|
||||
# Record log
|
||||
self.log("train_loss", loss, prog_bar=True)
|
||||
return loss
|
||||
|
||||
|
||||
def configure_optimizers(self):
|
||||
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.dit.parameters())
|
||||
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
checkpoint.clear()
|
||||
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.dit.named_parameters()))
|
||||
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||
state_dict = self.pipe.dit.state_dict()
|
||||
for name, param in state_dict.items():
|
||||
if name in trainable_param_names:
|
||||
checkpoint[name] = param
|
||||
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model. For example, `./HunyuanDiT/t2i`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="The path of the Dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default="./",
|
||||
help="Path to save the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps_per_epoch",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of steps per epoch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Image height.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Image width.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
||||
" cropped. The images will be resized to the resolution first before cropping."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random_flip",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to randomly flip images horizontally",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Batch size (per device) for the training dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
default="16-mixed",
|
||||
choices=["32", "16", "16-mixed"],
|
||||
help="Training precision",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="Learning rate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_rank",
|
||||
type=int,
|
||||
default=4,
|
||||
help="The dimension of the LoRA update matrices.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_alpha",
|
||||
type=float,
|
||||
default=4.0,
|
||||
help="The weight of the LoRA update matrices.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_gradient_checkpointing",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use gradient checkpointing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--accumulate_grad_batches",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of batches in gradient accumulation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--training_strategy",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"],
|
||||
help="Training strategy",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_epochs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of epochs.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# args
|
||||
args = parse_args()
|
||||
|
||||
# dataset and data loader
|
||||
dataset = TextImageDataset(
|
||||
args.dataset_path,
|
||||
steps_per_epoch=args.steps_per_epoch,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
center_crop=args.center_crop,
|
||||
random_flip=args.random_flip
|
||||
)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.dataloader_num_workers
|
||||
)
|
||||
|
||||
# model
|
||||
model = LightningModel(
|
||||
pretrained_weights=[
|
||||
os.path.join(args.pretrained_path, "clip_text_encoder/pytorch_model.bin"),
|
||||
os.path.join(args.pretrained_path, "mt5/pytorch_model.bin"),
|
||||
os.path.join(args.pretrained_path, "model/pytorch_model_ema.pt"),
|
||||
os.path.join(args.pretrained_path, "sdxl-vae-fp16-fix/diffusion_pytorch_model.bin"),
|
||||
],
|
||||
torch_dtype=torch.float32 if args.precision == "32" else torch.float16,
|
||||
learning_rate=args.learning_rate,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing
|
||||
)
|
||||
|
||||
# train
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=args.max_epochs,
|
||||
accelerator="gpu",
|
||||
devices="auto",
|
||||
precision=args.precision,
|
||||
strategy=args.training_strategy,
|
||||
default_root_dir=args.output_path,
|
||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
|
||||
)
|
||||
trainer.fit(model=model, train_dataloaders=train_loader)
|
||||
Reference in New Issue
Block a user