refactor patchify

This commit is contained in:
mi804
2026-01-31 13:55:52 +08:00
parent 9f07d65ebb
commit 1c8a0f8317
3 changed files with 43 additions and 9 deletions

View File

@@ -68,6 +68,24 @@ class VideoLatentPatchifier(Patchifier):
return latents
def unpatchify_video(
self,
latents: torch.Tensor,
frames: int,
height: int,
width: int,
) -> torch.Tensor:
latents = einops.rearrange(
latents,
"b (f h w) (c p q) -> b c f (h p) (w q)",
f=frames,
h=height // self._patch_size[1],
w=width // self._patch_size[2],
p=self._patch_size[1],
q=self._patch_size[2],
)
return latents
def get_patch_grid_bounds(
self,
output_shape: AudioLatentShape | VideoLatentShape,