mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
optimize stepvideo vae
This commit is contained in:
@@ -135,8 +135,8 @@ class VideoData:
|
||||
frame.save(os.path.join(folder, f"{i}.png"))
|
||||
|
||||
|
||||
def save_video(frames, save_path, fps, quality=9):
|
||||
writer = imageio.get_writer(save_path, fps=fps, quality=quality)
|
||||
def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
|
||||
writer = imageio.get_writer(save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params)
|
||||
for frame in tqdm(frames, desc="Saving video"):
|
||||
frame = np.array(frame)
|
||||
writer.append_data(frame)
|
||||
|
||||
@@ -14,6 +14,19 @@ import torch
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from tqdm import tqdm
|
||||
from einops import repeat
|
||||
|
||||
|
||||
class BaseGroupNorm(nn.GroupNorm):
|
||||
def __init__(self, num_groups, num_channels):
|
||||
super().__init__(num_groups=num_groups, num_channels=num_channels)
|
||||
|
||||
def forward(self, x, zero_pad=False, **kwargs):
|
||||
if zero_pad:
|
||||
return base_group_norm_with_zero_pad(x, self, **kwargs)
|
||||
else:
|
||||
return base_group_norm(x, self, **kwargs)
|
||||
|
||||
|
||||
def base_group_norm(x, norm_layer, act_silu=False, channel_last=False):
|
||||
@@ -456,14 +469,14 @@ class AttnBlock(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels)
|
||||
self.norm = BaseGroupNorm(num_groups=32, num_channels=in_channels)
|
||||
self.q = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
|
||||
self.k = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
|
||||
self.v = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
|
||||
self.proj_out = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
|
||||
|
||||
def attention(self, x, is_init=True):
|
||||
x = base_group_norm(x, self.norm, act_silu=False, channel_last=True)
|
||||
x = self.norm(x, act_silu=False, channel_last=True)
|
||||
q = self.q(x, is_init)
|
||||
k = self.k(x, is_init)
|
||||
v = self.v(x, is_init)
|
||||
@@ -495,12 +508,12 @@ class Resnet3DBlock(nn.Module):
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels)
|
||||
self.norm1 = BaseGroupNorm(num_groups=32, num_channels=in_channels)
|
||||
self.conv1 = CausalConvAfterNorm(in_channels, out_channels, kernel_size=3)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = nn.Linear(temb_channels, out_channels)
|
||||
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels)
|
||||
self.norm2 = BaseGroupNorm(num_groups=32, num_channels=out_channels)
|
||||
self.conv2 = CausalConvAfterNorm(out_channels, out_channels, kernel_size=3)
|
||||
|
||||
assert conv_shortcut is False
|
||||
@@ -514,14 +527,14 @@ class Resnet3DBlock(nn.Module):
|
||||
def forward(self, x, temb=None, is_init=True):
|
||||
x = x.permute(0,2,3,4,1).contiguous()
|
||||
|
||||
h = base_group_norm_with_zero_pad(x, self.norm1, act_silu=True, pad_size=2)
|
||||
h = self.norm1(x, zero_pad=True, act_silu=True, pad_size=2)
|
||||
h = self.conv1(h)
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nn.functional.silu(temb))[:, :, None, None]
|
||||
|
||||
x = self.nin_shortcut(x) if self.in_channels != self.out_channels else x
|
||||
|
||||
h = base_group_norm_with_zero_pad(h, self.norm2, act_silu=True, pad_size=2)
|
||||
h = self.norm2(h, zero_pad=True, act_silu=True, pad_size=2)
|
||||
x = self.conv2(h, residual=x)
|
||||
|
||||
x = x.permute(0,4,1,2,3)
|
||||
@@ -675,10 +688,10 @@ class Res3DBlockUpsample(nn.Module):
|
||||
self.act_ = nn.SiLU(inplace=True)
|
||||
|
||||
self.conv1 = CausalConvChannelLast(num_filters, num_filters, kernel_size=[3, 3, 3])
|
||||
self.norm1 = nn.GroupNorm(32, num_filters)
|
||||
self.norm1 = BaseGroupNorm(32, num_filters)
|
||||
|
||||
self.conv2 = CausalConvChannelLast(num_filters, num_filters, kernel_size=[3, 3, 3])
|
||||
self.norm2 = nn.GroupNorm(32, num_filters)
|
||||
self.norm2 = BaseGroupNorm(32, num_filters)
|
||||
|
||||
self.down_sampling = down_sampling
|
||||
if down_sampling:
|
||||
@@ -688,7 +701,7 @@ class Res3DBlockUpsample(nn.Module):
|
||||
|
||||
if num_filters != input_filters or down_sampling:
|
||||
self.conv3 = CausalConvChannelLast(input_filters, num_filters, kernel_size=[1, 1, 1], stride=self.down_sampling_stride)
|
||||
self.norm3 = nn.GroupNorm(32, num_filters)
|
||||
self.norm3 = BaseGroupNorm(32, num_filters)
|
||||
|
||||
def forward(self, x, is_init=False):
|
||||
x = x.permute(0,2,3,4,1).contiguous()
|
||||
@@ -696,14 +709,14 @@ class Res3DBlockUpsample(nn.Module):
|
||||
residual = x
|
||||
|
||||
h = self.conv1(x, is_init)
|
||||
h = base_group_norm(h, self.norm1, act_silu=True, channel_last=True)
|
||||
h = self.norm1(h, act_silu=True, channel_last=True)
|
||||
|
||||
h = self.conv2(h, is_init)
|
||||
h = base_group_norm(h, self.norm2, act_silu=False, channel_last=True)
|
||||
h = self.norm2(h, act_silu=False, channel_last=True)
|
||||
|
||||
if self.down_sampling or self.num_filters != self.input_filters:
|
||||
x = self.conv3(x, is_init)
|
||||
x = base_group_norm(x, self.norm3, act_silu=False, channel_last=True)
|
||||
x = self.norm3(x, act_silu=False, channel_last=True)
|
||||
|
||||
h.add_(x)
|
||||
h = self.act_(h)
|
||||
@@ -973,7 +986,7 @@ class StepVideoVAE(nn.Module):
|
||||
return dec
|
||||
|
||||
@torch.inference_mode()
|
||||
def decode(self, z):
|
||||
def decode_original(self, z):
|
||||
# b (nc cf) c h w -> (b nc) cf c h w -> decode -> (b nc) c cf h w -> b (nc cf) c h w
|
||||
chunks = list(z.split(self.latent_len, dim=1))
|
||||
|
||||
@@ -998,15 +1011,104 @@ class StepVideoVAE(nn.Module):
|
||||
x = self.mix(x)
|
||||
return x
|
||||
|
||||
def mix(self, x):
|
||||
remain_scale = 0.6
|
||||
def mix(self, x, smooth_scale = 0.6):
|
||||
remain_scale = smooth_scale
|
||||
mix_scale = 1. - remain_scale
|
||||
front = slice(self.frame_len - 1, x.size(1) - 1, self.frame_len)
|
||||
back = slice(self.frame_len, x.size(1), self.frame_len)
|
||||
x[:, back] = x[:, back] * remain_scale + x[:, front] * mix_scale
|
||||
x[:, front] = x[:, front] * remain_scale + x[:, back] * mix_scale
|
||||
x[:, front], x[:, back] = (
|
||||
x[:, front] * remain_scale + x[:, back] * mix_scale,
|
||||
x[:, back] * remain_scale + x[:, front] * mix_scale
|
||||
)
|
||||
return x
|
||||
|
||||
def single_decode(self, hidden_states, device):
|
||||
chunks = list(hidden_states.split(self.latent_len, dim=1))
|
||||
for i in range(len(chunks)):
|
||||
chunks[i] = self.decode_naive(chunks[i].to(device), True).permute(0,2,1,3,4).cpu()
|
||||
x = torch.cat(chunks, dim=1)
|
||||
return x
|
||||
|
||||
def build_1d_mask(self, length, left_bound, right_bound, border_width):
|
||||
x = torch.ones((length,))
|
||||
if not left_bound:
|
||||
x[:border_width] = (torch.arange(border_width) + 1) / border_width
|
||||
if not right_bound:
|
||||
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
|
||||
return x
|
||||
|
||||
def build_mask(self, data, is_bound, border_width):
|
||||
_, _, _, H, W = data.shape
|
||||
h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
|
||||
w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
|
||||
|
||||
h = repeat(h, "H -> H W", H=H, W=W)
|
||||
w = repeat(w, "W -> H W", H=H, W=W)
|
||||
|
||||
mask = torch.stack([h, w]).min(dim=0).values
|
||||
mask = rearrange(mask, "H W -> 1 1 1 H W")
|
||||
return mask
|
||||
|
||||
def tiled_decode(self, hidden_states, device, tile_size=(34, 34), tile_stride=(16, 16)):
|
||||
B, T, C, H, W = hidden_states.shape
|
||||
size_h, size_w = tile_size
|
||||
stride_h, stride_w = tile_stride
|
||||
|
||||
# Split tasks
|
||||
tasks = []
|
||||
for t in range(0, T, 3):
|
||||
for h in range(0, H, stride_h):
|
||||
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
|
||||
for w in range(0, W, stride_w):
|
||||
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
|
||||
t_, h_, w_ = t + 3, h + size_h, w + size_w
|
||||
tasks.append((t, t_, h, h_, w, w_))
|
||||
|
||||
# Run
|
||||
data_device = "cpu"
|
||||
computation_device = device
|
||||
|
||||
weight = torch.zeros((1, 1, T//3*17, H * 16, W * 16), dtype=hidden_states.dtype, device=data_device)
|
||||
values = torch.zeros((B, 3, T//3*17, H * 16, W * 16), dtype=hidden_states.dtype, device=data_device)
|
||||
|
||||
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
|
||||
hidden_states_batch = hidden_states[:, t:t_, :, h:h_, w:w_].to(computation_device)
|
||||
hidden_states_batch = self.decode_naive(hidden_states_batch, True).to(data_device)
|
||||
|
||||
mask = self.build_mask(
|
||||
hidden_states_batch,
|
||||
is_bound=(h==0, h_>=H, w==0, w_>=W),
|
||||
border_width=((size_h - stride_h) * 16, (size_w - stride_w) * 16)
|
||||
).to(dtype=hidden_states.dtype, device=data_device)
|
||||
|
||||
target_t = t // 3 * 17
|
||||
target_h = h * 16
|
||||
target_w = w * 16
|
||||
values[
|
||||
:,
|
||||
:,
|
||||
target_t: target_t + hidden_states_batch.shape[2],
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += hidden_states_batch * mask
|
||||
weight[
|
||||
:,
|
||||
:,
|
||||
target_t: target_t + hidden_states_batch.shape[2],
|
||||
target_h: target_h + hidden_states_batch.shape[3],
|
||||
target_w: target_w + hidden_states_batch.shape[4],
|
||||
] += mask
|
||||
return values / weight
|
||||
|
||||
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(16, 16), smooth_scale=0.6):
|
||||
hidden_states = hidden_states.to("cpu")
|
||||
if tiled:
|
||||
video = self.tiled_decode(hidden_states, device, tile_size, tile_stride)
|
||||
else:
|
||||
video = self.single_decode(hidden_states, device)
|
||||
video = self.mix(video, smooth_scale=smooth_scale)
|
||||
return video
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return StepVideoVAEStateDictConverter()
|
||||
|
||||
@@ -13,7 +13,7 @@ from PIL import Image
|
||||
from ..vram_management import enable_vram_management, AutoWrappedModule, AutoWrappedLinear
|
||||
from transformers.models.bert.modeling_bert import BertEmbeddings
|
||||
from ..models.stepvideo_dit import RMSNorm
|
||||
from ..models.stepvideo_vae import CausalConv, CausalConvAfterNorm, Resnet3DBlock, AttnBlock, Res3DBlockUpsample, Upsample2D
|
||||
from ..models.stepvideo_vae import CausalConv, CausalConvAfterNorm, Upsample2D, BaseGroupNorm
|
||||
|
||||
|
||||
|
||||
@@ -100,10 +100,8 @@ class StepVideoPipeline(BasePipeline):
|
||||
torch.nn.Conv3d: AutoWrappedModule,
|
||||
CausalConv: AutoWrappedModule,
|
||||
CausalConvAfterNorm: AutoWrappedModule,
|
||||
Resnet3DBlock: AutoWrappedModule,
|
||||
AttnBlock: AutoWrappedModule,
|
||||
Res3DBlockUpsample: AutoWrappedModule,
|
||||
Upsample2D: AutoWrappedModule,
|
||||
BaseGroupNorm: AutoWrappedModule,
|
||||
},
|
||||
module_config = dict(
|
||||
offload_dtype=dtype,
|
||||
@@ -143,7 +141,7 @@ class StepVideoPipeline(BasePipeline):
|
||||
|
||||
|
||||
def tensor2video(self, frames):
|
||||
frames = rearrange(frames, "T C H W -> T H W C")
|
||||
frames = rearrange(frames, "C T H W -> T H W C")
|
||||
frames = ((frames.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)
|
||||
frames = [Image.fromarray(frame) for frame in frames]
|
||||
return frames
|
||||
@@ -163,9 +161,16 @@ class StepVideoPipeline(BasePipeline):
|
||||
num_frames=204,
|
||||
cfg_scale=9.0,
|
||||
num_inference_steps=30,
|
||||
tiled=True,
|
||||
tile_size=(34, 34),
|
||||
tile_stride=(16, 16),
|
||||
smooth_scale=0.6,
|
||||
progress_bar_cmd=lambda x: x,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
@@ -197,7 +202,7 @@ class StepVideoPipeline(BasePipeline):
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
frames = self.vae.decode(latents)
|
||||
frames = self.vae.decode(latents, device=self.device, smooth_scale=smooth_scale, **tiler_kwargs)
|
||||
self.load_models_to_device([])
|
||||
frames = self.tensor2video(frames[0])
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ StepVideo is a state-of-the-art (SoTA) text-to-video pre-trained model with 30 b
|
||||
|
||||
For original BF16 version, please see [`./stepvideo_text_to_video.py`](./stepvideo_text_to_video.py). 80G VRAM required.
|
||||
|
||||
We also support auto-offload, which can reduce the VRAM requirement to **24GB**; however, it requires 2x time for inference. Please see [`./stepvideo_text_to_video_low_vram.py`](./stepvideo_text_to_video_low_vram.py).
|
||||
|
||||
https://github.com/user-attachments/assets/5954fdaa-a3cf-45a3-bd35-886e3cc4581b
|
||||
|
||||
For FP8 quantized version, please see [`./stepvideo_text_to_video_quantized.py`](./stepvideo_text_to_video_quantized.py). 40G VRAM required.
|
||||
|
||||
@@ -44,4 +44,7 @@ video = pipe(
|
||||
negative_prompt="画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。",
|
||||
num_inference_steps=30, cfg_scale=9, num_frames=51, seed=1
|
||||
)
|
||||
save_video(video, "video.mp4", fps=25, quality=5)
|
||||
save_video(
|
||||
video, "video.mp4", fps=25, quality=5,
|
||||
ffmpeg_params=["-vf", "atadenoise=0a=0.1:0b=0.1:1a=0.1:1b=0.1"]
|
||||
)
|
||||
|
||||
54
examples/stepvideo/stepvideo_text_to_video_low_vram.py
Normal file
54
examples/stepvideo/stepvideo_text_to_video_low_vram.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from modelscope import snapshot_download
|
||||
from diffsynth import ModelManager, StepVideoPipeline, save_video
|
||||
import torch
|
||||
|
||||
|
||||
# Download models
|
||||
snapshot_download(model_id="stepfun-ai/stepvideo-t2v", cache_dir="models")
|
||||
|
||||
# Load the compiled attention for the LLM text encoder.
|
||||
# If you encounter errors here. Please select other compiled file that matches your environment or delete this line.
|
||||
torch.ops.load_library("models/stepfun-ai/stepvideo-t2v/lib/liboptimus_ths-torch2.5-cu124.cpython-310-x86_64-linux-gnu.so")
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager()
|
||||
model_manager.load_models(
|
||||
["models/stepfun-ai/stepvideo-t2v/hunyuan_clip/clip_text_encoder/pytorch_model.bin"],
|
||||
torch_dtype=torch.float32, device="cpu"
|
||||
)
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/stepfun-ai/stepvideo-t2v/step_llm",
|
||||
[
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00001-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00002-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00003-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00004-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00005-of-00006.safetensors",
|
||||
"models/stepfun-ai/stepvideo-t2v/transformer/diffusion_pytorch_model-00006-of-00006.safetensors",
|
||||
]
|
||||
],
|
||||
torch_dtype=torch.float8_e4m3fn, device="cpu"
|
||||
)
|
||||
model_manager.load_models(
|
||||
["models/stepfun-ai/stepvideo-t2v/vae/vae_v2.safetensors"],
|
||||
torch_dtype=torch.bfloat16, device="cpu"
|
||||
)
|
||||
pipe = StepVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
# Enable VRAM management
|
||||
# This model requires 24G VRAM.
|
||||
# In order to speed up, please set `num_persistent_param_in_dit` to a large number or None (unlimited).
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=0)
|
||||
|
||||
# Run!
|
||||
video = pipe(
|
||||
prompt="一名宇航员在月球上发现一块石碑,上面印有“stepfun”字样,闪闪发光。超高清、HDR 视频、环境光、杜比全景声、画面稳定、流畅动作、逼真的细节、专业级构图、超现实主义、自然、生动、超细节、清晰。",
|
||||
negative_prompt="画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。",
|
||||
num_inference_steps=30, cfg_scale=9, num_frames=51, seed=1,
|
||||
tiled=True, tile_size=(34, 34), tile_stride=(16, 16)
|
||||
)
|
||||
save_video(
|
||||
video, "video.mp4", fps=25, quality=5,
|
||||
ffmpeg_params=["-vf", "atadenoise=0a=0.1:0b=0.1:1a=0.1:1b=0.1"]
|
||||
)
|
||||
@@ -37,7 +37,7 @@ model_manager.load_models(
|
||||
pipe = StepVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda")
|
||||
|
||||
# Enable VRAM management
|
||||
# This model requires 80G VRAM.
|
||||
# This model requires 40G VRAM.
|
||||
# In order to reduce VRAM required, please set `num_persistent_param_in_dit` to a small number.
|
||||
pipe.enable_vram_management(num_persistent_param_in_dit=None)
|
||||
|
||||
@@ -47,4 +47,7 @@ video = pipe(
|
||||
negative_prompt="画面暗、低分辨率、不良手、文本、缺少手指、多余的手指、裁剪、低质量、颗粒状、签名、水印、用户名、模糊。",
|
||||
num_inference_steps=30, cfg_scale=9, num_frames=51, seed=1
|
||||
)
|
||||
save_video(video, "video.mp4", fps=25, quality=5)
|
||||
save_video(
|
||||
video, "video.mp4", fps=25, quality=5,
|
||||
ffmpeg_params=["-vf", "atadenoise=0a=0.1:0b=0.1:1a=0.1:1b=0.1"]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user