support ltx2 distilled pipeline

This commit is contained in:
mi804
2026-01-30 17:40:30 +08:00
parent 4f23caa55f
commit 9f07d65ebb
5 changed files with 82 additions and 27 deletions

View File

@@ -61,7 +61,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
LTX2AudioVideoUnit_InputVideoEmbedder(),
]
self.model_fn = model_fn_ltx2
# self.lora_loader = LTX2LoRALoader
@staticmethod
def from_pretrained(
@@ -89,12 +88,12 @@ class LTX2AudioVideoPipeline(BasePipeline):
pipe.video_vae_decoder = model_pool.fetch_model("ltx2_video_vae_decoder")
pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder")
pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder")
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
# Stage 2
if stage2_lora_config is not None:
stage2_lora_config.download_if_necessary()
pipe.stage2_lora_path = stage2_lora_config.path
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
# Optional, currently not used
# pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
@@ -122,8 +121,8 @@ class LTX2AudioVideoPipeline(BasePipeline):
noise_pred = noise_pred_posi
return noise_pred
def stage2_denoise(self, cfg_scale, inputs_shared, inputs_posi, inputs_nega, use_two_stage_pipeline=True, progress_bar_cmd=tqdm):
if use_two_stage_pipeline:
def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm):
if inputs_shared["use_two_stage_pipeline"]:
latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
self.load_models_to_device(self.in_iteration_models + ('upsampler',))
latent = self.upsampler(latent)
@@ -135,12 +134,13 @@ class LTX2AudioVideoPipeline(BasePipeline):
inputs_shared["audio_latents"] = self.audio_patchifier.patchify(inputs_shared["audio_latents"])
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"]
self.load_models_to_device(self.in_iteration_models)
self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8)
if not inputs_shared["use_distilled_pipeline"]:
self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
self.model_fn, 1.0, inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id,
@@ -185,7 +185,8 @@ class LTX2AudioVideoPipeline(BasePipeline):
progress_bar_cmd=tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength,
special_case="ditilled_stage1" if use_distilled_pipeline else None)
# self.load_lora(self.dit, self.stage2_lora_path)
# Inputs
inputs_posi = {
@@ -223,7 +224,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"])
# Denoise Stage 2
inputs_shared = self.stage2_denoise(cfg_scale, inputs_shared, inputs_posi, inputs_nega, use_two_stage_pipeline, progress_bar_cmd)
inputs_shared = self.stage2_denoise(inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd)
# Decode
self.load_models_to_device(['video_vae_decoder'])
@@ -241,11 +242,17 @@ class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
super().__init__(take_over=True)
def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
if inputs_shared.get("use_distilled_pipeline", False):
inputs_shared["use_two_stage_pipeline"] = True
inputs_shared["cfg_scale"] = 1.0
print(f"Distilled pipeline requested, setting use_two_stage_pipeline to True, disable CFG by setting cfg_scale to 1.0.")
if inputs_shared.get("use_two_stage_pipeline", False):
if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None):
raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.")
# distill pipeline also uses two-stage, but it does not needs lora
if not inputs_shared.get("use_distilled_pipeline", False):
if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None):
raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.")
if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None):
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
return inputs_shared, inputs_posi, inputs_nega