refactor ltx2 stage2 pipeline (#1341)

* refactor ltx2 pipeline

* fix bug
This commit is contained in:
Hong Zhang
2026-03-10 13:55:40 +08:00
committed by GitHub
parent b272253956
commit d9228074bd
5 changed files with 139 additions and 134 deletions

View File

@@ -12,7 +12,7 @@ from transformers import AutoImageProcessor, Gemma3Processor
from ..core.device.npu_compatible_device import get_device_type from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler from ..diffusion import FlowMatchScheduler
from ..core import ModelConfig, gradient_checkpoint_forward from ..core import ModelConfig
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer
@@ -63,6 +63,13 @@ class LTX2AudioVideoPipeline(BasePipeline):
LTX2AudioVideoUnit_InputImagesEmbedder(), LTX2AudioVideoUnit_InputImagesEmbedder(),
LTX2AudioVideoUnit_InContextVideoEmbedder(), LTX2AudioVideoUnit_InContextVideoEmbedder(),
] ]
self.stage2_units = [
LTX2AudioVideoUnit_SwitchStage2(),
LTX2AudioVideoUnit_NoiseInitializer(),
LTX2AudioVideoUnit_LatentsUpsampler(),
LTX2AudioVideoUnit_SetScheduleStage2(),
LTX2AudioVideoUnit_InputImagesEmbedder(),
]
self.model_fn = model_fn_ltx2 self.model_fn = model_fn_ltx2
@staticmethod @staticmethod
@@ -72,6 +79,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
model_configs: list[ModelConfig] = [], model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
stage2_lora_config: Optional[ModelConfig] = None, stage2_lora_config: Optional[ModelConfig] = None,
stage2_lora_strength: float = 0.8,
vram_limit: float = None, vram_limit: float = None,
): ):
# Initialize pipeline # Initialize pipeline
@@ -98,53 +106,31 @@ class LTX2AudioVideoPipeline(BasePipeline):
if stage2_lora_config is not None: if stage2_lora_config is not None:
stage2_lora_config.download_if_necessary() stage2_lora_config.download_if_necessary()
pipe.stage2_lora_path = stage2_lora_config.path pipe.stage2_lora_path = stage2_lora_config.path
# Optional, currently not used pipe.stage2_lora_strength = stage2_lora_strength
# VRAM Management # VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state() pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe return pipe
def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm):
if inputs_shared["use_two_stage_pipeline"]:
if inputs_shared.get("clear_lora_before_state_two", False):
self.clear_lora()
latents = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
self.load_models_to_device('upsampler',)
latents = self.upsampler(latents)
latents = self.video_vae_encoder.per_channel_statistics.normalize(latents)
self.scheduler.set_timesteps(special_case="stage2")
inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")})
denoise_mask_video = 1.0
# input image
if inputs_shared.get("input_images", None) is not None:
initial_latents, denoise_mask_video = self.apply_input_images_to_latents(latents, initial_latents=latents, **inputs_shared.get("stage2_input_latents_apply_kwargs", {}))
inputs_shared.update({"input_latents_video": initial_latents, "denoise_mask_video": denoise_mask_video})
# remove in-context video control in stage 2
inputs_shared.pop("in_context_video_latents", None)
inputs_shared.pop("in_context_video_positions", None)
# initialize latents for stage 2 def denoise_stage(self, inputs_shared, inputs_posi, inputs_nega, units, cfg_scale=1.0, progress_bar_cmd=tqdm, skip_stage=False):
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * denoise_mask_video * inputs_shared[ if skip_stage:
"video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latents return inputs_shared, inputs_posi, inputs_nega
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + ( for unit in units:
1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"] inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video,
inpaint_mask=inputs_shared.get("denoise_mask_video", None), input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)
return inputs_shared, inputs_posi, inputs_nega
self.load_models_to_device(self.in_iteration_models)
if not inputs_shared["use_distilled_pipeline"]:
self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
self.model_fn, 1.0, inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id,
noise_pred=noise_pred_video, inpaint_mask=inputs_shared.get("denoise_mask_video", None),
input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
noise_pred=noise_pred_audio, **inputs_shared)
return inputs_shared
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
@@ -171,7 +157,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
# Classifier-free guidance # Classifier-free guidance
cfg_scale: Optional[float] = 3.0, cfg_scale: Optional[float] = 3.0,
# Scheduler # Scheduler
num_inference_steps: Optional[int] = 40, num_inference_steps: Optional[int] = 30,
# VAE tiling # VAE tiling
tiled: Optional[bool] = True, tiled: Optional[bool] = True,
tile_size_in_pixels: Optional[int] = 512, tile_size_in_pixels: Optional[int] = 512,
@@ -180,14 +166,14 @@ class LTX2AudioVideoPipeline(BasePipeline):
tile_overlap_in_frames: Optional[int] = 24, tile_overlap_in_frames: Optional[int] = 24,
# Special Pipelines # Special Pipelines
use_two_stage_pipeline: Optional[bool] = False, use_two_stage_pipeline: Optional[bool] = False,
stage2_spatial_upsample_factor: Optional[int] = 2,
clear_lora_before_state_two: Optional[bool] = False, clear_lora_before_state_two: Optional[bool] = False,
use_distilled_pipeline: Optional[bool] = False, use_distilled_pipeline: Optional[bool] = False,
# progress_bar # progress_bar
progress_bar_cmd=tqdm, progress_bar_cmd=tqdm,
): ):
# Scheduler # Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, special_case="ditilled_stage1" if use_distilled_pipeline else None)
special_case="ditilled_stage1" if use_distilled_pipeline else None)
# Inputs # Inputs
inputs_posi = { inputs_posi = {
"prompt": prompt, "prompt": prompt,
@@ -203,50 +189,22 @@ class LTX2AudioVideoPipeline(BasePipeline):
"cfg_scale": cfg_scale, "cfg_scale": cfg_scale,
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels, "tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames, "tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, "clear_lora_before_state_two": clear_lora_before_state_two, "use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, "clear_lora_before_state_two": clear_lora_before_state_two, "stage2_spatial_upsample_factor": stage2_spatial_upsample_factor,
"video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier, "video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier,
} }
for unit in self.units: # Stage 1
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) inputs_shared, inputs_posi, inputs_nega = self.denoise_stage(inputs_shared, inputs_posi, inputs_nega, self.units, cfg_scale, progress_bar_cmd)
# Stage 2
# Denoise Stage 1 inputs_shared, inputs_posi, inputs_nega = self.denoise_stage(inputs_shared, inputs_posi, inputs_nega, self.stage2_units, 1.0, progress_bar_cmd, not inputs_shared["use_two_stage_pipeline"])
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video,
inpaint_mask=inputs_shared.get("denoise_mask_video", None), input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
noise_pred=noise_pred_audio, **inputs_shared)
# Denoise Stage 2
inputs_shared = self.stage2_denoise(inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd)
# Decode # Decode
self.load_models_to_device(['video_vae_decoder']) self.load_models_to_device(['video_vae_decoder'])
video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels, video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels, tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames)
tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames)
video = self.vae_output_to_video(video) video = self.vae_output_to_video(video)
self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder']) self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder'])
decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"]) decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"])
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float() decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
return video, decoded_audio return video, decoded_audio
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength=1.0, initial_latents=None, denoise_mask_video=None):
b, _, f, h, w = latents.shape
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device) if denoise_mask_video is None else denoise_mask_video
initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents
for idx, input_latent in zip(input_indexes, input_latents):
idx = min(max(1 + (idx-1) // 8, 0), f - 1)
input_latent = input_latent.to(dtype=latents.dtype, device=latents.device)
initial_latents[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
return initial_latents, denoise_mask
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit): class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
def __init__(self): def __init__(self):
@@ -275,22 +233,23 @@ class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit):
""" """
For two-stage pipelines, the resolution must be divisible by 64. For two-stage pipelines, the resolution must be divisible by 64.
For one-stage pipelines, the resolution must be divisible by 32. For one-stage pipelines, the resolution must be divisible by 32.
This unit set height and width to stage 1 resolution, and stage_2_width and stage_2_height.
""" """
def __init__(self): def __init__(self):
super().__init__( super().__init__(
input_params=("height", "width", "num_frames"), input_params=("height", "width", "num_frames", "use_two_stage_pipeline", "stage2_spatial_upsample_factor"),
output_params=("height", "width", "num_frames"), output_params=("height", "width", "num_frames", "stage_2_height", "stage_2_width"),
) )
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False): def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False, stage2_spatial_upsample_factor=2):
if use_two_stage_pipeline: if use_two_stage_pipeline:
self.width_division_factor = 64 height, width = height // stage2_spatial_upsample_factor, width // stage2_spatial_upsample_factor
self.height_division_factor = 64 height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) stage_2_height, stage_2_width = int(height * stage2_spatial_upsample_factor), int(width * stage2_spatial_upsample_factor)
if use_two_stage_pipeline: else:
self.width_division_factor = 32 stage_2_height, stage_2_width = None, None
self.height_division_factor = 32 height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
return {"height": height, "width": width, "num_frames": num_frames} return {"height": height, "width": width, "num_frames": num_frames, "stage_2_height": stage_2_height, "stage_2_width": stage_2_width}
class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit): class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
@@ -328,7 +287,7 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit): class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate", "use_two_stage_pipeline"), input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate"),
output_params=("video_noise", "audio_noise", "video_positions", "audio_positions", "video_latent_shape", "audio_latent_shape") output_params=("video_noise", "audio_noise", "video_positions", "audio_positions", "video_latent_shape", "audio_latent_shape")
) )
@@ -354,15 +313,9 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
"audio_latent_shape": audio_latent_shape "audio_latent_shape": audio_latent_shape
} }
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0, use_two_stage_pipeline=False): def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
if use_two_stage_pipeline: return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
stage1_dict = self.process_stage(pipe, height // 2, width // 2, num_frames, seed, rand_device, frame_rate)
stage2_dict = self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
initial_dict = stage1_dict
initial_dict.update({"stage2_" + k: v for k, v in stage2_dict.items()})
return initial_dict
else:
return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit): class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self): def __init__(self):
@@ -384,6 +337,7 @@ class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
else: else:
raise NotImplementedError("Video-to-video not implemented yet.") raise NotImplementedError("Video-to-video not implemented yet.")
class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit): class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
@@ -407,11 +361,12 @@ class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
else: else:
raise NotImplementedError("Audio-to-video not supported.") raise NotImplementedError("Audio-to-video not supported.")
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit): class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"), input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "initial_latents"),
output_params=("denoise_mask_video", "input_latents_video", "stage2_input_latents_apply_kwargs"), output_params=("denoise_mask_video", "input_latents_video"),
onload_model_names=("video_vae_encoder") onload_model_names=("video_vae_encoder")
) )
@@ -423,53 +378,48 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
latents = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device) latents = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)
return latents return latents
def get_frame_conditions(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_indexes, input_images_strength, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, video_latents=None, skip_apply=False): def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength=1.0, initial_latents=None, denoise_mask_video=None):
frame_conditions = {} b, _, f, h, w = latents.shape
for img, index in zip(input_images, input_images_indexes): denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device) if denoise_mask_video is None else denoise_mask_video
latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels) initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents
# first_frame for idx, input_latent in zip(input_indexes, input_latents):
if index == 0: idx = min(max(1 + (idx-1) // 8, 0), f - 1)
if skip_apply: input_latent = input_latent.to(dtype=latents.dtype, device=latents.device)
frame_conditions = {"input_latents": [latents], "input_indexes": [0], "input_strength": input_images_strength} initial_latents[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent
else: denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
input_latents_video, denoise_mask_video = pipe.apply_input_images_to_latents(video_latents, [latents], [0], input_images_strength) return initial_latents, denoise_mask
frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video})
return frame_conditions
def process(self, pipe: LTX2AudioVideoPipeline, input_images, video_latents, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, input_images_indexes=[0], input_images_strength=1.0, use_two_stage_pipeline=False): def process(self, pipe: LTX2AudioVideoPipeline, video_latents, input_images, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, input_images_indexes=[0], input_images_strength=1.0, initial_latents=None):
if input_images is None or len(input_images) == 0: if input_images is None or len(input_images) == 0:
return {} return {}
else: else:
if len(input_images_indexes) != len(set(input_images_indexes)): if len(input_images_indexes) != len(set(input_images_indexes)):
raise ValueError("Input images must have unique indexes.") raise ValueError("Input images must have unique indexes.")
pipe.load_models_to_device(self.onload_model_names) pipe.load_models_to_device(self.onload_model_names)
output_dicts = {} frame_conditions = {}
stage1_height = height // 2 if use_two_stage_pipeline else height for img, index in zip(input_images, input_images_indexes):
stage1_width = width // 2 if use_two_stage_pipeline else width latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels)
stage_1_frame_conditions = self.get_frame_conditions(pipe, input_images, input_images_indexes, input_images_strength, stage1_height, stage1_width, # first_frame
tiled, tile_size_in_pixels, tile_overlap_in_pixels, video_latents) if index == 0:
output_dicts.update(stage_1_frame_conditions) input_latents_video, denoise_mask_video = self.apply_input_images_to_latents(video_latents, [latents], [0], input_images_strength, initial_latents)
if use_two_stage_pipeline: frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video})
stage2_input_latents_apply_kwargs = self.get_frame_conditions(pipe, input_images, input_images_indexes, input_images_strength, height, width, return frame_conditions
tiled, tile_size_in_pixels, tile_overlap_in_pixels, skip_apply=True)
output_dicts.update({"stage2_input_latents_apply_kwargs": stage2_input_latents_apply_kwargs})
return output_dicts
class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit): class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit):
def __init__(self): def __init__(self):
super().__init__( super().__init__(
input_params=("in_context_videos", "height", "width", "num_frames", "frame_rate", "in_context_downsample_factor", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"), input_params=("in_context_videos", "height", "width", "num_frames", "frame_rate", "in_context_downsample_factor", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels"),
output_params=("in_context_video_latents", "in_context_video_positions"), output_params=("in_context_video_latents", "in_context_video_positions"),
onload_model_names=("video_vae_encoder") onload_model_names=("video_vae_encoder")
) )
def check_in_context_video(self, pipe, in_context_video, height, width, num_frames, in_context_downsample_factor, use_two_stage_pipeline=True): def check_in_context_video(self, pipe, in_context_video, height, width, num_frames, in_context_downsample_factor):
if in_context_video is None or len(in_context_video) == 0: if in_context_video is None or len(in_context_video) == 0:
raise ValueError("In-context video is None or empty.") raise ValueError("In-context video is None or empty.")
in_context_video = in_context_video[:num_frames] in_context_video = in_context_video[:num_frames]
expected_height = height // in_context_downsample_factor // 2 if use_two_stage_pipeline else height // in_context_downsample_factor expected_height = height // in_context_downsample_factor
expected_width = width // in_context_downsample_factor // 2 if use_two_stage_pipeline else width // in_context_downsample_factor expected_width = width // in_context_downsample_factor
current_h, current_w, current_f = in_context_video[0].size[1], in_context_video[0].size[0], len(in_context_video) current_h, current_w, current_f = in_context_video[0].size[1], in_context_video[0].size[0], len(in_context_video)
h, w, f = pipe.check_resize_height_width(expected_height, expected_width, current_f, verbose=0) h, w, f = pipe.check_resize_height_width(expected_height, expected_width, current_f, verbose=0)
if current_h != h or current_w != w: if current_h != h or current_w != w:
@@ -479,14 +429,14 @@ class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit):
in_context_video = in_context_video + [Image.new("RGB", (w, h), (0, 0, 0))] * (f - current_f) in_context_video = in_context_video + [Image.new("RGB", (w, h), (0, 0, 0))] * (f - current_f)
return in_context_video return in_context_video
def process(self, pipe: LTX2AudioVideoPipeline, in_context_videos, height, width, num_frames, frame_rate, in_context_downsample_factor, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=True): def process(self, pipe: LTX2AudioVideoPipeline, in_context_videos, height, width, num_frames, frame_rate, in_context_downsample_factor, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
if in_context_videos is None or len(in_context_videos) == 0: if in_context_videos is None or len(in_context_videos) == 0:
return {} return {}
else: else:
pipe.load_models_to_device(self.onload_model_names) pipe.load_models_to_device(self.onload_model_names)
latents, positions = [], [] latents, positions = [], []
for in_context_video in in_context_videos: for in_context_video in in_context_videos:
in_context_video = self.check_in_context_video(pipe, in_context_video, height, width, num_frames, in_context_downsample_factor, use_two_stage_pipeline) in_context_video = self.check_in_context_video(pipe, in_context_video, height, width, num_frames, in_context_downsample_factor)
in_context_video = pipe.preprocess_video(in_context_video) in_context_video = pipe.preprocess_video(in_context_video)
in_context_latents = pipe.video_vae_encoder.encode(in_context_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device) in_context_latents = pipe.video_vae_encoder.encode(in_context_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
@@ -504,6 +454,62 @@ class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit):
return {"in_context_video_latents": latents, "in_context_video_positions": positions} return {"in_context_video_latents": latents, "in_context_video_positions": positions}
class LTX2AudioVideoUnit_SwitchStage2(PipelineUnit):
"""
1. switch height and width to stage 2 resolution
2. clear in_context_video_latents and in_context_video_positions
3. switch stage 2 lora model
"""
def __init__(self):
super().__init__(
input_params=("stage_2_height", "stage_2_width", "clear_lora_before_state_two", "use_distilled_pipeline"),
output_params=("height", "width", "in_context_video_latents", "in_context_video_positions"),
)
def process(self, pipe: LTX2AudioVideoPipeline, stage_2_height, stage_2_width, clear_lora_before_state_two, use_distilled_pipeline):
stage2_params = {}
stage2_params.update({"height": stage_2_height, "width": stage_2_width})
stage2_params.update({"in_context_video_latents": None, "in_context_video_positions": None})
if clear_lora_before_state_two:
pipe.clear_lora()
if not use_distilled_pipeline:
pipe.load_lora(pipe.dit, pipe.stage2_lora_path, alpha=pipe.stage2_lora_strength)
return stage2_params
class LTX2AudioVideoUnit_SetScheduleStage2(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("video_latents", "video_noise", "audio_latents", "audio_noise"),
output_params=("video_latents", "audio_latents"),
)
def process(self, pipe: LTX2AudioVideoPipeline, video_latents, video_noise, audio_latents, audio_noise):
pipe.scheduler.set_timesteps(special_case="stage2")
video_latents = pipe.scheduler.add_noise(video_latents, video_noise, pipe.scheduler.timesteps[0])
audio_latents = pipe.scheduler.add_noise(audio_latents, audio_noise, pipe.scheduler.timesteps[0])
return {"video_latents": video_latents, "audio_latents": audio_latents}
class LTX2AudioVideoUnit_LatentsUpsampler(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("video_latents",),
output_params=("video_latents", "initial_latents"),
onload_model_names=("upsampler",),
)
def process(self, pipe: LTX2AudioVideoPipeline, video_latents):
if video_latents is None or pipe.upsampler is None:
raise ValueError("No upsampler or no video latents before stage 2.")
else:
pipe.load_models_to_device(self.onload_model_names)
video_latents = pipe.video_vae_encoder.per_channel_statistics.un_normalize(video_latents)
video_latents = pipe.upsampler(video_latents)
video_latents = pipe.video_vae_encoder.per_channel_statistics.normalize(video_latents)
return {"video_latents": video_latents, "initial_latents": video_latents}
def model_fn_ltx2( def model_fn_ltx2(
dit: LTXModel, dit: LTXModel,
video_latents=None, video_latents=None,
@@ -515,10 +521,13 @@ def model_fn_ltx2(
audio_positions=None, audio_positions=None,
audio_patchifier=None, audio_patchifier=None,
timestep=None, timestep=None,
# First Frame Conditioning
input_latents_video=None, input_latents_video=None,
denoise_mask_video=None, denoise_mask_video=None,
# In-Context Conditioning
in_context_video_latents=None, in_context_video_latents=None,
in_context_video_positions=None, in_context_video_positions=None,
# Gradient Checkpointing
use_gradient_checkpointing=False, use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False, use_gradient_checkpointing_offload=False,
**kwargs, **kwargs,

View File

@@ -45,7 +45,6 @@ video, audio = pipe(
input_images=[image], input_images=[image],
input_images_indexes=[0], input_images_indexes=[0],
input_images_strength=1.0, input_images_strength=1.0,
num_inference_steps=40,
) )
write_video_audio_ltx2( write_video_audio_ltx2(
video=video, video=video,

View File

@@ -57,7 +57,6 @@ video, audio = pipe(
num_frames=num_frames, num_frames=num_frames,
tiled=True, tiled=True,
use_two_stage_pipeline=True, use_two_stage_pipeline=True,
num_inference_steps=40,
input_images=[image], input_images=[image],
input_images_indexes=[0], input_images_indexes=[0],
input_images_strength=1.0, input_images_strength=1.0,

View File

@@ -46,7 +46,6 @@ video, audio = pipe(
input_images=[image], input_images=[image],
input_images_indexes=[0], input_images_indexes=[0],
input_images_strength=1.0, input_images_strength=1.0,
num_inference_steps=40,
) )
write_video_audio_ltx2( write_video_audio_ltx2(
video=video, video=video,

View File

@@ -58,7 +58,6 @@ video, audio = pipe(
num_frames=num_frames, num_frames=num_frames,
tiled=True, tiled=True,
use_two_stage_pipeline=True, use_two_stage_pipeline=True,
num_inference_steps=40,
input_images=[image], input_images=[image],
input_images_indexes=[0], input_images_indexes=[0],
input_images_strength=1.0, input_images_strength=1.0,