support ltx2 two stage pipeline & vram

This commit is contained in:
mi804
2026-01-30 16:55:40 +08:00
parent b1a2782ad7
commit 4f23caa55f
7 changed files with 227 additions and 85 deletions

View File

@@ -60,10 +60,8 @@ class LTX2AudioVideoPipeline(BasePipeline):
LTX2AudioVideoUnit_NoiseInitializer(),
LTX2AudioVideoUnit_InputVideoEmbedder(),
]
self.post_units = [
LTX2AudioVideoPostUnit_UnPatchifier(),
]
self.model_fn = model_fn_ltx2
# self.lora_loader = LTX2LoRALoader
@staticmethod
def from_pretrained(
@@ -71,6 +69,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
device: Union[str, torch.device] = get_device_type(),
model_configs: list[ModelConfig] = [],
tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
stage2_lora_config: Optional[ModelConfig] = None,
vram_limit: float = None,
):
# Initialize pipeline
@@ -90,14 +89,67 @@ class LTX2AudioVideoPipeline(BasePipeline):
pipe.video_vae_decoder = model_pool.fetch_model("ltx2_video_vae_decoder")
pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder")
pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder")
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
# Optional
# Stage 2
if stage2_lora_config is not None:
stage2_lora_config.download_if_necessary()
pipe.stage2_lora_path = stage2_lora_config.path
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
# Optional, currently not used
# pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
return pipe
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
if cfg_scale != 1.0:
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
if isinstance(noise_pred_posi, tuple):
noise_pred = tuple(
n_nega + cfg_scale * (n_posi - n_nega)
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
)
else:
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
return noise_pred
def stage2_denoise(self, cfg_scale, inputs_shared, inputs_posi, inputs_nega, use_two_stage_pipeline=True, progress_bar_cmd=tqdm):
if use_two_stage_pipeline:
latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
self.load_models_to_device(self.in_iteration_models + ('upsampler',))
latent = self.upsampler(latent)
latent = self.video_vae_encoder.per_channel_statistics.normalize(latent)
latent = self.video_patchifier.patchify(latent)
self.scheduler.set_timesteps(special_case="stage2")
inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")})
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * inputs_shared["video_noise"] + (1 - self.scheduler.sigmas[0]) * latent
inputs_shared["audio_latents"] = self.audio_patchifier.patchify(inputs_shared["audio_latents"])
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"]
self.load_models_to_device(self.in_iteration_models)
self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id,
noise_pred=noise_pred_video, **inputs_shared)
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
noise_pred=noise_pred_audio, **inputs_shared)
inputs_shared["video_latents"] = self.video_patchifier.unpatchify(inputs_shared["video_latents"], inputs_shared["video_latent_shape"])
inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"])
return inputs_shared
@torch.no_grad()
def __call__(
@@ -126,14 +178,15 @@ class LTX2AudioVideoPipeline(BasePipeline):
tile_overlap_in_pixels: Optional[int] = 128,
tile_size_in_frames: Optional[int] = 128,
tile_overlap_in_frames: Optional[int] = 24,
# Two-Stage Pipeline
use_two_stage: Optional[bool] = True,
# Special Pipelines
use_two_stage_pipeline: Optional[bool] = False,
use_distilled_pipeline: Optional[bool] = False,
# progress_bar
progress_bar_cmd=tqdm,
):
# Scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
# self.load_lora(self.dit, self.stage2_lora_path)
# Inputs
inputs_posi = {
"prompt": prompt,
@@ -148,14 +201,12 @@ class LTX2AudioVideoPipeline(BasePipeline):
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
"use_two_stage": True
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline,
}
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# inputs_posi.update(torch.load("/mnt/nas1/zhanghong/project26/extern_codes/LTX-2/text_encodings.pt"))
# inputs_nega.update(torch.load("/mnt/nas1/zhanghong/project26/extern_codes/LTX-2/negative_text_encodings.pt"))
# Denoise
# Denoise Stage 1
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
@@ -164,12 +215,16 @@ class LTX2AudioVideoPipeline(BasePipeline):
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, **inputs_shared)
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id,
noise_pred=noise_pred_video, **inputs_shared)
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
noise_pred=noise_pred_audio, **inputs_shared)
inputs_shared["video_latents"] = self.video_patchifier.unpatchify(inputs_shared["video_latents"], inputs_shared["video_latent_shape"])
inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"])
# Denoise Stage 2
inputs_shared = self.stage2_denoise(cfg_scale, inputs_shared, inputs_posi, inputs_nega, use_two_stage_pipeline, progress_bar_cmd)
# post-denoising, pre-decoding processing logic
for unit in self.post_units:
inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# Decode
self.load_models_to_device(['video_vae_decoder'])
video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels,
@@ -181,39 +236,21 @@ class LTX2AudioVideoPipeline(BasePipeline):
return video, decoded_audio
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
if cfg_scale != 1.0:
if inputs_shared.get("positive_only_lora", None) is not None:
self.clear_lora(verbose=0)
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
if isinstance(noise_pred_posi, tuple):
noise_pred = tuple(
n_nega + cfg_scale * (n_posi - n_nega)
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
)
else:
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
return noise_pred
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
def __init__(self):
super().__init__(take_over=True)
def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
pass
if inputs_shared.get("use_two_stage_pipeline", False):
if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None):
raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.")
if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None):
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
return inputs_shared, inputs_posi, inputs_nega
class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit):
"""
# TODO: Adjust with two stage pipeline
For two-stage pipelines, the resolution must be divisible by 64.
For one-stage pipelines, the resolution must be divisible by 32.
"""
@@ -223,8 +260,14 @@ class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit):
output_params=("height", "width", "num_frames"),
)
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames):
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False):
if use_two_stage_pipeline:
self.width_division_factor = 64
self.height_division_factor = 64
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
if use_two_stage_pipeline:
self.width_division_factor = 32
self.height_division_factor = 32
return {"height": height, "width": width, "num_frames": num_frames}
@@ -327,10 +370,12 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
return pipe.text_encoder_post_modules.feature_extractor_linear(
normed_concated_encoded_text_features.to(encoded_text_features_dtype))
def _preprocess_text(self,
pipe,
text: str,
padding_side: str = "left") -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
def _preprocess_text(
self,
pipe,
text: str,
padding_side: str = "left",
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""
Encode a given string into feature tensors suitable for downstream tasks.
Args:
@@ -339,8 +384,8 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask.
"""
token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"]
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.text_encoder.device)
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.text_encoder.device)
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.device)
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.device)
outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
projected = self._run_feature_extractor(pipe,
hidden_states=outputs.hidden_states,
@@ -362,11 +407,11 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("height", "width", "num_frames", "seed", "rand_device",),
input_params=("height", "width", "num_frames", "seed", "rand_device", "use_two_stage_pipeline"),
output_params=("video_noise", "audio_noise",),
)
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels)
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
@@ -390,6 +435,15 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
"audio_latent_shape": audio_latent_shape
}
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0, use_two_stage_pipeline=False):
if use_two_stage_pipeline:
stage1_dict = self.process_stage(pipe, height // 2, width // 2, num_frames, seed, rand_device, frame_rate)
stage2_dict = self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
initial_dict = stage1_dict
initial_dict.update({"stage2_" + k: v for k, v in stage2_dict.items()})
return initial_dict
else:
return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self):
@@ -407,20 +461,6 @@ class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
raise NotImplementedError("Video-to-video not implemented yet.")
class LTX2AudioVideoPostUnit_UnPatchifier(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("video_latent_shape", "audio_latent_shape", "video_latents", "audio_latents"),
output_params=("video_latents", "audio_latents"),
)
def process(self, pipe: LTX2AudioVideoPipeline, video_latent_shape, audio_latent_shape, video_latents, audio_latents):
video_latents = pipe.video_patchifier.unpatchify(video_latents, output_shape=video_latent_shape)
audio_latents = pipe.audio_patchifier.unpatchify(audio_latents, output_shape=audio_latent_shape)
return {"video_latents": video_latents, "audio_latents": audio_latents}
def model_fn_ltx2(
dit: LTXModel,
video_latents=None,