From b048f1b1de8787bab656e815f8eaa7eadbdca607 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Wed, 18 Dec 2024 11:42:43 +0800 Subject: [PATCH] hunyuanvideo pipeline --- diffsynth/models/hunyuan_video_dit.py | 198 +++++++++++++++++- diffsynth/models/sd3_text_encoder.py | 6 +- diffsynth/pipelines/hunyuan_video.py | 75 +++++-- diffsynth/prompters/hunyuan_video_prompter.py | 23 +- diffsynth/schedulers/flow_match.py | 8 +- 5 files changed, 279 insertions(+), 31 deletions(-) diff --git a/diffsynth/models/hunyuan_video_dit.py b/diffsynth/models/hunyuan_video_dit.py index 761c740..4f4b49c 100644 --- a/diffsynth/models/hunyuan_video_dit.py +++ b/diffsynth/models/hunyuan_video_dit.py @@ -3,6 +3,193 @@ from .sd3_dit import TimestepEmbeddings, RMSNorm from .utils import init_weights_on_device from einops import rearrange, repeat from tqdm import tqdm +from typing import Union, Tuple, List + + +def HunyuanVideoRope(latents): + def _to_tuple(x, dim=2): + if isinstance(x, int): + return (x,) * dim + elif len(x) == dim: + return x + else: + raise ValueError(f"Expected length {dim} or int, but got {x}") + + + def get_meshgrid_nd(start, *args, dim=2): + """ + Get n-D meshgrid with start, stop and num. + + Args: + start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, + step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num + should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in + n-tuples. + *args: See above. + dim (int): Dimension of the meshgrid. Defaults to 2. + + Returns: + grid (np.ndarray): [dim, ...] + """ + if len(args) == 0: + # start is grid_size + num = _to_tuple(start, dim=dim) + start = (0,) * dim + stop = num + elif len(args) == 1: + # start is start, args[0] is stop, step is 1 + start = _to_tuple(start, dim=dim) + stop = _to_tuple(args[0], dim=dim) + num = [stop[i] - start[i] for i in range(dim)] + elif len(args) == 2: + # start is start, args[0] is stop, args[1] is num + start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0 + stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32 + num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124 + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False) + axis_grid = [] + for i in range(dim): + a, b, n = start[i], stop[i], num[i] + g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n] + axis_grid.append(g) + grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D] + grid = torch.stack(grid, dim=0) # [dim, W, H, D] + + return grid + + + def get_1d_rotary_pos_embed( + dim: int, + pos: Union[torch.FloatTensor, int], + theta: float = 10000.0, + use_real: bool = False, + theta_rescale_factor: float = 1.0, + interpolation_factor: float = 1.0, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Precompute the frequency tensor for complex exponential (cis) with given dimensions. + (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.) + + This function calculates a frequency tensor with complex exponential using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool, optional): If True, return real part and imaginary part separately. + Otherwise, return complex numbers. + theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0. + + Returns: + freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2] + freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D] + """ + if isinstance(pos, int): + pos = torch.arange(pos).float() + + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + if theta_rescale_factor != 1.0: + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) # [D/2] + # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}" + freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar( + torch.ones_like(freqs), freqs + ) # complex64 # [S, D/2] + return freqs_cis + + + def get_nd_rotary_pos_embed( + rope_dim_list, + start, + *args, + theta=10000.0, + use_real=False, + theta_rescale_factor: Union[float, List[float]] = 1.0, + interpolation_factor: Union[float, List[float]] = 1.0, + ): + """ + This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure. + + Args: + rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n. + sum(rope_dim_list) should equal to head_dim of attention layer. + start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start, + args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. + *args: See above. + theta (float): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers. + Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real + part and an imaginary part separately. + theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0. + + Returns: + pos_embed (torch.Tensor): [HW, D/2] + """ + + grid = get_meshgrid_nd( + start, *args, dim=len(rope_dim_list) + ) # [3, W, H, D] / [2, W, H] + + if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float): + theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list) + elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1: + theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list) + assert len(theta_rescale_factor) == len( + rope_dim_list + ), "len(theta_rescale_factor) should equal to len(rope_dim_list)" + + if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float): + interpolation_factor = [interpolation_factor] * len(rope_dim_list) + elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1: + interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list) + assert len(interpolation_factor) == len( + rope_dim_list + ), "len(interpolation_factor) should equal to len(rope_dim_list)" + + # use 1/ndim of dimensions to encode grid_axis + embs = [] + for i in range(len(rope_dim_list)): + emb = get_1d_rotary_pos_embed( + rope_dim_list[i], + grid[i].reshape(-1), + theta, + use_real=use_real, + theta_rescale_factor=theta_rescale_factor[i], + interpolation_factor=interpolation_factor[i], + ) # 2 x [WHD, rope_dim_list[i]] + embs.append(emb) + + if use_real: + cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2) + sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2) + return cos, sin + else: + emb = torch.cat(embs, dim=1) # (WHD, D/2) + return emb + + freqs_cos, freqs_sin = get_nd_rotary_pos_embed( + [16, 56, 56], + [latents.shape[2], latents.shape[3] // 2, latents.shape[4] // 2], + theta=256, + use_real=True, + theta_rescale_factor=1, + ) + return freqs_cos, freqs_sin class PatchEmbed(torch.nn.Module): @@ -406,13 +593,16 @@ class HunyuanVideoDiT(torch.nn.Module): model.to(device) torch.cuda.empty_cache() + def prepare_freqs(self, latents): + return HunyuanVideoRope(latents) + def forward( self, x: torch.Tensor, t: torch.Tensor, - text_states: torch.Tensor = None, + prompt_emb: torch.Tensor = None, text_mask: torch.Tensor = None, - text_states_2: torch.Tensor = None, + pooled_prompt_emb: torch.Tensor = None, freqs_cos: torch.Tensor = None, freqs_sin: torch.Tensor = None, guidance: torch.Tensor = None, @@ -420,9 +610,9 @@ class HunyuanVideoDiT(torch.nn.Module): ): B, C, T, H, W = x.shape - vec = self.time_in(t, dtype=torch.float32) + self.vector_in(text_states_2) + self.guidance_in(guidance, dtype=torch.float32) + vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32) img = self.img_in(x) - txt = self.txt_in(text_states, t, text_mask) + txt = self.txt_in(prompt_emb, t, text_mask) for block in tqdm(self.double_blocks, desc="Double stream blocks"): img, txt = block(img, txt, vec, (freqs_cos, freqs_sin)) diff --git a/diffsynth/models/sd3_text_encoder.py b/diffsynth/models/sd3_text_encoder.py index bc358f4..cb9bdcd 100644 --- a/diffsynth/models/sd3_text_encoder.py +++ b/diffsynth/models/sd3_text_encoder.py @@ -2,15 +2,17 @@ import torch from transformers import T5EncoderModel, T5Config from .sd_text_encoder import SDTextEncoder from .sdxl_text_encoder import SDXLTextEncoder2, SDXLTextEncoder2StateDictConverter - + class SD3TextEncoder1(SDTextEncoder): def __init__(self, vocab_size=49408): super().__init__(vocab_size=vocab_size) - def forward(self, input_ids, clip_skip=2): + def forward(self, input_ids, clip_skip=2, extra_mask=None): embeds = self.token_embedding(input_ids) + self.position_embeds attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype) + if extra_mask is not None: + attn_mask[:, extra_mask[0]==0] = float("-inf") for encoder_id, encoder in enumerate(self.encoders): embeds = encoder(embeds, attn_mask=attn_mask) if encoder_id + clip_skip == len(self.encoders): diff --git a/diffsynth/pipelines/hunyuan_video.py b/diffsynth/pipelines/hunyuan_video.py index 9e26a8f..5eb4795 100644 --- a/diffsynth/pipelines/hunyuan_video.py +++ b/diffsynth/pipelines/hunyuan_video.py @@ -1,40 +1,56 @@ from ..models import ModelManager, SD3TextEncoder1 +from ..models.hunyuan_video_dit import HunyuanVideoDiT +from ..schedulers.flow_match import FlowMatchScheduler from .base import BasePipeline from ..prompters import HunyuanVideoPrompter import torch from transformers import LlamaModel -from tqdm import tqdm + class HunyuanVideoPipeline(BasePipeline): def __init__(self, device="cuda", torch_dtype=torch.float16): super().__init__(device=device, torch_dtype=torch_dtype) - # 参照diffsynth的排序,text_encoder_1指CLIP;text_encoder_2指llm,与hunyuanvideo源代码刚好相反 + self.scheduler = FlowMatchScheduler(shift=7.0, sigma_min=0.0, extra_one_step=True) self.prompter = HunyuanVideoPrompter() self.text_encoder_1: SD3TextEncoder1 = None self.text_encoder_2: LlamaModel = None + self.dit: HunyuanVideoDiT = None + self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit'] def fetch_models(self, model_manager: ModelManager): self.text_encoder_1 = model_manager.fetch_model("sd3_text_encoder_1") self.text_encoder_2 = model_manager.fetch_model("hunyuan_video_text_encoder_2") + self.dit = model_manager.fetch_model("hunyuan_video_dit") self.prompter.fetch_models(self.text_encoder_1, self.text_encoder_2) - @staticmethod - def from_model_manager(model_manager: ModelManager, device=None): - pipe = HunyuanVideoPipeline( - device=model_manager.device if device is None else device, - torch_dtype=model_manager.torch_dtype, - ) + @staticmethod + def from_model_manager(model_manager: ModelManager, torch_dtype=None, device=None, enable_vram_management=True): + if device is None: device = model_manager.device + if torch_dtype is None: torch_dtype = model_manager.torch_dtype + pipe = HunyuanVideoPipeline(device=device, torch_dtype=torch_dtype) pipe.fetch_models(model_manager) + # VRAM management is automatically enabled. + if enable_vram_management: + pipe.enable_cpu_offload() + pipe.dit.enable_auto_offload(dtype=torch_dtype, device=device) return pipe + def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256): - prompt_emb, pooled_prompt_emb = self.prompter.encode_prompt( + prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt( prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length ) - return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb} + return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask} + + + def prepare_extra_input(self, latents=None, guidance=1.0): + freqs_cos, freqs_sin = self.dit.prepare_freqs(latents) + guidance = torch.Tensor([guidance] * latents.shape[0]).to(device=latents.device, dtype=latents.dtype) + return {"freqs_cos": freqs_cos, "freqs_sin": freqs_sin, "guidance": guidance} + @torch.no_grad() def __call__( @@ -42,10 +58,41 @@ class HunyuanVideoPipeline(BasePipeline): prompt, negative_prompt="", seed=None, - progress_bar_cmd=tqdm, + height=720, + width=1280, + num_frames=129, + embedded_guidance=6.0, + cfg_scale=1.0, + num_inference_steps=30, + progress_bar_cmd=lambda x: x, progress_bar_st=None, ): - pass - + latents = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) + + self.load_models_to_device(["text_encoder_1", "text_encoder_2"]) prompt_emb_posi = self.encode_prompt(prompt, positive=True) - return prompt_emb_posi \ No newline at end of file + if cfg_scale != 1.0: + prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) + + extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) + + self.scheduler.set_timesteps(num_inference_steps) + + self.load_models_to_device([]) + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(self.device) + with torch.autocast(device_type=self.device, dtype=self.torch_dtype): + print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}") + + noise_pred_posi = self.dit(latents, timestep, **prompt_emb_posi, **extra_input) + if cfg_scale != 1.0: + noise_pred_nega = self.dit(latents, timestep, **prompt_emb_nega, **extra_input) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) + + # TODO: Add VAE decode here. + + return latents \ No newline at end of file diff --git a/diffsynth/prompters/hunyuan_video_prompter.py b/diffsynth/prompters/hunyuan_video_prompter.py index 599de7e..70b035a 100644 --- a/diffsynth/prompters/hunyuan_video_prompter.py +++ b/diffsynth/prompters/hunyuan_video_prompter.py @@ -70,12 +70,17 @@ class HunyuanVideoPrompter(BasePrompter): raise TypeError(f"Unsupported prompt type: {type(text)}") def encode_prompt_using_clip(self, prompt, max_length, device): - input_ids = self.tokenizer_1(prompt, - return_tensors="pt", - padding="max_length", - max_length=max_length, - truncation=True).input_ids.to(device) - return self.text_encoder_1(input_ids=input_ids)[0] + tokenized_result = self.tokenizer_1( + prompt, + return_tensors="pt", + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True + ) + input_ids = tokenized_result.input_ids.to(device) + attention_mask = tokenized_result.attention_mask.to(device) + return self.text_encoder_1(input_ids=input_ids, extra_mask=attention_mask)[0] def encode_prompt_using_llm(self, prompt, @@ -110,7 +115,7 @@ class HunyuanVideoPrompter(BasePrompter): last_hidden_state = last_hidden_state[:, crop_start:] attention_mask = (attention_mask[:, crop_start:] if use_attention_mask else None) - return last_hidden_state + return last_hidden_state, attention_mask def encode_prompt(self, prompt, @@ -142,8 +147,8 @@ class HunyuanVideoPrompter(BasePrompter): pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device) # LLM - prompt_emb = self.encode_prompt_using_llm( + prompt_emb, attention_mask = self.encode_prompt_using_llm( prompt_formated, llm_sequence_length, device, crop_start, hidden_state_skip_layer, apply_final_norm, use_attention_mask) - return prompt_emb, pooled_prompt_emb + return prompt_emb, pooled_prompt_emb, attention_mask diff --git a/diffsynth/schedulers/flow_match.py b/diffsynth/schedulers/flow_match.py index fe6e762..949bfd7 100644 --- a/diffsynth/schedulers/flow_match.py +++ b/diffsynth/schedulers/flow_match.py @@ -4,18 +4,22 @@ import torch class FlowMatchScheduler(): - def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False): + def __init__(self, num_inference_steps=100, num_train_timesteps=1000, shift=3.0, sigma_max=1.0, sigma_min=0.003/1.002, inverse_timesteps=False, extra_one_step=False): self.num_train_timesteps = num_train_timesteps self.shift = shift self.sigma_max = sigma_max self.sigma_min = sigma_min self.inverse_timesteps = inverse_timesteps + self.extra_one_step = extra_one_step self.set_timesteps(num_inference_steps) def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False): sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength - self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps) + if self.extra_one_step: + self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps + 1)[:-1] + else: + self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps) if self.inverse_timesteps: self.sigmas = torch.flip(self.sigmas, dims=[0]) self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas)