mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
Merge pull request #859 from krahets/main
Fix batch decoding for Wan-Video-VAE
This commit is contained in:
@@ -1216,7 +1216,6 @@ class WanVideoVAE(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
def encode(self, videos, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
|
|
||||||
videos = [video.to("cpu") for video in videos]
|
videos = [video.to("cpu") for video in videos]
|
||||||
hidden_states = []
|
hidden_states = []
|
||||||
for video in videos:
|
for video in videos:
|
||||||
@@ -1234,11 +1233,18 @@ class WanVideoVAE(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
def decode(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
|
||||||
if tiled:
|
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
|
||||||
video = self.tiled_decode(hidden_states, device, tile_size, tile_stride)
|
videos = []
|
||||||
else:
|
for hidden_state in hidden_states:
|
||||||
video = self.single_decode(hidden_states, device)
|
hidden_state = hidden_state.unsqueeze(0)
|
||||||
return video
|
if tiled:
|
||||||
|
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
|
||||||
|
else:
|
||||||
|
video = self.single_decode(hidden_state, device)
|
||||||
|
video = video.squeeze(0)
|
||||||
|
videos.append(video)
|
||||||
|
videos = torch.stack(videos)
|
||||||
|
return videos
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user