Merge pull request #1296 from Explorer-Dong/fix/wan_vae

fix: WanVAE2.2 encode and decode error
This commit is contained in:
Zhongjie Duan
2026-03-02 10:19:36 +08:00
committed by GitHub

View File

@@ -469,7 +469,7 @@ class Down_ResidualBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0]):
x_copy = x.clone()
for module in self.downsamples:
x = module(x, feat_cache, feat_idx)
x, feat_cache, feat_idx = module(x, feat_cache, feat_idx)
return x + self.avg_shortcut(x_copy), feat_cache, feat_idx
@@ -506,10 +506,10 @@ class Up_ResidualBlock(nn.Module):
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
x_main = x.clone()
for module in self.upsamples:
x_main = module(x_main, feat_cache, feat_idx)
x_main, feat_cache, feat_idx = module(x_main, feat_cache, feat_idx)
if self.avg_shortcut is not None:
x_shortcut = self.avg_shortcut(x, first_chunk)
return x_main + x_shortcut
return x_main + x_shortcut, feat_cache, feat_idx
else:
return x_main, feat_cache, feat_idx