refactor patchify

This commit is contained in:
mi804
2026-01-31 13:55:52 +08:00
parent 9f07d65ebb
commit 1c8a0f8317
3 changed files with 43 additions and 9 deletions

View File

@@ -167,6 +167,20 @@ class AudioPatchifier(Patchifier):
return audio_latents
def unpatchify_audio(
self,
audio_latents: torch.Tensor,
channels: int,
mel_bins: int
) -> torch.Tensor:
audio_latents = einops.rearrange(
audio_latents,
"b t (c f) -> b c t f",
c=channels,
f=mel_bins,
)
return audio_latents
def get_patch_grid_bounds(
self,
output_shape: AudioLatentShape | VideoLatentShape,

View File

@@ -68,6 +68,24 @@ class VideoLatentPatchifier(Patchifier):
return latents
def unpatchify_video(
self,
latents: torch.Tensor,
frames: int,
height: int,
width: int,
) -> torch.Tensor:
latents = einops.rearrange(
latents,
"b (f h w) (c p q) -> b c f (h p) (w q)",
f=frames,
h=height // self._patch_size[1],
w=width // self._patch_size[2],
p=self._patch_size[1],
q=self._patch_size[2],
)
return latents
def get_patch_grid_bounds(
self,
output_shape: AudioLatentShape | VideoLatentShape,