From 29663b25a65d35f129a88dadbe28e74e895c0a49 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Mon, 28 Jul 2025 16:49:28 +0800 Subject: [PATCH] fix wan2.2 vae --- diffsynth/models/wan_video_vae.py | 8 +++++--- examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index d737e2f..397a2e7 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -1075,6 +1075,7 @@ class WanVideoVAE(nn.Module): # init model self.model = VideoVAE_(z_dim=z_dim).eval().requires_grad_(False) self.upsampling_factor = 8 + self.z_dim = z_dim def build_1d_mask(self, length, left_bound, right_bound, border_width): @@ -1170,7 +1171,7 @@ class WanVideoVAE(nn.Module): out_T = (T + 3) // 4 weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) - values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) + values = torch.zeros((1, self.z_dim, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device) for h, h_, w, w_ in tqdm(tasks, desc="VAE encoding"): hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device) @@ -1221,8 +1222,8 @@ class WanVideoVAE(nn.Module): for video in videos: video = video.unsqueeze(0) if tiled: - tile_size = (tile_size[0] * 8, tile_size[1] * 8) - tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8) + tile_size = (tile_size[0] * self.upsampling_factor, tile_size[1] * self.upsampling_factor) + tile_stride = (tile_stride[0] * self.upsampling_factor, tile_stride[1] * self.upsampling_factor) hidden_state = self.tiled_encode(video, device, tile_size, tile_stride) else: hidden_state = self.single_encode(video, device) @@ -1372,3 +1373,4 @@ class WanVideoVAE38(WanVideoVAE): # init model self.model = VideoVAE38_(z_dim=z_dim, dim=dim).eval().requires_grad_(False) self.upsampling_factor = 16 + self.z_dim = z_dim diff --git a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py index 50d81c2..fa91965 100644 --- a/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py +++ b/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py @@ -19,7 +19,7 @@ pipe.enable_vram_management() video = pipe( prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - seed=0, tiled=False, + seed=0, tiled=True, height=704, width=1248, num_frames=121, ) @@ -35,7 +35,7 @@ input_image = Image.open("data/examples/wan/cat_fightning.jpg").resize((1248, 70 video = pipe( prompt="两只可爱的橘猫戴上拳击手套,站在一个拳击台上搏斗。", negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - seed=0, tiled=False, + seed=0, tiled=True, height=704, width=1248, input_image=input_image, num_frames=121,