optimize stepvideo vae

This commit is contained in:
Artiprocher
2025-02-18 17:28:05 +08:00
parent f191353cf4
commit 9cff769fbd
7 changed files with 197 additions and 28 deletions

View File

@@ -14,6 +14,19 @@ import torch
from einops import rearrange
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm
from einops import repeat
class BaseGroupNorm(nn.GroupNorm):
def __init__(self, num_groups, num_channels):
super().__init__(num_groups=num_groups, num_channels=num_channels)
def forward(self, x, zero_pad=False, **kwargs):
if zero_pad:
return base_group_norm_with_zero_pad(x, self, **kwargs)
else:
return base_group_norm(x, self, **kwargs)
def base_group_norm(x, norm_layer, act_silu=False, channel_last=False):
@@ -456,14 +469,14 @@ class AttnBlock(nn.Module):
):
super().__init__()
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels)
self.norm = BaseGroupNorm(num_groups=32, num_channels=in_channels)
self.q = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
self.k = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
self.v = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
self.proj_out = CausalConvChannelLast(in_channels, in_channels, kernel_size=1)
def attention(self, x, is_init=True):
x = base_group_norm(x, self.norm, act_silu=False, channel_last=True)
x = self.norm(x, act_silu=False, channel_last=True)
q = self.q(x, is_init)
k = self.k(x, is_init)
v = self.v(x, is_init)
@@ -495,12 +508,12 @@ class Resnet3DBlock(nn.Module):
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels)
self.norm1 = BaseGroupNorm(num_groups=32, num_channels=in_channels)
self.conv1 = CausalConvAfterNorm(in_channels, out_channels, kernel_size=3)
if temb_channels > 0:
self.temb_proj = nn.Linear(temb_channels, out_channels)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels)
self.norm2 = BaseGroupNorm(num_groups=32, num_channels=out_channels)
self.conv2 = CausalConvAfterNorm(out_channels, out_channels, kernel_size=3)
assert conv_shortcut is False
@@ -514,14 +527,14 @@ class Resnet3DBlock(nn.Module):
def forward(self, x, temb=None, is_init=True):
x = x.permute(0,2,3,4,1).contiguous()
h = base_group_norm_with_zero_pad(x, self.norm1, act_silu=True, pad_size=2)
h = self.norm1(x, zero_pad=True, act_silu=True, pad_size=2)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nn.functional.silu(temb))[:, :, None, None]
x = self.nin_shortcut(x) if self.in_channels != self.out_channels else x
h = base_group_norm_with_zero_pad(h, self.norm2, act_silu=True, pad_size=2)
h = self.norm2(h, zero_pad=True, act_silu=True, pad_size=2)
x = self.conv2(h, residual=x)
x = x.permute(0,4,1,2,3)
@@ -675,10 +688,10 @@ class Res3DBlockUpsample(nn.Module):
self.act_ = nn.SiLU(inplace=True)
self.conv1 = CausalConvChannelLast(num_filters, num_filters, kernel_size=[3, 3, 3])
self.norm1 = nn.GroupNorm(32, num_filters)
self.norm1 = BaseGroupNorm(32, num_filters)
self.conv2 = CausalConvChannelLast(num_filters, num_filters, kernel_size=[3, 3, 3])
self.norm2 = nn.GroupNorm(32, num_filters)
self.norm2 = BaseGroupNorm(32, num_filters)
self.down_sampling = down_sampling
if down_sampling:
@@ -688,7 +701,7 @@ class Res3DBlockUpsample(nn.Module):
if num_filters != input_filters or down_sampling:
self.conv3 = CausalConvChannelLast(input_filters, num_filters, kernel_size=[1, 1, 1], stride=self.down_sampling_stride)
self.norm3 = nn.GroupNorm(32, num_filters)
self.norm3 = BaseGroupNorm(32, num_filters)
def forward(self, x, is_init=False):
x = x.permute(0,2,3,4,1).contiguous()
@@ -696,14 +709,14 @@ class Res3DBlockUpsample(nn.Module):
residual = x
h = self.conv1(x, is_init)
h = base_group_norm(h, self.norm1, act_silu=True, channel_last=True)
h = self.norm1(h, act_silu=True, channel_last=True)
h = self.conv2(h, is_init)
h = base_group_norm(h, self.norm2, act_silu=False, channel_last=True)
h = self.norm2(h, act_silu=False, channel_last=True)
if self.down_sampling or self.num_filters != self.input_filters:
x = self.conv3(x, is_init)
x = base_group_norm(x, self.norm3, act_silu=False, channel_last=True)
x = self.norm3(x, act_silu=False, channel_last=True)
h.add_(x)
h = self.act_(h)
@@ -973,7 +986,7 @@ class StepVideoVAE(nn.Module):
return dec
@torch.inference_mode()
def decode(self, z):
def decode_original(self, z):
# b (nc cf) c h w -> (b nc) cf c h w -> decode -> (b nc) c cf h w -> b (nc cf) c h w
chunks = list(z.split(self.latent_len, dim=1))
@@ -998,15 +1011,104 @@ class StepVideoVAE(nn.Module):
x = self.mix(x)
return x
def mix(self, x):
remain_scale = 0.6
def mix(self, x, smooth_scale = 0.6):
remain_scale = smooth_scale
mix_scale = 1. - remain_scale
front = slice(self.frame_len - 1, x.size(1) - 1, self.frame_len)
back = slice(self.frame_len, x.size(1), self.frame_len)
x[:, back] = x[:, back] * remain_scale + x[:, front] * mix_scale
x[:, front] = x[:, front] * remain_scale + x[:, back] * mix_scale
x[:, front], x[:, back] = (
x[:, front] * remain_scale + x[:, back] * mix_scale,
x[:, back] * remain_scale + x[:, front] * mix_scale
)
return x
def single_decode(self, hidden_states, device):
chunks = list(hidden_states.split(self.latent_len, dim=1))
for i in range(len(chunks)):
chunks[i] = self.decode_naive(chunks[i].to(device), True).permute(0,2,1,3,4).cpu()
x = torch.cat(chunks, dim=1)
return x
def build_1d_mask(self, length, left_bound, right_bound, border_width):
x = torch.ones((length,))
if not left_bound:
x[:border_width] = (torch.arange(border_width) + 1) / border_width
if not right_bound:
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
return x
def build_mask(self, data, is_bound, border_width):
_, _, _, H, W = data.shape
h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
h = repeat(h, "H -> H W", H=H, W=W)
w = repeat(w, "W -> H W", H=H, W=W)
mask = torch.stack([h, w]).min(dim=0).values
mask = rearrange(mask, "H W -> 1 1 1 H W")
return mask
def tiled_decode(self, hidden_states, device, tile_size=(34, 34), tile_stride=(16, 16)):
B, T, C, H, W = hidden_states.shape
size_h, size_w = tile_size
stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for t in range(0, T, 3):
for h in range(0, H, stride_h):
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
for w in range(0, W, stride_w):
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
t_, h_, w_ = t + 3, h + size_h, w + size_w
tasks.append((t, t_, h, h_, w, w_))
# Run
data_device = "cpu"
computation_device = device
weight = torch.zeros((1, 1, T//3*17, H * 16, W * 16), dtype=hidden_states.dtype, device=data_device)
values = torch.zeros((B, 3, T//3*17, H * 16, W * 16), dtype=hidden_states.dtype, device=data_device)
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
hidden_states_batch = hidden_states[:, t:t_, :, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.decode_naive(hidden_states_batch, True).to(data_device)
mask = self.build_mask(
hidden_states_batch,
is_bound=(h==0, h_>=H, w==0, w_>=W),
border_width=((size_h - stride_h) * 16, (size_w - stride_w) * 16)
).to(dtype=hidden_states.dtype, device=data_device)
target_t = t // 3 * 17
target_h = h * 16
target_w = w * 16
values[
:,
:,
target_t: target_t + hidden_states_batch.shape[2],
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += hidden_states_batch * mask
weight[
:,
:,
target_t: target_t + hidden_states_batch.shape[2],
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += mask
return values / weight
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(16, 16), smooth_scale=0.6):
hidden_states = hidden_states.to("cpu")
if tiled:
video = self.tiled_decode(hidden_states, device, tile_size, tile_stride)
else:
video = self.single_decode(hidden_states, device)
video = self.mix(video, smooth_scale=smooth_scale)
return video
@staticmethod
def state_dict_converter():
return StepVideoVAEStateDictConverter()