support ltx-2 training

This commit is contained in:
mi804
2026-02-25 17:19:57 +08:00
parent 288bbc7128
commit 586ac9d8a6
40 changed files with 893 additions and 49 deletions

View File

@@ -607,6 +607,12 @@ ltx2_series = [
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
},
{
"model_hash": "c567aaa37d5ed7454c73aa6024458661",
"model_name": "ltx2_dit",
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
@@ -614,6 +620,12 @@ ltx2_series = [
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
},
{
"model_hash": "7f7e904a53260ec0351b05f32153754b",
"model_name": "ltx2_video_vae_encoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
@@ -621,6 +633,12 @@ ltx2_series = [
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
},
{
"model_hash": "dc6029ca2825147872b45e35a2dc3a97",
"model_name": "ltx2_video_vae_decoder",
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
@@ -628,6 +646,12 @@ ltx2_series = [
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
},
{
"model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb",
"model_name": "ltx2_audio_vae_decoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
@@ -635,16 +659,34 @@ ltx2_series = [
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
},
# { # not used currently
# # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
# "model_hash": "aca7b0bbf8415e9c98360750268915fc",
# "model_name": "ltx2_audio_vae_encoder",
# "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
# "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
# },
{
"model_hash": "f471360f6b24bef702ab73133d9f8bb9",
"model_name": "ltx2_audio_vocoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_audio_vae_encoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
},
{
"model_hash": "29338f3b95e7e312a3460a482e4f4554",
"model_name": "ltx2_audio_vae_encoder",
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
},
{
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
"model_name": "ltx2_text_encoder_post_modules",
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
},
{
"model_hash": "981629689c8be92a712ab3c5eb4fc3f6",
"model_name": "ltx2_text_encoder_post_modules",
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",

View File

@@ -28,6 +28,36 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
return loss
def FlowMatchSFTAudioVideoLoss(pipe: BasePipeline, **inputs):
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
# video
noise = torch.randn_like(inputs["input_latents"])
inputs["video_latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
# audio
if inputs.get("audio_input_latents") is not None:
audio_noise = torch.randn_like(inputs["audio_input_latents"])
inputs["audio_latents"] = pipe.scheduler.add_noise(inputs["audio_input_latents"], audio_noise, timestep)
training_target_audio = pipe.scheduler.training_target(inputs["audio_input_latents"], audio_noise, timestep)
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
noise_pred, noise_pred_audio = pipe.model_fn(**models, **inputs, timestep=timestep)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * pipe.scheduler.training_weight(timestep)
if inputs.get("audio_input_latents") is not None:
loss_audio = torch.nn.functional.mse_loss(noise_pred_audio.float(), training_target_audio.float())
loss_audio = loss_audio * pipe.scheduler.training_weight(timestep)
loss = loss + loss_audio
return loss
def DirectDistillLoss(pipe: BasePipeline, **inputs):
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
pipe.scheduler.training = True

View File

@@ -5,8 +5,65 @@ import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from .ltx2_common import VideoLatentShape, AudioLatentShape, Patchifier, NormType, build_normalization_layer
class AudioProcessor(nn.Module):
"""Converts audio waveforms to log-mel spectrograms with optional resampling."""
def __init__(
self,
sample_rate: int = 16000,
mel_bins: int = 64,
mel_hop_length: int = 160,
n_fft: int = 1024,
) -> None:
super().__init__()
self.sample_rate = sample_rate
self.mel_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
win_length=n_fft,
hop_length=mel_hop_length,
f_min=0.0,
f_max=sample_rate / 2.0,
n_mels=mel_bins,
window_fn=torch.hann_window,
center=True,
pad_mode="reflect",
power=1.0,
mel_scale="slaney",
norm="slaney",
)
def resample_waveform(
self,
waveform: torch.Tensor,
source_rate: int,
target_rate: int,
) -> torch.Tensor:
"""Resample waveform to target sample rate if needed."""
if source_rate == target_rate:
return waveform
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
return resampled.to(device=waveform.device, dtype=waveform.dtype)
def waveform_to_mel(
self,
waveform: torch.Tensor,
waveform_sample_rate: int,
) -> torch.Tensor:
"""Convert waveform to log-mel spectrogram [batch, channels, time, n_mels]."""
waveform = self.resample_waveform(waveform, waveform_sample_rate, self.sample_rate)
mel = self.mel_transform(waveform)
mel = torch.log(torch.clamp(mel, min=1e-5))
mel = mel.to(device=waveform.device, dtype=waveform.dtype)
return mel.permute(0, 1, 3, 2).contiguous()
class AudioPatchifier(Patchifier):
def __init__(
self,

View File

@@ -1446,6 +1446,6 @@ class LTXModel(torch.nn.Module):
cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
self._init_preprocessors(cross_pe_max_pos)
video = Modality(video_latents, video_timesteps, video_positions, video_context)
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context)
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None
vx, ax = self._forward(video=video, audio=audio, perturbations=None)
return vx, ax

View File

@@ -18,7 +18,7 @@ from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer
from ..models.ltx2_dit import LTXModel
from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier
from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier
from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier, AudioProcessor
from ..models.ltx2_upsampler import LTX2LatentUpsampler
from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS
from ..utils.data.media_io_ltx2 import ltx2_preprocess
@@ -50,6 +50,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1)
self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1)
self.audio_processor: AudioProcessor = AudioProcessor()
self.in_iteration_models = ("dit",)
self.units = [
@@ -57,6 +58,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
LTX2AudioVideoUnit_ShapeChecker(),
LTX2AudioVideoUnit_PromptEmbedder(),
LTX2AudioVideoUnit_NoiseInitializer(),
LTX2AudioVideoUnit_InputAudioEmbedder(),
LTX2AudioVideoUnit_InputVideoEmbedder(),
LTX2AudioVideoUnit_InputImagesEmbedder(),
]
@@ -95,7 +97,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
stage2_lora_config.download_if_necessary()
pipe.stage2_lora_path = stage2_lora_config.path
# Optional, currently not used
# pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
# VRAM Management
pipe.vram_management_enabled = pipe.check_vram_management_state()
@@ -157,7 +159,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
num_frames=121,
# Classifier-free guidance
cfg_scale: Optional[float] = 3.0,
cfg_merge: Optional[bool] = False,
# Scheduler
num_inference_steps: Optional[int] = 40,
# VAE tiling
@@ -186,7 +187,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
"input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength,
"seed": seed, "rand_device": rand_device,
"height": height, "width": width, "num_frames": num_frames,
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
"cfg_scale": cfg_scale,
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline,
@@ -422,7 +423,7 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels)
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=128)
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)
@@ -455,17 +456,48 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_video", "video_noise", "audio_noise", "tiled", "tile_size", "tile_stride"),
output_params=("video_latents", "audio_latents"),
input_params=("input_video", "video_noise", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels"),
output_params=("video_latents", "input_latents"),
onload_model_names=("video_vae_encoder")
)
def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, audio_noise, tiled, tile_size, tile_stride):
def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
if input_video is None:
return {"video_latents": video_noise, "audio_latents": audio_noise}
return {"video_latents": video_noise}
else:
# TODO: implement video-to-video
raise NotImplementedError("Video-to-video not implemented yet.")
pipe.load_models_to_device(self.onload_model_names)
input_video = pipe.preprocess_video(input_video)
input_latents = pipe.video_vae_encoder.encode(input_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
if pipe.scheduler.training:
return {"video_latents": input_latents, "input_latents": input_latents}
else:
# TODO: implement video-to-video
raise NotImplementedError("Video-to-video not implemented yet.")
class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
def __init__(self):
super().__init__(
input_params=("input_audio", "audio_noise"),
output_params=("audio_latents", "audio_input_latents", "audio_positions", "audio_latent_shape"),
onload_model_names=("audio_vae_encoder",)
)
def process(self, pipe: LTX2AudioVideoPipeline, input_audio, audio_noise):
if input_audio is None:
return {"audio_latents": audio_noise}
else:
input_audio, sample_rate = input_audio
pipe.load_models_to_device(self.onload_model_names)
input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype)
audio_input_latents = pipe.audio_vae_encoder(input_audio)
audio_noise = torch.randn_like(audio_input_latents)
audio_latent_shape = AudioLatentShape.from_torch_shape(audio_input_latents.shape)
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
if pipe.scheduler.training:
return {"audio_latents": audio_input_latents, "audio_input_latents": audio_input_latents, "audio_positions": audio_positions, "audio_latent_shape": audio_latent_shape}
else:
# TODO: implement video-to-video
raise NotImplementedError("Video-to-video not implemented yet.")
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
def __init__(self):
@@ -530,9 +562,12 @@ def model_fn_ltx2(
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
if denoise_mask_video is not None:
video_timesteps = video_patchifier.patchify(denoise_mask_video) * video_timesteps
_, c_a, _, mel_bins = audio_latents.shape
audio_latents = audio_patchifier.patchify(audio_latents)
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
if audio_latents is not None:
_, c_a, _, mel_bins = audio_latents.shape
audio_latents = audio_patchifier.patchify(audio_latents)
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
else:
audio_timesteps = None
#TODO: support gradient checkpointing in training
vx, ax = dit(
video_latents=video_latents,
@@ -546,5 +581,5 @@ def model_fn_ltx2(
)
# unpatchify
vx = video_patchifier.unpatchify_video(vx, f, h, w)
ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins)
ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins) if ax is not None else None
return vx, ax