mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
* support mova inference * mova media_io * add unified audio_video api & fix bug of mono audio input for ltx * support mova train * mova docs * fix bug
732 lines
40 KiB
Python
732 lines
40 KiB
Python
import torch, types
|
|
import numpy as np
|
|
from PIL import Image
|
|
from einops import repeat
|
|
from typing import Optional, Union
|
|
from einops import rearrange
|
|
import numpy as np
|
|
from PIL import Image
|
|
from tqdm import tqdm
|
|
from typing import Optional
|
|
from transformers import AutoImageProcessor, Gemma3Processor
|
|
|
|
from ..core.device.npu_compatible_device import get_device_type
|
|
from ..diffusion import FlowMatchScheduler
|
|
from ..core import ModelConfig
|
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
|
|
|
from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer
|
|
from ..models.ltx2_dit import LTXModel
|
|
from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier
|
|
from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier, AudioProcessor
|
|
from ..models.ltx2_upsampler import LTX2LatentUpsampler
|
|
from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS
|
|
from ..utils.data.media_io_ltx2 import ltx2_preprocess
|
|
from ..utils.data.audio import convert_to_stereo
|
|
|
|
|
|
class LTX2AudioVideoPipeline(BasePipeline):
|
|
|
|
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
|
super().__init__(
|
|
device=device,
|
|
torch_dtype=torch_dtype,
|
|
height_division_factor=32,
|
|
width_division_factor=32,
|
|
time_division_factor=8,
|
|
time_division_remainder=1,
|
|
)
|
|
self.scheduler = FlowMatchScheduler("LTX-2")
|
|
self.text_encoder: LTX2TextEncoder = None
|
|
self.tokenizer: LTXVGemmaTokenizer = None
|
|
self.processor: Gemma3Processor = None
|
|
self.text_encoder_post_modules: LTX2TextEncoderPostModules = None
|
|
self.dit: LTXModel = None
|
|
self.video_vae_encoder: LTX2VideoEncoder = None
|
|
self.video_vae_decoder: LTX2VideoDecoder = None
|
|
self.audio_vae_encoder: LTX2AudioEncoder = None
|
|
self.audio_vae_decoder: LTX2AudioDecoder = None
|
|
self.audio_vocoder: LTX2Vocoder = None
|
|
self.upsampler: LTX2LatentUpsampler = None
|
|
|
|
self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1)
|
|
self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1)
|
|
self.audio_processor: AudioProcessor = AudioProcessor()
|
|
|
|
self.in_iteration_models = ("dit",)
|
|
self.units = [
|
|
LTX2AudioVideoUnit_PipelineChecker(),
|
|
LTX2AudioVideoUnit_ShapeChecker(),
|
|
LTX2AudioVideoUnit_PromptEmbedder(),
|
|
LTX2AudioVideoUnit_NoiseInitializer(),
|
|
LTX2AudioVideoUnit_VideoRetakeEmbedder(),
|
|
LTX2AudioVideoUnit_AudioRetakeEmbedder(),
|
|
LTX2AudioVideoUnit_InputAudioEmbedder(),
|
|
LTX2AudioVideoUnit_InputVideoEmbedder(),
|
|
LTX2AudioVideoUnit_InputImagesEmbedder(),
|
|
LTX2AudioVideoUnit_InContextVideoEmbedder(),
|
|
]
|
|
self.stage2_units = [
|
|
LTX2AudioVideoUnit_SwitchStage2(),
|
|
LTX2AudioVideoUnit_NoiseInitializer(),
|
|
LTX2AudioVideoUnit_LatentsUpsampler(),
|
|
LTX2AudioVideoUnit_VideoRetakeEmbedder(),
|
|
LTX2AudioVideoUnit_AudioRetakeEmbedder(),
|
|
LTX2AudioVideoUnit_InputImagesEmbedder(),
|
|
LTX2AudioVideoUnit_SetScheduleStage2(),
|
|
]
|
|
self.model_fn = model_fn_ltx2
|
|
|
|
self.default_negative_prompt = {
|
|
"LTX-2": (
|
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
|
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
|
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
|
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
|
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
|
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
|
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
|
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
|
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
|
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
|
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
|
),
|
|
"LTX-2.3": (
|
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
|
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
|
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
|
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
|
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
|
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
|
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
|
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
|
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
|
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
|
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
|
),
|
|
}
|
|
|
|
@staticmethod
|
|
def from_pretrained(
|
|
torch_dtype: torch.dtype = torch.bfloat16,
|
|
device: Union[str, torch.device] = get_device_type(),
|
|
model_configs: list[ModelConfig] = [],
|
|
tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
|
stage2_lora_config: Optional[ModelConfig] = None,
|
|
stage2_lora_strength: float = 0.8,
|
|
vram_limit: float = None,
|
|
):
|
|
# Initialize pipeline
|
|
pipe = LTX2AudioVideoPipeline(device=device, torch_dtype=torch_dtype)
|
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
|
|
|
# Fetch models
|
|
pipe.text_encoder = model_pool.fetch_model("ltx2_text_encoder")
|
|
tokenizer_config.download_if_necessary()
|
|
pipe.tokenizer = LTXVGemmaTokenizer(tokenizer_path=tokenizer_config.path)
|
|
image_processor = AutoImageProcessor.from_pretrained(tokenizer_config.path, local_files_only=True)
|
|
pipe.processor = Gemma3Processor(image_processor=image_processor, tokenizer=pipe.tokenizer.tokenizer)
|
|
|
|
pipe.text_encoder_post_modules = model_pool.fetch_model("ltx2_text_encoder_post_modules")
|
|
pipe.dit = model_pool.fetch_model("ltx2_dit")
|
|
pipe.video_vae_encoder = model_pool.fetch_model("ltx2_video_vae_encoder")
|
|
pipe.video_vae_decoder = model_pool.fetch_model("ltx2_video_vae_decoder")
|
|
pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder")
|
|
pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder")
|
|
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
|
|
pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
|
|
|
|
# Stage 2
|
|
if stage2_lora_config is not None:
|
|
stage2_lora_config.download_if_necessary()
|
|
pipe.stage2_lora_path = stage2_lora_config.path
|
|
pipe.stage2_lora_strength = stage2_lora_strength
|
|
|
|
# VRAM Management
|
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
|
return pipe
|
|
|
|
def denoise_stage(self, inputs_shared, inputs_posi, inputs_nega, units, cfg_scale=1.0, progress_bar_cmd=tqdm, skip_stage=False):
|
|
if skip_stage:
|
|
return inputs_shared, inputs_posi, inputs_nega
|
|
for unit in units:
|
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
|
self.load_models_to_device(self.in_iteration_models)
|
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
|
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
|
|
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
|
|
**models, timestep=timestep, progress_id=progress_id
|
|
)
|
|
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video,
|
|
inpaint_mask=inputs_shared.get("denoise_mask_video", None), input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
|
|
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio,
|
|
inpaint_mask=inputs_shared.get("denoise_mask_audio", None), input_latents=inputs_shared.get("input_latents_audio", None), **inputs_shared)
|
|
return inputs_shared, inputs_posi, inputs_nega
|
|
|
|
@torch.no_grad()
|
|
def __call__(
|
|
self,
|
|
# Prompt
|
|
prompt: str,
|
|
negative_prompt: Optional[str] = "",
|
|
denoising_strength: float = 1.0,
|
|
# Image-to-video
|
|
input_images: Optional[list[Image.Image]] = None,
|
|
input_images_indexes: Optional[list[int]] = [0],
|
|
input_images_strength: Optional[float] = 1.0,
|
|
# In-Context Video Control
|
|
in_context_videos: Optional[list[list[Image.Image]]] = None,
|
|
in_context_downsample_factor: Optional[int] = 2,
|
|
# Video-to-video
|
|
retake_video: Optional[list[Image.Image]] = None,
|
|
retake_video_regions: Optional[list[tuple[float, float]]] = None,
|
|
# Audio-to-video
|
|
retake_audio: Optional[torch.Tensor] = None,
|
|
audio_sample_rate: Optional[int] = 48000,
|
|
retake_audio_regions: Optional[list[tuple[float, float]]] = None,
|
|
# Randomness
|
|
seed: Optional[int] = None,
|
|
rand_device: Optional[str] = "cpu",
|
|
# Shape
|
|
height: Optional[int] = 512,
|
|
width: Optional[int] = 768,
|
|
num_frames: Optional[int] = 121,
|
|
frame_rate: Optional[int] = 24,
|
|
# Classifier-free guidance
|
|
cfg_scale: Optional[float] = 3.0,
|
|
# Scheduler
|
|
num_inference_steps: Optional[int] = 30,
|
|
# VAE tiling
|
|
tiled: Optional[bool] = True,
|
|
tile_size_in_pixels: Optional[int] = 512,
|
|
tile_overlap_in_pixels: Optional[int] = 128,
|
|
tile_size_in_frames: Optional[int] = 128,
|
|
tile_overlap_in_frames: Optional[int] = 24,
|
|
# Special Pipelines
|
|
use_two_stage_pipeline: Optional[bool] = False,
|
|
stage2_spatial_upsample_factor: Optional[int] = 2,
|
|
clear_lora_before_state_two: Optional[bool] = False,
|
|
use_distilled_pipeline: Optional[bool] = False,
|
|
# progress_bar
|
|
progress_bar_cmd=tqdm,
|
|
):
|
|
# Scheduler
|
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, special_case="ditilled_stage1" if use_distilled_pipeline else None)
|
|
# Inputs
|
|
inputs_posi = {
|
|
"prompt": prompt,
|
|
}
|
|
inputs_nega = {
|
|
"negative_prompt": negative_prompt,
|
|
}
|
|
inputs_shared = {
|
|
"input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength,
|
|
"retake_video": retake_video, "retake_video_regions": retake_video_regions,
|
|
"retake_audio": (retake_audio, audio_sample_rate) if retake_audio is not None else None, "retake_audio_regions": retake_audio_regions,
|
|
"in_context_videos": in_context_videos, "in_context_downsample_factor": in_context_downsample_factor,
|
|
"seed": seed, "rand_device": rand_device,
|
|
"height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate,
|
|
"cfg_scale": cfg_scale,
|
|
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
|
|
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
|
|
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, "clear_lora_before_state_two": clear_lora_before_state_two, "stage2_spatial_upsample_factor": stage2_spatial_upsample_factor,
|
|
"video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier,
|
|
}
|
|
# Stage 1
|
|
inputs_shared, inputs_posi, inputs_nega = self.denoise_stage(inputs_shared, inputs_posi, inputs_nega, self.units, cfg_scale, progress_bar_cmd)
|
|
# Stage 2
|
|
inputs_shared, inputs_posi, inputs_nega = self.denoise_stage(inputs_shared, inputs_posi, inputs_nega, self.stage2_units, 1.0, progress_bar_cmd, not inputs_shared["use_two_stage_pipeline"])
|
|
# Decode
|
|
self.load_models_to_device(['video_vae_decoder'])
|
|
video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels, tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames)
|
|
video = self.vae_output_to_video(video)
|
|
self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder'])
|
|
decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"])
|
|
decoded_audio = self.audio_vocoder(decoded_audio)
|
|
decoded_audio = self.output_audio_format_check(decoded_audio)
|
|
return video, decoded_audio
|
|
|
|
|
|
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
take_over=True,
|
|
input_params=("use_distilled_pipeline", "use_two_stage_pipeline"),
|
|
output_params=("use_two_stage_pipeline", "cfg_scale")
|
|
)
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
|
if inputs_shared.get("use_distilled_pipeline", False):
|
|
inputs_shared["use_two_stage_pipeline"] = True
|
|
inputs_shared["cfg_scale"] = 1.0
|
|
print(f"Distilled pipeline requested, setting use_two_stage_pipeline to True, disable CFG by setting cfg_scale to 1.0.")
|
|
if inputs_shared.get("use_two_stage_pipeline", False):
|
|
# distill pipeline also uses two-stage, but it does not needs lora
|
|
if not inputs_shared.get("use_distilled_pipeline", False):
|
|
if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None):
|
|
raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.")
|
|
if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None):
|
|
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
|
|
return inputs_shared, inputs_posi, inputs_nega
|
|
|
|
|
|
class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit):
|
|
"""
|
|
For two-stage pipelines, the resolution must be divisible by 64.
|
|
For one-stage pipelines, the resolution must be divisible by 32.
|
|
This unit set height and width to stage 1 resolution, and stage_2_width and stage_2_height.
|
|
"""
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("height", "width", "num_frames", "use_two_stage_pipeline", "stage2_spatial_upsample_factor"),
|
|
output_params=("height", "width", "num_frames", "stage_2_height", "stage_2_width"),
|
|
)
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False, stage2_spatial_upsample_factor=2):
|
|
if use_two_stage_pipeline:
|
|
height, width = height // stage2_spatial_upsample_factor, width // stage2_spatial_upsample_factor
|
|
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
|
|
stage_2_height, stage_2_width = int(height * stage2_spatial_upsample_factor), int(width * stage2_spatial_upsample_factor)
|
|
else:
|
|
stage_2_height, stage_2_width = None, None
|
|
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
|
|
return {"height": height, "width": width, "num_frames": num_frames, "stage_2_height": stage_2_height, "stage_2_width": stage_2_width}
|
|
|
|
|
|
class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
|
|
|
|
def __init__(self):
|
|
super().__init__(
|
|
seperate_cfg=True,
|
|
input_params_posi={"prompt": "prompt"},
|
|
input_params_nega={"prompt": "negative_prompt"},
|
|
output_params=("video_context", "audio_context"),
|
|
onload_model_names=("text_encoder", "text_encoder_post_modules"),
|
|
)
|
|
def _preprocess_text(
|
|
self,
|
|
pipe,
|
|
text: str,
|
|
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
|
token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"]
|
|
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.device)
|
|
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.device)
|
|
outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
|
return outputs.hidden_states, attention_mask
|
|
def encode_prompt(self, pipe, text, padding_side="left"):
|
|
hidden_states, attention_mask = self._preprocess_text(pipe, text)
|
|
video_encoding, audio_encoding, attention_mask = pipe.text_encoder_post_modules.process_hidden_states(
|
|
hidden_states, attention_mask, padding_side)
|
|
return video_encoding, audio_encoding, attention_mask
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, prompt: str):
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
video_context, audio_context, _ = self.encode_prompt(pipe, prompt)
|
|
return {"video_context": video_context, "audio_context": audio_context}
|
|
|
|
|
|
class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate"),
|
|
output_params=("video_noise", "audio_noise", "video_positions", "audio_positions", "video_latent_shape", "audio_latent_shape")
|
|
)
|
|
|
|
def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
|
|
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
|
|
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=128)
|
|
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
|
|
|
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)
|
|
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()
|
|
video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate
|
|
video_positions = video_positions.to(pipe.torch_dtype)
|
|
|
|
audio_latent_shape = AudioLatentShape.from_video_pixel_shape(video_pixel_shape)
|
|
audio_noise = pipe.generate_noise(audio_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
|
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
|
|
return {
|
|
"video_noise": video_noise,
|
|
"audio_noise": audio_noise,
|
|
"video_positions": video_positions,
|
|
"audio_positions": audio_positions,
|
|
"video_latent_shape": video_latent_shape,
|
|
"audio_latent_shape": audio_latent_shape
|
|
}
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
|
|
return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
|
|
|
|
|
|
class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("input_video", "video_noise", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels"),
|
|
output_params=("video_latents", "input_latents"),
|
|
onload_model_names=("video_vae_encoder")
|
|
)
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
|
|
if input_video is None or not pipe.scheduler.training:
|
|
return {"video_latents": video_noise}
|
|
else:
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
input_video = pipe.preprocess_video(input_video)
|
|
input_latents = pipe.video_vae_encoder.encode(input_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
return {"video_latents": input_latents, "input_latents": input_latents}
|
|
|
|
class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("input_audio", "audio_noise"),
|
|
output_params=("audio_latents", "audio_input_latents", "audio_positions", "audio_latent_shape"),
|
|
onload_model_names=("audio_vae_encoder",)
|
|
)
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, input_audio, audio_noise):
|
|
if input_audio is None or not pipe.scheduler.training:
|
|
return {"audio_latents": audio_noise}
|
|
else:
|
|
input_audio, sample_rate = input_audio
|
|
input_audio = convert_to_stereo(input_audio)
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype)
|
|
audio_input_latents = pipe.audio_vae_encoder(input_audio)
|
|
audio_latent_shape = AudioLatentShape.from_torch_shape(audio_input_latents.shape)
|
|
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
|
|
return {"audio_latents": audio_input_latents, "audio_input_latents": audio_input_latents, "audio_positions": audio_positions, "audio_latent_shape": audio_latent_shape}
|
|
|
|
|
|
class LTX2AudioVideoUnit_VideoRetakeEmbedder(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("retake_video", "height", "width", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "video_positions", "retake_video_regions"),
|
|
output_params=("input_latents_video", "denoise_mask_video"),
|
|
onload_model_names=("video_vae_encoder")
|
|
)
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, retake_video, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, video_positions, retake_video_regions=None):
|
|
if retake_video is None:
|
|
return {}
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
resized_video = [frame.resize((width, height)) for frame in retake_video]
|
|
input_video = pipe.preprocess_video(resized_video)
|
|
input_latents_video = pipe.video_vae_encoder.encode(input_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
|
|
b, c, f, h, w = input_latents_video.shape
|
|
denoise_mask_video = torch.zeros((b, 1, f, h, w), device=input_latents_video.device, dtype=input_latents_video.dtype)
|
|
if retake_video_regions is not None and len(retake_video_regions) > 0:
|
|
for start_time, end_time in retake_video_regions:
|
|
t_start, t_end = video_positions[0, 0].unbind(dim=-1)
|
|
in_region = (t_end >= start_time) & (t_start <= end_time)
|
|
in_region = pipe.video_patchifier.unpatchify_video(in_region.unsqueeze(0).unsqueeze(-1), f, h, w)
|
|
denoise_mask_video = torch.where(in_region, torch.ones_like(denoise_mask_video), denoise_mask_video)
|
|
|
|
return {"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video}
|
|
|
|
|
|
class LTX2AudioVideoUnit_AudioRetakeEmbedder(PipelineUnit):
|
|
"""
|
|
Functionality of audio2video, audio retaking.
|
|
"""
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("retake_audio", "seed", "rand_device", "retake_audio_regions"),
|
|
output_params=("input_latents_audio", "audio_noise", "audio_positions", "audio_latent_shape", "denoise_mask_audio"),
|
|
onload_model_names=("audio_vae_encoder",)
|
|
)
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, retake_audio, seed, rand_device, retake_audio_regions=None):
|
|
if retake_audio is None:
|
|
return {}
|
|
else:
|
|
input_audio, sample_rate = retake_audio
|
|
input_audio = convert_to_stereo(input_audio)
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
input_latents_audio = pipe.audio_vae_encoder(input_audio)
|
|
audio_latent_shape = AudioLatentShape.from_torch_shape(input_latents_audio.shape)
|
|
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
|
|
# Regenerate noise for the new shape if retake_audio is provided, to avoid shape mismatch.
|
|
audio_noise = pipe.generate_noise(input_latents_audio.shape, seed=seed, rand_device=rand_device)
|
|
|
|
b, c, t, f = input_latents_audio.shape
|
|
denoise_mask_audio = torch.zeros((b, 1, t, 1), device=input_latents_audio.device, dtype=input_latents_audio.dtype)
|
|
if retake_audio_regions is not None and len(retake_audio_regions) > 0:
|
|
for start_time, end_time in retake_audio_regions:
|
|
t_start, t_end = audio_positions[:, 0, :, 0], audio_positions[:, 0, :, 1]
|
|
in_region = (t_end >= start_time) & (t_start <= end_time)
|
|
in_region = pipe.audio_patchifier.unpatchify_audio(in_region.unsqueeze(-1), 1, 1)
|
|
denoise_mask_audio = torch.where(in_region, torch.ones_like(denoise_mask_audio), denoise_mask_audio)
|
|
|
|
return {
|
|
"input_latents_audio": input_latents_audio,
|
|
"denoise_mask_audio": denoise_mask_audio,
|
|
"audio_noise": audio_noise,
|
|
"audio_positions": audio_positions,
|
|
"audio_latent_shape": audio_latent_shape,
|
|
}
|
|
|
|
|
|
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "frame_rate", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "input_latents_video", "denoise_mask_video"),
|
|
output_params=("denoise_mask_video", "input_latents_video", "ref_frames_latents", "ref_frames_positions"),
|
|
onload_model_names=("video_vae_encoder")
|
|
)
|
|
|
|
def get_image_latent(self, pipe, input_image, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
|
|
image = ltx2_preprocess(np.array(input_image.resize((width, height))))
|
|
image = torch.Tensor(np.array(image, dtype=np.float32)).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
image = image / 127.5 - 1.0
|
|
image = repeat(image, f"H W C -> B C F H W", B=1, F=1)
|
|
latents = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)
|
|
return latents
|
|
|
|
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength=1.0, input_latents_video=None, denoise_mask_video=None):
|
|
b, _, f, h, w = latents.shape
|
|
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device) if denoise_mask_video is None else denoise_mask_video
|
|
input_latents_video = torch.zeros_like(latents) if input_latents_video is None else input_latents_video
|
|
for idx, input_latent in zip(input_indexes, input_latents):
|
|
idx = min(max(1 + (idx-1) // 8, 0), f - 1)
|
|
input_latent = input_latent.to(dtype=latents.dtype, device=latents.device)
|
|
input_latents_video[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent
|
|
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
|
|
return input_latents_video, denoise_mask
|
|
|
|
def process(
|
|
self,
|
|
pipe: LTX2AudioVideoPipeline,
|
|
video_latents,
|
|
input_images,
|
|
height,
|
|
width,
|
|
frame_rate,
|
|
tiled,
|
|
tile_size_in_pixels,
|
|
tile_overlap_in_pixels,
|
|
input_images_indexes=[0],
|
|
input_images_strength=1.0,
|
|
input_latents_video=None,
|
|
denoise_mask_video=None,
|
|
):
|
|
if input_images is None or len(input_images) == 0:
|
|
return {}
|
|
else:
|
|
if len(input_images_indexes) != len(set(input_images_indexes)):
|
|
raise ValueError("Input images must have unique indexes.")
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
frame_conditions = {"input_latents_video": None, "denoise_mask_video": None, "ref_frames_latents": [], "ref_frames_positions": []}
|
|
for img, index in zip(input_images, input_images_indexes):
|
|
latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels)
|
|
# first_frame by replacing latents
|
|
if index == 0:
|
|
input_latents_video, denoise_mask_video = self.apply_input_images_to_latents(
|
|
video_latents, [latents], [0], input_images_strength, input_latents_video, denoise_mask_video)
|
|
frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video})
|
|
# other frames by adding reference latents
|
|
else:
|
|
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(latents.shape), device=pipe.device)
|
|
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, False).float()
|
|
video_positions[:, 0, ...] = (video_positions[:, 0, ...] + index) / frame_rate
|
|
video_positions = video_positions.to(pipe.torch_dtype)
|
|
frame_conditions["ref_frames_latents"].append(latents)
|
|
frame_conditions["ref_frames_positions"].append(video_positions)
|
|
if len(frame_conditions["ref_frames_latents"]) == 0:
|
|
frame_conditions.update({"ref_frames_latents": None, "ref_frames_positions": None})
|
|
return frame_conditions
|
|
|
|
|
|
class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("in_context_videos", "height", "width", "num_frames", "frame_rate", "in_context_downsample_factor", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels"),
|
|
output_params=("in_context_video_latents", "in_context_video_positions"),
|
|
onload_model_names=("video_vae_encoder")
|
|
)
|
|
|
|
def check_in_context_video(self, pipe, in_context_video, height, width, num_frames, in_context_downsample_factor):
|
|
if in_context_video is None or len(in_context_video) == 0:
|
|
raise ValueError("In-context video is None or empty.")
|
|
in_context_video = in_context_video[:num_frames]
|
|
expected_height = height // in_context_downsample_factor
|
|
expected_width = width // in_context_downsample_factor
|
|
current_h, current_w, current_f = in_context_video[0].size[1], in_context_video[0].size[0], len(in_context_video)
|
|
h, w, f = pipe.check_resize_height_width(expected_height, expected_width, current_f, verbose=0)
|
|
if current_h != h or current_w != w:
|
|
in_context_video = [img.resize((w, h)) for img in in_context_video]
|
|
if current_f != f:
|
|
# pad black frames at the end
|
|
in_context_video = in_context_video + [Image.new("RGB", (w, h), (0, 0, 0))] * (f - current_f)
|
|
return in_context_video
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, in_context_videos, height, width, num_frames, frame_rate, in_context_downsample_factor, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
|
|
if in_context_videos is None or len(in_context_videos) == 0:
|
|
return {}
|
|
else:
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
latents, positions = [], []
|
|
for in_context_video in in_context_videos:
|
|
in_context_video = self.check_in_context_video(pipe, in_context_video, height, width, num_frames, in_context_downsample_factor)
|
|
in_context_video = pipe.preprocess_video(in_context_video)
|
|
in_context_latents = pipe.video_vae_encoder.encode(in_context_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
|
|
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(in_context_latents.shape), device=pipe.device)
|
|
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()
|
|
video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate
|
|
video_positions[:, 1, ...] *= in_context_downsample_factor # height axis
|
|
video_positions[:, 2, ...] *= in_context_downsample_factor # width axis
|
|
video_positions = video_positions.to(pipe.torch_dtype)
|
|
|
|
latents.append(in_context_latents)
|
|
positions.append(video_positions)
|
|
latents = torch.cat(latents, dim=1)
|
|
positions = torch.cat(positions, dim=1)
|
|
return {"in_context_video_latents": latents, "in_context_video_positions": positions}
|
|
|
|
|
|
class LTX2AudioVideoUnit_SwitchStage2(PipelineUnit):
|
|
"""
|
|
1. switch height and width to stage 2 resolution
|
|
2. clear in_context_video_latents and in_context_video_positions
|
|
3. switch stage 2 lora model
|
|
"""
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("stage_2_height", "stage_2_width", "clear_lora_before_state_two", "use_distilled_pipeline"),
|
|
output_params=("height", "width", "in_context_video_latents", "in_context_video_positions"),
|
|
)
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, stage_2_height, stage_2_width, clear_lora_before_state_two, use_distilled_pipeline):
|
|
stage2_params = {}
|
|
stage2_params.update({"height": stage_2_height, "width": stage_2_width})
|
|
stage2_params.update({"in_context_video_latents": None, "in_context_video_positions": None})
|
|
stage2_params.update({"input_latents_video": None, "denoise_mask_video": None})
|
|
if clear_lora_before_state_two:
|
|
pipe.clear_lora()
|
|
if not use_distilled_pipeline:
|
|
pipe.load_lora(pipe.dit, pipe.stage2_lora_path, alpha=pipe.stage2_lora_strength)
|
|
return stage2_params
|
|
|
|
|
|
class LTX2AudioVideoUnit_SetScheduleStage2(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("video_latents", "video_noise", "audio_latents", "audio_noise"),
|
|
output_params=("video_latents", "audio_latents"),
|
|
)
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, video_latents, video_noise, audio_latents, audio_noise):
|
|
pipe.scheduler.set_timesteps(special_case="stage2")
|
|
video_latents = pipe.scheduler.add_noise(video_latents, video_noise, pipe.scheduler.timesteps[0])
|
|
audio_latents = pipe.scheduler.add_noise(audio_latents, audio_noise, pipe.scheduler.timesteps[0])
|
|
return {"video_latents": video_latents, "audio_latents": audio_latents}
|
|
|
|
|
|
class LTX2AudioVideoUnit_LatentsUpsampler(PipelineUnit):
|
|
def __init__(self):
|
|
super().__init__(
|
|
input_params=("video_latents",),
|
|
output_params=("video_latents",),
|
|
onload_model_names=("upsampler",),
|
|
)
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, video_latents):
|
|
if video_latents is None or pipe.upsampler is None:
|
|
raise ValueError("No upsampler or no video latents before stage 2.")
|
|
else:
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
video_latents = pipe.video_vae_encoder.per_channel_statistics.un_normalize(video_latents)
|
|
video_latents = pipe.upsampler(video_latents)
|
|
video_latents = pipe.video_vae_encoder.per_channel_statistics.normalize(video_latents)
|
|
return {"video_latents": video_latents}
|
|
|
|
|
|
def model_fn_ltx2(
|
|
dit: LTXModel,
|
|
video_latents=None,
|
|
video_context=None,
|
|
video_positions=None,
|
|
video_patchifier=None,
|
|
audio_latents=None,
|
|
audio_context=None,
|
|
audio_positions=None,
|
|
audio_patchifier=None,
|
|
timestep=None,
|
|
# First Frame Conditioning
|
|
input_latents_video=None,
|
|
denoise_mask_video=None,
|
|
# Other Frames Conditioning
|
|
ref_frames_latents=None,
|
|
ref_frames_positions=None,
|
|
# In-Context Conditioning
|
|
in_context_video_latents=None,
|
|
in_context_video_positions=None,
|
|
# Audio Inputs
|
|
input_latents_audio=None,
|
|
denoise_mask_audio=None,
|
|
# Gradient Checkpointing
|
|
use_gradient_checkpointing=False,
|
|
use_gradient_checkpointing_offload=False,
|
|
**kwargs,
|
|
):
|
|
timestep = timestep.float() / 1000.
|
|
|
|
# patchify
|
|
b, c_v, f, h, w = video_latents.shape
|
|
video_latents = video_patchifier.patchify(video_latents)
|
|
seq_len_video = video_latents.shape[1]
|
|
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
|
|
# Frist frame conditioning by replacing the video latents
|
|
if input_latents_video is not None:
|
|
denoise_mask_video = video_patchifier.patchify(denoise_mask_video)
|
|
video_latents = video_latents * denoise_mask_video + video_patchifier.patchify(input_latents_video) * (1.0 - denoise_mask_video)
|
|
video_timesteps = denoise_mask_video * video_timesteps
|
|
|
|
# Reference conditioning by appending the reference video or frame latents
|
|
total_ref_latents = ref_frames_latents if ref_frames_latents is not None else []
|
|
total_ref_positions = ref_frames_positions if ref_frames_positions is not None else []
|
|
total_ref_latents += [in_context_video_latents] if in_context_video_latents is not None else []
|
|
total_ref_positions += [in_context_video_positions] if in_context_video_positions is not None else []
|
|
if len(total_ref_latents) > 0:
|
|
for ref_frames_latent, ref_frames_position in zip(total_ref_latents, total_ref_positions):
|
|
ref_frames_latent = video_patchifier.patchify(ref_frames_latent)
|
|
ref_frames_timestep = timestep.repeat(1, ref_frames_latent.shape[1], 1) * 0.
|
|
video_latents = torch.cat([video_latents, ref_frames_latent], dim=1)
|
|
video_positions = torch.cat([video_positions, ref_frames_position], dim=2)
|
|
video_timesteps = torch.cat([video_timesteps, ref_frames_timestep], dim=1)
|
|
|
|
if audio_latents is not None:
|
|
_, c_a, _, mel_bins = audio_latents.shape
|
|
audio_latents = audio_patchifier.patchify(audio_latents)
|
|
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
|
|
else:
|
|
audio_timesteps = None
|
|
if input_latents_audio is not None:
|
|
denoise_mask_audio = audio_patchifier.patchify(denoise_mask_audio)
|
|
audio_latents = audio_latents * denoise_mask_audio + audio_patchifier.patchify(input_latents_audio) * (1.0 - denoise_mask_audio)
|
|
audio_timesteps = denoise_mask_audio * audio_timesteps
|
|
|
|
vx, ax = dit(
|
|
video_latents=video_latents,
|
|
video_positions=video_positions,
|
|
video_context=video_context,
|
|
video_timesteps=video_timesteps,
|
|
audio_latents=audio_latents,
|
|
audio_positions=audio_positions,
|
|
audio_context=audio_context,
|
|
audio_timesteps=audio_timesteps,
|
|
sigma=timestep,
|
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
|
)
|
|
|
|
vx = vx[:, :seq_len_video, ...]
|
|
# unpatchify
|
|
vx = video_patchifier.unpatchify_video(vx, f, h, w)
|
|
ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins) if ax is not None else None
|
|
return vx, ax
|