diff --git a/diffsynth/models/wanx_vae.py b/diffsynth/models/wanx_vae.py index 2f0c8a7..90bf325 100644 --- a/diffsynth/models/wanx_vae.py +++ b/diffsynth/models/wanx_vae.py @@ -1,8 +1,9 @@ -from einops import rearrange +from einops import rearrange, repeat import torch import torch.nn as nn import torch.nn.functional as F +from tqdm import tqdm CACHE_T = 2 @@ -605,6 +606,7 @@ class WanXVideoVAE(nn.Module): # init model self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False) + self.upsampling_factor = 8 def encode(self, videos): """ @@ -615,12 +617,98 @@ class WanXVideoVAE(nn.Module): for u in videos ] - def decode(self, zs): - return [ - self.model.decode(u.unsqueeze(0), - self.scale).float().clamp_(-1, 1).squeeze(0) - for u in zs - ] + + def single_decode(self, hidden_state, device): + hidden_state = hidden_state.to(device) + x = self.model.decode(hidden_state.unsqueeze(0), self.scale) + return x.float().clamp_(-1, 1).squeeze(0) + + + def build_1d_mask(self, length, left_bound, right_bound, border_width): + x = torch.ones((length,)) + if not left_bound: + x[:border_width] = (torch.arange(border_width) + 1) / border_width + if not right_bound: + x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,)) + return x + + + def build_mask(self, data, is_bound, border_width): + _, _, _, H, W = data.shape + h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0]) + w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1]) + + h = repeat(h, "H -> H W", H=H, W=W) + w = repeat(w, "W -> H W", H=H, W=W) + + mask = torch.stack([h, w]).min(dim=0).values + mask = rearrange(mask, "H W -> 1 1 1 H W") + return mask + + + def tiled_decode(self, hidden_states, device, tile_size, tile_stride): + hidden_states = hidden_states.unsqueeze(0) + _, _, T, H, W = hidden_states.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + # Split tasks + tasks = [] + for h in range(0, H, stride_h): + if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue + for w in range(0, W, stride_w): + if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue + h_, w_ = h + size_h, w + size_w + tasks.append((h, h_, w, w_)) + + data_device = "cpu" + computation_device = device + + out_T = T*4-3 + weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device) + values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device) + + for h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"): + hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device) + hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device) + + mask = self.build_mask( + hidden_states_batch, + is_bound=(h==0, h_>=H, w==0, w_>=W), + border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor) + ).to(dtype=hidden_states.dtype, device=data_device) + + target_h = h * self.upsampling_factor + target_w = w * self.upsampling_factor + values[ + :, + :, + :, + target_h:target_h + hidden_states_batch.shape[3], + target_w:target_w + hidden_states_batch.shape[4], + ] += hidden_states_batch * mask + weight[ + :, + :, + :, + target_h: target_h + hidden_states_batch.shape[3], + target_w: target_w + hidden_states_batch.shape[4], + ] += mask + values = values / weight + values = values.float().clamp_(-1, 1).squeeze(0) + return values + + + def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): + hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states] + videos = [] + for hidden_state in hidden_states: + if tiled: + video = self.tiled_decode(hidden_state, device, tile_size, tile_stride) + else: + video = self.single_decode(hidden_state, device) + videos.append(video) + return videos @staticmethod diff --git a/examples/WanX/test_vae.py b/examples/WanX/test_vae.py index 83523d1..450f1a3 100644 --- a/examples/WanX/test_vae.py +++ b/examples/WanX/test_vae.py @@ -16,7 +16,7 @@ def save_video(tensor, u, nrow=nrow, normalize=normalize, value_range=value_range) for u in tensor.unbind(2) ], - dim=1).permute(1, 2, 3, 0) + dim=1).permute(1, 2, 3, 0) #frame, h, w, 3 tensor = (tensor * 255).type(torch.uint8).cpu() # write video @@ -26,6 +26,8 @@ def save_video(tensor, writer.append_data(frame) writer.close() +torch.cuda.memory._record_memory_history() + model_manager = ModelManager(torch_dtype=torch.float, device="cuda") model_manager.load_models([ "models/WanX/vae.pth", @@ -34,9 +36,12 @@ model_manager.load_models([ vae = model_manager.fetch_model('wanxvideo_vae') latents = [torch.load('sample.pt')] -videos = vae.decode(latents) -back_encode = vae.encode(videos) -save_video(videos[0][None], save_file='example.mp4', fps=16, nrow=1) +videos = vae.decode(latents, device=latents[0].device, tiled=True) +# back_encode = vae.encode(videos) + +torch.cuda.memory._dump_snapshot("my_snapshot.pickle") + +save_video(videos[0][None], save_file='example3.mp4', fps=16, nrow=8) print(latents) print(videos) -print(back_encode) +# print(back_encode)