mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
Fix batch decoding for Wan VAE.
This commit is contained in:
@@ -1234,11 +1234,18 @@ class WanVideoVAE(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
if tiled:
|
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
|
||||||
video = self.tiled_decode(hidden_states, device, tile_size, tile_stride)
|
videos = []
|
||||||
else:
|
for hidden_state in hidden_states:
|
||||||
video = self.single_decode(hidden_states, device)
|
hidden_state = hidden_state.unsqueeze(0)
|
||||||
return video
|
if tiled:
|
||||||
|
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
|
||||||
|
else:
|
||||||
|
video = self.single_decode(hidden_state, device)
|
||||||
|
video = video.squeeze(0)
|
||||||
|
videos.append(video)
|
||||||
|
videos = torch.stack(videos)
|
||||||
|
return videos
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user