From d9228074bdd48278ffeef1bd4db376a595c3e288 Mon Sep 17 00:00:00 2001 From: Hong Zhang <41229682+mi804@users.noreply.github.com> Date: Tue, 10 Mar 2026 13:55:40 +0800 Subject: [PATCH] refactor ltx2 stage2 pipeline (#1341) * refactor ltx2 pipeline * fix bug --- diffsynth/pipelines/ltx2_audio_video.py | 269 +++++++++--------- .../model_inference/LTX-2.3-I2AV-OneStage.py | 1 - .../model_inference/LTX-2.3-I2AV-TwoStage.py | 1 - .../LTX-2.3-I2AV-OneStage.py | 1 - .../LTX-2.3-I2AV-TwoStage.py | 1 - 5 files changed, 139 insertions(+), 134 deletions(-) diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index 7da54ec..2b15369 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -12,7 +12,7 @@ from transformers import AutoImageProcessor, Gemma3Processor from ..core.device.npu_compatible_device import get_device_type from ..diffusion import FlowMatchScheduler -from ..core import ModelConfig, gradient_checkpoint_forward +from ..core import ModelConfig from ..diffusion.base_pipeline import BasePipeline, PipelineUnit from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer @@ -63,6 +63,13 @@ class LTX2AudioVideoPipeline(BasePipeline): LTX2AudioVideoUnit_InputImagesEmbedder(), LTX2AudioVideoUnit_InContextVideoEmbedder(), ] + self.stage2_units = [ + LTX2AudioVideoUnit_SwitchStage2(), + LTX2AudioVideoUnit_NoiseInitializer(), + LTX2AudioVideoUnit_LatentsUpsampler(), + LTX2AudioVideoUnit_SetScheduleStage2(), + LTX2AudioVideoUnit_InputImagesEmbedder(), + ] self.model_fn = model_fn_ltx2 @staticmethod @@ -72,6 +79,7 @@ class LTX2AudioVideoPipeline(BasePipeline): model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), stage2_lora_config: Optional[ModelConfig] = None, + stage2_lora_strength: float = 0.8, vram_limit: float = None, ): # Initialize pipeline @@ -98,53 +106,31 @@ class LTX2AudioVideoPipeline(BasePipeline): if stage2_lora_config is not None: stage2_lora_config.download_if_necessary() pipe.stage2_lora_path = stage2_lora_config.path - # Optional, currently not used + pipe.stage2_lora_strength = stage2_lora_strength # VRAM Management pipe.vram_management_enabled = pipe.check_vram_management_state() return pipe - def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm): - if inputs_shared["use_two_stage_pipeline"]: - if inputs_shared.get("clear_lora_before_state_two", False): - self.clear_lora() - latents = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"]) - self.load_models_to_device('upsampler',) - latents = self.upsampler(latents) - latents = self.video_vae_encoder.per_channel_statistics.normalize(latents) - self.scheduler.set_timesteps(special_case="stage2") - inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")}) - denoise_mask_video = 1.0 - # input image - if inputs_shared.get("input_images", None) is not None: - initial_latents, denoise_mask_video = self.apply_input_images_to_latents(latents, initial_latents=latents, **inputs_shared.get("stage2_input_latents_apply_kwargs", {})) - inputs_shared.update({"input_latents_video": initial_latents, "denoise_mask_video": denoise_mask_video}) - # remove in-context video control in stage 2 - inputs_shared.pop("in_context_video_latents", None) - inputs_shared.pop("in_context_video_positions", None) - # initialize latents for stage 2 - inputs_shared["video_latents"] = self.scheduler.sigmas[0] * denoise_mask_video * inputs_shared[ - "video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latents - inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + ( - 1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"] + def denoise_stage(self, inputs_shared, inputs_posi, inputs_nega, units, cfg_scale=1.0, progress_bar_cmd=tqdm, skip_stage=False): + if skip_stage: + return inputs_shared, inputs_posi, inputs_nega + for unit in units: + inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) + self.load_models_to_device(self.in_iteration_models) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, + inpaint_mask=inputs_shared.get("denoise_mask_video", None), input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared) + inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared) + return inputs_shared, inputs_posi, inputs_nega - self.load_models_to_device(self.in_iteration_models) - if not inputs_shared["use_distilled_pipeline"]: - self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8) - models = {name: getattr(self, name) for name in self.in_iteration_models} - for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): - timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) - noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn( - self.model_fn, 1.0, inputs_shared, inputs_posi, inputs_nega, - **models, timestep=timestep, progress_id=progress_id - ) - inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, - noise_pred=noise_pred_video, inpaint_mask=inputs_shared.get("denoise_mask_video", None), - input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared) - inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, - noise_pred=noise_pred_audio, **inputs_shared) - return inputs_shared @torch.no_grad() def __call__( @@ -171,7 +157,7 @@ class LTX2AudioVideoPipeline(BasePipeline): # Classifier-free guidance cfg_scale: Optional[float] = 3.0, # Scheduler - num_inference_steps: Optional[int] = 40, + num_inference_steps: Optional[int] = 30, # VAE tiling tiled: Optional[bool] = True, tile_size_in_pixels: Optional[int] = 512, @@ -180,14 +166,14 @@ class LTX2AudioVideoPipeline(BasePipeline): tile_overlap_in_frames: Optional[int] = 24, # Special Pipelines use_two_stage_pipeline: Optional[bool] = False, + stage2_spatial_upsample_factor: Optional[int] = 2, clear_lora_before_state_two: Optional[bool] = False, use_distilled_pipeline: Optional[bool] = False, # progress_bar progress_bar_cmd=tqdm, ): # Scheduler - self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, - special_case="ditilled_stage1" if use_distilled_pipeline else None) + self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, special_case="ditilled_stage1" if use_distilled_pipeline else None) # Inputs inputs_posi = { "prompt": prompt, @@ -203,50 +189,22 @@ class LTX2AudioVideoPipeline(BasePipeline): "cfg_scale": cfg_scale, "tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels, "tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames, - "use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, "clear_lora_before_state_two": clear_lora_before_state_two, + "use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, "clear_lora_before_state_two": clear_lora_before_state_two, "stage2_spatial_upsample_factor": stage2_spatial_upsample_factor, "video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier, } - for unit in self.units: - inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) - - # Denoise Stage 1 - self.load_models_to_device(self.in_iteration_models) - models = {name: getattr(self, name) for name in self.in_iteration_models} - for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): - timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) - noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn( - self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, - **models, timestep=timestep, progress_id=progress_id - ) - inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, - inpaint_mask=inputs_shared.get("denoise_mask_video", None), input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared) - inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, - noise_pred=noise_pred_audio, **inputs_shared) - - # Denoise Stage 2 - inputs_shared = self.stage2_denoise(inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd) - + # Stage 1 + inputs_shared, inputs_posi, inputs_nega = self.denoise_stage(inputs_shared, inputs_posi, inputs_nega, self.units, cfg_scale, progress_bar_cmd) + # Stage 2 + inputs_shared, inputs_posi, inputs_nega = self.denoise_stage(inputs_shared, inputs_posi, inputs_nega, self.stage2_units, 1.0, progress_bar_cmd, not inputs_shared["use_two_stage_pipeline"]) # Decode self.load_models_to_device(['video_vae_decoder']) - video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels, - tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames) + video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels, tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames) video = self.vae_output_to_video(video) self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder']) decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"]) decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float() return video, decoded_audio - def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength=1.0, initial_latents=None, denoise_mask_video=None): - b, _, f, h, w = latents.shape - denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device) if denoise_mask_video is None else denoise_mask_video - initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents - for idx, input_latent in zip(input_indexes, input_latents): - idx = min(max(1 + (idx-1) // 8, 0), f - 1) - input_latent = input_latent.to(dtype=latents.dtype, device=latents.device) - initial_latents[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent - denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength - return initial_latents, denoise_mask - class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit): def __init__(self): @@ -275,22 +233,23 @@ class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit): """ For two-stage pipelines, the resolution must be divisible by 64. For one-stage pipelines, the resolution must be divisible by 32. + This unit set height and width to stage 1 resolution, and stage_2_width and stage_2_height. """ def __init__(self): super().__init__( - input_params=("height", "width", "num_frames"), - output_params=("height", "width", "num_frames"), + input_params=("height", "width", "num_frames", "use_two_stage_pipeline", "stage2_spatial_upsample_factor"), + output_params=("height", "width", "num_frames", "stage_2_height", "stage_2_width"), ) - def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False): + def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False, stage2_spatial_upsample_factor=2): if use_two_stage_pipeline: - self.width_division_factor = 64 - self.height_division_factor = 64 - height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) - if use_two_stage_pipeline: - self.width_division_factor = 32 - self.height_division_factor = 32 - return {"height": height, "width": width, "num_frames": num_frames} + height, width = height // stage2_spatial_upsample_factor, width // stage2_spatial_upsample_factor + height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) + stage_2_height, stage_2_width = int(height * stage2_spatial_upsample_factor), int(width * stage2_spatial_upsample_factor) + else: + stage_2_height, stage_2_width = None, None + height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) + return {"height": height, "width": width, "num_frames": num_frames, "stage_2_height": stage_2_height, "stage_2_width": stage_2_width} class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit): @@ -328,7 +287,7 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit): class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit): def __init__(self): super().__init__( - input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate", "use_two_stage_pipeline"), + input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate"), output_params=("video_noise", "audio_noise", "video_positions", "audio_positions", "video_latent_shape", "audio_latent_shape") ) @@ -354,15 +313,9 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit): "audio_latent_shape": audio_latent_shape } - def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0, use_two_stage_pipeline=False): - if use_two_stage_pipeline: - stage1_dict = self.process_stage(pipe, height // 2, width // 2, num_frames, seed, rand_device, frame_rate) - stage2_dict = self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate) - initial_dict = stage1_dict - initial_dict.update({"stage2_" + k: v for k, v in stage2_dict.items()}) - return initial_dict - else: - return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate) + def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0): + return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate) + class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit): def __init__(self): @@ -384,6 +337,7 @@ class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit): else: raise NotImplementedError("Video-to-video not implemented yet.") + class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit): def __init__(self): super().__init__( @@ -407,11 +361,12 @@ class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit): else: raise NotImplementedError("Audio-to-video not supported.") + class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit): def __init__(self): super().__init__( - input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"), - output_params=("denoise_mask_video", "input_latents_video", "stage2_input_latents_apply_kwargs"), + input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "initial_latents"), + output_params=("denoise_mask_video", "input_latents_video"), onload_model_names=("video_vae_encoder") ) @@ -423,53 +378,48 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit): latents = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device) return latents - def get_frame_conditions(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_indexes, input_images_strength, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, video_latents=None, skip_apply=False): - frame_conditions = {} - for img, index in zip(input_images, input_images_indexes): - latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels) - # first_frame - if index == 0: - if skip_apply: - frame_conditions = {"input_latents": [latents], "input_indexes": [0], "input_strength": input_images_strength} - else: - input_latents_video, denoise_mask_video = pipe.apply_input_images_to_latents(video_latents, [latents], [0], input_images_strength) - frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video}) - return frame_conditions + def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength=1.0, initial_latents=None, denoise_mask_video=None): + b, _, f, h, w = latents.shape + denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device) if denoise_mask_video is None else denoise_mask_video + initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents + for idx, input_latent in zip(input_indexes, input_latents): + idx = min(max(1 + (idx-1) // 8, 0), f - 1) + input_latent = input_latent.to(dtype=latents.dtype, device=latents.device) + initial_latents[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent + denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength + return initial_latents, denoise_mask - def process(self, pipe: LTX2AudioVideoPipeline, input_images, video_latents, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, input_images_indexes=[0], input_images_strength=1.0, use_two_stage_pipeline=False): + def process(self, pipe: LTX2AudioVideoPipeline, video_latents, input_images, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, input_images_indexes=[0], input_images_strength=1.0, initial_latents=None): if input_images is None or len(input_images) == 0: return {} else: if len(input_images_indexes) != len(set(input_images_indexes)): raise ValueError("Input images must have unique indexes.") pipe.load_models_to_device(self.onload_model_names) - output_dicts = {} - stage1_height = height // 2 if use_two_stage_pipeline else height - stage1_width = width // 2 if use_two_stage_pipeline else width - stage_1_frame_conditions = self.get_frame_conditions(pipe, input_images, input_images_indexes, input_images_strength, stage1_height, stage1_width, - tiled, tile_size_in_pixels, tile_overlap_in_pixels, video_latents) - output_dicts.update(stage_1_frame_conditions) - if use_two_stage_pipeline: - stage2_input_latents_apply_kwargs = self.get_frame_conditions(pipe, input_images, input_images_indexes, input_images_strength, height, width, - tiled, tile_size_in_pixels, tile_overlap_in_pixels, skip_apply=True) - output_dicts.update({"stage2_input_latents_apply_kwargs": stage2_input_latents_apply_kwargs}) - return output_dicts + frame_conditions = {} + for img, index in zip(input_images, input_images_indexes): + latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels) + # first_frame + if index == 0: + input_latents_video, denoise_mask_video = self.apply_input_images_to_latents(video_latents, [latents], [0], input_images_strength, initial_latents) + frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video}) + return frame_conditions class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit): def __init__(self): super().__init__( - input_params=("in_context_videos", "height", "width", "num_frames", "frame_rate", "in_context_downsample_factor", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"), + input_params=("in_context_videos", "height", "width", "num_frames", "frame_rate", "in_context_downsample_factor", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels"), output_params=("in_context_video_latents", "in_context_video_positions"), onload_model_names=("video_vae_encoder") ) - def check_in_context_video(self, pipe, in_context_video, height, width, num_frames, in_context_downsample_factor, use_two_stage_pipeline=True): + def check_in_context_video(self, pipe, in_context_video, height, width, num_frames, in_context_downsample_factor): if in_context_video is None or len(in_context_video) == 0: raise ValueError("In-context video is None or empty.") in_context_video = in_context_video[:num_frames] - expected_height = height // in_context_downsample_factor // 2 if use_two_stage_pipeline else height // in_context_downsample_factor - expected_width = width // in_context_downsample_factor // 2 if use_two_stage_pipeline else width // in_context_downsample_factor + expected_height = height // in_context_downsample_factor + expected_width = width // in_context_downsample_factor current_h, current_w, current_f = in_context_video[0].size[1], in_context_video[0].size[0], len(in_context_video) h, w, f = pipe.check_resize_height_width(expected_height, expected_width, current_f, verbose=0) if current_h != h or current_w != w: @@ -479,14 +429,14 @@ class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit): in_context_video = in_context_video + [Image.new("RGB", (w, h), (0, 0, 0))] * (f - current_f) return in_context_video - def process(self, pipe: LTX2AudioVideoPipeline, in_context_videos, height, width, num_frames, frame_rate, in_context_downsample_factor, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=True): + def process(self, pipe: LTX2AudioVideoPipeline, in_context_videos, height, width, num_frames, frame_rate, in_context_downsample_factor, tiled, tile_size_in_pixels, tile_overlap_in_pixels): if in_context_videos is None or len(in_context_videos) == 0: return {} else: pipe.load_models_to_device(self.onload_model_names) latents, positions = [], [] for in_context_video in in_context_videos: - in_context_video = self.check_in_context_video(pipe, in_context_video, height, width, num_frames, in_context_downsample_factor, use_two_stage_pipeline) + in_context_video = self.check_in_context_video(pipe, in_context_video, height, width, num_frames, in_context_downsample_factor) in_context_video = pipe.preprocess_video(in_context_video) in_context_latents = pipe.video_vae_encoder.encode(in_context_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device) @@ -504,6 +454,62 @@ class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit): return {"in_context_video_latents": latents, "in_context_video_positions": positions} +class LTX2AudioVideoUnit_SwitchStage2(PipelineUnit): + """ + 1. switch height and width to stage 2 resolution + 2. clear in_context_video_latents and in_context_video_positions + 3. switch stage 2 lora model + """ + def __init__(self): + super().__init__( + input_params=("stage_2_height", "stage_2_width", "clear_lora_before_state_two", "use_distilled_pipeline"), + output_params=("height", "width", "in_context_video_latents", "in_context_video_positions"), + ) + + def process(self, pipe: LTX2AudioVideoPipeline, stage_2_height, stage_2_width, clear_lora_before_state_two, use_distilled_pipeline): + stage2_params = {} + stage2_params.update({"height": stage_2_height, "width": stage_2_width}) + stage2_params.update({"in_context_video_latents": None, "in_context_video_positions": None}) + if clear_lora_before_state_two: + pipe.clear_lora() + if not use_distilled_pipeline: + pipe.load_lora(pipe.dit, pipe.stage2_lora_path, alpha=pipe.stage2_lora_strength) + return stage2_params + + +class LTX2AudioVideoUnit_SetScheduleStage2(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("video_latents", "video_noise", "audio_latents", "audio_noise"), + output_params=("video_latents", "audio_latents"), + ) + + def process(self, pipe: LTX2AudioVideoPipeline, video_latents, video_noise, audio_latents, audio_noise): + pipe.scheduler.set_timesteps(special_case="stage2") + video_latents = pipe.scheduler.add_noise(video_latents, video_noise, pipe.scheduler.timesteps[0]) + audio_latents = pipe.scheduler.add_noise(audio_latents, audio_noise, pipe.scheduler.timesteps[0]) + return {"video_latents": video_latents, "audio_latents": audio_latents} + + +class LTX2AudioVideoUnit_LatentsUpsampler(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("video_latents",), + output_params=("video_latents", "initial_latents"), + onload_model_names=("upsampler",), + ) + + def process(self, pipe: LTX2AudioVideoPipeline, video_latents): + if video_latents is None or pipe.upsampler is None: + raise ValueError("No upsampler or no video latents before stage 2.") + else: + pipe.load_models_to_device(self.onload_model_names) + video_latents = pipe.video_vae_encoder.per_channel_statistics.un_normalize(video_latents) + video_latents = pipe.upsampler(video_latents) + video_latents = pipe.video_vae_encoder.per_channel_statistics.normalize(video_latents) + return {"video_latents": video_latents, "initial_latents": video_latents} + + def model_fn_ltx2( dit: LTXModel, video_latents=None, @@ -515,10 +521,13 @@ def model_fn_ltx2( audio_positions=None, audio_patchifier=None, timestep=None, + # First Frame Conditioning input_latents_video=None, denoise_mask_video=None, + # In-Context Conditioning in_context_video_latents=None, in_context_video_positions=None, + # Gradient Checkpointing use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, **kwargs, diff --git a/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py b/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py index d3cdd99..7e76f0e 100644 --- a/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py +++ b/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py @@ -45,7 +45,6 @@ video, audio = pipe( input_images=[image], input_images_indexes=[0], input_images_strength=1.0, - num_inference_steps=40, ) write_video_audio_ltx2( video=video, diff --git a/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py b/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py index 380a281..264211b 100644 --- a/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py +++ b/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py @@ -57,7 +57,6 @@ video, audio = pipe( num_frames=num_frames, tiled=True, use_two_stage_pipeline=True, - num_inference_steps=40, input_images=[image], input_images_indexes=[0], input_images_strength=1.0, diff --git a/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py b/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py index 8c18563..61cd20e 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py @@ -46,7 +46,6 @@ video, audio = pipe( input_images=[image], input_images_indexes=[0], input_images_strength=1.0, - num_inference_steps=40, ) write_video_audio_ltx2( video=video, diff --git a/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py b/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py index ad2dd7e..d23bd61 100644 --- a/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py +++ b/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py @@ -58,7 +58,6 @@ video, audio = pipe( num_frames=num_frames, tiled=True, use_two_stage_pipeline=True, - num_inference_steps=40, input_images=[image], input_images_indexes=[0], input_images_strength=1.0,