From 2cefc20ed66295ba269e55c5a89d6fac8ad362c3 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Fri, 21 Feb 2025 12:58:45 +0800 Subject: [PATCH] wanx tiled encode --- diffsynth/models/wanx_vae.py | 102 ++++++++++++++++++++++++++++------- examples/WanX/test_vae.py | 9 ++-- 2 files changed, 88 insertions(+), 23 deletions(-) diff --git a/diffsynth/models/wanx_vae.py b/diffsynth/models/wanx_vae.py index 90bf325..a207b1a 100644 --- a/diffsynth/models/wanx_vae.py +++ b/diffsynth/models/wanx_vae.py @@ -608,21 +608,6 @@ class WanXVideoVAE(nn.Module): self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False) self.upsampling_factor = 8 - def encode(self, videos): - """ - videos: A list of videos each with shape [C, T, H, W]. - """ - return [ - self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) - for u in videos - ] - - - def single_decode(self, hidden_state, device): - hidden_state = hidden_state.to(device) - x = self.model.decode(hidden_state.unsqueeze(0), self.scale) - return x.float().clamp_(-1, 1).squeeze(0) - def build_1d_mask(self, length, left_bound, right_bound, border_width): x = torch.ones((length,)) @@ -647,7 +632,6 @@ class WanXVideoVAE(nn.Module): def tiled_decode(self, hidden_states, device, tile_size, tile_stride): - hidden_states = hidden_states.unsqueeze(0) _, _, T, H, W = hidden_states.shape size_h, size_w = tile_size stride_h, stride_w = tile_stride @@ -664,7 +648,7 @@ class WanXVideoVAE(nn.Module): data_device = "cpu" computation_device = device - out_T = T*4-3 + out_T = T * 4 - 3 weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device) values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device) @@ -695,18 +679,100 @@ class WanXVideoVAE(nn.Module): target_w: target_w + hidden_states_batch.shape[4], ] += mask values = values / weight - values = values.float().clamp_(-1, 1).squeeze(0) + values = values.float().clamp_(-1, 1) return values + def tiled_encode(self, video, device, tile_size, tile_stride): + _, _, T, H, W = video.shape + size_h, size_w = tile_size + stride_h, stride_w = tile_stride + + # Split tasks + tasks = [] + for h in range(0, H, stride_h): + if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue + for w in range(0, W, stride_w): + if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue + h_, w_ = h + size_h, w + size_w + tasks.append((h, h_, w, w_)) + + data_device = "cpu" + computation_device = device + + out_T = (T + 3) // 4 + weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + + for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"): + hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device) + hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device) + + mask = self.build_mask( + hidden_states_batch, + is_bound=(h==0, h_>=H, w==0, w_>=W), + border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor) + ).to(dtype=video.dtype, device=data_device) + + target_h = h // self.upsampling_factor + target_w = w // self.upsampling_factor + values[ + :, + :, + :, + target_h:target_h + hidden_states_batch.shape[3], + target_w:target_w + hidden_states_batch.shape[4], + ] += hidden_states_batch * mask + weight[ + :, + :, + :, + target_h: target_h + hidden_states_batch.shape[3], + target_w: target_w + hidden_states_batch.shape[4], + ] += mask + values = values / weight + values = values.float() + return values + + + def single_encode(self, video, device): + video = video.to(device) + x = self.model.encode(video, self.scale) + return x.float() + + + def single_decode(self, hidden_state, device): + hidden_state = hidden_state.to(device) + video = self.model.decode(hidden_state, self.scale) + return video.float().clamp_(-1, 1) + + + def encode(self, videos, device, tiled=False, tile_size=(272, 272), tile_stride=(144, 128)): + + videos = [video.to("cpu") for video in videos] + hidden_states = [] + for video in videos: + video = video.unsqueeze(0) + if tiled: + assert tile_size[0] % self.upsampling_factor == 0 and tile_size[1] % self.upsampling_factor == 0, f"tile_size must be devisible by {self.upsampling_factor}" + hidden_state = self.tiled_encode(video, device, tile_size, tile_stride) + else: + hidden_state = self.single_encode(video, device) + hidden_state = hidden_state.squeeze(0) + hidden_states.append(hidden_state) + return hidden_states + + def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states] videos = [] for hidden_state in hidden_states: + hidden_state = hidden_state.unsqueeze(0) if tiled: video = self.tiled_decode(hidden_state, device, tile_size, tile_stride) else: video = self.single_decode(hidden_state, device) + video = video.squeeze(0) videos.append(video) return videos diff --git a/examples/WanX/test_vae.py b/examples/WanX/test_vae.py index 450f1a3..7b44615 100644 --- a/examples/WanX/test_vae.py +++ b/examples/WanX/test_vae.py @@ -37,11 +37,10 @@ vae = model_manager.fetch_model('wanxvideo_vae') latents = [torch.load('sample.pt')] videos = vae.decode(latents, device=latents[0].device, tiled=True) -# back_encode = vae.encode(videos) +back_encode = vae.encode(videos, device=latents[0].device, tiled=True) +videos_back_encode = vae.decode(back_encode, device=latents[0].device, tiled=False) torch.cuda.memory._dump_snapshot("my_snapshot.pickle") -save_video(videos[0][None], save_file='example3.mp4', fps=16, nrow=8) -print(latents) -print(videos) -# print(back_encode) +save_video(videos[0][None], save_file='example.mp4', fps=16, nrow=1) +save_video(videos_back_encode[0][None], save_file='example_backencode.mp4', fps=16, nrow=1)