fix: WanVAE2.2 decode error

This commit is contained in:
Mr_Dwj
2026-02-13 01:13:08 +08:00
parent 96fb0f3afe
commit bd3c5822a1

View File

@@ -509,7 +509,7 @@ class Up_ResidualBlock(nn.Module):
x_main, feat_cache, feat_idx = module(x_main, feat_cache, feat_idx)
if self.avg_shortcut is not None:
x_shortcut = self.avg_shortcut(x, first_chunk)
return x_main + x_shortcut
return x_main + x_shortcut, feat_cache, feat_idx
else:
return x_main, feat_cache, feat_idx
@@ -1336,6 +1336,7 @@ class VideoVAE38_(VideoVAE_):
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
# breakpoint()
if i == 0:
out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
feat_cache=self._feat_map,