support teacache-hunyuanvideo

This commit is contained in:
Artiprocher
2025-01-14 14:46:35 +08:00
parent d9fad821b2
commit 9f8112ec34
3 changed files with 163 additions and 7 deletions

View File

@@ -8,6 +8,7 @@ import torch
from einops import rearrange
import numpy as np
from PIL import Image
from tqdm import tqdm
@@ -94,6 +95,7 @@ class HunyuanVideoPipeline(BasePipeline):
embedded_guidance=6.0,
cfg_scale=1.0,
num_inference_steps=30,
tea_cache_l1_thresh=None,
tile_size=(17, 30, 30),
tile_stride=(12, 20, 20),
step_processor=None,
@@ -126,6 +128,9 @@ class HunyuanVideoPipeline(BasePipeline):
# Extra input
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
# TeaCache
tea_cache_kwargs = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh) if tea_cache_l1_thresh is not None else None}
# Denoise
self.load_models_to_device([] if self.vram_management else ["dit"])
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
@@ -134,9 +139,9 @@ class HunyuanVideoPipeline(BasePipeline):
# Inference
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
noise_pred_posi = self.dit(latents, timestep, **prompt_emb_posi, **extra_input)
noise_pred_posi = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
if cfg_scale != 1.0:
noise_pred_nega = self.dit(latents, timestep, **prompt_emb_nega, **extra_input)
noise_pred_nega = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
@@ -165,3 +170,94 @@ class HunyuanVideoPipeline(BasePipeline):
frames = self.tensor2video(frames[0])
return frames
class TeaCache:
def __init__(self, num_inference_steps, rel_l1_thresh):
self.num_inference_steps = num_inference_steps
self.step = 0
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = None
self.rel_l1_thresh = rel_l1_thresh
self.previous_residual = None
self.previous_hidden_states = None
def check(self, dit: HunyuanVideoDiT, img, vec):
img_ = img.clone()
vec_ = vec.clone()
img_mod1_shift, img_mod1_scale, _, _, _, _ = dit.double_blocks[0].component_a.mod(vec_).chunk(6, dim=-1)
normed_inp = dit.double_blocks[0].component_a.norm1(img_)
modulated_inp = normed_inp * (1 + img_mod1_scale.unsqueeze(1)) + img_mod1_shift.unsqueeze(1)
if self.step == 0 or self.step == self.num_inference_steps - 1:
should_calc = True
self.accumulated_rel_l1_distance = 0
else:
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
rescale_func = np.poly1d(coefficients)
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance = 0
self.previous_modulated_input = modulated_inp
self.step += 1
if self.step == self.num_inference_steps:
self.step = 0
if should_calc:
self.previous_hidden_states = img.clone()
return not should_calc
def store(self, hidden_states):
self.previous_residual = hidden_states - self.previous_hidden_states
self.previous_hidden_states = None
def update(self, hidden_states):
hidden_states = hidden_states + self.previous_residual
return hidden_states
def lets_dance_hunyuan_video(
dit: HunyuanVideoDiT,
x: torch.Tensor,
t: torch.Tensor,
prompt_emb: torch.Tensor = None,
text_mask: torch.Tensor = None,
pooled_prompt_emb: torch.Tensor = None,
freqs_cos: torch.Tensor = None,
freqs_sin: torch.Tensor = None,
guidance: torch.Tensor = None,
tea_cache: TeaCache = None,
**kwargs
):
B, C, T, H, W = x.shape
vec = dit.time_in(t, dtype=torch.float32) + dit.vector_in(pooled_prompt_emb) + dit.guidance_in(guidance * 1000, dtype=torch.float32)
img = dit.img_in(x)
txt = dit.txt_in(prompt_emb, t, text_mask)
# TeaCache
if tea_cache is not None:
tea_cache_update = tea_cache.check(dit, img, vec)
else:
tea_cache_update = False
if tea_cache_update:
print("TeaCache skip forward.")
img = tea_cache.update(img)
else:
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
x = torch.concat([img, txt], dim=1)
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
x = block(x, vec, (freqs_cos, freqs_sin))
img = x[:, :-256]
if tea_cache is not None:
tea_cache.store(img)
img = dit.final_layer(img, vec)
img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
return img