Merge pull request #316 from modelscope/teacache-hunyuanvideo

support teacache-hunyuanvideo
This commit is contained in:
Zhongjie Duan
2025-01-14 14:48:04 +08:00
committed by GitHub
3 changed files with 163 additions and 7 deletions

View File

@@ -8,6 +8,7 @@ import torch
from einops import rearrange
import numpy as np
from PIL import Image
from tqdm import tqdm
@@ -94,6 +95,7 @@ class HunyuanVideoPipeline(BasePipeline):
embedded_guidance=6.0,
cfg_scale=1.0,
num_inference_steps=30,
tea_cache_l1_thresh=None,
tile_size=(17, 30, 30),
tile_stride=(12, 20, 20),
step_processor=None,
@@ -126,6 +128,9 @@ class HunyuanVideoPipeline(BasePipeline):
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# TeaCache
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
# Denoise
self.load_models_to_device([] if self.vram_management else ["dit"])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
@@ -134,9 +139,9 @@ class HunyuanVideoPipeline(BasePipeline):
# Inference
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
noise_pred_posi = self.dit(latents, timestep, **prompt_emb_posi, **extra_input)
noise_pred_posi = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(latents, timestep, **prompt_emb_nega, **extra_input)
noise_pred_nega = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
@@ -165,3 +170,94 @@ class HunyuanVideoPipeline(BasePipeline):
frames = self.tensor2video(frames[0])
return frames
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
def check(self, dit: HunyuanVideoDiT, img, vec):
img_ = img.clone()
vec_ = vec.clone()
img_mod1_shift, img_mod1_scale, _, _, _, _ = dit.double_blocks[0].component_a.mod(vec_).chunk(6, dim=-1)
normed_inp = dit.double_blocks[0].component_a.norm1(img_)
modulated_inp = normed_inp * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1)
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step += 1
if self.step == self.num_inference_steps:
self.step = 0
if should_calc:
self.previous_hidden_states = img.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
def lets_dance_hunyuan_video(
dit: HunyuanVideoDiT,
x: torch.Tensor,
t: torch.Tensor,
prompt_emb: torch.Tensor = None,
text_mask: torch.Tensor = None,
pooled_prompt_emb: torch.Tensor = None,
freqs_cos: torch.Tensor = None,
freqs_sin: torch.Tensor = None,
guidance: torch.Tensor = None,
tea_cache: TeaCache = None,
**kwargs
):
B, C, T, H, W = x.shape
vec = dit.time_in(t, dtype=torch.float32) + dit.vector_in(pooled_prompt_emb) + dit.guidance_in(guidance * 1000, dtype=torch.float32)
img = dit.img_in(x)
txt = dit.txt_in(prompt_emb, t, text_mask)
# TeaCache
if tea_cache is not None:
tea_cache_update = tea_cache.check(dit, img, vec)
else:
tea_cache_update = False
if tea_cache_update:
print("TeaCache skip forward.")
img = tea_cache.update(img)
else:
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
x = torch.concat([img, txt], dim=1)
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
x = block(x, vec, (freqs_cos, freqs_sin))
img = x[:, :-256]
if tea_cache is not None:
tea_cache.store(img)
img = dit.final_layer(img, vec)
img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
return img

View File

@@ -4,13 +4,31 @@ TeaCache ([Timestep Embedding Aware Cache](https://github.com/ali-vilab/TeaCache
## Examples
We provide examples on FLUX.1-dev. See [./flux_teacache.py](./flux_teacache.py).
### FLUX
Script: [./flux_teacache.py](./flux_teacache.py)
Model: FLUX.1-dev
Steps: 50
GPU: A100
|TeaCache is disabled|tea_cache_l1_thresh=0.2|tea_cache_l1_thresh=0.4|tea_cache_l1_thresh=0.6|tea_cache_l1_thresh=0.8|
|-|-|-|-|-|
|23s|13s|9s|6s|5s|
|![image_None](https://github.com/user-attachments/assets/2bf5187a-9693-44d3-9ebb-6c33cd15443f)|![image_0 2](https://github.com/user-attachments/assets/5532ba94-c7e2-446e-a9ba-1c68c0f63350)|![image_0 4](https://github.com/user-attachments/assets/4c57c50d-87cd-493b-8603-1da57ec3b70d)|![image_0 6](https://github.com/user-attachments/assets/1d95a3a9-71f9-4b1a-ad5f-a5ea8d52eca7)|![image_0 8](https://github.com/user-attachments/assets/d8cfdd74-8b45-4048-b1b7-ce480aa23fa1)
|TeaCache is disabled|tea_cache_l1_thresh=0.2|tea_cache_l1_thresh=0.8|
|-|-|-|
|23s|13s|5s|
|![image_None](https://github.com/user-attachments/assets/2bf5187a-9693-44d3-9ebb-6c33cd15443f)|![image_0 2](https://github.com/user-attachments/assets/5532ba94-c7e2-446e-a9ba-1c68c0f63350)|![image_0 8](https://github.com/user-attachments/assets/d8cfdd74-8b45-4048-b1b7-ce480aa23fa1)
### Hunyuan Video
Script: [./hunyuanvideo_teacache.py](./hunyuanvideo_teacache.py)
Model: Hunyuan Video
Steps: 30
GPU: A100
The following video was generated using TeaCache. It is nearly identical to [the video without TeaCache enabled](https://github.com/user-attachments/assets/48dd24bb-0cc6-40d2-88c3-10feed3267e9), but with double the speed.
https://github.com/user-attachments/assets/cd9801c5-88ce-4efc-b055-2c7737166f34

View File

@@ -0,0 +1,42 @@
import torch
torch.cuda.set_per_process_memory_fraction(1.0, 0)
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
download_models(["HunyuanVideo"])
model_manager = ModelManager()
# The DiT model is loaded in bfloat16.
model_manager.load_models(
[
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
],
torch_dtype=torch.bfloat16, # you can use torch_dtype=torch.float8_e4m3fn to enable quantization.
device="cpu"
)
# The other modules are loaded in float16.
model_manager.load_models(
[
"models/HunyuanVideo/text_encoder/model.safetensors",
"models/HunyuanVideo/text_encoder_2",
"models/HunyuanVideo/vae/pytorch_model.pt",
],
torch_dtype=torch.float16,
device="cpu"
)
# We support LoRA inference. You can use the following code to load your LoRA model.
# model_manager.load_lora("models/lora/xxx.safetensors", lora_alpha=1.0)
# The computation device is "cuda".
pipe = HunyuanVideoPipeline.from_model_manager(
model_manager,
torch_dtype=torch.bfloat16,
device="cuda"
)
# Enjoy!
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
video = pipe(prompt, seed=0, tea_cache_l1_thresh=0.15)
save_video(video, "video_girl.mp4", fps=30, quality=6)