From e425753f7995d90a3a13b2d69da6216990bed58e Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 14 Mar 2025 17:45:52 +0800 Subject: [PATCH] support teacache in wan --- diffsynth/pipelines/wan_video.py | 114 +++++++++++++++++- examples/wanvideo/README.md | 2 + .../wan_1.3b_text_to_video_accelerate.py | 34 ++++++ 3 files changed, 147 insertions(+), 3 deletions(-) create mode 100644 examples/wanvideo/wan_1.3b_text_to_video_accelerate.py diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index 2f19d42..76e1fa0 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -11,10 +11,11 @@ from einops import rearrange import numpy as np from PIL import Image from tqdm import tqdm +from typing import Optional from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm -from ..models.wan_video_dit import RMSNorm +from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample @@ -209,6 +210,8 @@ class WanVideoPipeline(BasePipeline): tiled=True, tile_size=(30, 52), tile_stride=(15, 26), + tea_cache_l1_thresh=None, + tea_cache_model_id="", progress_bar_cmd=tqdm, progress_bar_st=None, ): @@ -251,6 +254,10 @@ class WanVideoPipeline(BasePipeline): # Extra input extra_input = self.prepare_extra_input(latents) + + # TeaCache + tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} + tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None} # Denoise self.load_models_to_device(["dit"]) @@ -258,9 +265,9 @@ class WanVideoPipeline(BasePipeline): timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) # Inference - noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input) + noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi) if cfg_scale != 1.0: - noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input) + noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega) noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) else: noise_pred = noise_pred_posi @@ -275,3 +282,104 @@ class WanVideoPipeline(BasePipeline): frames = self.tensor2video(frames[0]) return frames + + + +class TeaCache: + def __init__(self, num_inference_steps, rel_l1_thresh, model_id): + self.num_inference_steps = num_inference_steps + self.step = 0 + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = None + self.rel_l1_thresh = rel_l1_thresh + self.previous_residual = None + self.previous_hidden_states = None + + self.coefficients_dict = { + "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], + "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], + "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], + "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], + } + if model_id not in self.coefficients_dict: + supported_model_ids = ", ".join([i for i in self.coefficients_dict]) + raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") + self.coefficients = self.coefficients_dict[model_id] + + def check(self, dit: WanModel, x, t_mod): + modulated_inp = t_mod.clone() + if self.step == 0 or self.step == self.num_inference_steps - 1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = self.coefficients + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = modulated_inp + self.step += 1 + if self.step == self.num_inference_steps: + self.step = 0 + if should_calc: + self.previous_hidden_states = x.clone() + return not should_calc + + def store(self, hidden_states): + self.previous_residual = hidden_states - self.previous_hidden_states + self.previous_hidden_states = None + + def update(self, hidden_states): + hidden_states = hidden_states + self.previous_residual + return hidden_states + + + +def model_fn_wan_video( + dit: WanModel, + x: torch.Tensor, + timestep: torch.Tensor, + context: torch.Tensor, + clip_feature: Optional[torch.Tensor] = None, + y: Optional[torch.Tensor] = None, + tea_cache: TeaCache = None, + **kwargs, +): + t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep)) + t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim)) + context = dit.text_embedding(context) + + if dit.has_image_input: + x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w) + clip_embdding = dit.img_emb(clip_feature) + context = torch.cat([clip_embdding, context], dim=1) + + x, (f, h, w) = dit.patchify(x) + + freqs = torch.cat([ + dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(f * h * w, 1, -1).to(x.device) + + # TeaCache + if tea_cache is not None: + tea_cache_update = tea_cache.check(dit, x, t_mod) + else: + tea_cache_update = False + + if tea_cache_update: + x = tea_cache.update(x) + else: + # blocks + for block in dit.blocks: + x = block(x, context, t_mod, freqs) + if tea_cache is not None: + tea_cache.store(x) + + x = dit.head(x, t) + x = dit.unpatchify(x, (f, h, w)) + return x diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 22c5d98..92f25a3 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -31,6 +31,8 @@ Put sunglasses on the dog. https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb +[TeaCache](https://github.com/ali-vilab/TeaCache) is supported in both T2V and I2V models. It can significantly improve the efficiency. See [`./wan_1.3b_text_to_video_accelerate.py`](./wan_1.3b_text_to_video_accelerate.py). + ### Wan-Video-14B-T2V Wan-Video-14B-T2V is an enhanced version of Wan-Video-1.3B-T2V, offering greater size and power. To utilize this model, you need additional VRAM. We recommend that users adjust the `torch_dtype` and `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py). diff --git a/examples/wanvideo/wan_1.3b_text_to_video_accelerate.py b/examples/wanvideo/wan_1.3b_text_to_video_accelerate.py new file mode 100644 index 0000000..b56915c --- /dev/null +++ b/examples/wanvideo/wan_1.3b_text_to_video_accelerate.py @@ -0,0 +1,34 @@ +import torch +from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData +from modelscope import snapshot_download + + +# Download models +snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B") + +# Load models +model_manager = ModelManager(device="cpu") +model_manager.load_models( + [ + "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors", + "models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth", + "models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth", + ], + torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization. +) +pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") +pipe.enable_vram_management(num_persistent_param_in_dit=None) + +# Text-to-video +video = pipe( + prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。", + negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + num_inference_steps=50, + seed=0, tiled=True, + # TeaCache parameters + tea_cache_l1_thresh=0.05, # The larger this value is, the faster the speed, but the worse the visual quality. + tea_cache_model_id="Wan2.1-T2V-1.3B", # Choose one in (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P). +) +save_video(video, "video1.mp4", fps=15, quality=5) + +# TeaCache doesn't support video-to-video