support hunyuanvideo v2v

This commit is contained in:
Artiprocher
2024-12-23 20:43:47 +08:00
parent c06ea2271a
commit 405ca6be33
6 changed files with 173 additions and 16 deletions

View File

@@ -453,7 +453,7 @@ class HunyuanVideoVAEDecoder(nn.Module):
weight = torch.zeros((1, 1, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
values = torch.zeros((B, 3, (T - 1) * 4 + 1, H * 8, W * 8), dtype=torch_dtype, device=data_device)
for t, t_, h, h_, w, w_ in tqdm(tasks):
for t, t_, h, h_, w, w_ in tqdm(tasks, desc="VAE decoding"):
hidden_states_batch = hidden_states[:, :, t:t_, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.forward(hidden_states_batch).to(data_device)
if t > 0: