hunyuanvideo pipeline

This commit is contained in:
Artiprocher
2024-12-18 11:42:43 +08:00
parent 7a45b7efa7
commit b048f1b1de
5 changed files with 279 additions and 31 deletions

View File

@@ -3,6 +3,193 @@ from .sd3_dit import TimestepEmbeddings, RMSNorm
from .utils import init_weights_on_device
from einops import rearrange, repeat
from tqdm import tqdm
from typing import Union, Tuple, List
def HunyuanVideoRope(latents):
def _to_tuple(x, dim=2):
if isinstance(x, int):
return (x,) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_meshgrid_nd(start, *args, dim=2):
"""
Get n-D meshgrid with start, stop and num.
Args:
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
n-tuples.
*args: See above.
dim (int): Dimension of the meshgrid. Defaults to 2.
Returns:
grid (np.ndarray): [dim, ...]
"""
if len(args) == 0:
# start is grid_size
num = _to_tuple(start, dim=dim)
start = (0,) * dim
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = [stop[i] - start[i] for i in range(dim)]
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
return grid
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[torch.FloatTensor, int],
theta: float = 10000.0,
use_real: bool = False,
theta_rescale_factor: float = 1.0,
interpolation_factor: float = 1.0,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool, optional): If True, return real part and imaginary part separately.
Otherwise, return complex numbers.
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
Returns:
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
"""
if isinstance(pos, int):
pos = torch.arange(pos).float()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
if theta_rescale_factor != 1.0:
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
) # [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(
torch.ones_like(freqs), freqs
) # complex64 # [S, D/2]
return freqs_cis
def get_nd_rotary_pos_embed(
rope_dim_list,
start,
*args,
theta=10000.0,
use_real=False,
theta_rescale_factor: Union[float, List[float]] = 1.0,
interpolation_factor: Union[float, List[float]] = 1.0,
):
"""
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
Args:
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
sum(rope_dim_list) should equal to head_dim of attention layer.
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
*args: See above.
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
part and an imaginary part separately.
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
Returns:
pos_embed (torch.Tensor): [HW, D/2]
"""
grid = get_meshgrid_nd(
start, *args, dim=len(rope_dim_list)
) # [3, W, H, D] / [2, W, H]
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
assert len(theta_rescale_factor) == len(
rope_dim_list
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
assert len(interpolation_factor) == len(
rope_dim_list
), "len(interpolation_factor) should equal to len(rope_dim_list)"
# use 1/ndim of dimensions to encode grid_axis
embs = []
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i],
grid[i].reshape(-1),
theta,
use_real=use_real,
theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
) # 2 x [WHD, rope_dim_list[i]]
embs.append(emb)
if use_real:
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
return cos, sin
else:
emb = torch.cat(embs, dim=1) # (WHD, D/2)
return emb
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
[16, 56, 56],
[latents.shape[2], latents.shape[3] // 2, latents.shape[4] // 2],
theta=256,
use_real=True,
theta_rescale_factor=1,
)
return freqs_cos, freqs_sin
class PatchEmbed(torch.nn.Module):
@@ -406,13 +593,16 @@ class HunyuanVideoDiT(torch.nn.Module):
model.to(device)
torch.cuda.empty_cache()
def prepare_freqs(self, latents):
return HunyuanVideoRope(latents)
def forward(
self,
x: torch.Tensor,
t: torch.Tensor,
text_states: torch.Tensor = None,
prompt_emb: torch.Tensor = None,
text_mask: torch.Tensor = None,
text_states_2: torch.Tensor = None,
pooled_prompt_emb: torch.Tensor = None,
freqs_cos: torch.Tensor = None,
freqs_sin: torch.Tensor = None,
guidance: torch.Tensor = None,
@@ -420,9 +610,9 @@ class HunyuanVideoDiT(torch.nn.Module):
):
B, C, T, H, W = x.shape
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(text_states_2) + self.guidance_in(guidance, dtype=torch.float32)
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32)
img = self.img_in(x)
txt = self.txt_in(text_states, t, text_mask)
txt = self.txt_in(prompt_emb, t, text_mask)
for block in tqdm(self.double_blocks, desc="Double stream blocks"):
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))

View File

@@ -2,15 +2,17 @@ import torch
from transformers import T5EncoderModel, T5Config
from .sd_text_encoder import SDTextEncoder
from .sdxl_text_encoder import SDXLTextEncoder2, SDXLTextEncoder2StateDictConverter
class SD3TextEncoder1(SDTextEncoder):
def __init__(self, vocab_size=49408):
super().__init__(vocab_size=vocab_size)
def forward(self, input_ids, clip_skip=2):
def forward(self, input_ids, clip_skip=2, extra_mask=None):
embeds = self.token_embedding(input_ids) + self.position_embeds
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
if extra_mask is not None:
attn_mask[:, extra_mask[0]==0] = float("-inf")
for encoder_id, encoder in enumerate(self.encoders):
embeds = encoder(embeds, attn_mask=attn_mask)
if encoder_id + clip_skip == len(self.encoders):

View File

@@ -1,40 +1,56 @@
from ..models import ModelManager, SD3TextEncoder1
from ..models.hunyuan_video_dit import HunyuanVideoDiT
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
from ..prompters import HunyuanVideoPrompter
import torch
from transformers import LlamaModel
from tqdm import tqdm
class HunyuanVideoPipeline(BasePipeline):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__(device=device, torch_dtype=torch_dtype)
# 参照diffsynth的排序text_encoder_1指CLIPtext_encoder_2指llm与hunyuanvideo源代码刚好相反
self.scheduler = FlowMatchScheduler(shift=7.0, sigma_min=0.0, extra_one_step=True)
self.prompter = HunyuanVideoPrompter()
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: LlamaModel = None
self.dit: HunyuanVideoDiT = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit']
def fetch_models(self, model_manager: ModelManager):
self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1")
self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2")
self.dit = model_manager.fetch_model("hunyuan_video_dit")
self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2)
@staticmethod
def from_model_manager(model_manager: ModelManager, device=None):
pipe = HunyuanVideoPipeline(
device=model_manager.device if device is None else device,
torch_dtype=model_manager.torch_dtype,
)
@staticmethod
def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, enable_vram_management=True):
if device is None: device = model_manager.device
if torch_dtype is None: torch_dtype = model_manager.torch_dtype
pipe = HunyuanVideoPipeline(device=device, torch_dtype=torch_dtype)
pipe.fetch_models(model_manager)
# VRAM management is automatically enabled.
if enable_vram_management:
pipe.enable_cpu_offload()
pipe.dit.enable_auto_offload(dtype=torch_dtype, device=device)
return pipe
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt(
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length
)
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb}
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
def prepare_extra_input(self, latents=None, guidance=1.0):
freqs_cos, freqs_sin = self.dit.prepare_freqs(latents)
guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype)
return {"freqs_cos": freqs_cos, "freqs_sin": freqs_sin, "guidance": guidance}
@torch.no_grad()
def __call__(
@@ -42,10 +58,41 @@ class HunyuanVideoPipeline(BasePipeline):
prompt,
negative_prompt="",
seed=None,
progress_bar_cmd=tqdm,
height=720,
width=1280,
num_frames=129,
embedded_guidance=6.0,
cfg_scale=1.0,
num_inference_steps=30,
progress_bar_cmd=lambda x: x,
progress_bar_st=None,
):
pass
latents = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
self.load_models_to_device(["text_encoder_1", "text_encoder_2"])
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
return prompt_emb_posi
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
self.scheduler.set_timesteps(num_inference_steps)
self.load_models_to_device([])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(self.device)
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
noise_pred_posi = self.dit(latents, timestep, **prompt_emb_posi, **extra_input)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(latents, timestep, **prompt_emb_nega, **extra_input)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
# TODO: Add VAE decode here.
return latents

View File

@@ -70,12 +70,17 @@ class HunyuanVideoPrompter(BasePrompter):
raise TypeError(f"Unsupported prompt type: {type(text)}")
def encode_prompt_using_clip(self, prompt, max_length, device):
input_ids = self.tokenizer_1(prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True).input_ids.to(device)
return self.text_encoder_1(input_ids=input_ids)[0]
tokenized_result = self.tokenizer_1(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True
)
input_ids = tokenized_result.input_ids.to(device)
attention_mask = tokenized_result.attention_mask.to(device)
return self.text_encoder_1(input_ids=input_ids, extra_mask=attention_mask)[0]
def encode_prompt_using_llm(self,
prompt,
@@ -110,7 +115,7 @@ class HunyuanVideoPrompter(BasePrompter):
last_hidden_state = last_hidden_state[:, crop_start:]
attention_mask = (attention_mask[:, crop_start:] if use_attention_mask else None)
return last_hidden_state
return last_hidden_state, attention_mask
def encode_prompt(self,
prompt,
@@ -142,8 +147,8 @@ class HunyuanVideoPrompter(BasePrompter):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
# LLM
prompt_emb = self.encode_prompt_using_llm(
prompt_emb, attention_mask = self.encode_prompt_using_llm(
prompt_formated, llm_sequence_length, device, crop_start,
hidden_state_skip_layer, apply_final_norm, use_attention_mask)
return prompt_emb, pooled_prompt_emb
return prompt_emb, pooled_prompt_emb, attention_mask

View File

@@ -4,18 +4,22 @@ import torch
class FlowMatchScheduler():
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False):
def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False):
self.num_train_timesteps = num_train_timesteps
self.shift = shift
self.sigma_max = sigma_max
self.sigma_min = sigma_min
self.inverse_timesteps = inverse_timesteps
self.extra_one_step = extra_one_step
self.set_timesteps(num_inference_steps)
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False):
sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
if self.extra_one_step:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1]
else:
self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps)
if self.inverse_timesteps:
self.sigmas = torch.flip(self.sigmas, dims=[0])
self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)