diff --git a/diffsynth/models/wan_video_vae.py b/diffsynth/models/wan_video_vae.py index e43e6a8..3c2181a 100644 --- a/diffsynth/models/wan_video_vae.py +++ b/diffsynth/models/wan_video_vae.py @@ -1023,11 +1023,11 @@ class VideoVAE_(nn.Module): for i in range(iter_): self._conv_idx = [0] if i == 0: - out = self.decoder(x[:, :, i:i + 1, :, :], + out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) else: - out_ = self.decoder(x[:, :, i:i + 1, :, :], + out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) out = torch.cat([out, out_], 2) # may add tensor offload @@ -1303,11 +1303,11 @@ class VideoVAE38_(VideoVAE_): for i in range(iter_): self._enc_conv_idx = [0] if i == 0: - out = self.encoder(x[:, :, :1, :, :], + out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) else: - out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx) out = torch.cat([out, out_], 2) @@ -1337,12 +1337,12 @@ class VideoVAE38_(VideoVAE_): for i in range(iter_): self._conv_idx = [0] if i == 0: - out = self.decoder(x[:, :, i:i + 1, :, :], + out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True) else: - out_ = self.decoder(x[:, :, i:i + 1, :, :], + out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx) out = torch.cat([out, out_], 2)