mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
fix wan decoder bug
This commit is contained in:
@@ -1023,11 +1023,11 @@ class VideoVAE_(nn.Module):
|
|||||||
for i in range(iter_):
|
for i in range(iter_):
|
||||||
self._conv_idx = [0]
|
self._conv_idx = [0]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
out = self.decoder(x[:, :, i:i + 1, :, :],
|
out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
|
||||||
feat_cache=self._feat_map,
|
feat_cache=self._feat_map,
|
||||||
feat_idx=self._conv_idx)
|
feat_idx=self._conv_idx)
|
||||||
else:
|
else:
|
||||||
out_ = self.decoder(x[:, :, i:i + 1, :, :],
|
out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
|
||||||
feat_cache=self._feat_map,
|
feat_cache=self._feat_map,
|
||||||
feat_idx=self._conv_idx)
|
feat_idx=self._conv_idx)
|
||||||
out = torch.cat([out, out_], 2) # may add tensor offload
|
out = torch.cat([out, out_], 2) # may add tensor offload
|
||||||
@@ -1303,11 +1303,11 @@ class VideoVAE38_(VideoVAE_):
|
|||||||
for i in range(iter_):
|
for i in range(iter_):
|
||||||
self._enc_conv_idx = [0]
|
self._enc_conv_idx = [0]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
out = self.encoder(x[:, :, :1, :, :],
|
out, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, :1, :, :],
|
||||||
feat_cache=self._enc_feat_map,
|
feat_cache=self._enc_feat_map,
|
||||||
feat_idx=self._enc_conv_idx)
|
feat_idx=self._enc_conv_idx)
|
||||||
else:
|
else:
|
||||||
out_ = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
out_, self._enc_feat_map, self._enc_conv_idx = self.encoder(x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
||||||
feat_cache=self._enc_feat_map,
|
feat_cache=self._enc_feat_map,
|
||||||
feat_idx=self._enc_conv_idx)
|
feat_idx=self._enc_conv_idx)
|
||||||
out = torch.cat([out, out_], 2)
|
out = torch.cat([out, out_], 2)
|
||||||
@@ -1337,12 +1337,12 @@ class VideoVAE38_(VideoVAE_):
|
|||||||
for i in range(iter_):
|
for i in range(iter_):
|
||||||
self._conv_idx = [0]
|
self._conv_idx = [0]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
out = self.decoder(x[:, :, i:i + 1, :, :],
|
out, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
|
||||||
feat_cache=self._feat_map,
|
feat_cache=self._feat_map,
|
||||||
feat_idx=self._conv_idx,
|
feat_idx=self._conv_idx,
|
||||||
first_chunk=True)
|
first_chunk=True)
|
||||||
else:
|
else:
|
||||||
out_ = self.decoder(x[:, :, i:i + 1, :, :],
|
out_, self._feat_map, self._conv_idx = self.decoder(x[:, :, i:i + 1, :, :],
|
||||||
feat_cache=self._feat_map,
|
feat_cache=self._feat_map,
|
||||||
feat_idx=self._conv_idx)
|
feat_idx=self._conv_idx)
|
||||||
out = torch.cat([out, out_], 2)
|
out = torch.cat([out, out_], 2)
|
||||||
|
|||||||
Reference in New Issue
Block a user