Merge pull request #705 from modelscope/wan2.2

fix wan2.2 vae
This commit is contained in:
Zhongjie Duan
2025-07-28 17:04:40 +08:00
committed by GitHub
2 changed files with 7 additions and 5 deletions

View File

@@ -1075,6 +1075,7 @@ class WanVideoVAE(nn.Module):
# init model
self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False)
self.upsampling_factor = 8
self.z_dim = z_dim
def build_1d_mask(self, length, left_bound, right_bound, border_width):
@@ -1170,7 +1171,7 @@ class WanVideoVAE(nn.Module):
out_T = (T + 3) // 4
weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
values = torch.zeros((1, self.z_dim, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"):
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
@@ -1221,8 +1222,8 @@ class WanVideoVAE(nn.Module):
for video in videos:
video = video.unsqueeze(0)
if tiled:
tile_size = (tile_size[0] * 8, tile_size[1] * 8)
tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
tile_size = (tile_size[0] * self.upsampling_factor, tile_size[1] * self.upsampling_factor)
tile_stride = (tile_stride[0] * self.upsampling_factor, tile_stride[1] * self.upsampling_factor)
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
else:
hidden_state = self.single_encode(video, device)
@@ -1372,3 +1373,4 @@ class WanVideoVAE38(WanVideoVAE):
# init model
self.model = VideoVAE38_(z_dim=z_dim, dim=dim).eval().requires_grad_(False)
self.upsampling_factor = 16
self.z_dim = z_dim