mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support teacache in wan
This commit is contained in:
@@ -11,10 +11,11 @@ from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
from ..models.wan_video_text_encoder import T5RelativeEmbedding, T5LayerNorm
|
||||
from ..models.wan_video_dit import RMSNorm
|
||||
from ..models.wan_video_dit import RMSNorm, sinusoidal_embedding_1d
|
||||
from ..models.wan_video_vae import RMS_norm, CausalConv3d, Upsample
|
||||
|
||||
|
||||
@@ -209,6 +210,8 @@ class WanVideoPipeline(BasePipeline):
|
||||
tiled=True,
|
||||
tile_size=(30, 52),
|
||||
tile_stride=(15, 26),
|
||||
tea_cache_l1_thresh=None,
|
||||
tea_cache_model_id="",
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
@@ -251,6 +254,10 @@ class WanVideoPipeline(BasePipeline):
|
||||
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents)
|
||||
|
||||
# TeaCache
|
||||
tea_cache_posi = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||
tea_cache_nega = {"tea_cache": TeaCache(num_inference_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None else None}
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(["dit"])
|
||||
@@ -258,9 +265,9 @@ class WanVideoPipeline(BasePipeline):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = self.dit(latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input)
|
||||
noise_pred_posi = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_posi, **image_emb, **extra_input, **tea_cache_posi)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = self.dit(latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input)
|
||||
noise_pred_nega = model_fn_wan_video(self.dit, latents, timestep=timestep, **prompt_emb_nega, **image_emb, **extra_input, **tea_cache_nega)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
@@ -275,3 +282,104 @@ class WanVideoPipeline(BasePipeline):
|
||||
frames = self.tensor2video(frames[0])
|
||||
|
||||
return frames
|
||||
|
||||
|
||||
|
||||
class TeaCache:
|
||||
def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.step = 0
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = None
|
||||
self.rel_l1_thresh = rel_l1_thresh
|
||||
self.previous_residual = None
|
||||
self.previous_hidden_states = None
|
||||
|
||||
self.coefficients_dict = {
|
||||
"Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
|
||||
"Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
|
||||
"Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
|
||||
"Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
|
||||
}
|
||||
if model_id not in self.coefficients_dict:
|
||||
supported_model_ids = ", ".join([i for i in self.coefficients_dict])
|
||||
raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
|
||||
self.coefficients = self.coefficients_dict[model_id]
|
||||
|
||||
def check(self, dit: WanModel, x, t_mod):
|
||||
modulated_inp = t_mod.clone()
|
||||
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
coefficients = self.coefficients
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
|
||||
should_calc = False
|
||||
else:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.step += 1
|
||||
if self.step == self.num_inference_steps:
|
||||
self.step = 0
|
||||
if should_calc:
|
||||
self.previous_hidden_states = x.clone()
|
||||
return not should_calc
|
||||
|
||||
def store(self, hidden_states):
|
||||
self.previous_residual = hidden_states - self.previous_hidden_states
|
||||
self.previous_hidden_states = None
|
||||
|
||||
def update(self, hidden_states):
|
||||
hidden_states = hidden_states + self.previous_residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
||||
def model_fn_wan_video(
|
||||
dit: WanModel,
|
||||
x: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
clip_feature: Optional[torch.Tensor] = None,
|
||||
y: Optional[torch.Tensor] = None,
|
||||
tea_cache: TeaCache = None,
|
||||
**kwargs,
|
||||
):
|
||||
t = dit.time_embedding(sinusoidal_embedding_1d(dit.freq_dim, timestep))
|
||||
t_mod = dit.time_projection(t).unflatten(1, (6, dit.dim))
|
||||
context = dit.text_embedding(context)
|
||||
|
||||
if dit.has_image_input:
|
||||
x = torch.cat([x, y], dim=1) # (b, c_x + c_y, f, h, w)
|
||||
clip_embdding = dit.img_emb(clip_feature)
|
||||
context = torch.cat([clip_embdding, context], dim=1)
|
||||
|
||||
x, (f, h, w) = dit.patchify(x)
|
||||
|
||||
freqs = torch.cat([
|
||||
dit.freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
||||
dit.freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
||||
], dim=-1).reshape(f * h * w, 1, -1).to(x.device)
|
||||
|
||||
# TeaCache
|
||||
if tea_cache is not None:
|
||||
tea_cache_update = tea_cache.check(dit, x, t_mod)
|
||||
else:
|
||||
tea_cache_update = False
|
||||
|
||||
if tea_cache_update:
|
||||
x = tea_cache.update(x)
|
||||
else:
|
||||
# blocks
|
||||
for block in dit.blocks:
|
||||
x = block(x, context, t_mod, freqs)
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(x)
|
||||
|
||||
x = dit.head(x, t)
|
||||
x = dit.unpatchify(x, (f, h, w))
|
||||
return x
|
||||
|
||||
@@ -31,6 +31,8 @@ Put sunglasses on the dog.
|
||||
|
||||
https://github.com/user-attachments/assets/272808d7-fbeb-4747-a6df-14a0860c75fb
|
||||
|
||||
[TeaCache](https://github.com/ali-vilab/TeaCache) is supported in both T2V and I2V models. It can significantly improve the efficiency. See [`./wan_1.3b_text_to_video_accelerate.py`](./wan_1.3b_text_to_video_accelerate.py).
|
||||
|
||||
### Wan-Video-14B-T2V
|
||||
|
||||
Wan-Video-14B-T2V is an enhanced version of Wan-Video-1.3B-T2V, offering greater size and power. To utilize this model, you need additional VRAM. We recommend that users adjust the `torch_dtype` and `num_persistent_param_in_dit` settings to find an optimal balance between speed and VRAM requirements. See [`./wan_14b_text_to_video.py`](./wan_14b_text_to_video.py).
|
||||
|
||||
34
examples/wanvideo/wan_1.3b_text_to_video_accelerate.py
Normal file
34
examples/wanvideo/wan_1.3b_text_to_video_accelerate.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import torch
|
||||
from diffsynth import ModelManager, WanVideoPipeline, save_video, VideoData
|
||||
from modelscope import snapshot_download
|
||||
|
||||
|
||||
# Download models
|
||||
snapshot_download("Wan-AI/Wan2.1-T2V-1.3B", local_dir="models/Wan-AI/Wan2.1-T2V-1.3B")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(device="cpu")
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors",
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth",
|
||||
"models/Wan-AI/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth",
|
||||
],
|
||||
torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.float8_e4m3fn` to enable FP8 quantization.
|
||||
)
|
||||
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
|
||||
# Text-to-video
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
num_inference_steps=50,
|
||||
seed=0, tiled=True,
|
||||
# TeaCache parameters
|
||||
tea_cache_l1_thresh=0.05, # The larger this value is, the faster the speed, but the worse the visual quality.
|
||||
tea_cache_model_id="Wan2.1-T2V-1.3B", # Choose one in (Wan2.1-T2V-1.3B, Wan2.1-T2V-14B, Wan2.1-I2V-14B-480P, Wan2.1-I2V-14B-720P).
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
|
||||
# TeaCache doesn't support video-to-video
|
||||
Reference in New Issue
Block a user