mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
refactor patchify
This commit is contained in:
@@ -167,6 +167,20 @@ class AudioPatchifier(Patchifier):
|
||||
|
||||
return audio_latents
|
||||
|
||||
def unpatchify_audio(
|
||||
self,
|
||||
audio_latents: torch.Tensor,
|
||||
channels: int,
|
||||
mel_bins: int
|
||||
) -> torch.Tensor:
|
||||
audio_latents = einops.rearrange(
|
||||
audio_latents,
|
||||
"b t (c f) -> b c t f",
|
||||
c=channels,
|
||||
f=mel_bins,
|
||||
)
|
||||
return audio_latents
|
||||
|
||||
def get_patch_grid_bounds(
|
||||
self,
|
||||
output_shape: AudioLatentShape | VideoLatentShape,
|
||||
|
||||
@@ -68,6 +68,24 @@ class VideoLatentPatchifier(Patchifier):
|
||||
|
||||
return latents
|
||||
|
||||
def unpatchify_video(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
) -> torch.Tensor:
|
||||
latents = einops.rearrange(
|
||||
latents,
|
||||
"b (f h w) (c p q) -> b c f (h p) (w q)",
|
||||
f=frames,
|
||||
h=height // self._patch_size[1],
|
||||
w=width // self._patch_size[2],
|
||||
p=self._patch_size[1],
|
||||
q=self._patch_size[2],
|
||||
)
|
||||
return latents
|
||||
|
||||
def get_patch_grid_bounds(
|
||||
self,
|
||||
output_shape: AudioLatentShape | VideoLatentShape,
|
||||
|
||||
Reference in New Issue
Block a user