diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index 397a2e7..f04da2d 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -1234,11 +1234,18 @@ class WanVideoVAE(nn.Module): def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)): - if tiled: - video = self.tiled_decode(hidden_states, device, tile_size, tile_stride) - else: - video = self.single_decode(hidden_states, device) - return video + hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states] + videos = [] + for hidden_state in hidden_states: + hidden_state = hidden_state.unsqueeze(0) + if tiled: + video = self.tiled_decode(hidden_state, device, tile_size, tile_stride) + else: + video = self.single_decode(hidden_state, device) + video = video.squeeze(0) + videos.append(video) + videos = torch.stack(videos) + return videos @staticmethod