From 9f8112ec344e1fc17aa26a7b22893d2c87b3d47c Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Tue, 14 Jan 2025 14:46:35 +0800 Subject: [PATCH] support teacache-hunyuanvideo --- diffsynth/pipelines/hunyuan_video.py | 100 ++++++++++++++++++++- examples/TeaCache/README.md | 28 ++++-- examples/TeaCache/hunyuanvideo_teacache.py | 42 +++++++++ 3 files changed, 163 insertions(+), 7 deletions(-) create mode 100644 examples/TeaCache/hunyuanvideo_teacache.py diff --git a/diffsynth/pipelines/hunyuan_video.py b/diffsynth/pipelines/hunyuan_video.py index 2beba5a..3cf2d8e 100644 --- a/diffsynth/pipelines/hunyuan_video.py +++ b/diffsynth/pipelines/hunyuan_video.py @@ -8,6 +8,7 @@ import torch from einops import rearrange import numpy as np from PIL import Image +from tqdm import tqdm @@ -94,6 +95,7 @@ class HunyuanVideoPipeline(BasePipeline): embedded_guidance=6.0, cfg_scale=1.0, num_inference_steps=30, + tea_cache_l1_thresh=None, tile_size=(17, 30, 30), tile_stride=(12, 20, 20), step_processor=None, @@ -126,6 +128,9 @@ class HunyuanVideoPipeline(BasePipeline): # Extra input extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance) + # TeaCache + tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None} + # Denoise self.load_models_to_device([] if self.vram_management else ["dit"]) for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): @@ -134,9 +139,9 @@ class HunyuanVideoPipeline(BasePipeline): # Inference with torch.autocast(device_type=self.device, dtype=self.torch_dtype): - noise_pred_posi = self.dit(latents, timestep, **prompt_emb_posi, **extra_input) + noise_pred_posi = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs) if cfg_scale != 1.0: - noise_pred_nega = self.dit(latents, timestep, **prompt_emb_nega, **extra_input) + noise_pred_nega = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_nega, **extra_input) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: noise_pred = noise_pred_posi @@ -165,3 +170,94 @@ class HunyuanVideoPipeline(BasePipeline): frames = self.tensor2video(frames[0]) return frames + + + +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + def check(self, dit: HunyuanVideoDiT, img, vec): + img_ = img.clone() + vec_ = vec.clone() + img_mod1_shift, img_mod1_scale, _, _, _, _ = dit.double_blocks[0].component_a.mod(vec_).chunk(6, dim=-1) + normed_inp = dit.double_blocks[0].component_a.norm1(img_) + modulated_inp = normed_inp * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1) + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02] + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step += 1 + if self.step == self.num_inference_steps: + self.step = 0 + if should_calc: + self.previous_hidden_states = img.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + + +def lets_dance_hunyuan_video( + dit: HunyuanVideoDiT, + x: torch.Tensor, + t: torch.Tensor, + prompt_emb: torch.Tensor = None, + text_mask: torch.Tensor = None, + pooled_prompt_emb: torch.Tensor = None, + freqs_cos: torch.Tensor = None, + freqs_sin: torch.Tensor = None, + guidance: torch.Tensor = None, + tea_cache: TeaCache = None, + **kwargs +): + B, C, T, H, W = x.shape + + vec = dit.time_in(t, dtype=torch.float32) + dit.vector_in(pooled_prompt_emb) + dit.guidance_in(guidance * 1000, dtype=torch.float32) + img = dit.img_in(x) + txt = dit.txt_in(prompt_emb, t, text_mask) + + # TeaCache + if tea_cache is not None: + tea_cache_update = tea_cache.check(dit, img, vec) + else: + tea_cache_update = False + + if tea_cache_update: + print("TeaCache skip forward.") + img = tea_cache.update(img) + else: + for block in tqdm(dit.double_blocks, desc="Double stream blocks"): + img, txt = block(img, txt, vec, (freqs_cos, freqs_sin)) + + x = torch.concat([img, txt], dim=1) + for block in tqdm(dit.single_blocks, desc="Single stream blocks"): + x = block(x, vec, (freqs_cos, freqs_sin)) + img = x[:, :-256] + + if tea_cache is not None: + tea_cache.store(img) + img = dit.final_layer(img, vec) + img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2) + return img diff --git a/examples/TeaCache/README.md b/examples/TeaCache/README.md index 6f15c30..2757d3b 100644 --- a/examples/TeaCache/README.md +++ b/examples/TeaCache/README.md @@ -4,13 +4,31 @@ TeaCache ([Timestep Embedding Aware Cache](https://github.com/ali-vilab/TeaCache ## Examples -We provide examples on FLUX.1-dev. See [./flux_teacache.py](./flux_teacache.py). +### FLUX + +Script: [./flux_teacache.py](./flux_teacache.py) + +Model: FLUX.1-dev Steps: 50 GPU: A100 -|TeaCache is disabled|tea_cache_l1_thresh=0.2|tea_cache_l1_thresh=0.4|tea_cache_l1_thresh=0.6|tea_cache_l1_thresh=0.8| -|-|-|-|-|-| -|23s|13s|9s|6s|5s| -|![image_None](https://github.com/user-attachments/assets/2bf5187a-9693-44d3-9ebb-6c33cd15443f)|![image_0 2](https://github.com/user-attachments/assets/5532ba94-c7e2-446e-a9ba-1c68c0f63350)|![image_0 4](https://github.com/user-attachments/assets/4c57c50d-87cd-493b-8603-1da57ec3b70d)|![image_0 6](https://github.com/user-attachments/assets/1d95a3a9-71f9-4b1a-ad5f-a5ea8d52eca7)|![image_0 8](https://github.com/user-attachments/assets/d8cfdd74-8b45-4048-b1b7-ce480aa23fa1) +|TeaCache is disabled|tea_cache_l1_thresh=0.2|tea_cache_l1_thresh=0.8| +|-|-|-| +|23s|13s|5s| +|![image_None](https://github.com/user-attachments/assets/2bf5187a-9693-44d3-9ebb-6c33cd15443f)|![image_0 2](https://github.com/user-attachments/assets/5532ba94-c7e2-446e-a9ba-1c68c0f63350)|![image_0 8](https://github.com/user-attachments/assets/d8cfdd74-8b45-4048-b1b7-ce480aa23fa1) + +### Hunyuan Video + +Script: [./hunyuanvideo_teacache.py](./hunyuanvideo_teacache.py) + +Model: Hunyuan Video + +Steps: 30 + +GPU: A100 + +The following video was generated using TeaCache. It is nearly identical to [the video without TeaCache enabled](https://github.com/user-attachments/assets/48dd24bb-0cc6-40d2-88c3-10feed3267e9), but with double the speed. + +https://github.com/user-attachments/assets/cd9801c5-88ce-4efc-b055-2c7737166f34 diff --git a/examples/TeaCache/hunyuanvideo_teacache.py b/examples/TeaCache/hunyuanvideo_teacache.py new file mode 100644 index 0000000..e29602d --- /dev/null +++ b/examples/TeaCache/hunyuanvideo_teacache.py @@ -0,0 +1,42 @@ +import torch +torch.cuda.set_per_process_memory_fraction(1.0, 0) +from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video + + +download_models(["HunyuanVideo"]) +model_manager = ModelManager() + +# The DiT model is loaded in bfloat16. +model_manager.load_models( + [ + "models/HunyuanVideo/transformers/mp_rank_00_model_states.pt" + ], + torch_dtype=torch.bfloat16, # you can use torch_dtype=torch.float8_e4m3fn to enable quantization. + device="cpu" +) + +# The other modules are loaded in float16. +model_manager.load_models( + [ + "models/HunyuanVideo/text_encoder/model.safetensors", + "models/HunyuanVideo/text_encoder_2", + "models/HunyuanVideo/vae/pytorch_model.pt", + ], + torch_dtype=torch.float16, + device="cpu" +) + +# We support LoRA inference. You can use the following code to load your LoRA model. +# model_manager.load_lora("models/lora/xxx.safetensors", lora_alpha=1.0) + +# The computation device is "cuda". +pipe = HunyuanVideoPipeline.from_model_manager( + model_manager, + torch_dtype=torch.bfloat16, + device="cuda" +) + +# Enjoy! +prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." +video = pipe(prompt, seed=0, tea_cache_l1_thresh=0.15) +save_video(video, "video_girl.mp4", fps=30, quality=6)