support ltx-2 t2v and i2v

This commit is contained in:
mi804
2026-02-02 19:53:07 +08:00
parent 1c8a0f8317
commit f4f991d409
20 changed files with 1084 additions and 25 deletions

View File

@@ -1442,6 +1442,10 @@ class LTXModel(torch.nn.Module):
return vx, ax
def forward(self, video_latents, video_positions, video_context, video_timesteps, audio_latents, audio_positions, audio_context, audio_timesteps):
cross_pe_max_pos = None
if self.model_type.is_video_enabled() and self.model_type.is_audio_enabled():
cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
self._init_preprocessors(cross_pe_max_pos)
video = Modality(video_latents, video_timesteps, video_positions, video_context)
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context)
vx, ax = self._forward(video=video, audio=audio, perturbations=None)

View File

@@ -1648,11 +1648,8 @@ class LTX2VideoEncoder(nn.Module):
tile_overlap_in_pixels: Optional[int] = 128,
**kwargs,
) -> torch.Tensor:
device = next(self.parameters()).device
vae_dtype = next(self.parameters()).dtype
if video.ndim == 4:
video = video.unsqueeze(0) # [C, F, H, W] -> [B, C, F, H, W]
video = video.to(device=device, dtype=vae_dtype)
# Choose encoding method based on tiling flag
if tiled:
latents = self.tiled_encode_video(

View File

@@ -8,9 +8,7 @@ import numpy as np
from PIL import Image
from tqdm import tqdm
from typing import Optional
from typing_extensions import Literal
from transformers import AutoImageProcessor, Gemma3Processor
import einops
from ..core.device.npu_compatible_device import get_device_type
from ..diffusion import FlowMatchScheduler
@@ -23,6 +21,7 @@ from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLat
from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier
from ..models.ltx2_upsampler import LTX2LatentUpsampler
from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS
from ..utils.data.media_io_ltx2 import ltx2_preprocess
class LTX2AudioVideoPipeline(BasePipeline):
@@ -59,6 +58,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
LTX2AudioVideoUnit_PromptEmbedder(),
LTX2AudioVideoUnit_NoiseInitializer(),
LTX2AudioVideoUnit_InputVideoEmbedder(),
LTX2AudioVideoUnit_InputImagesEmbedder(),
]
self.model_fn = model_fn_ltx2
@@ -124,13 +124,22 @@ class LTX2AudioVideoPipeline(BasePipeline):
def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm):
if inputs_shared["use_two_stage_pipeline"]:
latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
self.load_models_to_device(self.in_iteration_models + ('upsampler',))
self.load_models_to_device('upsampler',)
latent = self.upsampler(latent)
latent = self.video_vae_encoder.per_channel_statistics.normalize(latent)
self.scheduler.set_timesteps(special_case="stage2")
inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")})
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * inputs_shared["video_noise"] + (1 - self.scheduler.sigmas[0]) * latent
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"]
denoise_mask_video = 1.0
if inputs_shared.get("input_images", None) is not None:
latent, denoise_mask_video, initial_latents = self.apply_input_images_to_latents(
latent, inputs_shared.pop("input_latents"), inputs_shared["input_images_indexes"],
inputs_shared["input_images_strength"], latent.clone())
inputs_shared.update({"input_latents_video": initial_latents, "denoise_mask_video": denoise_mask_video})
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * denoise_mask_video * inputs_shared[
"video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latent
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (
1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"]
self.load_models_to_device(self.in_iteration_models)
if not inputs_shared["use_distilled_pipeline"]:
self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8)
@@ -142,7 +151,8 @@ class LTX2AudioVideoPipeline(BasePipeline):
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id,
noise_pred=noise_pred_video, **inputs_shared)
noise_pred=noise_pred_video, inpaint_mask=inputs_shared.get("denoise_mask_video", None),
input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
noise_pred=noise_pred_audio, **inputs_shared)
return inputs_shared
@@ -154,8 +164,10 @@ class LTX2AudioVideoPipeline(BasePipeline):
prompt: str,
negative_prompt: Optional[str] = "",
# Image-to-video
input_image: Optional[Image.Image] = None,
denoising_strength: float = 1.0,
input_images: Optional[list[Image.Image]] = None,
input_images_indexes: Optional[list[int]] = None,
input_images_strength: Optional[float] = 1.0,
# Randomness
seed: Optional[int] = None,
rand_device: Optional[str] = "cpu",
@@ -191,7 +203,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
"negative_prompt": negative_prompt,
}
inputs_shared = {
"input_image": input_image,
"input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength,
"seed": seed, "rand_device": rand_device,
"height": height, "width": width, "num_frames": num_frames,
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
@@ -212,8 +224,8 @@ class LTX2AudioVideoPipeline(BasePipeline):
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
**models, timestep=timestep, progress_id=progress_id
)
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id,
noise_pred=noise_pred_video, **inputs_shared)
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video,
inpaint_mask=inputs_shared.get("denoise_mask_video", None), input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
noise_pred=noise_pred_audio, **inputs_shared)
@@ -223,13 +235,25 @@ class LTX2AudioVideoPipeline(BasePipeline):
# Decode
self.load_models_to_device(['video_vae_decoder'])
video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels,
tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames)
tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames)
video = self.vae_output_to_video(video)
self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder'])
decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"])
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
return video, decoded_audio
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength, initial_latents=None, num_frames=121):
b, _, f, h, w = latents.shape
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device)
initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents
for idx, input_latent in zip(input_indexes, input_latents):
idx = min(max(1 + (idx-1) // 8, 0), f - 1)
input_latent = input_latent.to(dtype=latents.dtype, device=latents.device)
initial_latents[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
latents = latents * denoise_mask + initial_latents * (1.0 - denoise_mask)
return latents, denoise_mask, initial_latents
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
def __init__(self):
@@ -246,7 +270,7 @@ class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None):
raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.")
if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None):
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
return inputs_shared, inputs_posi, inputs_nega
@@ -459,6 +483,44 @@ class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
# TODO: implement video-to-video
raise NotImplementedError("Video-to-video not implemented yet.")
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "num_frames", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"),
output_params=("video_latents"),
onload_model_names=("video_vae_encoder")
)
def get_image_latent(self, pipe, input_image, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
image = ltx2_preprocess(np.array(input_image.resize((width, height))))
image = torch.Tensor(np.array(image, dtype=np.float32)).to(dtype=pipe.torch_dtype, device=pipe.device)
image = image / 127.5 - 1.0
image = repeat(image, f"H W C -> B C F H W", B=1, F=1)
latent = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)
return latent
def process(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_indexes, input_images_strength, video_latents, height, width, num_frames, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=False):
if input_images is None or len(input_images) == 0:
return {"video_latents": video_latents}
else:
pipe.load_models_to_device(self.onload_model_names)
output_dicts = {}
stage1_height = height // 2 if use_two_stage_pipeline else height
stage1_width = width // 2 if use_two_stage_pipeline else width
stage1_latents = [
self.get_image_latent(pipe, img, stage1_height, stage1_width, tiled, tile_size_in_pixels,
tile_overlap_in_pixels) for img in input_images
]
video_latents, denoise_mask_video, initial_latents = pipe.apply_input_images_to_latents(video_latents, stage1_latents, input_images_indexes, input_images_strength, num_frames=num_frames)
output_dicts.update({"video_latents": video_latents, "denoise_mask_video": denoise_mask_video, "input_latents_video": initial_latents})
if use_two_stage_pipeline:
stage2_latents = [
self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels,
tile_overlap_in_pixels) for img in input_images
]
output_dicts.update({"stage2_input_latents": stage2_latents})
return output_dicts
def model_fn_ltx2(
dit: LTXModel,
@@ -471,19 +533,23 @@ def model_fn_ltx2(
audio_positions=None,
audio_patchifier=None,
timestep=None,
denoise_mask_video=None,
use_gradient_checkpointing=False,
use_gradient_checkpointing_offload=False,
**kwargs,
):
timestep = timestep.float() / 1000.
# patchify
b, c_v, f, h, w = video_latents.shape
_, c_a, _, mel_bins = audio_latents.shape
video_latents = video_patchifier.patchify(video_latents)
audio_latents = audio_patchifier.patchify(audio_latents)
#TODO: support gradient checkpointing
timestep = timestep.float() / 1000.
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
if denoise_mask_video is not None:
video_timesteps = video_patchifier.patchify(denoise_mask_video) * video_timesteps
_, c_a, _, mel_bins = audio_latents.shape
audio_latents = audio_patchifier.patchify(audio_latents)
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
#TODO: support gradient checkpointing in training
vx, ax = dit(
video_latents=video_latents,
video_positions=video_positions,

View File

@@ -4,6 +4,9 @@ import torch
import av
from tqdm import tqdm
from PIL import Image
import numpy as np
from io import BytesIO
from collections.abc import Generator, Iterator
def _resample_audio(
@@ -105,3 +108,42 @@ def write_video_audio_ltx2(
_write_audio(container, audio_stream, audio, audio_sample_rate)
container.close()
def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None:
container = av.open(output_file, "w", format="mp4")
try:
stream = container.add_stream("libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"})
# Round to nearest multiple of 2 for compatibility with video codecs
height = image_array.shape[0] // 2 * 2
width = image_array.shape[1] // 2 * 2
image_array = image_array[:height, :width]
stream.height = height
stream.width = width
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(format="yuv420p")
container.mux(stream.encode(av_frame))
container.mux(stream.encode())
finally:
container.close()
def decode_single_frame(video_file: str) -> np.array:
container = av.open(video_file)
try:
stream = next(s for s in container.streams if s.type == "video")
frame = next(container.decode(stream))
finally:
container.close()
return frame.to_ndarray(format="rgb24")
def ltx2_preprocess(image: np.array, crf: float = 33) -> np.array:
if crf == 0:
return image
with BytesIO() as output_file:
encode_single_frame(output_file, image, crf)
video_bytes = output_file.getvalue()
with BytesIO(video_bytes) as video_file:
image_array = decode_single_frame(video_file)
return image_array