wanx vae tile decode

This commit is contained in:
mi804
2025-02-21 11:27:30 +08:00
parent 582e33ad51
commit 02a4c8df9f
2 changed files with 105 additions and 12 deletions

View File

@@ -1,8 +1,9 @@
from einops import rearrange
from einops import rearrange, repeat
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
CACHE_T = 2
@@ -605,6 +606,7 @@ class WanXVideoVAE(nn.Module):
# init model
self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
self.upsampling_factor = 8
def encode(self, videos):
"""
@@ -615,12 +617,98 @@ class WanXVideoVAE(nn.Module):
for u in videos
]
def decode(self, zs):
return [
self.model.decode(u.unsqueeze(0),
self.scale).float().clamp_(-1, 1).squeeze(0)
for u in zs
]
def single_decode(self, hidden_state, device):
hidden_state = hidden_state.to(device)
x = self.model.decode(hidden_state.unsqueeze(0), self.scale)
return x.float().clamp_(-1, 1).squeeze(0)
def build_1d_mask(self, length, left_bound, right_bound, border_width):
x = torch.ones((length,))
if not left_bound:
x[:border_width] = (torch.arange(border_width) + 1) / border_width
if not right_bound:
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
return x
def build_mask(self, data, is_bound, border_width):
_, _, _, H, W = data.shape
h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
h = repeat(h, "H -> H W", H=H, W=W)
w = repeat(w, "W -> H W", H=H, W=W)
mask = torch.stack([h, w]).min(dim=0).values
mask = rearrange(mask, "H W -> 1 1 1 H W")
return mask
def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
hidden_states = hidden_states.unsqueeze(0)
_, _, T, H, W = hidden_states.shape
size_h, size_w = tile_size
stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for h in range(0, H, stride_h):
if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
for w in range(0, W, stride_w):
if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
h_, w_ = h + size_h, w + size_w
tasks.append((h, h_, w, w_))
data_device = "cpu"
computation_device = device
out_T = T*4-3
weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
mask = self.build_mask(
hidden_states_batch,
is_bound=(h==0, h_>=H, w==0, w_>=W),
border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
).to(dtype=hidden_states.dtype, device=data_device)
target_h = h * self.upsampling_factor
target_w = w * self.upsampling_factor
values[
:,
:,
:,
target_h:target_h + hidden_states_batch.shape[3],
target_w:target_w + hidden_states_batch.shape[4],
] += hidden_states_batch * mask
weight[
:,
:,
:,
target_h: target_h + hidden_states_batch.shape[3],
target_w: target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
values = values.float().clamp_(-1, 1).squeeze(0)
return values
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
videos = []
for hidden_state in hidden_states:
if tiled:
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
else:
video = self.single_decode(hidden_state, device)
videos.append(video)
return videos
@staticmethod

View File

@@ -16,7 +16,7 @@ def save_video(tensor,
u, nrow=nrow, normalize=normalize, value_range=value_range)
for u in tensor.unbind(2)
],
dim=1).permute(1, 2, 3, 0)
dim=1).permute(1, 2, 3, 0) #frame, h, w, 3
tensor = (tensor * 255).type(torch.uint8).cpu()
# write video
@@ -26,6 +26,8 @@ def save_video(tensor,
writer.append_data(frame)
writer.close()
torch.cuda.memory._record_memory_history()
model_manager = ModelManager(torch_dtype=torch.float, device="cuda")
model_manager.load_models([
"models/WanX/vae.pth",
@@ -34,9 +36,12 @@ model_manager.load_models([
vae = model_manager.fetch_model('wanxvideo_vae')
latents = [torch.load('sample.pt')]
videos = vae.decode(latents)
back_encode = vae.encode(videos)
save_video(videos[0][None], save_file='example.mp4', fps=16, nrow=1)
videos = vae.decode(latents, device=latents[0].device, tiled=True)
# back_encode = vae.encode(videos)
torch.cuda.memory._dump_snapshot("my_snapshot.pickle")
save_video(videos[0][None], save_file='example3.mp4', fps=16, nrow=8)
print(latents)
print(videos)
print(back_encode)
# print(back_encode)