diff --git a/diffsynth/diffusion/base_pipeline.py b/diffsynth/diffusion/base_pipeline.py index 4d046ab..8e3649a 100644 --- a/diffsynth/diffusion/base_pipeline.py +++ b/diffsynth/diffusion/base_pipeline.py @@ -147,6 +147,19 @@ class BasePipeline(torch.nn.Module): video = [self.vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output] return video + def output_audio_format_check(self, audio_output): + # output standard foramt: [C, T], output dtype: float() + # remove batch dim + if audio_output.ndim == 3: + audio_output = audio_output.squeeze(0) + # Transform to stereo + if audio_output.shape[0] == 1: + audio_output = audio_output.repeat(2, 1) + elif audio_output.shape[0] == 2: + pass + else: + raise ValueError("The output audio should be [C, T] or [1, C, T] or [2, C, T].") + return audio_output.float() def load_models_to_device(self, model_names): if self.vram_management_enabled: diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index c8ee359..ba25f6d 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -141,7 +141,6 @@ class LTX2AudioVideoPipeline(BasePipeline): pipe.vram_management_enabled = pipe.check_vram_management_state() return pipe - def denoise_stage(self, inputs_shared, inputs_posi, inputs_nega, units, cfg_scale=1.0, progress_bar_cmd=tqdm, skip_stage=False): if skip_stage: return inputs_shared, inputs_posi, inputs_nega @@ -160,7 +159,6 @@ class LTX2AudioVideoPipeline(BasePipeline): inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared) return inputs_shared, inputs_posi, inputs_nega - @torch.no_grad() def __call__( self, @@ -231,7 +229,8 @@ class LTX2AudioVideoPipeline(BasePipeline): video = self.vae_output_to_video(video) self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder']) decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"]) - decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float() + decoded_audio = self.audio_vocoder(decoded_audio) + decoded_audio = self.output_audio_format_check(decoded_audio) return video, decoded_audio @@ -394,8 +393,8 @@ class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit): class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit): def __init__(self): super().__init__( - input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "initial_latents"), - output_params=("denoise_mask_video", "input_latents_video"), + input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "frame_rate", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "initial_latents"), + output_params=("denoise_mask_video", "input_latents_video", "ref_frames_latents", "ref_frames_positions"), onload_model_names=("video_vae_encoder") ) @@ -418,20 +417,30 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit): denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength return initial_latents, denoise_mask - def process(self, pipe: LTX2AudioVideoPipeline, video_latents, input_images, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, input_images_indexes=[0], input_images_strength=1.0, initial_latents=None): + def process(self, pipe: LTX2AudioVideoPipeline, video_latents, input_images, height, width, frame_rate, tiled, tile_size_in_pixels, tile_overlap_in_pixels, input_images_indexes=[0], input_images_strength=1.0, initial_latents=None): if input_images is None or len(input_images) == 0: return {} else: if len(input_images_indexes) != len(set(input_images_indexes)): raise ValueError("Input images must have unique indexes.") pipe.load_models_to_device(self.onload_model_names) - frame_conditions = {} + frame_conditions = {"input_latents_video": None, "denoise_mask_video": None, "ref_frames_latents": [], "ref_frames_positions": []} for img, index in zip(input_images, input_images_indexes): latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels) - # first_frame + # first_frame by replacing latents if index == 0: input_latents_video, denoise_mask_video = self.apply_input_images_to_latents(video_latents, [latents], [0], input_images_strength, initial_latents) frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video}) + # other frames by adding reference latents + else: + latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(latents.shape), device=pipe.device) + video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, False).float() + video_positions[:, 0, ...] = (video_positions[:, 0, ...] + index) / frame_rate + video_positions = video_positions.to(pipe.torch_dtype) + frame_conditions["ref_frames_latents"].append(latents) + frame_conditions["ref_frames_positions"].append(video_positions) + if len(frame_conditions["ref_frames_latents"]) == 0: + frame_conditions.update({"ref_frames_latents": None, "ref_frames_positions": None}) return frame_conditions @@ -553,6 +562,9 @@ def model_fn_ltx2( # First Frame Conditioning input_latents_video=None, denoise_mask_video=None, + # Other Frames Conditioning + ref_frames_latents=None, + ref_frames_positions=None, # In-Context Conditioning in_context_video_latents=None, in_context_video_positions=None, @@ -568,17 +580,24 @@ def model_fn_ltx2( video_latents = video_patchifier.patchify(video_latents) seq_len_video = video_latents.shape[1] video_timesteps = timestep.repeat(1, video_latents.shape[1], 1) - if denoise_mask_video is not None: + # Frist frame conditioning by replacing the video latents + if input_latents_video is not None: denoise_mask_video = video_patchifier.patchify(denoise_mask_video) video_latents = video_latents * denoise_mask_video + video_patchifier.patchify(input_latents_video) * (1.0 - denoise_mask_video) video_timesteps = denoise_mask_video * video_timesteps - - if in_context_video_latents is not None: - in_context_video_latents = video_patchifier.patchify(in_context_video_latents) - in_context_video_timesteps = timestep.repeat(1, in_context_video_latents.shape[1], 1) * 0. - video_latents = torch.cat([video_latents, in_context_video_latents], dim=1) - video_positions = torch.cat([video_positions, in_context_video_positions], dim=2) - video_timesteps = torch.cat([video_timesteps, in_context_video_timesteps], dim=1) + + # Conditioning by replacing the video latents + total_ref_latents = ref_frames_latents if ref_frames_latents is not None else [] + total_ref_positions = ref_frames_positions if ref_frames_positions is not None else [] + total_ref_latents += [in_context_video_latents] if in_context_video_latents is not None else [] + total_ref_positions += [in_context_video_positions] if in_context_video_positions is not None else [] + if len(total_ref_latents) > 0: + for ref_frames_latent, ref_frames_position in zip(total_ref_latents, total_ref_positions): + ref_frames_latent = video_patchifier.patchify(ref_frames_latent) + ref_frames_timestep = timestep.repeat(1, ref_frames_latent.shape[1], 1) * 0. + video_latents = torch.cat([video_latents, ref_frames_latent], dim=1) + video_positions = torch.cat([video_positions, ref_frames_position], dim=2) + video_timesteps = torch.cat([video_timesteps, ref_frames_timestep], dim=1) if audio_latents is not None: _, c_a, _, mel_bins = audio_latents.shape diff --git a/diffsynth/utils/data/media_io_ltx2.py b/diffsynth/utils/data/media_io_ltx2.py index 5526ca9..689d8af 100644 --- a/diffsynth/utils/data/media_io_ltx2.py +++ b/diffsynth/utils/data/media_io_ltx2.py @@ -43,13 +43,10 @@ def _write_audio( ) -> None: if samples.ndim == 1: samples = samples[:, None] - - if samples.shape[1] != 2 and samples.shape[0] == 2: - samples = samples.T - - if samples.shape[1] != 2: - raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.") - + if samples.shape[0] == 1: + samples = samples.repeat(2, 1) + assert samples.ndim == 2 and samples.shape[0] == 2, "audio samples must be [C, S] or [S], C must be 1 or 2" + samples = samples.T # Convert to int16 packed for ingestion; resampler converts to encoder fmt. if samples.dtype != torch.int16: samples = torch.clip(samples, -1.0, 1.0) @@ -69,10 +66,17 @@ def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: """ Prepare the audio stream for writing. """ - audio_stream = container.add_stream("aac", rate=audio_sample_rate) - audio_stream.codec_context.sample_rate = audio_sample_rate + audio_stream = container.add_stream("aac") + supported_sample_rates = audio_stream.codec_context.codec.audio_rates + if supported_sample_rates: + best_rate = min(supported_sample_rates, key=lambda x: abs(x - audio_sample_rate)) + if best_rate != audio_sample_rate: + print(f"Using closest supported audio sample rate: {best_rate}") + else: + best_rate = audio_sample_rate + audio_stream.codec_context.sample_rate = best_rate audio_stream.codec_context.layout = "stereo" - audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate) + audio_stream.codec_context.time_base = Fraction(1, best_rate) return audio_stream def write_video_audio_ltx2( @@ -80,8 +84,31 @@ def write_video_audio_ltx2( audio: torch.Tensor | None, output_path: str, fps: int = 24, - audio_sample_rate: int | None = 24000, + audio_sample_rate: int | None = None, ) -> None: + """ + Writes a sequence of images and an audio tensor to a video file. + + This function utilizes PyAV (or a similar multimedia library) to encode a list of PIL images into a video stream + and multiplex a PyTorch tensor as the audio stream into the output container. + + Args: + video (list[Image.Image]): A list of PIL Image objects representing the video frames. + The length of this list determines the total duration of the video based on the FPS. + audio (torch.Tensor | None): The audio data as a PyTorch tensor. + The shape is typically (channels, samples). If no audio is required, pass None. + channels can be 1 or 2. 1 for mono, 2 for stereo. + output_path (str): The file path (including extension) where the output video will be saved. + fps (int, optional): The frame rate (frames per second) for the video. Defaults to 24. + audio_sample_rate (int | None, optional): The sample rate (e.g., 44100, 48000) for the audio. + If the audio tensor is provided and this is None, the function attempts to infer the rate + based on the audio tensor's length and the video duration. + Raises: + ValueError: If an audio tensor is provided but the sample rate cannot be determined. + """ + duration = len(video) / fps + if audio_sample_rate is None: + audio_sample_rate = int(audio.shape[-1] / duration) width, height = video[0].size container = av.open(output_path, mode="w") diff --git a/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py b/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py index 264211b..0dd0854 100644 --- a/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py +++ b/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py @@ -26,7 +26,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained( stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled-lora-384.safetensors"), ) -prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”" +prompt = "Two cute orange cats, wearing boxing gloves, stand in a boxing ring and fight each other. They are punching each other fast and yelling: 'I will win!'" negative_prompt = ( "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " @@ -41,12 +41,9 @@ negative_prompt = ( "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." ) height, width, num_frames = 512 * 2, 768 * 2, 121 -dataset_snapshot_download( - dataset_id="DiffSynth-Studio/examples_in_diffsynth", - local_dir="./", - allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"] -) -image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height)) +dataset_snapshot_download("DiffSynth-Studio/example_video_dataset", allow_file_pattern="ltx2/*", local_dir="data/example_video_dataset") +first_frame = Image.open("data/example_video_dataset/ltx2/first_frame.png").convert("RGB").resize((width, height)) +last_frame = Image.open("data/example_video_dataset/ltx2/last_frame.png").convert("RGB").resize((width, height)) # first frame video, audio = pipe( prompt=prompt, @@ -57,7 +54,7 @@ video, audio = pipe( num_frames=num_frames, tiled=True, use_two_stage_pipeline=True, - input_images=[image], + input_images=[first_frame], input_images_indexes=[0], input_images_strength=1.0, ) @@ -66,5 +63,26 @@ write_video_audio_ltx2( audio=audio, output_path='ltx2.3_twostage_i2av_first.mp4', fps=24, - audio_sample_rate=pipe.audio_vocoder.output_sampling_rate, +) +pipe.clear_lora() + +# This example uses the first and last frames for demonstration. However, you can use any frames by setting input_images and input_indexes. Note that input_indexes must be within the range of num_frames. +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=42, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, + input_images=[first_frame, last_frame], + input_images_indexes=[0, num_frames-1], + input_images_strength=1.0, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2.3_twostage_i2av_first_last.mp4', + fps=24, )