|
|
|
|
@@ -12,7 +12,7 @@ from transformers import AutoImageProcessor, Gemma3Processor
|
|
|
|
|
|
|
|
|
|
from ..core.device.npu_compatible_device import get_device_type
|
|
|
|
|
from ..diffusion import FlowMatchScheduler
|
|
|
|
|
from ..core import ModelConfig, gradient_checkpoint_forward
|
|
|
|
|
from ..core import ModelConfig
|
|
|
|
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
|
|
|
|
|
|
|
|
|
from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer
|
|
|
|
|
@@ -63,6 +63,13 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|
|
|
|
LTX2AudioVideoUnit_InputImagesEmbedder(),
|
|
|
|
|
LTX2AudioVideoUnit_InContextVideoEmbedder(),
|
|
|
|
|
]
|
|
|
|
|
self.stage2_units = [
|
|
|
|
|
LTX2AudioVideoUnit_SwitchStage2(),
|
|
|
|
|
LTX2AudioVideoUnit_NoiseInitializer(),
|
|
|
|
|
LTX2AudioVideoUnit_LatentsUpsampler(),
|
|
|
|
|
LTX2AudioVideoUnit_SetScheduleStage2(),
|
|
|
|
|
LTX2AudioVideoUnit_InputImagesEmbedder(),
|
|
|
|
|
]
|
|
|
|
|
self.model_fn = model_fn_ltx2
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
@@ -72,6 +79,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|
|
|
|
model_configs: list[ModelConfig] = [],
|
|
|
|
|
tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
|
|
|
|
stage2_lora_config: Optional[ModelConfig] = None,
|
|
|
|
|
stage2_lora_strength: float = 0.8,
|
|
|
|
|
vram_limit: float = None,
|
|
|
|
|
):
|
|
|
|
|
# Initialize pipeline
|
|
|
|
|
@@ -98,53 +106,31 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|
|
|
|
if stage2_lora_config is not None:
|
|
|
|
|
stage2_lora_config.download_if_necessary()
|
|
|
|
|
pipe.stage2_lora_path = stage2_lora_config.path
|
|
|
|
|
# Optional, currently not used
|
|
|
|
|
pipe.stage2_lora_strength = stage2_lora_strength
|
|
|
|
|
|
|
|
|
|
# VRAM Management
|
|
|
|
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
|
|
|
|
return pipe
|
|
|
|
|
|
|
|
|
|
def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm):
|
|
|
|
|
if inputs_shared["use_two_stage_pipeline"]:
|
|
|
|
|
if inputs_shared.get("clear_lora_before_state_two", False):
|
|
|
|
|
self.clear_lora()
|
|
|
|
|
latents = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
|
|
|
|
|
self.load_models_to_device('upsampler',)
|
|
|
|
|
latents = self.upsampler(latents)
|
|
|
|
|
latents = self.video_vae_encoder.per_channel_statistics.normalize(latents)
|
|
|
|
|
self.scheduler.set_timesteps(special_case="stage2")
|
|
|
|
|
inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")})
|
|
|
|
|
denoise_mask_video = 1.0
|
|
|
|
|
# input image
|
|
|
|
|
if inputs_shared.get("input_images", None) is not None:
|
|
|
|
|
initial_latents, denoise_mask_video = self.apply_input_images_to_latents(latents, initial_latents=latents, **inputs_shared.get("stage2_input_latents_apply_kwargs", {}))
|
|
|
|
|
inputs_shared.update({"input_latents_video": initial_latents, "denoise_mask_video": denoise_mask_video})
|
|
|
|
|
# remove in-context video control in stage 2
|
|
|
|
|
inputs_shared.pop("in_context_video_latents", None)
|
|
|
|
|
inputs_shared.pop("in_context_video_positions", None)
|
|
|
|
|
|
|
|
|
|
# initialize latents for stage 2
|
|
|
|
|
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * denoise_mask_video * inputs_shared[
|
|
|
|
|
"video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latents
|
|
|
|
|
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (
|
|
|
|
|
1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"]
|
|
|
|
|
def denoise_stage(self, inputs_shared, inputs_posi, inputs_nega, units, cfg_scale=1.0, progress_bar_cmd=tqdm, skip_stage=False):
|
|
|
|
|
if skip_stage:
|
|
|
|
|
return inputs_shared, inputs_posi, inputs_nega
|
|
|
|
|
for unit in units:
|
|
|
|
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
|
|
|
|
self.load_models_to_device(self.in_iteration_models)
|
|
|
|
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
|
|
|
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
|
|
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
|
|
|
|
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
|
|
|
|
|
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
|
|
|
|
|
**models, timestep=timestep, progress_id=progress_id
|
|
|
|
|
)
|
|
|
|
|
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video,
|
|
|
|
|
inpaint_mask=inputs_shared.get("denoise_mask_video", None), input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
|
|
|
|
|
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared)
|
|
|
|
|
return inputs_shared, inputs_posi, inputs_nega
|
|
|
|
|
|
|
|
|
|
self.load_models_to_device(self.in_iteration_models)
|
|
|
|
|
if not inputs_shared["use_distilled_pipeline"]:
|
|
|
|
|
self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8)
|
|
|
|
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
|
|
|
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
|
|
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
|
|
|
|
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
|
|
|
|
|
self.model_fn, 1.0, inputs_shared, inputs_posi, inputs_nega,
|
|
|
|
|
**models, timestep=timestep, progress_id=progress_id
|
|
|
|
|
)
|
|
|
|
|
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id,
|
|
|
|
|
noise_pred=noise_pred_video, inpaint_mask=inputs_shared.get("denoise_mask_video", None),
|
|
|
|
|
input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
|
|
|
|
|
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
|
|
|
|
|
noise_pred=noise_pred_audio, **inputs_shared)
|
|
|
|
|
return inputs_shared
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def __call__(
|
|
|
|
|
@@ -171,7 +157,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|
|
|
|
# Classifier-free guidance
|
|
|
|
|
cfg_scale: Optional[float] = 3.0,
|
|
|
|
|
# Scheduler
|
|
|
|
|
num_inference_steps: Optional[int] = 40,
|
|
|
|
|
num_inference_steps: Optional[int] = 30,
|
|
|
|
|
# VAE tiling
|
|
|
|
|
tiled: Optional[bool] = True,
|
|
|
|
|
tile_size_in_pixels: Optional[int] = 512,
|
|
|
|
|
@@ -180,14 +166,14 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|
|
|
|
tile_overlap_in_frames: Optional[int] = 24,
|
|
|
|
|
# Special Pipelines
|
|
|
|
|
use_two_stage_pipeline: Optional[bool] = False,
|
|
|
|
|
stage2_spatial_upsample_factor: Optional[int] = 2,
|
|
|
|
|
clear_lora_before_state_two: Optional[bool] = False,
|
|
|
|
|
use_distilled_pipeline: Optional[bool] = False,
|
|
|
|
|
# progress_bar
|
|
|
|
|
progress_bar_cmd=tqdm,
|
|
|
|
|
):
|
|
|
|
|
# Scheduler
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength,
|
|
|
|
|
special_case="ditilled_stage1" if use_distilled_pipeline else None)
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, special_case="ditilled_stage1" if use_distilled_pipeline else None)
|
|
|
|
|
# Inputs
|
|
|
|
|
inputs_posi = {
|
|
|
|
|
"prompt": prompt,
|
|
|
|
|
@@ -203,50 +189,22 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|
|
|
|
"cfg_scale": cfg_scale,
|
|
|
|
|
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
|
|
|
|
|
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
|
|
|
|
|
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, "clear_lora_before_state_two": clear_lora_before_state_two,
|
|
|
|
|
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, "clear_lora_before_state_two": clear_lora_before_state_two, "stage2_spatial_upsample_factor": stage2_spatial_upsample_factor,
|
|
|
|
|
"video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier,
|
|
|
|
|
}
|
|
|
|
|
for unit in self.units:
|
|
|
|
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
|
|
|
|
|
|
|
|
|
# Denoise Stage 1
|
|
|
|
|
self.load_models_to_device(self.in_iteration_models)
|
|
|
|
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
|
|
|
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
|
|
|
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
|
|
|
|
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
|
|
|
|
|
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
|
|
|
|
|
**models, timestep=timestep, progress_id=progress_id
|
|
|
|
|
)
|
|
|
|
|
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video,
|
|
|
|
|
inpaint_mask=inputs_shared.get("denoise_mask_video", None), input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
|
|
|
|
|
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
|
|
|
|
|
noise_pred=noise_pred_audio, **inputs_shared)
|
|
|
|
|
|
|
|
|
|
# Denoise Stage 2
|
|
|
|
|
inputs_shared = self.stage2_denoise(inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd)
|
|
|
|
|
|
|
|
|
|
# Stage 1
|
|
|
|
|
inputs_shared, inputs_posi, inputs_nega = self.denoise_stage(inputs_shared, inputs_posi, inputs_nega, self.units, cfg_scale, progress_bar_cmd)
|
|
|
|
|
# Stage 2
|
|
|
|
|
inputs_shared, inputs_posi, inputs_nega = self.denoise_stage(inputs_shared, inputs_posi, inputs_nega, self.stage2_units, 1.0, progress_bar_cmd, not inputs_shared["use_two_stage_pipeline"])
|
|
|
|
|
# Decode
|
|
|
|
|
self.load_models_to_device(['video_vae_decoder'])
|
|
|
|
|
video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels,
|
|
|
|
|
tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames)
|
|
|
|
|
video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels, tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames)
|
|
|
|
|
video = self.vae_output_to_video(video)
|
|
|
|
|
self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder'])
|
|
|
|
|
decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"])
|
|
|
|
|
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
|
|
|
|
|
return video, decoded_audio
|
|
|
|
|
|
|
|
|
|
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength=1.0, initial_latents=None, denoise_mask_video=None):
|
|
|
|
|
b, _, f, h, w = latents.shape
|
|
|
|
|
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device) if denoise_mask_video is None else denoise_mask_video
|
|
|
|
|
initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents
|
|
|
|
|
for idx, input_latent in zip(input_indexes, input_latents):
|
|
|
|
|
idx = min(max(1 + (idx-1) // 8, 0), f - 1)
|
|
|
|
|
input_latent = input_latent.to(dtype=latents.dtype, device=latents.device)
|
|
|
|
|
initial_latents[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent
|
|
|
|
|
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
|
|
|
|
|
return initial_latents, denoise_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
@@ -275,22 +233,23 @@ class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit):
|
|
|
|
|
"""
|
|
|
|
|
For two-stage pipelines, the resolution must be divisible by 64.
|
|
|
|
|
For one-stage pipelines, the resolution must be divisible by 32.
|
|
|
|
|
This unit set height and width to stage 1 resolution, and stage_2_width and stage_2_height.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__(
|
|
|
|
|
input_params=("height", "width", "num_frames"),
|
|
|
|
|
output_params=("height", "width", "num_frames"),
|
|
|
|
|
input_params=("height", "width", "num_frames", "use_two_stage_pipeline", "stage2_spatial_upsample_factor"),
|
|
|
|
|
output_params=("height", "width", "num_frames", "stage_2_height", "stage_2_width"),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False):
|
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False, stage2_spatial_upsample_factor=2):
|
|
|
|
|
if use_two_stage_pipeline:
|
|
|
|
|
self.width_division_factor = 64
|
|
|
|
|
self.height_division_factor = 64
|
|
|
|
|
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
|
|
|
|
|
if use_two_stage_pipeline:
|
|
|
|
|
self.width_division_factor = 32
|
|
|
|
|
self.height_division_factor = 32
|
|
|
|
|
return {"height": height, "width": width, "num_frames": num_frames}
|
|
|
|
|
height, width = height // stage2_spatial_upsample_factor, width // stage2_spatial_upsample_factor
|
|
|
|
|
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
|
|
|
|
|
stage_2_height, stage_2_width = int(height * stage2_spatial_upsample_factor), int(width * stage2_spatial_upsample_factor)
|
|
|
|
|
else:
|
|
|
|
|
stage_2_height, stage_2_width = None, None
|
|
|
|
|
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
|
|
|
|
|
return {"height": height, "width": width, "num_frames": num_frames, "stage_2_height": stage_2_height, "stage_2_width": stage_2_width}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
|
|
|
|
|
@@ -328,7 +287,7 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
|
|
|
|
|
class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__(
|
|
|
|
|
input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate", "use_two_stage_pipeline"),
|
|
|
|
|
input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate"),
|
|
|
|
|
output_params=("video_noise", "audio_noise", "video_positions", "audio_positions", "video_latent_shape", "audio_latent_shape")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -354,15 +313,9 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
|
|
|
|
"audio_latent_shape": audio_latent_shape
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0, use_two_stage_pipeline=False):
|
|
|
|
|
if use_two_stage_pipeline:
|
|
|
|
|
stage1_dict = self.process_stage(pipe, height // 2, width // 2, num_frames, seed, rand_device, frame_rate)
|
|
|
|
|
stage2_dict = self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
|
|
|
|
|
initial_dict = stage1_dict
|
|
|
|
|
initial_dict.update({"stage2_" + k: v for k, v in stage2_dict.items()})
|
|
|
|
|
return initial_dict
|
|
|
|
|
else:
|
|
|
|
|
return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
|
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
|
|
|
|
|
return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
@@ -384,6 +337,7 @@ class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError("Video-to-video not implemented yet.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__(
|
|
|
|
|
@@ -407,11 +361,12 @@ class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError("Audio-to-video not supported.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__(
|
|
|
|
|
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"),
|
|
|
|
|
output_params=("denoise_mask_video", "input_latents_video", "stage2_input_latents_apply_kwargs"),
|
|
|
|
|
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "initial_latents"),
|
|
|
|
|
output_params=("denoise_mask_video", "input_latents_video"),
|
|
|
|
|
onload_model_names=("video_vae_encoder")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -423,53 +378,48 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
|
|
|
|
latents = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)
|
|
|
|
|
return latents
|
|
|
|
|
|
|
|
|
|
def get_frame_conditions(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_indexes, input_images_strength, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, video_latents=None, skip_apply=False):
|
|
|
|
|
frame_conditions = {}
|
|
|
|
|
for img, index in zip(input_images, input_images_indexes):
|
|
|
|
|
latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels)
|
|
|
|
|
# first_frame
|
|
|
|
|
if index == 0:
|
|
|
|
|
if skip_apply:
|
|
|
|
|
frame_conditions = {"input_latents": [latents], "input_indexes": [0], "input_strength": input_images_strength}
|
|
|
|
|
else:
|
|
|
|
|
input_latents_video, denoise_mask_video = pipe.apply_input_images_to_latents(video_latents, [latents], [0], input_images_strength)
|
|
|
|
|
frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video})
|
|
|
|
|
return frame_conditions
|
|
|
|
|
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength=1.0, initial_latents=None, denoise_mask_video=None):
|
|
|
|
|
b, _, f, h, w = latents.shape
|
|
|
|
|
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device) if denoise_mask_video is None else denoise_mask_video
|
|
|
|
|
initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents
|
|
|
|
|
for idx, input_latent in zip(input_indexes, input_latents):
|
|
|
|
|
idx = min(max(1 + (idx-1) // 8, 0), f - 1)
|
|
|
|
|
input_latent = input_latent.to(dtype=latents.dtype, device=latents.device)
|
|
|
|
|
initial_latents[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent
|
|
|
|
|
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
|
|
|
|
|
return initial_latents, denoise_mask
|
|
|
|
|
|
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, input_images, video_latents, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, input_images_indexes=[0], input_images_strength=1.0, use_two_stage_pipeline=False):
|
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, video_latents, input_images, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, input_images_indexes=[0], input_images_strength=1.0, initial_latents=None):
|
|
|
|
|
if input_images is None or len(input_images) == 0:
|
|
|
|
|
return {}
|
|
|
|
|
else:
|
|
|
|
|
if len(input_images_indexes) != len(set(input_images_indexes)):
|
|
|
|
|
raise ValueError("Input images must have unique indexes.")
|
|
|
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
|
|
|
output_dicts = {}
|
|
|
|
|
stage1_height = height // 2 if use_two_stage_pipeline else height
|
|
|
|
|
stage1_width = width // 2 if use_two_stage_pipeline else width
|
|
|
|
|
stage_1_frame_conditions = self.get_frame_conditions(pipe, input_images, input_images_indexes, input_images_strength, stage1_height, stage1_width,
|
|
|
|
|
tiled, tile_size_in_pixels, tile_overlap_in_pixels, video_latents)
|
|
|
|
|
output_dicts.update(stage_1_frame_conditions)
|
|
|
|
|
if use_two_stage_pipeline:
|
|
|
|
|
stage2_input_latents_apply_kwargs = self.get_frame_conditions(pipe, input_images, input_images_indexes, input_images_strength, height, width,
|
|
|
|
|
tiled, tile_size_in_pixels, tile_overlap_in_pixels, skip_apply=True)
|
|
|
|
|
output_dicts.update({"stage2_input_latents_apply_kwargs": stage2_input_latents_apply_kwargs})
|
|
|
|
|
return output_dicts
|
|
|
|
|
frame_conditions = {}
|
|
|
|
|
for img, index in zip(input_images, input_images_indexes):
|
|
|
|
|
latents = self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels)
|
|
|
|
|
# first_frame
|
|
|
|
|
if index == 0:
|
|
|
|
|
input_latents_video, denoise_mask_video = self.apply_input_images_to_latents(video_latents, [latents], [0], input_images_strength, initial_latents)
|
|
|
|
|
frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video})
|
|
|
|
|
return frame_conditions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__(
|
|
|
|
|
input_params=("in_context_videos", "height", "width", "num_frames", "frame_rate", "in_context_downsample_factor", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"),
|
|
|
|
|
input_params=("in_context_videos", "height", "width", "num_frames", "frame_rate", "in_context_downsample_factor", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels"),
|
|
|
|
|
output_params=("in_context_video_latents", "in_context_video_positions"),
|
|
|
|
|
onload_model_names=("video_vae_encoder")
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def check_in_context_video(self, pipe, in_context_video, height, width, num_frames, in_context_downsample_factor, use_two_stage_pipeline=True):
|
|
|
|
|
def check_in_context_video(self, pipe, in_context_video, height, width, num_frames, in_context_downsample_factor):
|
|
|
|
|
if in_context_video is None or len(in_context_video) == 0:
|
|
|
|
|
raise ValueError("In-context video is None or empty.")
|
|
|
|
|
in_context_video = in_context_video[:num_frames]
|
|
|
|
|
expected_height = height // in_context_downsample_factor // 2 if use_two_stage_pipeline else height // in_context_downsample_factor
|
|
|
|
|
expected_width = width // in_context_downsample_factor // 2 if use_two_stage_pipeline else width // in_context_downsample_factor
|
|
|
|
|
expected_height = height // in_context_downsample_factor
|
|
|
|
|
expected_width = width // in_context_downsample_factor
|
|
|
|
|
current_h, current_w, current_f = in_context_video[0].size[1], in_context_video[0].size[0], len(in_context_video)
|
|
|
|
|
h, w, f = pipe.check_resize_height_width(expected_height, expected_width, current_f, verbose=0)
|
|
|
|
|
if current_h != h or current_w != w:
|
|
|
|
|
@@ -479,14 +429,14 @@ class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit):
|
|
|
|
|
in_context_video = in_context_video + [Image.new("RGB", (w, h), (0, 0, 0))] * (f - current_f)
|
|
|
|
|
return in_context_video
|
|
|
|
|
|
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, in_context_videos, height, width, num_frames, frame_rate, in_context_downsample_factor, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=True):
|
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, in_context_videos, height, width, num_frames, frame_rate, in_context_downsample_factor, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
|
|
|
|
|
if in_context_videos is None or len(in_context_videos) == 0:
|
|
|
|
|
return {}
|
|
|
|
|
else:
|
|
|
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
|
|
|
latents, positions = [], []
|
|
|
|
|
for in_context_video in in_context_videos:
|
|
|
|
|
in_context_video = self.check_in_context_video(pipe, in_context_video, height, width, num_frames, in_context_downsample_factor, use_two_stage_pipeline)
|
|
|
|
|
in_context_video = self.check_in_context_video(pipe, in_context_video, height, width, num_frames, in_context_downsample_factor)
|
|
|
|
|
in_context_video = pipe.preprocess_video(in_context_video)
|
|
|
|
|
in_context_latents = pipe.video_vae_encoder.encode(in_context_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
|
|
|
|
|
|
|
|
|
|
@@ -504,6 +454,62 @@ class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit):
|
|
|
|
|
return {"in_context_video_latents": latents, "in_context_video_positions": positions}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LTX2AudioVideoUnit_SwitchStage2(PipelineUnit):
|
|
|
|
|
"""
|
|
|
|
|
1. switch height and width to stage 2 resolution
|
|
|
|
|
2. clear in_context_video_latents and in_context_video_positions
|
|
|
|
|
3. switch stage 2 lora model
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__(
|
|
|
|
|
input_params=("stage_2_height", "stage_2_width", "clear_lora_before_state_two", "use_distilled_pipeline"),
|
|
|
|
|
output_params=("height", "width", "in_context_video_latents", "in_context_video_positions"),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, stage_2_height, stage_2_width, clear_lora_before_state_two, use_distilled_pipeline):
|
|
|
|
|
stage2_params = {}
|
|
|
|
|
stage2_params.update({"height": stage_2_height, "width": stage_2_width})
|
|
|
|
|
stage2_params.update({"in_context_video_latents": None, "in_context_video_positions": None})
|
|
|
|
|
if clear_lora_before_state_two:
|
|
|
|
|
pipe.clear_lora()
|
|
|
|
|
if not use_distilled_pipeline:
|
|
|
|
|
pipe.load_lora(pipe.dit, pipe.stage2_lora_path, alpha=pipe.stage2_lora_strength)
|
|
|
|
|
return stage2_params
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LTX2AudioVideoUnit_SetScheduleStage2(PipelineUnit):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__(
|
|
|
|
|
input_params=("video_latents", "video_noise", "audio_latents", "audio_noise"),
|
|
|
|
|
output_params=("video_latents", "audio_latents"),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, video_latents, video_noise, audio_latents, audio_noise):
|
|
|
|
|
pipe.scheduler.set_timesteps(special_case="stage2")
|
|
|
|
|
video_latents = pipe.scheduler.add_noise(video_latents, video_noise, pipe.scheduler.timesteps[0])
|
|
|
|
|
audio_latents = pipe.scheduler.add_noise(audio_latents, audio_noise, pipe.scheduler.timesteps[0])
|
|
|
|
|
return {"video_latents": video_latents, "audio_latents": audio_latents}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LTX2AudioVideoUnit_LatentsUpsampler(PipelineUnit):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super().__init__(
|
|
|
|
|
input_params=("video_latents",),
|
|
|
|
|
output_params=("video_latents", "initial_latents"),
|
|
|
|
|
onload_model_names=("upsampler",),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def process(self, pipe: LTX2AudioVideoPipeline, video_latents):
|
|
|
|
|
if video_latents is None or pipe.upsampler is None:
|
|
|
|
|
raise ValueError("No upsampler or no video latents before stage 2.")
|
|
|
|
|
else:
|
|
|
|
|
pipe.load_models_to_device(self.onload_model_names)
|
|
|
|
|
video_latents = pipe.video_vae_encoder.per_channel_statistics.un_normalize(video_latents)
|
|
|
|
|
video_latents = pipe.upsampler(video_latents)
|
|
|
|
|
video_latents = pipe.video_vae_encoder.per_channel_statistics.normalize(video_latents)
|
|
|
|
|
return {"video_latents": video_latents, "initial_latents": video_latents}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def model_fn_ltx2(
|
|
|
|
|
dit: LTXModel,
|
|
|
|
|
video_latents=None,
|
|
|
|
|
@@ -515,10 +521,13 @@ def model_fn_ltx2(
|
|
|
|
|
audio_positions=None,
|
|
|
|
|
audio_patchifier=None,
|
|
|
|
|
timestep=None,
|
|
|
|
|
# First Frame Conditioning
|
|
|
|
|
input_latents_video=None,
|
|
|
|
|
denoise_mask_video=None,
|
|
|
|
|
# In-Context Conditioning
|
|
|
|
|
in_context_video_latents=None,
|
|
|
|
|
in_context_video_positions=None,
|
|
|
|
|
# Gradient Checkpointing
|
|
|
|
|
use_gradient_checkpointing=False,
|
|
|
|
|
use_gradient_checkpointing_offload=False,
|
|
|
|
|
**kwargs,
|
|
|
|
|
|