mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
@@ -43,9 +43,13 @@ from ..models.cog_dit import CogDiT
|
||||
|
||||
from ..models.omnigen import OmniGenTransformer
|
||||
|
||||
from ..models.hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
|
||||
from ..models.hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
|
||||
|
||||
from ..extensions.RIFE import IFNet
|
||||
from ..extensions.ESRGAN import RRDBNet
|
||||
|
||||
from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
||||
|
||||
|
||||
model_loader_configs = [
|
||||
@@ -93,6 +97,10 @@ model_loader_configs = [
|
||||
(None, "51aed3d27d482fceb5e0739b03060e8f", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||
(None, "98cc34ccc5b54ae0e56bdea8688dcd5a", ["sd3_text_encoder_2"], [SD3TextEncoder2], "civitai"),
|
||||
(None, "77ff18050dbc23f50382e45d51a779fe", ["sd3_dit", "sd3_vae_encoder", "sd3_vae_decoder"], [SD3DiT, SD3VAEEncoder, SD3VAEDecoder], "civitai"),
|
||||
(None, "5da81baee73198a7c19e6d2fe8b5148e", ["sd3_text_encoder_1"], [SD3TextEncoder1], "diffusers"),
|
||||
(None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
|
||||
(None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
||||
(None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
|
||||
]
|
||||
huggingface_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
@@ -101,10 +109,11 @@ huggingface_model_loader_configs = [
|
||||
("MarianMTModel", "transformers.models.marian.modeling_marian", "translator", None),
|
||||
("BloomForCausalLM", "transformers.models.bloom.modeling_bloom", "beautiful_prompt", None),
|
||||
("Qwen2ForCausalLM", "transformers.models.qwen2.modeling_qwen2", "qwen_prompt", None),
|
||||
("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
|
||||
# ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "omost_prompt", None),
|
||||
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
||||
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
||||
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel")
|
||||
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
||||
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder")
|
||||
]
|
||||
patch_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
@@ -627,6 +636,25 @@ preset_models_on_modelscope = {
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/clip_g.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
("AI-ModelScope/stable-diffusion-3.5-large", "text_encoders/t5xxl_fp16.safetensors", "models/stable_diffusion_3/text_encoders"),
|
||||
],
|
||||
"HunyuanVideo":{
|
||||
"file_list": [
|
||||
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00001-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00002-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00003-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model-00004-of-00004.safetensors", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "config.json", "models/HunyuanVideo/text_encoder_2"),
|
||||
("DiffSynth-Studio/HunyuanVideo_MLLM_text_encoder", "model.safetensors.index.json", "models/HunyuanVideo/text_encoder_2"),
|
||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/vae/pytorch_model.pt", "models/HunyuanVideo/vae"),
|
||||
("AI-ModelScope/HunyuanVideo", "hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideo/transformers")
|
||||
],
|
||||
"load_path": [
|
||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideo/text_encoder_2",
|
||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
},
|
||||
}
|
||||
Preset_model_id: TypeAlias = Literal[
|
||||
"HunyuanDiT",
|
||||
@@ -682,4 +710,5 @@ Preset_model_id: TypeAlias = Literal[
|
||||
"Annotators:Openpose",
|
||||
"StableDiffusion3.5-large",
|
||||
"StableDiffusion3.5-medium",
|
||||
"HunyuanVideo",
|
||||
]
|
||||
|
||||
885
diffsynth/models/hunyuan_video_dit.py
Normal file
885
diffsynth/models/hunyuan_video_dit.py
Normal file
@@ -0,0 +1,885 @@
|
||||
import torch
|
||||
from .sd3_dit import TimestepEmbeddings, RMSNorm
|
||||
from .utils import init_weights_on_device
|
||||
from einops import rearrange, repeat
|
||||
from tqdm import tqdm
|
||||
from typing import Union, Tuple, List
|
||||
|
||||
|
||||
def HunyuanVideoRope(latents):
|
||||
def _to_tuple(x, dim=2):
|
||||
if isinstance(x, int):
|
||||
return (x,) * dim
|
||||
elif len(x) == dim:
|
||||
return x
|
||||
else:
|
||||
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
||||
|
||||
|
||||
def get_meshgrid_nd(start, *args, dim=2):
|
||||
"""
|
||||
Get n-D meshgrid with start, stop and num.
|
||||
|
||||
Args:
|
||||
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
|
||||
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
|
||||
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
|
||||
n-tuples.
|
||||
*args: See above.
|
||||
dim (int): Dimension of the meshgrid. Defaults to 2.
|
||||
|
||||
Returns:
|
||||
grid (np.ndarray): [dim, ...]
|
||||
"""
|
||||
if len(args) == 0:
|
||||
# start is grid_size
|
||||
num = _to_tuple(start, dim=dim)
|
||||
start = (0,) * dim
|
||||
stop = num
|
||||
elif len(args) == 1:
|
||||
# start is start, args[0] is stop, step is 1
|
||||
start = _to_tuple(start, dim=dim)
|
||||
stop = _to_tuple(args[0], dim=dim)
|
||||
num = [stop[i] - start[i] for i in range(dim)]
|
||||
elif len(args) == 2:
|
||||
# start is start, args[0] is stop, args[1] is num
|
||||
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
|
||||
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
|
||||
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
|
||||
else:
|
||||
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
||||
|
||||
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
|
||||
axis_grid = []
|
||||
for i in range(dim):
|
||||
a, b, n = start[i], stop[i], num[i]
|
||||
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
||||
axis_grid.append(g)
|
||||
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
|
||||
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def get_1d_rotary_pos_embed(
|
||||
dim: int,
|
||||
pos: Union[torch.FloatTensor, int],
|
||||
theta: float = 10000.0,
|
||||
use_real: bool = False,
|
||||
theta_rescale_factor: float = 1.0,
|
||||
interpolation_factor: float = 1.0,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
|
||||
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
|
||||
|
||||
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
|
||||
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
||||
The returned tensor contains complex values in complex64 data type.
|
||||
|
||||
Args:
|
||||
dim (int): Dimension of the frequency tensor.
|
||||
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
|
||||
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (bool, optional): If True, return real part and imaginary part separately.
|
||||
Otherwise, return complex numbers.
|
||||
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
|
||||
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
|
||||
"""
|
||||
if isinstance(pos, int):
|
||||
pos = torch.arange(pos).float()
|
||||
|
||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||
# has some connection to NTK literature
|
||||
if theta_rescale_factor != 1.0:
|
||||
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
||||
|
||||
freqs = 1.0 / (
|
||||
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
||||
) # [D/2]
|
||||
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
|
||||
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
|
||||
if use_real:
|
||||
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
||||
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
||||
return freqs_cos, freqs_sin
|
||||
else:
|
||||
freqs_cis = torch.polar(
|
||||
torch.ones_like(freqs), freqs
|
||||
) # complex64 # [S, D/2]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
def get_nd_rotary_pos_embed(
|
||||
rope_dim_list,
|
||||
start,
|
||||
*args,
|
||||
theta=10000.0,
|
||||
use_real=False,
|
||||
theta_rescale_factor: Union[float, List[float]] = 1.0,
|
||||
interpolation_factor: Union[float, List[float]] = 1.0,
|
||||
):
|
||||
"""
|
||||
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
|
||||
|
||||
Args:
|
||||
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
|
||||
sum(rope_dim_list) should equal to head_dim of attention layer.
|
||||
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
|
||||
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
||||
*args: See above.
|
||||
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
|
||||
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
||||
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
|
||||
part and an imaginary part separately.
|
||||
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
|
||||
|
||||
Returns:
|
||||
pos_embed (torch.Tensor): [HW, D/2]
|
||||
"""
|
||||
|
||||
grid = get_meshgrid_nd(
|
||||
start, *args, dim=len(rope_dim_list)
|
||||
) # [3, W, H, D] / [2, W, H]
|
||||
|
||||
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
|
||||
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
|
||||
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
|
||||
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
|
||||
assert len(theta_rescale_factor) == len(
|
||||
rope_dim_list
|
||||
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
|
||||
|
||||
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
|
||||
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
|
||||
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
|
||||
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
|
||||
assert len(interpolation_factor) == len(
|
||||
rope_dim_list
|
||||
), "len(interpolation_factor) should equal to len(rope_dim_list)"
|
||||
|
||||
# use 1/ndim of dimensions to encode grid_axis
|
||||
embs = []
|
||||
for i in range(len(rope_dim_list)):
|
||||
emb = get_1d_rotary_pos_embed(
|
||||
rope_dim_list[i],
|
||||
grid[i].reshape(-1),
|
||||
theta,
|
||||
use_real=use_real,
|
||||
theta_rescale_factor=theta_rescale_factor[i],
|
||||
interpolation_factor=interpolation_factor[i],
|
||||
) # 2 x [WHD, rope_dim_list[i]]
|
||||
embs.append(emb)
|
||||
|
||||
if use_real:
|
||||
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
|
||||
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
|
||||
return cos, sin
|
||||
else:
|
||||
emb = torch.cat(embs, dim=1) # (WHD, D/2)
|
||||
return emb
|
||||
|
||||
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
||||
[16, 56, 56],
|
||||
[latents.shape[2], latents.shape[3] // 2, latents.shape[4] // 2],
|
||||
theta=256,
|
||||
use_real=True,
|
||||
theta_rescale_factor=1,
|
||||
)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
class PatchEmbed(torch.nn.Module):
|
||||
def __init__(self, patch_size=(1, 2, 2), in_channels=16, embed_dim=3072):
|
||||
super().__init__()
|
||||
self.proj = torch.nn.Conv3d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class IndividualTokenRefinerBlock(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, num_heads=24):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.self_attn_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
||||
self.self_attn_proj = torch.nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
|
||||
self.mlp = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden_size, hidden_size * 4),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size * 4, hidden_size)
|
||||
)
|
||||
self.adaLN_modulation = torch.nn.Sequential(
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size * 2, device="cuda", dtype=torch.bfloat16),
|
||||
)
|
||||
|
||||
def forward(self, x, c, attn_mask=None):
|
||||
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
|
||||
norm_x = self.norm1(x)
|
||||
qkv = self.self_attn_qkv(norm_x)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
|
||||
attn = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
attn = rearrange(attn, "B H L D -> B L (H D)")
|
||||
|
||||
x = x + self.self_attn_proj(attn) * gate_msa.unsqueeze(1)
|
||||
x = x + self.mlp(self.norm2(x)) * gate_mlp.unsqueeze(1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SingleTokenRefiner(torch.nn.Module):
|
||||
def __init__(self, in_channels=4096, hidden_size=3072, depth=2):
|
||||
super().__init__()
|
||||
self.input_embedder = torch.nn.Linear(in_channels, hidden_size, bias=True)
|
||||
self.t_embedder = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
||||
self.c_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(in_channels, hidden_size),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size)
|
||||
)
|
||||
self.blocks = torch.nn.ModuleList([IndividualTokenRefinerBlock(hidden_size=hidden_size) for _ in range(depth)])
|
||||
|
||||
def forward(self, x, t, mask=None):
|
||||
timestep_aware_representations = self.t_embedder(t, dtype=torch.float32)
|
||||
|
||||
mask_float = mask.float().unsqueeze(-1)
|
||||
context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
||||
context_aware_representations = self.c_embedder(context_aware_representations)
|
||||
c = timestep_aware_representations + context_aware_representations
|
||||
|
||||
x = self.input_embedder(x)
|
||||
|
||||
mask = mask.to(device=x.device, dtype=torch.bool)
|
||||
mask = repeat(mask, "B L -> B 1 D L", D=mask.shape[-1])
|
||||
mask = mask & mask.transpose(2, 3)
|
||||
mask[:, :, :, 0] = True
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, c, mask)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ModulateDiT(torch.nn.Module):
|
||||
def __init__(self, hidden_size, factor=6):
|
||||
super().__init__()
|
||||
self.act = torch.nn.SiLU()
|
||||
self.linear = torch.nn.Linear(hidden_size, factor * hidden_size)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(self.act(x))
|
||||
|
||||
|
||||
def modulate(x, shift=None, scale=None):
|
||||
if scale is None and shift is None:
|
||||
return x
|
||||
elif shift is None:
|
||||
return x * (1 + scale.unsqueeze(1))
|
||||
elif scale is None:
|
||||
return x + shift.unsqueeze(1)
|
||||
else:
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
def reshape_for_broadcast(
|
||||
freqs_cis,
|
||||
x: torch.Tensor,
|
||||
head_first=False,
|
||||
):
|
||||
ndim = x.ndim
|
||||
assert 0 <= 1 < ndim
|
||||
|
||||
if isinstance(freqs_cis, tuple):
|
||||
# freqs_cis: (cos, sin) in real space
|
||||
if head_first:
|
||||
assert freqs_cis[0].shape == (
|
||||
x.shape[-2],
|
||||
x.shape[-1],
|
||||
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
||||
shape = [
|
||||
d if i == ndim - 2 or i == ndim - 1 else 1
|
||||
for i, d in enumerate(x.shape)
|
||||
]
|
||||
else:
|
||||
assert freqs_cis[0].shape == (
|
||||
x.shape[1],
|
||||
x.shape[-1],
|
||||
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
||||
else:
|
||||
# freqs_cis: values in complex space
|
||||
if head_first:
|
||||
assert freqs_cis.shape == (
|
||||
x.shape[-2],
|
||||
x.shape[-1],
|
||||
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
||||
shape = [
|
||||
d if i == ndim - 2 or i == ndim - 1 else 1
|
||||
for i, d in enumerate(x.shape)
|
||||
]
|
||||
else:
|
||||
assert freqs_cis.shape == (
|
||||
x.shape[1],
|
||||
x.shape[-1],
|
||||
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
||||
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
||||
return freqs_cis.view(*shape)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x_real, x_imag = (
|
||||
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
||||
) # [B, S, H, D//2]
|
||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
xq: torch.Tensor,
|
||||
xk: torch.Tensor,
|
||||
freqs_cis,
|
||||
head_first: bool = False,
|
||||
):
|
||||
xk_out = None
|
||||
if isinstance(freqs_cis, tuple):
|
||||
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
||||
# real * cos - imag * sin
|
||||
# imag * cos + real * sin
|
||||
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
||||
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
||||
else:
|
||||
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
|
||||
xq_ = torch.view_as_complex(
|
||||
xq.float().reshape(*xq.shape[:-1], -1, 2)
|
||||
) # [B, S, H, D//2]
|
||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
|
||||
xq.device
|
||||
) # [S, D//2] --> [1, S, 1, D//2]
|
||||
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
|
||||
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
||||
xk_ = torch.view_as_complex(
|
||||
xk.float().reshape(*xk.shape[:-1], -1, 2)
|
||||
) # [B, S, H, D//2]
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
||||
|
||||
return xq_out, xk_out
|
||||
|
||||
|
||||
def attention(q, k, v):
|
||||
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = x.transpose(1, 2).flatten(2, 3)
|
||||
return x
|
||||
|
||||
|
||||
class MMDoubleStreamBlockComponent(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
super().__init__()
|
||||
self.heads_num = heads_num
|
||||
|
||||
self.mod = ModulateDiT(hidden_size)
|
||||
self.norm1 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
||||
self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
self.to_out = torch.nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
self.norm2 = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None):
|
||||
mod1_shift, mod1_scale, mod1_gate, mod2_shift, mod2_scale, mod2_gate = self.mod(conditioning).chunk(6, dim=-1)
|
||||
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod1_shift, scale=mod1_scale)
|
||||
qkv = self.to_qkv(norm_hidden_states)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
|
||||
q = self.norm_q(q)
|
||||
k = self.norm_k(k)
|
||||
|
||||
if freqs_cis is not None:
|
||||
q, k = apply_rotary_emb(q, k, freqs_cis, head_first=False)
|
||||
|
||||
return (q, k, v), (mod1_gate, mod2_shift, mod2_scale, mod2_gate)
|
||||
|
||||
def process_ff(self, hidden_states, attn_output, mod):
|
||||
mod1_gate, mod2_shift, mod2_scale, mod2_gate = mod
|
||||
hidden_states = hidden_states + self.to_out(attn_output) * mod1_gate.unsqueeze(1)
|
||||
hidden_states = hidden_states + self.ff(modulate(self.norm2(hidden_states), shift=mod2_shift, scale=mod2_scale)) * mod2_gate.unsqueeze(1)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MMDoubleStreamBlock(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
super().__init__()
|
||||
self.component_a = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||
self.component_b = MMDoubleStreamBlockComponent(hidden_size, heads_num, mlp_width_ratio)
|
||||
|
||||
def forward(self, hidden_states_a, hidden_states_b, conditioning, freqs_cis):
|
||||
(q_a, k_a, v_a), mod_a = self.component_a(hidden_states_a, conditioning, freqs_cis)
|
||||
(q_b, k_b, v_b), mod_b = self.component_b(hidden_states_b, conditioning, freqs_cis=None)
|
||||
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
|
||||
v_a, v_b = torch.concat([v_a, v_b[:, :71]], dim=1), v_b[:, 71:].contiguous()
|
||||
attn_output_a = attention(q_a, k_a, v_a)
|
||||
attn_output_b = attention(q_b, k_b, v_b)
|
||||
attn_output_a, attn_output_b = attn_output_a[:, :-71].contiguous(), torch.concat([attn_output_a[:, -71:], attn_output_b], dim=1)
|
||||
|
||||
hidden_states_a = self.component_a.process_ff(hidden_states_a, attn_output_a, mod_a)
|
||||
hidden_states_b = self.component_b.process_ff(hidden_states_b, attn_output_b, mod_b)
|
||||
return hidden_states_a, hidden_states_b
|
||||
|
||||
|
||||
class MMSingleStreamBlockOriginal(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.heads_num = heads_num
|
||||
self.mlp_hidden_dim = hidden_size * mlp_width_ratio
|
||||
|
||||
self.linear1 = torch.nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
||||
self.linear2 = torch.nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
||||
|
||||
self.q_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
self.k_norm = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
|
||||
self.pre_norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.mlp_act = torch.nn.GELU(approximate="tanh")
|
||||
self.modulation = ModulateDiT(hidden_size, factor=3)
|
||||
|
||||
def forward(self, x, vec, freqs_cis=None, txt_len=256):
|
||||
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
|
||||
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
||||
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
||||
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
||||
q = torch.cat((q_a, q_b), dim=1)
|
||||
k = torch.cat((k_a, k_b), dim=1)
|
||||
|
||||
attn_output_a = attention(q[:, :-185].contiguous(), k[:, :-185].contiguous(), v[:, :-185].contiguous())
|
||||
attn_output_b = attention(q[:, -185:].contiguous(), k[:, -185:].contiguous(), v[:, -185:].contiguous())
|
||||
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
||||
|
||||
output = self.linear2(torch.cat((attn_output, self.mlp_act(mlp)), 2))
|
||||
return x + output * mod_gate.unsqueeze(1)
|
||||
|
||||
|
||||
class MMSingleStreamBlock(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, heads_num=24, mlp_width_ratio=4):
|
||||
super().__init__()
|
||||
self.heads_num = heads_num
|
||||
|
||||
self.mod = ModulateDiT(hidden_size, factor=3)
|
||||
self.norm = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
self.to_qkv = torch.nn.Linear(hidden_size, hidden_size * 3)
|
||||
self.norm_q = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
self.norm_k = RMSNorm(dim=hidden_size // heads_num, eps=1e-6)
|
||||
self.to_out = torch.nn.Linear(hidden_size, hidden_size)
|
||||
|
||||
self.ff = torch.nn.Sequential(
|
||||
torch.nn.Linear(hidden_size, hidden_size * mlp_width_ratio),
|
||||
torch.nn.GELU(approximate="tanh"),
|
||||
torch.nn.Linear(hidden_size * mlp_width_ratio, hidden_size, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, hidden_states, conditioning, freqs_cis=None, txt_len=256):
|
||||
mod_shift, mod_scale, mod_gate = self.mod(conditioning).chunk(3, dim=-1)
|
||||
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
norm_hidden_states = modulate(norm_hidden_states, shift=mod_shift, scale=mod_scale)
|
||||
qkv = self.to_qkv(norm_hidden_states)
|
||||
|
||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
||||
|
||||
q = self.norm_q(q)
|
||||
k = self.norm_k(k)
|
||||
|
||||
q_a, q_b = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
||||
k_a, k_b = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
||||
q_a, k_a = apply_rotary_emb(q_a, k_a, freqs_cis, head_first=False)
|
||||
|
||||
q_a, q_b = torch.concat([q_a, q_b[:, :71]], dim=1), q_b[:, 71:].contiguous()
|
||||
k_a, k_b = torch.concat([k_a, k_b[:, :71]], dim=1), k_b[:, 71:].contiguous()
|
||||
v_a, v_b = v[:, :-185].contiguous(), v[:, -185:].contiguous()
|
||||
|
||||
attn_output_a = attention(q_a, k_a, v_a)
|
||||
attn_output_b = attention(q_b, k_b, v_b)
|
||||
attn_output = torch.concat([attn_output_a, attn_output_b], dim=1)
|
||||
|
||||
hidden_states = hidden_states + self.to_out(attn_output) * mod_gate.unsqueeze(1)
|
||||
hidden_states = hidden_states + self.ff(norm_hidden_states) * mod_gate.unsqueeze(1)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FinalLayer(torch.nn.Module):
|
||||
def __init__(self, hidden_size=3072, patch_size=(1, 2, 2), out_channels=16):
|
||||
super().__init__()
|
||||
|
||||
self.norm_final = torch.nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = torch.nn.Linear(hidden_size, patch_size[0] * patch_size[1] * patch_size[2] * out_channels)
|
||||
|
||||
self.adaLN_modulation = torch.nn.Sequential(torch.nn.SiLU(), torch.nn.Linear(hidden_size, 2 * hidden_size))
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class HunyuanVideoDiT(torch.nn.Module):
|
||||
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40):
|
||||
super().__init__()
|
||||
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
|
||||
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
|
||||
self.time_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
||||
self.vector_in = torch.nn.Sequential(
|
||||
torch.nn.Linear(768, hidden_size),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(hidden_size, hidden_size)
|
||||
)
|
||||
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu")
|
||||
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
|
||||
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
|
||||
self.final_layer = FinalLayer(hidden_size)
|
||||
|
||||
# TODO: remove these parameters
|
||||
self.dtype = torch.bfloat16
|
||||
self.patch_size = [1, 2, 2]
|
||||
self.hidden_size = 3072
|
||||
self.heads_num = 24
|
||||
self.rope_dim_list = [16, 56, 56]
|
||||
|
||||
def unpatchify(self, x, T, H, W):
|
||||
x = rearrange(x, "B (T H W) (C pT pH pW) -> B C (T pT) (H pH) (W pW)", H=H, W=W, pT=1, pH=2, pW=2)
|
||||
return x
|
||||
|
||||
def enable_block_wise_offload(self, warm_device="cuda", cold_device="cpu"):
|
||||
self.warm_device = warm_device
|
||||
self.cold_device = cold_device
|
||||
self.to(self.cold_device)
|
||||
|
||||
def load_models_to_device(self, loadmodel_names=[], device="cpu"):
|
||||
for model_name in loadmodel_names:
|
||||
model = getattr(self, model_name)
|
||||
if model is not None:
|
||||
model.to(device)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def prepare_freqs(self, latents):
|
||||
return HunyuanVideoRope(latents)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
prompt_emb: torch.Tensor = None,
|
||||
text_mask: torch.Tensor = None,
|
||||
pooled_prompt_emb: torch.Tensor = None,
|
||||
freqs_cos: torch.Tensor = None,
|
||||
freqs_sin: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
**kwargs
|
||||
):
|
||||
B, C, T, H, W = x.shape
|
||||
|
||||
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32)
|
||||
img = self.img_in(x)
|
||||
txt = self.txt_in(prompt_emb, t, text_mask)
|
||||
|
||||
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
||||
|
||||
x = torch.concat([img, txt], dim=1)
|
||||
for block in tqdm(self.single_blocks, desc="Single stream blocks"):
|
||||
x = block(x, vec, (freqs_cos, freqs_sin))
|
||||
|
||||
img = x[:, :-256]
|
||||
img = self.final_layer(img, vec)
|
||||
img = self.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
||||
return img
|
||||
|
||||
|
||||
def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"):
|
||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight)
|
||||
return r
|
||||
|
||||
def cast_weight(s, input=None, dtype=None, device=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
return weight
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
bias = cast_to(s.bias, bias_dtype, device) if s.bias is not None else None
|
||||
return weight, bias
|
||||
|
||||
class quantized_layer:
|
||||
class Linear(torch.nn.Linear):
|
||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def block_forward_(self, x, i, j, dtype, device):
|
||||
weight_ = cast_to(
|
||||
self.weight[j * self.block_size: (j + 1) * self.block_size, i * self.block_size: (i + 1) * self.block_size],
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
if self.bias is None or i > 0:
|
||||
bias_ = None
|
||||
else:
|
||||
bias_ = cast_to(self.bias[j * self.block_size: (j + 1) * self.block_size], dtype=dtype, device=device)
|
||||
x_ = x[..., i * self.block_size: (i + 1) * self.block_size]
|
||||
y_ = torch.nn.functional.linear(x_, weight_, bias_)
|
||||
del x_, weight_, bias_
|
||||
torch.cuda.empty_cache()
|
||||
return y_
|
||||
|
||||
def block_forward(self, x, **kwargs):
|
||||
# This feature can only reduce 2GB VRAM, so we disable it.
|
||||
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device)
|
||||
for i in range((self.in_features + self.block_size - 1) // self.block_size):
|
||||
for j in range((self.out_features + self.block_size - 1) // self.block_size):
|
||||
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device)
|
||||
return y
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, module, dtype=torch.bfloat16, device="cuda"):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def forward(self, hidden_states, **kwargs):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
if self.module.weight is not None:
|
||||
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda")
|
||||
hidden_states = hidden_states * weight
|
||||
return hidden_states
|
||||
|
||||
class Conv3d(torch.nn.Conv3d):
|
||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def forward(self, x):
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
def forward(self, x):
|
||||
if self.weight is not None and self.bias is not None:
|
||||
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device)
|
||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps)
|
||||
else:
|
||||
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
|
||||
def replace_layer(model, dtype=torch.bfloat16, device="cuda"):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
with init_weights_on_device():
|
||||
new_layer = quantized_layer.Linear(
|
||||
module.in_features, module.out_features, bias=module.bias is not None,
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
new_layer.load_state_dict(module.state_dict(), assign=True)
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, torch.nn.Conv3d):
|
||||
with init_weights_on_device():
|
||||
new_layer = quantized_layer.Conv3d(
|
||||
module.in_channels, module.out_channels, kernel_size=module.kernel_size, stride=module.stride,
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
new_layer.load_state_dict(module.state_dict(), assign=True)
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, RMSNorm):
|
||||
new_layer = quantized_layer.RMSNorm(
|
||||
module,
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, torch.nn.LayerNorm):
|
||||
with init_weights_on_device():
|
||||
new_layer = quantized_layer.LayerNorm(
|
||||
module.normalized_shape, elementwise_affine=module.elementwise_affine, eps=module.eps,
|
||||
dtype=dtype, device=device
|
||||
)
|
||||
new_layer.load_state_dict(module.state_dict(), assign=True)
|
||||
setattr(model, name, new_layer)
|
||||
else:
|
||||
replace_layer(module, dtype=dtype, device=device)
|
||||
|
||||
replace_layer(self, dtype=dtype, device=device)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanVideoDiTStateDictConverter()
|
||||
|
||||
|
||||
|
||||
class HunyuanVideoDiTStateDictConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_civitai(self, state_dict):
|
||||
if "module" in state_dict:
|
||||
state_dict = state_dict["module"]
|
||||
direct_dict = {
|
||||
"img_in.proj": "img_in.proj",
|
||||
"time_in.mlp.0": "time_in.timestep_embedder.0",
|
||||
"time_in.mlp.2": "time_in.timestep_embedder.2",
|
||||
"vector_in.in_layer": "vector_in.0",
|
||||
"vector_in.out_layer": "vector_in.2",
|
||||
"guidance_in.mlp.0": "guidance_in.timestep_embedder.0",
|
||||
"guidance_in.mlp.2": "guidance_in.timestep_embedder.2",
|
||||
"txt_in.input_embedder": "txt_in.input_embedder",
|
||||
"txt_in.t_embedder.mlp.0": "txt_in.t_embedder.timestep_embedder.0",
|
||||
"txt_in.t_embedder.mlp.2": "txt_in.t_embedder.timestep_embedder.2",
|
||||
"txt_in.c_embedder.linear_1": "txt_in.c_embedder.0",
|
||||
"txt_in.c_embedder.linear_2": "txt_in.c_embedder.2",
|
||||
"final_layer.linear": "final_layer.linear",
|
||||
"final_layer.adaLN_modulation.1": "final_layer.adaLN_modulation.1",
|
||||
}
|
||||
txt_suffix_dict = {
|
||||
"norm1": "norm1",
|
||||
"self_attn_qkv": "self_attn_qkv",
|
||||
"self_attn_proj": "self_attn_proj",
|
||||
"norm2": "norm2",
|
||||
"mlp.fc1": "mlp.0",
|
||||
"mlp.fc2": "mlp.2",
|
||||
"adaLN_modulation.1": "adaLN_modulation.1",
|
||||
}
|
||||
double_suffix_dict = {
|
||||
"img_mod.linear": "component_a.mod.linear",
|
||||
"img_attn_qkv": "component_a.to_qkv",
|
||||
"img_attn_q_norm": "component_a.norm_q",
|
||||
"img_attn_k_norm": "component_a.norm_k",
|
||||
"img_attn_proj": "component_a.to_out",
|
||||
"img_mlp.fc1": "component_a.ff.0",
|
||||
"img_mlp.fc2": "component_a.ff.2",
|
||||
"txt_mod.linear": "component_b.mod.linear",
|
||||
"txt_attn_qkv": "component_b.to_qkv",
|
||||
"txt_attn_q_norm": "component_b.norm_q",
|
||||
"txt_attn_k_norm": "component_b.norm_k",
|
||||
"txt_attn_proj": "component_b.to_out",
|
||||
"txt_mlp.fc1": "component_b.ff.0",
|
||||
"txt_mlp.fc2": "component_b.ff.2",
|
||||
}
|
||||
single_suffix_dict = {
|
||||
"linear1": ["to_qkv", "ff.0"],
|
||||
"linear2": ["to_out", "ff.2"],
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
"modulation.linear": "mod.linear",
|
||||
}
|
||||
# single_suffix_dict = {
|
||||
# "linear1": "linear1",
|
||||
# "linear2": "linear2",
|
||||
# "q_norm": "q_norm",
|
||||
# "k_norm": "k_norm",
|
||||
# "modulation.linear": "modulation.linear",
|
||||
# }
|
||||
state_dict_ = {}
|
||||
for name, param in state_dict.items():
|
||||
names = name.split(".")
|
||||
direct_name = ".".join(names[:-1])
|
||||
if direct_name in direct_dict:
|
||||
name_ = direct_dict[direct_name] + "." + names[-1]
|
||||
state_dict_[name_] = param
|
||||
elif names[0] == "double_blocks":
|
||||
prefix = ".".join(names[:2])
|
||||
suffix = ".".join(names[2:-1])
|
||||
name_ = prefix + "." + double_suffix_dict[suffix] + "." + names[-1]
|
||||
state_dict_[name_] = param
|
||||
elif names[0] == "single_blocks":
|
||||
prefix = ".".join(names[:2])
|
||||
suffix = ".".join(names[2:-1])
|
||||
if isinstance(single_suffix_dict[suffix], list):
|
||||
if suffix == "linear1":
|
||||
name_a, name_b = single_suffix_dict[suffix]
|
||||
param_a, param_b = torch.split(param, (3072*3, 3072*4), dim=0)
|
||||
state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
|
||||
state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
|
||||
elif suffix == "linear2":
|
||||
if names[-1] == "weight":
|
||||
name_a, name_b = single_suffix_dict[suffix]
|
||||
param_a, param_b = torch.split(param, (3072*1, 3072*4), dim=-1)
|
||||
state_dict_[prefix + "." + name_a + "." + names[-1]] = param_a
|
||||
state_dict_[prefix + "." + name_b + "." + names[-1]] = param_b
|
||||
else:
|
||||
name_a, name_b = single_suffix_dict[suffix]
|
||||
state_dict_[prefix + "." + name_a + "." + names[-1]] = param
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
name_ = prefix + "." + single_suffix_dict[suffix] + "." + names[-1]
|
||||
state_dict_[name_] = param
|
||||
elif names[0] == "txt_in":
|
||||
prefix = ".".join(names[:4]).replace(".individual_token_refiner.", ".")
|
||||
suffix = ".".join(names[4:-1])
|
||||
name_ = prefix + "." + txt_suffix_dict[suffix] + "." + names[-1]
|
||||
state_dict_[name_] = param
|
||||
else:
|
||||
pass
|
||||
return state_dict_
|
||||
55
diffsynth/models/hunyuan_video_text_encoder.py
Normal file
55
diffsynth/models/hunyuan_video_text_encoder.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from transformers import LlamaModel, LlamaConfig, DynamicCache
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
|
||||
|
||||
class HunyuanVideoLLMEncoder(LlamaModel):
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.auto_offload = False
|
||||
|
||||
|
||||
def enable_auto_offload(self, **kwargs):
|
||||
self.auto_offload = True
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
hidden_state_skip_layer=2
|
||||
):
|
||||
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
|
||||
inputs_embeds = embed_tokens(input_ids)
|
||||
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, None, False)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
rotary_emb = deepcopy(self.rotary_emb).to(input_ids.device) if self.auto_offload else self.rotary_emb
|
||||
position_embeddings = rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
for layer_id, decoder_layer in enumerate(self.layers):
|
||||
if self.auto_offload:
|
||||
decoder_layer = deepcopy(decoder_layer).to(hidden_states.device)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
if layer_id + hidden_state_skip_layer + 1 >= len(self.layers):
|
||||
break
|
||||
|
||||
return hidden_states
|
||||
507
diffsynth/models/hunyuan_video_vae_decoder.py
Normal file
507
diffsynth/models/hunyuan_video_vae_decoder.py
Normal file
@@ -0,0 +1,507 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from einops import repeat
|
||||
|
||||
|
||||
class CausalConv3d(nn.Module):
|
||||
|
||||
def __init__(self, in_channel, out_channel, kernel_size, stride=1, dilation=1, pad_mode='replicate', **kwargs):
|
||||
super().__init__()
|
||||
self.pad_mode = pad_mode
|
||||
self.time_causal_padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0
|
||||
) # W, H, T
|
||||
self.conv = nn.Conv3d(in_channel, out_channel, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class UpsampleCausal3D(nn.Module):
|
||||
|
||||
def __init__(self, channels, use_conv=False, out_channels=None, kernel_size=None, bias=True, upsample_factor=(2, 2, 2)):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.upsample_factor = upsample_factor
|
||||
self.conv = None
|
||||
if use_conv:
|
||||
kernel_size = 3 if kernel_size is None else kernel_size
|
||||
self.conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
||||
dtype = hidden_states.dtype
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
||||
if hidden_states.shape[0] >= 64:
|
||||
hidden_states = hidden_states.contiguous()
|
||||
|
||||
# interpolate
|
||||
B, C, T, H, W = hidden_states.shape
|
||||
first_h, other_h = hidden_states.split((1, T - 1), dim=2)
|
||||
if T > 1:
|
||||
other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
|
||||
first_h = F.interpolate(first_h.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest").unsqueeze(2)
|
||||
hidden_states = torch.cat((first_h, other_h), dim=2) if T > 1 else first_h
|
||||
|
||||
# If the input is bfloat16, we cast back to bfloat16
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
if self.conv:
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ResnetBlockCausal3D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels=None, dropout=0.0, groups=32, eps=1e-6, conv_shortcut_bias=True):
|
||||
super().__init__()
|
||||
self.pre_norm = True
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
|
||||
|
||||
self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True)
|
||||
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3, stride=1)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
self.conv_shortcut = None
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, bias=conv_shortcut_bias)
|
||||
|
||||
def forward(self, input_tensor):
|
||||
hidden_states = input_tensor
|
||||
# conv1
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
# conv2
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
# shortcut
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = (self.conv_shortcut(input_tensor))
|
||||
# shortcut and scale
|
||||
output_tensor = input_tensor + hidden_states
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
def prepare_causal_attention_mask(n_frame, n_hw, dtype, device, batch_size=None):
|
||||
seq_len = n_frame * n_hw
|
||||
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
|
||||
for i in range(seq_len):
|
||||
i_frame = i // n_hw
|
||||
mask[i, :(i_frame + 1) * n_hw] = 0
|
||||
if batch_size is not None:
|
||||
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
return mask
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_groups=32,
|
||||
dropout=0.0,
|
||||
eps=1e-6,
|
||||
bias=True,
|
||||
residual_connection=True):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.residual_connection = residual_connection
|
||||
dim_inner = head_dim * num_heads
|
||||
self.group_norm = nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
self.to_q = nn.Linear(in_channels, dim_inner, bias=bias)
|
||||
self.to_k = nn.Linear(in_channels, dim_inner, bias=bias)
|
||||
self.to_v = nn.Linear(in_channels, dim_inner, bias=bias)
|
||||
self.to_out = nn.Sequential(nn.Linear(dim_inner, in_channels, bias=bias), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, input_tensor, attn_mask=None):
|
||||
hidden_states = self.group_norm(input_tensor.transpose(1, 2)).transpose(1, 2)
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
q = self.to_q(hidden_states)
|
||||
k = self.to_k(hidden_states)
|
||||
v = self.to_v(hidden_states)
|
||||
|
||||
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.view(batch_size, self.num_heads, -1, attn_mask.shape[-1])
|
||||
hidden_states = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
|
||||
hidden_states = self.to_out(hidden_states)
|
||||
if self.residual_connection:
|
||||
output_tensor = input_tensor + hidden_states
|
||||
return output_tensor
|
||||
|
||||
|
||||
class UNetMidBlockCausal3D(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, dropout=0.0, num_layers=1, eps=1e-6, num_groups=32, attention_head_dim=None):
|
||||
super().__init__()
|
||||
resnets = [
|
||||
ResnetBlockCausal3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
groups=num_groups,
|
||||
eps=eps,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
attention_head_dim = attention_head_dim or in_channels
|
||||
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
Attention(
|
||||
in_channels,
|
||||
num_heads=in_channels // attention_head_dim,
|
||||
head_dim=attention_head_dim,
|
||||
num_groups=num_groups,
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
bias=True,
|
||||
residual_connection=True,
|
||||
))
|
||||
|
||||
resnets.append(
|
||||
ResnetBlockCausal3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
dropout=dropout,
|
||||
groups=num_groups,
|
||||
eps=eps,
|
||||
))
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.resnets[0](hidden_states)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
B, C, T, H, W = hidden_states.shape
|
||||
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
|
||||
attn_mask = prepare_causal_attention_mask(T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B)
|
||||
hidden_states = attn(hidden_states, attn_mask=attn_mask)
|
||||
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpDecoderBlockCausal3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
eps=1e-6,
|
||||
num_groups=32,
|
||||
add_upsample=True,
|
||||
upsample_scale_factor=(2, 2, 2),
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
for i in range(num_layers):
|
||||
cur_in_channel = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlockCausal3D(
|
||||
in_channels=cur_in_channel,
|
||||
out_channels=out_channels,
|
||||
groups=num_groups,
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
))
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.upsamplers = None
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList([
|
||||
UpsampleCausal3D(
|
||||
out_channels,
|
||||
use_conv=True,
|
||||
out_channels=out_channels,
|
||||
upsample_factor=upsample_scale_factor,
|
||||
)
|
||||
])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DecoderCausal3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=16,
|
||||
out_channels=3,
|
||||
eps=1e-6,
|
||||
dropout=0.0,
|
||||
block_out_channels=[128, 256, 512, 512],
|
||||
layers_per_block=2,
|
||||
num_groups=32,
|
||||
time_compression_ratio=4,
|
||||
spatial_compression_ratio=8,
|
||||
gradient_checkpointing=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlockCausal3D(
|
||||
in_channels=block_out_channels[-1],
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
num_groups=num_groups,
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
)
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
|
||||
num_time_upsample_layers = int(np.log2(time_compression_ratio))
|
||||
|
||||
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
|
||||
add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
|
||||
|
||||
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
|
||||
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
|
||||
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
|
||||
|
||||
up_block = UpDecoderBlockCausal3D(
|
||||
in_channels=prev_output_channel,
|
||||
out_channels=output_channel,
|
||||
dropout=dropout,
|
||||
num_layers=layers_per_block + 1,
|
||||
eps=eps,
|
||||
num_groups=num_groups,
|
||||
add_upsample=bool(add_spatial_upsample or add_time_upsample),
|
||||
upsample_scale_factor=upsample_scale_factor,
|
||||
)
|
||||
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=num_groups, eps=eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
|
||||
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
# middle
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block),
|
||||
hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
# middle
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
hidden_states = up_block(hidden_states)
|
||||
# post-process
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoVAEDecoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=16,
|
||||
out_channels=3,
|
||||
eps=1e-6,
|
||||
dropout=0.0,
|
||||
block_out_channels=[128, 256, 512, 512],
|
||||
layers_per_block=2,
|
||||
num_groups=32,
|
||||
time_compression_ratio=4,
|
||||
spatial_compression_ratio=8,
|
||||
gradient_checkpointing=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder = DecoderCausal3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
eps=eps,
|
||||
dropout=dropout,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
num_groups=num_groups,
|
||||
time_compression_ratio=time_compression_ratio,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
self.post_quant_conv = nn.Conv3d(in_channels, in_channels, kernel_size=1)
|
||||
self.scaling_factor = 0.476986
|
||||
|
||||
|
||||
def forward(self, latents):
|
||||
latents = latents / self.scaling_factor
|
||||
latents = self.post_quant_conv(latents)
|
||||
dec = self.decoder(latents)
|
||||
return dec
|
||||
|
||||
|
||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||
x = torch.ones((length,))
|
||||
if not left_bound:
|
||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
||||
if not right_bound:
|
||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
||||
return x
|
||||
|
||||
|
||||
def build_mask(self, data, is_bound, border_width):
|
||||
_, _, T, H, W = data.shape
|
||||
t = self.build_1d_mask(T, is_bound[0], is_bound[1], border_width[0])
|
||||
h = self.build_1d_mask(H, is_bound[2], is_bound[3], border_width[1])
|
||||
w = self.build_1d_mask(W, is_bound[4], is_bound[5], border_width[2])
|
||||
|
||||
t = repeat(t, "T -> T H W", T=T, H=H, W=W)
|
||||
h = repeat(h, "H -> T H W", T=T, H=H, W=W)
|
||||
w = repeat(w, "W -> T H W", T=T, H=H, W=W)
|
||||
|
||||
mask = torch.stack([t, h, w]).min(dim=0).values
|
||||
mask = rearrange(mask, "T H W -> 1 1 T H W")
|
||||
return mask
|
||||
|
||||
|
||||
def tile_forward(self, hidden_states, tile_size, tile_stride):
|
||||
B, C, T, H, W = hidden_states.shape
|
||||
size_t, size_h, size_w = tile_size
|
||||
stride_t, stride_h, stride_w = tile_stride
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for t in range(0, T, stride_t):
|
||||
if (t-stride_t >= 0 and t-stride_t+size_t >= T): continue
|
||||
for h in range(0, H, stride_h):
|
||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||
for w in range(0, W, stride_w):
|
||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||
t_, h_, w_ = t + size_t, h + size_h, w + size_w
|
||||
tasks.append((t, t_, h, h_, w, w_))
|
||||
|
||||
# Run
|
||||
torch_dtype = self.post_quant_conv.weight.dtype
|
||||
data_device = hidden_states.device
|
||||
computation_device = self.post_quant_conv.weight.device
|
||||
|
||||
weight = torch.zeros((1, 1, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
||||
values = torch.zeros((B, 3, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
|
||||
|
||||
for t, t_, h, h_, w, w_ in tqdm(tasks):
|
||||
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
|
||||
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
|
||||
if t > 0:
|
||||
hidden_states_batch = hidden_states_batch[:, :, 1:]
|
||||
|
||||
mask = self.build_mask(
|
||||
hidden_states_batch,
|
||||
is_bound=(t==0, t_>=T, h==0, h_>=H, w==0, w_>=W),
|
||||
border_width=((size_t - stride_t) * 4, (size_h - stride_h) * 8, (size_w - stride_w) * 8)
|
||||
).to(dtype=torch_dtype, device=data_device)
|
||||
|
||||
target_t = 0 if t==0 else t * 4 + 1
|
||||
target_h = h * 8
|
||||
target_w = w * 8
|
||||
values[
|
||||
:,
|
||||
:,
|
||||
target_t: target_t + hidden_states_batch.shape[2],
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += hidden_states_batch * mask
|
||||
weight[
|
||||
:,
|
||||
:,
|
||||
target_t: target_t + hidden_states_batch.shape[2],
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
return values / weight
|
||||
|
||||
|
||||
def decode_video(self, latents, tile_size=(17, 32, 32), tile_stride=(12, 24, 24)):
|
||||
latents = latents.to(self.post_quant_conv.weight.dtype)
|
||||
return self.tile_forward(latents, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanVideoVAEDecoderStateDictConverter()
|
||||
|
||||
|
||||
class HunyuanVideoVAEDecoderStateDictConverter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith('decoder.') or name.startswith('post_quant_conv.'):
|
||||
state_dict_[name] = state_dict[name]
|
||||
return state_dict_
|
||||
217
diffsynth/models/hunyuan_video_vae_encoder.py
Normal file
217
diffsynth/models/hunyuan_video_vae_encoder.py
Normal file
@@ -0,0 +1,217 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from .hunyuan_video_vae_decoder import CausalConv3d, ResnetBlockCausal3D, UNetMidBlockCausal3D
|
||||
|
||||
|
||||
class DownsampleCausal3D(nn.Module):
|
||||
|
||||
def __init__(self, channels, out_channels, kernel_size=3, bias=True, stride=2):
|
||||
super().__init__()
|
||||
self.conv = CausalConv3d(channels, out_channels, kernel_size, stride=stride, bias=bias)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DownEncoderBlockCausal3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
dropout=0.0,
|
||||
num_layers=1,
|
||||
eps=1e-6,
|
||||
num_groups=32,
|
||||
add_downsample=True,
|
||||
downsample_stride=2,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
resnets = []
|
||||
for i in range(num_layers):
|
||||
cur_in_channel = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlockCausal3D(
|
||||
in_channels=cur_in_channel,
|
||||
out_channels=out_channels,
|
||||
groups=num_groups,
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
))
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.downsamplers = None
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList([DownsampleCausal3D(
|
||||
out_channels,
|
||||
out_channels,
|
||||
stride=downsample_stride,
|
||||
)])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class EncoderCausal3D(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 3,
|
||||
out_channels: int = 16,
|
||||
eps=1e-6,
|
||||
dropout=0.0,
|
||||
block_out_channels=[128, 256, 512, 512],
|
||||
layers_per_block=2,
|
||||
num_groups=32,
|
||||
time_compression_ratio: int = 4,
|
||||
spatial_compression_ratio: int = 8,
|
||||
gradient_checkpointing=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i in range(len(block_out_channels)):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
|
||||
num_time_downsample_layers = int(np.log2(time_compression_ratio))
|
||||
|
||||
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
|
||||
add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
|
||||
|
||||
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
|
||||
downsample_stride_T = (2,) if add_time_downsample else (1,)
|
||||
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
|
||||
down_block = DownEncoderBlockCausal3D(
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
dropout=dropout,
|
||||
num_layers=layers_per_block,
|
||||
eps=eps,
|
||||
num_groups=num_groups,
|
||||
add_downsample=bool(add_spatial_downsample or add_time_downsample),
|
||||
downsample_stride=downsample_stride,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlockCausal3D(
|
||||
in_channels=block_out_channels[-1],
|
||||
dropout=dropout,
|
||||
eps=eps,
|
||||
num_groups=num_groups,
|
||||
attention_head_dim=block_out_channels[-1],
|
||||
)
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=num_groups, eps=eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = CausalConv3d(block_out_channels[-1], 2 * out_channels, kernel_size=3)
|
||||
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(down_block),
|
||||
hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
# middle
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.mid_block),
|
||||
hidden_states,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states)
|
||||
# middle
|
||||
hidden_states = self.mid_block(hidden_states)
|
||||
# post-process
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
hidden_states = self.conv_act(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HunyuanVideoVAEEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
eps=1e-6,
|
||||
dropout=0.0,
|
||||
block_out_channels=[128, 256, 512, 512],
|
||||
layers_per_block=2,
|
||||
num_groups=32,
|
||||
time_compression_ratio=4,
|
||||
spatial_compression_ratio=8,
|
||||
gradient_checkpointing=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.encoder = EncoderCausal3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
eps=eps,
|
||||
dropout=dropout,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
num_groups=num_groups,
|
||||
time_compression_ratio=time_compression_ratio,
|
||||
spatial_compression_ratio=spatial_compression_ratio,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
)
|
||||
self.quant_conv = nn.Conv3d(2 * out_channels, 2 * out_channels, kernel_size=1)
|
||||
|
||||
def forward(self, images):
|
||||
latents = self.encoder(images)
|
||||
latents = self.quant_conv(latents)
|
||||
# latents: (B C T H W)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return HunyuanVideoVAEEncoderStateDictConverter()
|
||||
|
||||
|
||||
class HunyuanVideoVAEEncoderStateDictConverter:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def from_diffusers(self, state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith('encoder.') or name.startswith('quant_conv.'):
|
||||
state_dict_[name] = state_dict[name]
|
||||
return state_dict_
|
||||
@@ -7,6 +7,7 @@ from .sd3_dit import SD3DiT
|
||||
from .flux_dit import FluxDiT
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
from .cog_dit import CogDiT
|
||||
from .hunyuan_video_dit import HunyuanVideoDiT
|
||||
|
||||
|
||||
|
||||
@@ -259,6 +260,14 @@ class GeneralLoRAFromPeft:
|
||||
return None
|
||||
|
||||
|
||||
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.supported_model_classes = [HunyuanVideoDiT]
|
||||
self.lora_prefix = ["diffusion_model."]
|
||||
self.special_keys = {}
|
||||
|
||||
|
||||
class FluxLoRAConverter:
|
||||
def __init__(self):
|
||||
pass
|
||||
@@ -355,4 +364,4 @@ class FluxLoRAConverter:
|
||||
|
||||
|
||||
def get_lora_loaders():
|
||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]
|
||||
|
||||
@@ -35,6 +35,8 @@ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
|
||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
from .hunyuan_video_vae_decoder import HunyuanVideoVAEDecoder
|
||||
from .hunyuan_video_vae_encoder import HunyuanVideoVAEEncoder
|
||||
|
||||
from .flux_dit import FluxDiT
|
||||
from .flux_text_encoder import FluxTextEncoder2
|
||||
|
||||
@@ -52,9 +52,9 @@ class PatchEmbed(torch.nn.Module):
|
||||
|
||||
|
||||
class TimestepEmbeddings(torch.nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
def __init__(self, dim_in, dim_out, computation_device=None):
|
||||
super().__init__()
|
||||
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.time_proj = TemporalTimesteps(num_channels=dim_in, flip_sin_to_cos=True, downscale_freq_shift=0, computation_device=computation_device)
|
||||
self.timestep_embedder = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim_in, dim_out), torch.nn.SiLU(), torch.nn.Linear(dim_out, dim_out)
|
||||
)
|
||||
|
||||
@@ -2,15 +2,17 @@ import torch
|
||||
from transformers import T5EncoderModel, T5Config
|
||||
from .sd_text_encoder import SDTextEncoder
|
||||
from .sdxl_text_encoder import SDXLTextEncoder2, SDXLTextEncoder2StateDictConverter
|
||||
|
||||
|
||||
|
||||
class SD3TextEncoder1(SDTextEncoder):
|
||||
def __init__(self, vocab_size=49408):
|
||||
super().__init__(vocab_size=vocab_size)
|
||||
|
||||
def forward(self, input_ids, clip_skip=2):
|
||||
def forward(self, input_ids, clip_skip=2, extra_mask=None):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
if extra_mask is not None:
|
||||
attn_mask[:, extra_mask[0]==0] = float("-inf")
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
|
||||
@@ -44,6 +44,7 @@ def get_timestep_embedding(
|
||||
downscale_freq_shift: float = 1,
|
||||
scale: float = 1,
|
||||
max_period: int = 10000,
|
||||
computation_device = None,
|
||||
):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
||||
@@ -57,11 +58,11 @@ def get_timestep_embedding(
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
exponent = -math.log(max_period) * torch.arange(
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
||||
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device if computation_device is None else computation_device
|
||||
)
|
||||
exponent = exponent / (half_dim - downscale_freq_shift)
|
||||
|
||||
emb = torch.exp(exponent)
|
||||
emb = torch.exp(exponent).to(timesteps.device)
|
||||
emb = timesteps[:, None].float() * emb[None, :]
|
||||
|
||||
# scale embeddings
|
||||
@@ -81,11 +82,12 @@ def get_timestep_embedding(
|
||||
|
||||
|
||||
class TemporalTimesteps(torch.nn.Module):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
||||
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, computation_device = None):
|
||||
super().__init__()
|
||||
self.num_channels = num_channels
|
||||
self.flip_sin_to_cos = flip_sin_to_cos
|
||||
self.downscale_freq_shift = downscale_freq_shift
|
||||
self.computation_device = computation_device
|
||||
|
||||
def forward(self, timesteps):
|
||||
t_emb = get_timestep_embedding(
|
||||
@@ -93,6 +95,7 @@ class TemporalTimesteps(torch.nn.Module):
|
||||
self.num_channels,
|
||||
flip_sin_to_cos=self.flip_sin_to_cos,
|
||||
downscale_freq_shift=self.downscale_freq_shift,
|
||||
computation_device=self.computation_device,
|
||||
)
|
||||
return t_emb
|
||||
|
||||
|
||||
@@ -80,7 +80,7 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None):
|
||||
|
||||
|
||||
def load_state_dict_from_bin(file_path, torch_dtype=None):
|
||||
state_dict = torch.load(file_path, map_location="cpu")
|
||||
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
|
||||
if torch_dtype is not None:
|
||||
for i in state_dict:
|
||||
if isinstance(state_dict[i], torch.Tensor):
|
||||
|
||||
@@ -9,4 +9,5 @@ from .flux_image import FluxImagePipeline
|
||||
from .cog_video import CogVideoPipeline
|
||||
from .omnigen_image import OmnigenImagePipeline
|
||||
from .pipeline_runner import SDVideoPipelineRunner
|
||||
from .hunyuan_video import HunyuanVideoPipeline
|
||||
KolorsImagePipeline = SDXLImagePipeline
|
||||
|
||||
139
diffsynth/pipelines/hunyuan_video.py
Normal file
139
diffsynth/pipelines/hunyuan_video.py
Normal file
@@ -0,0 +1,139 @@
|
||||
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder
|
||||
from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from ..prompters import HunyuanVideoPrompter
|
||||
import torch
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
|
||||
class HunyuanVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = FlowMatchScheduler(shift=7.0, sigma_min=0.0, extra_one_step=True)
|
||||
self.prompter = HunyuanVideoPrompter()
|
||||
self.text_encoder_1: SD3TextEncoder1 = None
|
||||
self.text_encoder_2: HunyuanVideoLLMEncoder = None
|
||||
self.dit: HunyuanVideoDiT = None
|
||||
self.vae_decoder: HunyuanVideoVAEDecoder = None
|
||||
self.vae_encoder: HunyuanVideoVAEEncoder = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder', 'vae_encoder']
|
||||
self.vram_management = False
|
||||
|
||||
|
||||
def enable_vram_management(self):
|
||||
self.vram_management = True
|
||||
self.enable_cpu_offload()
|
||||
self.text_encoder_2.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
|
||||
self.dit.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
|
||||
def fetch_models(self, model_manager: ModelManager):
|
||||
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
|
||||
self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2")
|
||||
self.dit = model_manager.fetch_model("hunyuan_video_dit")
|
||||
self.vae_decoder = model_manager.fetch_model("hunyuan_video_vae_decoder")
|
||||
self.vae_encoder = model_manager.fetch_model("hunyuan_video_vae_encoder")
|
||||
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, enable_vram_management=True):
|
||||
if device is None: device = model_manager.device
|
||||
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
|
||||
pipe = HunyuanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
pipe.fetch_models(model_manager)
|
||||
if enable_vram_management:
|
||||
pipe.enable_vram_management()
|
||||
return pipe
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
|
||||
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
|
||||
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length
|
||||
)
|
||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
|
||||
|
||||
|
||||
def prepare_extra_input(self, latents=None, guidance=1.0):
|
||||
freqs_cos, freqs_sin = self.dit.prepare_freqs(latents)
|
||||
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
|
||||
return {"freqs_cos": freqs_cos, "freqs_sin": freqs_sin, "guidance": guidance}
|
||||
|
||||
|
||||
def tensor2video(self, frames):
|
||||
frames = rearrange(frames, "C T H W -> T H W C")
|
||||
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
||||
frames = [Image.fromarray(frame) for frame in frames]
|
||||
return frames
|
||||
|
||||
def encode_video(self, frames):
|
||||
# frames : (B, C, T, H, W)
|
||||
latents = self.vae_encoder(frames)
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
seed=None,
|
||||
height=720,
|
||||
width=1280,
|
||||
num_frames=129,
|
||||
embedded_guidance=6.0,
|
||||
cfg_scale=1.0,
|
||||
num_inference_steps=30,
|
||||
tile_size=(17, 30, 30),
|
||||
tile_stride=(12, 20, 20),
|
||||
progress_bar_cmd=lambda x: x,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Initialize noise
|
||||
latents = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device([] if self.vram_management else ["dit"])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
|
||||
|
||||
# Inference
|
||||
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
|
||||
noise_pred_posi = self.dit(latents, timestep, **prompt_emb_posi, **extra_input)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = self.dit(latents, timestep, **prompt_emb_nega, **extra_input)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# Scheduler
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
frames = self.vae_decoder.decode_video(latents, **tiler_kwargs)
|
||||
self.load_models_to_device([])
|
||||
frames = self.tensor2video(frames[0])
|
||||
|
||||
return frames
|
||||
@@ -7,3 +7,4 @@ from .kolors_prompter import KolorsPrompter
|
||||
from .flux_prompter import FluxPrompter
|
||||
from .omost import OmostPromter
|
||||
from .cog_prompter import CogPrompter
|
||||
from .hunyuan_video_prompter import HunyuanVideoPrompter
|
||||
|
||||
143
diffsynth/prompters/hunyuan_video_prompter.py
Normal file
143
diffsynth/prompters/hunyuan_video_prompter.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.sd3_text_encoder import SD3TextEncoder1
|
||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
|
||||
from transformers import CLIPTokenizer, LlamaTokenizerFast
|
||||
import os, torch
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
|
||||
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE_VIDEO = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
|
||||
"1. The main content and theme of the video."
|
||||
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
||||
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
||||
"4. background environment, light, style and atmosphere."
|
||||
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
|
||||
|
||||
PROMPT_TEMPLATE = {
|
||||
"dit-llm-encode": {
|
||||
"template": PROMPT_TEMPLATE_ENCODE,
|
||||
"crop_start": 36,
|
||||
},
|
||||
"dit-llm-encode-video": {
|
||||
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
|
||||
"crop_start": 95,
|
||||
},
|
||||
}
|
||||
|
||||
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
|
||||
|
||||
|
||||
class HunyuanVideoPrompter(BasePrompter):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_1_path=None,
|
||||
tokenizer_2_path=None,
|
||||
):
|
||||
if tokenizer_1_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_1_path = os.path.join(
|
||||
base_path, "tokenizer_configs/hunyuan_video/tokenizer_1")
|
||||
if tokenizer_2_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_2_path = os.path.join(
|
||||
base_path, "tokenizer_configs/hunyuan_video/tokenizer_2")
|
||||
super().__init__()
|
||||
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
|
||||
self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right')
|
||||
self.text_encoder_1: SD3TextEncoder1 = None
|
||||
self.text_encoder_2: HunyuanVideoLLMEncoder = None
|
||||
|
||||
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
|
||||
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
|
||||
|
||||
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: HunyuanVideoLLMEncoder = None):
|
||||
self.text_encoder_1 = text_encoder_1
|
||||
self.text_encoder_2 = text_encoder_2
|
||||
|
||||
def apply_text_to_template(self, text, template):
|
||||
assert isinstance(template, str)
|
||||
if isinstance(text, list):
|
||||
return [self.apply_text_to_template(text_) for text_ in text]
|
||||
elif isinstance(text, str):
|
||||
# Will send string to tokenizer. Used for llm
|
||||
return template.format(text)
|
||||
else:
|
||||
raise TypeError(f"Unsupported prompt type: {type(text)}")
|
||||
|
||||
def encode_prompt_using_clip(self, prompt, max_length, device):
|
||||
tokenized_result = self.tokenizer_1(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True
|
||||
)
|
||||
input_ids = tokenized_result.input_ids.to(device)
|
||||
attention_mask = tokenized_result.attention_mask.to(device)
|
||||
return self.text_encoder_1(input_ids=input_ids, extra_mask=attention_mask)[0]
|
||||
|
||||
def encode_prompt_using_llm(self,
|
||||
prompt,
|
||||
max_length,
|
||||
device,
|
||||
crop_start,
|
||||
hidden_state_skip_layer=2,
|
||||
use_attention_mask=True):
|
||||
max_length += crop_start
|
||||
inputs = self.tokenizer_2(prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True)
|
||||
input_ids = inputs.input_ids.to(device)
|
||||
attention_mask = inputs.attention_mask.to(device)
|
||||
last_hidden_state = self.text_encoder_2(input_ids, attention_mask, hidden_state_skip_layer)
|
||||
|
||||
# crop out
|
||||
if crop_start > 0:
|
||||
last_hidden_state = last_hidden_state[:, crop_start:]
|
||||
attention_mask = (attention_mask[:, crop_start:] if use_attention_mask else None)
|
||||
|
||||
return last_hidden_state, attention_mask
|
||||
|
||||
def encode_prompt(self,
|
||||
prompt,
|
||||
positive=True,
|
||||
device="cuda",
|
||||
clip_sequence_length=77,
|
||||
llm_sequence_length=256,
|
||||
data_type='video',
|
||||
use_template=True,
|
||||
hidden_state_skip_layer=2,
|
||||
use_attention_mask=True):
|
||||
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
|
||||
# apply template
|
||||
if use_template:
|
||||
template = self.prompt_template_video if data_type == 'video' else self.prompt_template
|
||||
prompt_formated = self.apply_text_to_template(prompt, template['template'])
|
||||
else:
|
||||
prompt_formated = prompt
|
||||
# Text encoder
|
||||
if data_type == 'video':
|
||||
crop_start = self.prompt_template_video.get("crop_start", 0)
|
||||
else:
|
||||
crop_start = self.prompt_template.get("crop_start", 0)
|
||||
|
||||
# CLIP
|
||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
|
||||
|
||||
# LLM
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_llm(
|
||||
prompt_formated, llm_sequence_length, device, crop_start,
|
||||
hidden_state_skip_layer, use_attention_mask)
|
||||
|
||||
return prompt_emb, pooled_prompt_emb, attention_mask
|
||||
@@ -4,18 +4,22 @@ import torch
|
||||
|
||||
class FlowMatchScheduler():
|
||||
|
||||
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False):
|
||||
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False):
|
||||
self.num_train_timesteps = num_train_timesteps
|
||||
self.shift = shift
|
||||
self.sigma_max = sigma_max
|
||||
self.sigma_min = sigma_min
|
||||
self.inverse_timesteps = inverse_timesteps
|
||||
self.extra_one_step = extra_one_step
|
||||
self.set_timesteps(num_inference_steps)
|
||||
|
||||
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False):
|
||||
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
||||
if self.extra_one_step:
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
|
||||
else:
|
||||
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
|
||||
if self.inverse_timesteps:
|
||||
self.sigmas = torch.flip(self.sigmas, dims=[0])
|
||||
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)
|
||||
|
||||
48895
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/merges.txt
Normal file
48895
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/merges.txt
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"add_prefix_space": false,
|
||||
"added_tokens_decoder": {
|
||||
"49406": {
|
||||
"content": "<|startoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": true,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"49407": {
|
||||
"content": "<|endoftext|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"bos_token": "<|startoftext|>",
|
||||
"clean_up_tokenization_spaces": true,
|
||||
"do_lower_case": true,
|
||||
"eos_token": "<|endoftext|>",
|
||||
"errors": "replace",
|
||||
"model_max_length": 77,
|
||||
"pad_token": "<|endoftext|>",
|
||||
"tokenizer_class": "CLIPTokenizer",
|
||||
"unk_token": "<|endoftext|>"
|
||||
}
|
||||
49410
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/vocab.json
Normal file
49410
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_1/vocab.json
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"bos_token": {
|
||||
"content": "<|begin_of_text|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"eos_token": {
|
||||
"content": "<|end_of_text|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"pad_token": {
|
||||
"content": "<pad>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
},
|
||||
"unk_token": {
|
||||
"content": "<unk>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false
|
||||
}
|
||||
}
|
||||
1251020
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer.json
Normal file
1251020
diffsynth/tokenizer_configs/hunyuan_video/tokenizer_2/tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
19
examples/HunyuanVideo/README.md
Normal file
19
examples/HunyuanVideo/README.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# HunyuanVideo
|
||||
|
||||
HunyuanVideo is a video generation model trained by Tencent. We provide advanced VRAM management for this model, including three stages:
|
||||
|
||||
|VRAM required|Example script|Frames|Resolution|Note|
|
||||
|-|-|-|-|-|
|
||||
|80G|[hunyuanvideo_80G.py](hunyuanvideo_80G.py)|129|720*1280|No VRAM management.|
|
||||
|24G|[hunyuanvideo_24G.py](hunyuanvideo_24G.py)|129|720*1280|The video is consistent with the original implementation, but it requires 5%~10% more time than [hunyuanvideo_80G.py](hunyuanvideo_80G.py)|
|
||||
|6G|[hunyuanvideo_6G.py](hunyuanvideo_6G.py)|129|512*384|The base model doesn't support low resolutions. We recommend users to use some LoRA ([example](https://civitai.com/models/1032126/walking-animation-hunyuan-video)) trained using low resolutions.|
|
||||
|
||||
## Gallery
|
||||
|
||||
Video generated by [hunyuanvideo_80G.py](hunyuanvideo_80G.py) and [hunyuanvideo_24G.py](hunyuanvideo_24G.py):
|
||||
|
||||
https://github.com/user-attachments/assets/48dd24bb-0cc6-40d2-88c3-10feed3267e9
|
||||
|
||||
Video generated by [hunyuanvideo_6G.py](hunyuanvideo_6G.py) using [this LoRA](https://civitai.com/models/1032126/walking-animation-hunyuan-video):
|
||||
|
||||
https://github.com/user-attachments/assets/2997f107-d02d-4ecb-89bb-5ce1a7f93817
|
||||
42
examples/HunyuanVideo/hunyuanvideo_24G.py
Normal file
42
examples/HunyuanVideo/hunyuanvideo_24G.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
torch.cuda.set_per_process_memory_fraction(1.0, 0)
|
||||
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||
|
||||
|
||||
download_models(["HunyuanVideo"])
|
||||
model_manager = ModelManager()
|
||||
|
||||
# The DiT model is loaded in bfloat16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# The other modules are loaded in float16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideo/text_encoder_2",
|
||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||
],
|
||||
torch_dtype=torch.float16,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# We support LoRA inference. You can use the following code to load your LoRA model.
|
||||
# model_manager.load_lora("models/lora/xxx.safetensors", lora_alpha=1.0)
|
||||
|
||||
# The computation device is "cuda".
|
||||
pipe = HunyuanVideoPipeline.from_model_manager(
|
||||
model_manager,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Enjoy!
|
||||
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
video = pipe(prompt, seed=0)
|
||||
save_video(video, "video_girl.mp4", fps=30, quality=6)
|
||||
47
examples/HunyuanVideo/hunyuanvideo_6G.py
Normal file
47
examples/HunyuanVideo/hunyuanvideo_6G.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import torch
|
||||
torch.cuda.set_per_process_memory_fraction(1.0, 0)
|
||||
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video, FlowMatchScheduler
|
||||
|
||||
|
||||
download_models(["HunyuanVideo"])
|
||||
model_manager = ModelManager()
|
||||
|
||||
# The DiT model is loaded in bfloat16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# The other modules are loaded in float16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideo/text_encoder_2",
|
||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||
],
|
||||
torch_dtype=torch.float16,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# We support LoRA inference. You can use the following code to load your LoRA model.
|
||||
# Example LoRA: https://civitai.com/models/1032126/walking-animation-hunyuan-video
|
||||
model_manager.load_lora("models/lora/kxsr_walking_anim_v1-5.safetensors", lora_alpha=1.0)
|
||||
|
||||
# The computation device is "cuda".
|
||||
pipe = HunyuanVideoPipeline.from_model_manager(
|
||||
model_manager,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda"
|
||||
)
|
||||
# This LoRA requires shift=9.0.
|
||||
pipe.scheduler = FlowMatchScheduler(shift=9.0, sigma_min=0.0, extra_one_step=True)
|
||||
|
||||
# Enjoy!
|
||||
for clothes_up in ["white t-shirt", "black t-shirt", "orange t-shirt"]:
|
||||
for clothes_down in ["blue sports skirt", "red sports skirt", "white sports skirt"]:
|
||||
prompt = f"kxsr, full body, no crop, A 3D-rendered CG animation video featuring a Gorgeous, mature, curvaceous, fair-skinned female girl with long silver hair and blue eyes. She wears a {clothes_up} and a {clothes_down}, walking offering a sense of fluid movement and vivid animation."
|
||||
video = pipe(prompt, seed=0, height=512, width=384, num_frames=129, num_inference_steps=18, tile_size=(17, 16, 16), tile_stride=(12, 12, 12))
|
||||
save_video(video, f"video-{clothes_up}-{clothes_down}.mp4", fps=30, quality=6)
|
||||
45
examples/HunyuanVideo/hunyuanvideo_80G.py
Normal file
45
examples/HunyuanVideo/hunyuanvideo_80G.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import torch
|
||||
torch.cuda.set_per_process_memory_fraction(1.0, 0)
|
||||
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||
|
||||
|
||||
download_models(["HunyuanVideo"])
|
||||
model_manager = ModelManager()
|
||||
|
||||
# The DiT model is loaded in bfloat16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# The other modules are loaded in float16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideo/text_encoder_2",
|
||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||
],
|
||||
torch_dtype=torch.float16,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# We support LoRA inference. You can use the following code to load your LoRA model.
|
||||
# model_manager.load_lora("models/lora/xxx.safetensors", lora_alpha=1.0)
|
||||
|
||||
# The computation device is "cuda".
|
||||
pipe = HunyuanVideoPipeline.from_model_manager(
|
||||
model_manager,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
enable_vram_management=False
|
||||
)
|
||||
# Although you have enough VRAM, we still recommend you to enable offload.
|
||||
pipe.enable_cpu_offload()
|
||||
|
||||
# Enjoy!
|
||||
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
video = pipe(prompt, seed=0)
|
||||
save_video(video, "video.mp4", fps=30, quality=6)
|
||||
Reference in New Issue
Block a user