diff --git a/diffsynth/models/ltx2_audio_vae.py b/diffsynth/models/ltx2_audio_vae.py index dd46941..708ded7 100644 --- a/diffsynth/models/ltx2_audio_vae.py +++ b/diffsynth/models/ltx2_audio_vae.py @@ -167,6 +167,20 @@ class AudioPatchifier(Patchifier): return audio_latents + def unpatchify_audio( + self, + audio_latents: torch.Tensor, + channels: int, + mel_bins: int + ) -> torch.Tensor: + audio_latents = einops.rearrange( + audio_latents, + "b t (c f) -> b c t f", + c=channels, + f=mel_bins, + ) + return audio_latents + def get_patch_grid_bounds( self, output_shape: AudioLatentShape | VideoLatentShape, diff --git a/diffsynth/models/ltx2_video_vae.py b/diffsynth/models/ltx2_video_vae.py index 0bd1331..ebc9483 100644 --- a/diffsynth/models/ltx2_video_vae.py +++ b/diffsynth/models/ltx2_video_vae.py @@ -68,6 +68,24 @@ class VideoLatentPatchifier(Patchifier): return latents + def unpatchify_video( + self, + latents: torch.Tensor, + frames: int, + height: int, + width: int, + ) -> torch.Tensor: + latents = einops.rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q)", + f=frames, + h=height // self._patch_size[1], + w=width // self._patch_size[2], + p=self._patch_size[1], + q=self._patch_size[2], + ) + return latents + def get_patch_grid_bounds( self, output_shape: AudioLatentShape | VideoLatentShape, diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index 5973e9d..4a96e36 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -127,11 +127,9 @@ class LTX2AudioVideoPipeline(BasePipeline): self.load_models_to_device(self.in_iteration_models + ('upsampler',)) latent = self.upsampler(latent) latent = self.video_vae_encoder.per_channel_statistics.normalize(latent) - latent = self.video_patchifier.patchify(latent) self.scheduler.set_timesteps(special_case="stage2") inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")}) inputs_shared["video_latents"] = self.scheduler.sigmas[0] * inputs_shared["video_noise"] + (1 - self.scheduler.sigmas[0]) * latent - inputs_shared["audio_latents"] = self.audio_patchifier.patchify(inputs_shared["audio_latents"]) inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"] self.load_models_to_device(self.in_iteration_models) if not inputs_shared["use_distilled_pipeline"]: @@ -147,8 +145,6 @@ class LTX2AudioVideoPipeline(BasePipeline): noise_pred=noise_pred_video, **inputs_shared) inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared) - inputs_shared["video_latents"] = self.video_patchifier.unpatchify(inputs_shared["video_latents"], inputs_shared["video_latent_shape"]) - inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"]) return inputs_shared @torch.no_grad() @@ -187,7 +183,6 @@ class LTX2AudioVideoPipeline(BasePipeline): # Scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, special_case="ditilled_stage1" if use_distilled_pipeline else None) - # self.load_lora(self.dit, self.stage2_lora_path) # Inputs inputs_posi = { "prompt": prompt, @@ -203,6 +198,7 @@ class LTX2AudioVideoPipeline(BasePipeline): "tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels, "tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames, "use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, + "video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) @@ -220,8 +216,6 @@ class LTX2AudioVideoPipeline(BasePipeline): noise_pred=noise_pred_video, **inputs_shared) inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared) - inputs_shared["video_latents"] = self.video_patchifier.unpatchify(inputs_shared["video_latents"], inputs_shared["video_latent_shape"]) - inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"]) # Denoise Stage 2 inputs_shared = self.stage2_denoise(inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd) @@ -422,7 +416,6 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit): video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate) video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels) video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device) - video_noise = pipe.video_patchifier.patchify(video_noise) latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device) video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float() @@ -431,7 +424,6 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit): audio_latent_shape = AudioLatentShape.from_video_pixel_shape(video_pixel_shape) audio_noise = pipe.generate_noise(audio_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device) - audio_noise = pipe.audio_patchifier.patchify(audio_noise) audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device) return { "video_noise": video_noise, @@ -473,14 +465,21 @@ def model_fn_ltx2( video_latents=None, video_context=None, video_positions=None, + video_patchifier=None, audio_latents=None, audio_context=None, audio_positions=None, + audio_patchifier=None, timestep=None, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, **kwargs, ): + # patchify + b, c_v, f, h, w = video_latents.shape + _, c_a, _, mel_bins = audio_latents.shape + video_latents = video_patchifier.patchify(video_latents) + audio_latents = audio_patchifier.patchify(audio_latents) #TODO: support gradient checkpointing timestep = timestep.float() / 1000. video_timesteps = timestep.repeat(1, video_latents.shape[1], 1) @@ -495,4 +494,7 @@ def model_fn_ltx2( audio_context=audio_context, audio_timesteps=audio_timesteps, ) + # unpatchify + vx = video_patchifier.unpatchify_video(vx, f, h, w) + ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins) return vx, ax