refactor patchify

This commit is contained in:
mi804
2026-01-31 13:55:52 +08:00
parent 9f07d65ebb
commit 1c8a0f8317
3 changed files with 43 additions and 9 deletions

View File

@@ -167,6 +167,20 @@ class AudioPatchifier(Patchifier):
return audio_latents
def unpatchify_audio(
self,
audio_latents: torch.Tensor,
channels: int,
mel_bins: int
) -> torch.Tensor:
audio_latents = einops.rearrange(
audio_latents,
"b t (c f) -> b c t f",
c=channels,
f=mel_bins,
)
return audio_latents
def get_patch_grid_bounds(
self,
output_shape: AudioLatentShape | VideoLatentShape,

View File

@@ -68,6 +68,24 @@ class VideoLatentPatchifier(Patchifier):
return latents
def unpatchify_video(
self,
latents: torch.Tensor,
frames: int,
height: int,
width: int,
) -> torch.Tensor:
latents = einops.rearrange(
latents,
"b (f h w) (c p q) -> b c f (h p) (w q)",
f=frames,
h=height // self._patch_size[1],
w=width // self._patch_size[2],
p=self._patch_size[1],
q=self._patch_size[2],
)
return latents
def get_patch_grid_bounds(
self,
output_shape: AudioLatentShape | VideoLatentShape,

View File

@@ -127,11 +127,9 @@ class LTX2AudioVideoPipeline(BasePipeline):
self.load_models_to_device(self.in_iteration_models + ('upsampler',))
latent = self.upsampler(latent)
latent = self.video_vae_encoder.per_channel_statistics.normalize(latent)
latent = self.video_patchifier.patchify(latent)
self.scheduler.set_timesteps(special_case="stage2")
inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")})
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * inputs_shared["video_noise"] + (1 - self.scheduler.sigmas[0]) * latent
inputs_shared["audio_latents"] = self.audio_patchifier.patchify(inputs_shared["audio_latents"])
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"]
self.load_models_to_device(self.in_iteration_models)
if not inputs_shared["use_distilled_pipeline"]:
@@ -147,8 +145,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
noise_pred=noise_pred_video, **inputs_shared)
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
noise_pred=noise_pred_audio, **inputs_shared)
inputs_shared["video_latents"] = self.video_patchifier.unpatchify(inputs_shared["video_latents"], inputs_shared["video_latent_shape"])
inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"])
return inputs_shared
@torch.no_grad()
@@ -187,7 +183,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength,
special_case="ditilled_stage1" if use_distilled_pipeline else None)
# self.load_lora(self.dit, self.stage2_lora_path)
# Inputs
inputs_posi = {
"prompt": prompt,
@@ -203,6 +198,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline,
"video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
@@ -220,8 +216,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
noise_pred=noise_pred_video, **inputs_shared)
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
noise_pred=noise_pred_audio, **inputs_shared)
inputs_shared["video_latents"] = self.video_patchifier.unpatchify(inputs_shared["video_latents"], inputs_shared["video_latent_shape"])
inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"])
# Denoise Stage 2
inputs_shared = self.stage2_denoise(inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd)
@@ -422,7 +416,6 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels)
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
video_noise = pipe.video_patchifier.patchify(video_noise)
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()
@@ -431,7 +424,6 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
audio_latent_shape = AudioLatentShape.from_video_pixel_shape(video_pixel_shape)
audio_noise = pipe.generate_noise(audio_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
audio_noise = pipe.audio_patchifier.patchify(audio_noise)
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
return {
"video_noise": video_noise,
@@ -473,14 +465,21 @@ def model_fn_ltx2(
video_latents=None,
video_context=None,
video_positions=None,
video_patchifier=None,
audio_latents=None,
audio_context=None,
audio_positions=None,
audio_patchifier=None,
timestep=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
# patchify
b, c_v, f, h, w = video_latents.shape
_, c_a, _, mel_bins = audio_latents.shape
video_latents = video_patchifier.patchify(video_latents)
audio_latents = audio_patchifier.patchify(audio_latents)
#TODO: support gradient checkpointing
timestep = timestep.float() / 1000.
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
@@ -495,4 +494,7 @@ def model_fn_ltx2(
audio_context=audio_context,
audio_timesteps=audio_timesteps,
)
# unpatchify
vx = video_patchifier.unpatchify_video(vx, f, h, w)
ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins)
return vx, ax