refactor patchify

This commit is contained in:
mi804
2026-01-31 13:55:52 +08:00
parent 9f07d65ebb
commit 1c8a0f8317
3 changed files with 43 additions and 9 deletions

View File

@@ -167,6 +167,20 @@ class AudioPatchifier(Patchifier):
return audio_latents
def unpatchify_audio(
self,
audio_latents: torch.Tensor,
channels: int,
mel_bins: int
) -> torch.Tensor:
audio_latents = einops.rearrange(
audio_latents,
"b t (c f) -> b c t f",
c=channels,
f=mel_bins,
)
return audio_latents
def get_patch_grid_bounds(
self,
output_shape: AudioLatentShape | VideoLatentShape,