mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
support ltx-2 training
This commit is contained in:
@@ -607,6 +607,12 @@ ltx2_series = [
|
|||||||
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"model_hash": "c567aaa37d5ed7454c73aa6024458661",
|
||||||
|
"model_name": "ltx2_dit",
|
||||||
|
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
@@ -614,6 +620,12 @@ ltx2_series = [
|
|||||||
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"model_hash": "7f7e904a53260ec0351b05f32153754b",
|
||||||
|
"model_name": "ltx2_video_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
@@ -621,6 +633,12 @@ ltx2_series = [
|
|||||||
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"model_hash": "dc6029ca2825147872b45e35a2dc3a97",
|
||||||
|
"model_name": "ltx2_video_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
@@ -628,6 +646,12 @@ ltx2_series = [
|
|||||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"model_hash": "7d7823dde8f1ea0b50fb07ac329dd4cb",
|
||||||
|
"model_name": "ltx2_audio_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
@@ -635,16 +659,34 @@ ltx2_series = [
|
|||||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
||||||
},
|
},
|
||||||
# { # not used currently
|
{
|
||||||
# # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
"model_hash": "f471360f6b24bef702ab73133d9f8bb9",
|
||||||
# "model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
"model_name": "ltx2_audio_vocoder",
|
||||||
# "model_name": "ltx2_audio_vae_encoder",
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
|
||||||
# "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
||||||
# "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
},
|
||||||
# },
|
|
||||||
{
|
{
|
||||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_audio_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_hash": "29338f3b95e7e312a3460a482e4f4554",
|
||||||
|
"model_name": "ltx2_audio_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_text_encoder_post_modules",
|
||||||
|
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_hash": "981629689c8be92a712ab3c5eb4fc3f6",
|
||||||
"model_name": "ltx2_text_encoder_post_modules",
|
"model_name": "ltx2_text_encoder_post_modules",
|
||||||
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
||||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
||||||
|
|||||||
@@ -28,6 +28,36 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def FlowMatchSFTAudioVideoLoss(pipe: BasePipeline, **inputs):
|
||||||
|
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
||||||
|
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
||||||
|
|
||||||
|
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
||||||
|
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
|
||||||
|
# video
|
||||||
|
noise = torch.randn_like(inputs["input_latents"])
|
||||||
|
inputs["video_latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||||
|
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||||
|
|
||||||
|
# audio
|
||||||
|
if inputs.get("audio_input_latents") is not None:
|
||||||
|
audio_noise = torch.randn_like(inputs["audio_input_latents"])
|
||||||
|
inputs["audio_latents"] = pipe.scheduler.add_noise(inputs["audio_input_latents"], audio_noise, timestep)
|
||||||
|
training_target_audio = pipe.scheduler.training_target(inputs["audio_input_latents"], audio_noise, timestep)
|
||||||
|
|
||||||
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
|
noise_pred, noise_pred_audio = pipe.model_fn(**models, **inputs, timestep=timestep)
|
||||||
|
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
|
loss = loss * pipe.scheduler.training_weight(timestep)
|
||||||
|
if inputs.get("audio_input_latents") is not None:
|
||||||
|
loss_audio = torch.nn.functional.mse_loss(noise_pred_audio.float(), training_target_audio.float())
|
||||||
|
loss_audio = loss_audio * pipe.scheduler.training_weight(timestep)
|
||||||
|
loss = loss + loss_audio
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
def DirectDistillLoss(pipe: BasePipeline, **inputs):
|
def DirectDistillLoss(pipe: BasePipeline, **inputs):
|
||||||
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
|
pipe.scheduler.set_timesteps(inputs["num_inference_steps"])
|
||||||
pipe.scheduler.training = True
|
pipe.scheduler.training = True
|
||||||
|
|||||||
@@ -5,8 +5,65 @@ import einops
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import torchaudio
|
||||||
from .ltx2_common import VideoLatentShape, AudioLatentShape, Patchifier, NormType, build_normalization_layer
|
from .ltx2_common import VideoLatentShape, AudioLatentShape, Patchifier, NormType, build_normalization_layer
|
||||||
|
|
||||||
|
|
||||||
|
class AudioProcessor(nn.Module):
|
||||||
|
"""Converts audio waveforms to log-mel spectrograms with optional resampling."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
mel_bins: int = 64,
|
||||||
|
mel_hop_length: int = 160,
|
||||||
|
n_fft: int = 1024,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.mel_transform = torchaudio.transforms.MelSpectrogram(
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
n_fft=n_fft,
|
||||||
|
win_length=n_fft,
|
||||||
|
hop_length=mel_hop_length,
|
||||||
|
f_min=0.0,
|
||||||
|
f_max=sample_rate / 2.0,
|
||||||
|
n_mels=mel_bins,
|
||||||
|
window_fn=torch.hann_window,
|
||||||
|
center=True,
|
||||||
|
pad_mode="reflect",
|
||||||
|
power=1.0,
|
||||||
|
mel_scale="slaney",
|
||||||
|
norm="slaney",
|
||||||
|
)
|
||||||
|
|
||||||
|
def resample_waveform(
|
||||||
|
self,
|
||||||
|
waveform: torch.Tensor,
|
||||||
|
source_rate: int,
|
||||||
|
target_rate: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Resample waveform to target sample rate if needed."""
|
||||||
|
if source_rate == target_rate:
|
||||||
|
return waveform
|
||||||
|
resampled = torchaudio.functional.resample(waveform, source_rate, target_rate)
|
||||||
|
return resampled.to(device=waveform.device, dtype=waveform.dtype)
|
||||||
|
|
||||||
|
def waveform_to_mel(
|
||||||
|
self,
|
||||||
|
waveform: torch.Tensor,
|
||||||
|
waveform_sample_rate: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Convert waveform to log-mel spectrogram [batch, channels, time, n_mels]."""
|
||||||
|
waveform = self.resample_waveform(waveform, waveform_sample_rate, self.sample_rate)
|
||||||
|
|
||||||
|
mel = self.mel_transform(waveform)
|
||||||
|
mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||||
|
|
||||||
|
mel = mel.to(device=waveform.device, dtype=waveform.dtype)
|
||||||
|
return mel.permute(0, 1, 3, 2).contiguous()
|
||||||
|
|
||||||
|
|
||||||
class AudioPatchifier(Patchifier):
|
class AudioPatchifier(Patchifier):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1446,6 +1446,6 @@ class LTXModel(torch.nn.Module):
|
|||||||
cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
|
cross_pe_max_pos = max(self.positional_embedding_max_pos[0], self.audio_positional_embedding_max_pos[0])
|
||||||
self._init_preprocessors(cross_pe_max_pos)
|
self._init_preprocessors(cross_pe_max_pos)
|
||||||
video = Modality(video_latents, video_timesteps, video_positions, video_context)
|
video = Modality(video_latents, video_timesteps, video_positions, video_context)
|
||||||
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context)
|
audio = Modality(audio_latents, audio_timesteps, audio_positions, audio_context) if audio_latents is not None else None
|
||||||
vx, ax = self._forward(video=video, audio=audio, perturbations=None)
|
vx, ax = self._forward(video=video, audio=audio, perturbations=None)
|
||||||
return vx, ax
|
return vx, ax
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
|||||||
from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer
|
from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer
|
||||||
from ..models.ltx2_dit import LTXModel
|
from ..models.ltx2_dit import LTXModel
|
||||||
from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier
|
from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier
|
||||||
from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier
|
from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier, AudioProcessor
|
||||||
from ..models.ltx2_upsampler import LTX2LatentUpsampler
|
from ..models.ltx2_upsampler import LTX2LatentUpsampler
|
||||||
from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS
|
from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS
|
||||||
from ..utils.data.media_io_ltx2 import ltx2_preprocess
|
from ..utils.data.media_io_ltx2 import ltx2_preprocess
|
||||||
@@ -50,6 +50,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
|
|
||||||
self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1)
|
self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1)
|
||||||
self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1)
|
self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1)
|
||||||
|
self.audio_processor: AudioProcessor = AudioProcessor()
|
||||||
|
|
||||||
self.in_iteration_models = ("dit",)
|
self.in_iteration_models = ("dit",)
|
||||||
self.units = [
|
self.units = [
|
||||||
@@ -57,6 +58,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
LTX2AudioVideoUnit_ShapeChecker(),
|
LTX2AudioVideoUnit_ShapeChecker(),
|
||||||
LTX2AudioVideoUnit_PromptEmbedder(),
|
LTX2AudioVideoUnit_PromptEmbedder(),
|
||||||
LTX2AudioVideoUnit_NoiseInitializer(),
|
LTX2AudioVideoUnit_NoiseInitializer(),
|
||||||
|
LTX2AudioVideoUnit_InputAudioEmbedder(),
|
||||||
LTX2AudioVideoUnit_InputVideoEmbedder(),
|
LTX2AudioVideoUnit_InputVideoEmbedder(),
|
||||||
LTX2AudioVideoUnit_InputImagesEmbedder(),
|
LTX2AudioVideoUnit_InputImagesEmbedder(),
|
||||||
]
|
]
|
||||||
@@ -95,7 +97,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
stage2_lora_config.download_if_necessary()
|
stage2_lora_config.download_if_necessary()
|
||||||
pipe.stage2_lora_path = stage2_lora_config.path
|
pipe.stage2_lora_path = stage2_lora_config.path
|
||||||
# Optional, currently not used
|
# Optional, currently not used
|
||||||
# pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
|
pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
|
||||||
|
|
||||||
# VRAM Management
|
# VRAM Management
|
||||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
@@ -157,7 +159,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
num_frames=121,
|
num_frames=121,
|
||||||
# Classifier-free guidance
|
# Classifier-free guidance
|
||||||
cfg_scale: Optional[float] = 3.0,
|
cfg_scale: Optional[float] = 3.0,
|
||||||
cfg_merge: Optional[bool] = False,
|
|
||||||
# Scheduler
|
# Scheduler
|
||||||
num_inference_steps: Optional[int] = 40,
|
num_inference_steps: Optional[int] = 40,
|
||||||
# VAE tiling
|
# VAE tiling
|
||||||
@@ -186,7 +187,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
|||||||
"input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength,
|
"input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength,
|
||||||
"seed": seed, "rand_device": rand_device,
|
"seed": seed, "rand_device": rand_device,
|
||||||
"height": height, "width": width, "num_frames": num_frames,
|
"height": height, "width": width, "num_frames": num_frames,
|
||||||
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
|
"cfg_scale": cfg_scale,
|
||||||
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
|
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
|
||||||
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
|
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
|
||||||
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline,
|
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline,
|
||||||
@@ -422,7 +423,7 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
|||||||
|
|
||||||
def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
|
def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
|
||||||
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
|
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
|
||||||
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels)
|
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=128)
|
||||||
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
||||||
|
|
||||||
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)
|
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)
|
||||||
@@ -455,17 +456,48 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
|||||||
class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
|
class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
input_params=("input_video", "video_noise", "audio_noise", "tiled", "tile_size", "tile_stride"),
|
input_params=("input_video", "video_noise", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels"),
|
||||||
output_params=("video_latents", "audio_latents"),
|
output_params=("video_latents", "input_latents"),
|
||||||
onload_model_names=("video_vae_encoder")
|
onload_model_names=("video_vae_encoder")
|
||||||
)
|
)
|
||||||
|
|
||||||
def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, audio_noise, tiled, tile_size, tile_stride):
|
def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
|
||||||
if input_video is None:
|
if input_video is None:
|
||||||
return {"video_latents": video_noise, "audio_latents": audio_noise}
|
return {"video_latents": video_noise}
|
||||||
else:
|
else:
|
||||||
# TODO: implement video-to-video
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
raise NotImplementedError("Video-to-video not implemented yet.")
|
input_video = pipe.preprocess_video(input_video)
|
||||||
|
input_latents = pipe.video_vae_encoder.encode(input_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
if pipe.scheduler.training:
|
||||||
|
return {"video_latents": input_latents, "input_latents": input_latents}
|
||||||
|
else:
|
||||||
|
# TODO: implement video-to-video
|
||||||
|
raise NotImplementedError("Video-to-video not implemented yet.")
|
||||||
|
|
||||||
|
class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_audio", "audio_noise"),
|
||||||
|
output_params=("audio_latents", "audio_input_latents", "audio_positions", "audio_latent_shape"),
|
||||||
|
onload_model_names=("audio_vae_encoder",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: LTX2AudioVideoPipeline, input_audio, audio_noise):
|
||||||
|
if input_audio is None:
|
||||||
|
return {"audio_latents": audio_noise}
|
||||||
|
else:
|
||||||
|
input_audio, sample_rate = input_audio
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
input_audio = pipe.audio_processor.waveform_to_mel(input_audio.unsqueeze(0), waveform_sample_rate=sample_rate).to(dtype=pipe.torch_dtype)
|
||||||
|
audio_input_latents = pipe.audio_vae_encoder(input_audio)
|
||||||
|
audio_noise = torch.randn_like(audio_input_latents)
|
||||||
|
audio_latent_shape = AudioLatentShape.from_torch_shape(audio_input_latents.shape)
|
||||||
|
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
|
||||||
|
if pipe.scheduler.training:
|
||||||
|
return {"audio_latents": audio_input_latents, "audio_input_latents": audio_input_latents, "audio_positions": audio_positions, "audio_latent_shape": audio_latent_shape}
|
||||||
|
else:
|
||||||
|
# TODO: implement video-to-video
|
||||||
|
raise NotImplementedError("Video-to-video not implemented yet.")
|
||||||
|
|
||||||
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -530,9 +562,12 @@ def model_fn_ltx2(
|
|||||||
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
|
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
|
||||||
if denoise_mask_video is not None:
|
if denoise_mask_video is not None:
|
||||||
video_timesteps = video_patchifier.patchify(denoise_mask_video) * video_timesteps
|
video_timesteps = video_patchifier.patchify(denoise_mask_video) * video_timesteps
|
||||||
_, c_a, _, mel_bins = audio_latents.shape
|
if audio_latents is not None:
|
||||||
audio_latents = audio_patchifier.patchify(audio_latents)
|
_, c_a, _, mel_bins = audio_latents.shape
|
||||||
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
|
audio_latents = audio_patchifier.patchify(audio_latents)
|
||||||
|
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
|
||||||
|
else:
|
||||||
|
audio_timesteps = None
|
||||||
#TODO: support gradient checkpointing in training
|
#TODO: support gradient checkpointing in training
|
||||||
vx, ax = dit(
|
vx, ax = dit(
|
||||||
video_latents=video_latents,
|
video_latents=video_latents,
|
||||||
@@ -546,5 +581,5 @@ def model_fn_ltx2(
|
|||||||
)
|
)
|
||||||
# unpatchify
|
# unpatchify
|
||||||
vx = video_patchifier.unpatchify_video(vx, f, h, w)
|
vx = video_patchifier.unpatchify_video(vx, f, h, w)
|
||||||
ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins)
|
ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins) if ax is not None else None
|
||||||
return vx, ax
|
return vx, ax
|
||||||
|
|||||||
@@ -19,7 +19,12 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer_distilled.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -19,7 +19,12 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -19,7 +19,12 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer_distilled.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
)
|
)
|
||||||
@@ -40,3 +44,12 @@ write_video_audio_ltx2(
|
|||||||
fps=24,
|
fps=24,
|
||||||
audio_sample_rate=24000,
|
audio_sample_rate=24000,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
# ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
# ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
|
||||||
|
# ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
# ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
# ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_encoder.safetensors", **vram_config),
|
||||||
|
# ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
@@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -19,7 +19,12 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer_distilled.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -19,7 +19,12 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
|||||||
@@ -19,7 +19,12 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer_distilled.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
|||||||
@@ -17,7 +17,11 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|||||||
device="cuda",
|
device="cuda",
|
||||||
model_configs=[
|
model_configs=[
|
||||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
|||||||
35
examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh
Normal file
35
examples/ltx2/model_training/full/LTX-2-T2AV-splited.sh
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
# Splited Training
|
||||||
|
accelerate launch examples/ltx2/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_video_dataset/ltx2 \
|
||||||
|
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
|
||||||
|
--data_file_keys "video,input_audio" \
|
||||||
|
--extra_inputs "input_audio" \
|
||||||
|
--height 512 \
|
||||||
|
--width 768 \
|
||||||
|
--num_frames 49 \
|
||||||
|
--dataset_repeat 1 \
|
||||||
|
--model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/LTX2-T2AV-full-splited-cache" \
|
||||||
|
--trainable_models "dit" \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--task "sft:data_process"
|
||||||
|
|
||||||
|
accelerate launch examples/ltx2/model_training/train.py \
|
||||||
|
--dataset_base_path ./models/train/LTX2-T2AV-full-splited-cache \
|
||||||
|
--data_file_keys "video,input_audio" \
|
||||||
|
--extra_inputs "input_audio" \
|
||||||
|
--height 512 \
|
||||||
|
--width 768 \
|
||||||
|
--num_frames 49 \
|
||||||
|
--dataset_repeat 100 \
|
||||||
|
--model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/LTX2-T2AV-full" \
|
||||||
|
--trainable_models "dit" \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--task "sft:train"
|
||||||
56
examples/ltx2/model_training/lora/LTX-2-T2AV-noaudio.sh
Normal file
56
examples/ltx2/model_training/lora/LTX-2-T2AV-noaudio.sh
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
# single stage training
|
||||||
|
# accelerate launch examples/ltx2/model_training/train.py \
|
||||||
|
# --dataset_base_path data/example_video_dataset/ltx2 \
|
||||||
|
# --dataset_metadata_path data/example_video_dataset/ltx2_t2v.csv \
|
||||||
|
# --height 256 \
|
||||||
|
# --width 384 \
|
||||||
|
# --num_frames 25\
|
||||||
|
# --dataset_repeat 100 \
|
||||||
|
# --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors,DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \
|
||||||
|
# --learning_rate 1e-4 \
|
||||||
|
# --num_epochs 5 \
|
||||||
|
# --remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
# --output_path "./models/train/LTX2-T2AV-noaudio_lora" \
|
||||||
|
# --lora_base_model "dit" \
|
||||||
|
# --lora_target_modules "to_k,to_q,to_v,to_out.0" \
|
||||||
|
# --lora_rank 32 \
|
||||||
|
# --use_gradient_checkpointing \
|
||||||
|
# --find_unused_parameters
|
||||||
|
|
||||||
|
|
||||||
|
# Splited Training
|
||||||
|
accelerate launch examples/ltx2/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_video_dataset/ltx2 \
|
||||||
|
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
|
||||||
|
--height 256 \
|
||||||
|
--width 384 \
|
||||||
|
--num_frames 49\
|
||||||
|
--dataset_repeat 1 \
|
||||||
|
--model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/LTX2-T2AV-noaudio_lora-splited-cache" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "to_k,to_q,to_v,to_out.0" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--task "sft:data_process"
|
||||||
|
|
||||||
|
|
||||||
|
accelerate launch examples/ltx2/model_training/train.py \
|
||||||
|
--dataset_base_path ./models/train/LTX2-T2AV-noaudio_lora-splited-cache \
|
||||||
|
--height 256 \
|
||||||
|
--width 384 \
|
||||||
|
--num_frames 49\
|
||||||
|
--dataset_repeat 100 \
|
||||||
|
--model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/LTX2-T2AV-noaudio_lora" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "to_k,to_q,to_v,to_out.0" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--task "sft:train"
|
||||||
40
examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh
Normal file
40
examples/ltx2/model_training/lora/LTX-2-T2AV-splited.sh
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# Splited Training
|
||||||
|
accelerate launch examples/ltx2/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_video_dataset/ltx2 \
|
||||||
|
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
|
||||||
|
--data_file_keys "video,input_audio" \
|
||||||
|
--extra_inputs "input_audio" \
|
||||||
|
--height 512 \
|
||||||
|
--width 768 \
|
||||||
|
--num_frames 49 \
|
||||||
|
--dataset_repeat 1 \
|
||||||
|
--model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/LTX2-T2AV_lora-splited-cache" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "to_k,to_q,to_v,to_out.0" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--task "sft:data_process"
|
||||||
|
|
||||||
|
|
||||||
|
accelerate launch examples/ltx2/model_training/train.py \
|
||||||
|
--dataset_base_path ./models/train/LTX2-T2AV_lora-splited-cache \
|
||||||
|
--data_file_keys "video,input_audio" \
|
||||||
|
--extra_inputs "input_audio" \
|
||||||
|
--height 512 \
|
||||||
|
--width 768 \
|
||||||
|
--num_frames 49 \
|
||||||
|
--dataset_repeat 100 \
|
||||||
|
--model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/LTX2-T2AV_lora" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "to_k,to_q,to_v,to_out.0" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--task "sft:train"
|
||||||
19
examples/ltx2/model_training/lora/LTX-2-T2AV.sh
Normal file
19
examples/ltx2/model_training/lora/LTX-2-T2AV.sh
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
accelerate launch examples/ltx2/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_video_dataset/ltx2 \
|
||||||
|
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
|
||||||
|
--data_file_keys "video,input_audio" \
|
||||||
|
--extra_inputs "input_audio" \
|
||||||
|
--height 256 \
|
||||||
|
--width 384 \
|
||||||
|
--num_frames 25\
|
||||||
|
--dataset_repeat 100 \
|
||||||
|
--model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors,DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/LTX2-T2AV_lora" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "to_k,to_q,to_v,to_out.0" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--find_unused_parameters
|
||||||
162
examples/ltx2/model_training/train.py
Normal file
162
examples/ltx2/model_training/train.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
import torch, os, argparse, accelerate, warnings
|
||||||
|
from diffsynth.core import UnifiedDataset
|
||||||
|
from diffsynth.core.data.operators import LoadAudioWithTorchaudio, ToAbsolutePath
|
||||||
|
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.diffusion import *
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2TrainingModule(DiffusionTrainingModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_paths=None, model_id_with_origin_paths=None,
|
||||||
|
tokenizer_path=None,
|
||||||
|
trainable_models=None,
|
||||||
|
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||||
|
preset_lora_path=None, preset_lora_model=None,
|
||||||
|
use_gradient_checkpointing=True,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
extra_inputs=None,
|
||||||
|
fp8_models=None,
|
||||||
|
offload_models=None,
|
||||||
|
device="cpu",
|
||||||
|
task="sft",
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Warning
|
||||||
|
if not use_gradient_checkpointing:
|
||||||
|
warnings.warn("Gradient checkpointing is detected as disabled. To prevent out-of-memory errors, the training framework will forcibly enable gradient checkpointing.")
|
||||||
|
use_gradient_checkpointing = True
|
||||||
|
|
||||||
|
# Load models
|
||||||
|
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||||
|
tokenizer_config = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized") if tokenizer_path is None else ModelConfig(tokenizer_path)
|
||||||
|
self.pipe = LTX2AudioVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
|
||||||
|
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||||
|
# Training mode
|
||||||
|
self.switch_pipe_to_training_mode(
|
||||||
|
self.pipe, trainable_models,
|
||||||
|
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
||||||
|
preset_lora_path, preset_lora_model,
|
||||||
|
task=task,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store other configs
|
||||||
|
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||||
|
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||||
|
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||||
|
self.fp8_models = fp8_models
|
||||||
|
self.task = task
|
||||||
|
self.task_to_loss = {
|
||||||
|
"sft:data_process": lambda pipe, *args: args,
|
||||||
|
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTAudioVideoLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTAudioVideoLoss(pipe, **inputs_shared, **inputs_posi),
|
||||||
|
}
|
||||||
|
|
||||||
|
def parse_extra_inputs(self, data, extra_inputs, inputs_shared):
|
||||||
|
for extra_input in extra_inputs:
|
||||||
|
inputs_shared[extra_input] = data[extra_input]
|
||||||
|
return inputs_shared
|
||||||
|
|
||||||
|
def get_pipeline_inputs(self, data):
|
||||||
|
inputs_posi = {"prompt": data["prompt"]}
|
||||||
|
inputs_nega = {}
|
||||||
|
inputs_shared = {
|
||||||
|
# Assume you are using this pipeline for inference,
|
||||||
|
# please fill in the input parameters.
|
||||||
|
"input_video": data["video"],
|
||||||
|
"height": data["video"][0].size[1],
|
||||||
|
"width": data["video"][0].size[0],
|
||||||
|
"num_frames": len(data["video"]),
|
||||||
|
# Please do not modify the following parameters
|
||||||
|
# unless you clearly know what this will cause.
|
||||||
|
"cfg_scale": 1,
|
||||||
|
"tiled": False,
|
||||||
|
"rand_device": self.pipe.device,
|
||||||
|
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||||
|
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||||
|
"video_patchifier": self.pipe.video_patchifier,
|
||||||
|
"audio_patchifier": self.pipe.audio_patchifier,
|
||||||
|
}
|
||||||
|
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
|
def forward(self, data, inputs=None):
|
||||||
|
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
||||||
|
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||||
|
for unit in self.pipe.units:
|
||||||
|
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
|
||||||
|
loss = self.task_to_loss[self.task](self.pipe, *inputs)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def ltx2_parser():
|
||||||
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
|
parser = add_general_config(parser)
|
||||||
|
parser = add_video_size_config(parser)
|
||||||
|
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
||||||
|
parser.add_argument("--frame_rate", type=float, default=24, help="frame rate of the training videos. If not specified, it will be determined by the dataset.")
|
||||||
|
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = ltx2_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
accelerator = accelerate.Accelerator(
|
||||||
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
|
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
||||||
|
)
|
||||||
|
dataset = UnifiedDataset(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
metadata_path=args.dataset_metadata_path,
|
||||||
|
repeat=args.dataset_repeat,
|
||||||
|
data_file_keys=args.data_file_keys.split(","),
|
||||||
|
main_data_operator=UnifiedDataset.default_video_operator(
|
||||||
|
base_path=args.dataset_base_path,
|
||||||
|
max_pixels=args.max_pixels,
|
||||||
|
height=args.height,
|
||||||
|
width=args.width,
|
||||||
|
height_division_factor=16,
|
||||||
|
width_division_factor=16,
|
||||||
|
num_frames=args.num_frames,
|
||||||
|
time_division_factor=4,
|
||||||
|
time_division_remainder=1,
|
||||||
|
),
|
||||||
|
special_operator_map={
|
||||||
|
"input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudioWithTorchaudio(duration=float(args.num_frames) / float(args.frame_rate)),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
model = LTX2TrainingModule(
|
||||||
|
model_paths=args.model_paths,
|
||||||
|
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||||
|
tokenizer_path=args.tokenizer_path,
|
||||||
|
trainable_models=args.trainable_models,
|
||||||
|
lora_base_model=args.lora_base_model,
|
||||||
|
lora_target_modules=args.lora_target_modules,
|
||||||
|
lora_rank=args.lora_rank,
|
||||||
|
lora_checkpoint=args.lora_checkpoint,
|
||||||
|
preset_lora_path=args.preset_lora_path,
|
||||||
|
preset_lora_model=args.preset_lora_model,
|
||||||
|
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||||
|
extra_inputs=args.extra_inputs,
|
||||||
|
fp8_models=args.fp8_models,
|
||||||
|
offload_models=args.offload_models,
|
||||||
|
task=args.task,
|
||||||
|
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
|
||||||
|
)
|
||||||
|
model_logger = ModelLogger(
|
||||||
|
args.output_path,
|
||||||
|
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||||
|
)
|
||||||
|
launcher_map = {
|
||||||
|
"sft:data_process": launch_data_process_task,
|
||||||
|
"direct_distill:data_process": launch_data_process_task,
|
||||||
|
"sft": launch_training_task,
|
||||||
|
"sft:train": launch_training_task,
|
||||||
|
"direct_distill": launch_training_task,
|
||||||
|
"direct_distill:train": launch_training_task,
|
||||||
|
}
|
||||||
|
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
|
||||||
47
examples/ltx2/model_training/validate_full/LTX-2-T2AV.py
Normal file
47
examples/ltx2/model_training/validate_full/LTX-2-T2AV.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cuda",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(path="./models/train/LTX2-T2AV-full/epoch-4.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
)
|
||||||
|
prompt = "A beautiful sunset over the ocean."
|
||||||
|
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
height, width, num_frames = 512, 768, 121
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=43,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=True,
|
||||||
|
cfg_scale=4.0
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_onestage.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
49
examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py
Normal file
49
examples/ltx2/model_training/validate_lora/LTX-2-T2AV.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cuda",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
)
|
||||||
|
pipe.load_lora(pipe.dit, "models/train/LTX2-T2AV_lora/epoch-4.safetensors")
|
||||||
|
prompt = "A beautiful sunset over the ocean."
|
||||||
|
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
height, width, num_frames = 512, 768, 121
|
||||||
|
height, width, num_frames = 256, 384, 25
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=43,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=True,
|
||||||
|
cfg_scale=4.0
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_onestage.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cuda",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
)
|
||||||
|
pipe.load_lora(pipe.dit, "models/train/LTX2-T2AV-noaudio_lora/epoch-4.safetensors")
|
||||||
|
prompt = "A beautiful sunset over the ocean."
|
||||||
|
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
height, width, num_frames = 512, 768, 121
|
||||||
|
height, width, num_frames = 256, 384, 25
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=43,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=True,
|
||||||
|
cfg_scale=4.0
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_onestage.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
104
examples/ltx2/scripts/split_model_statedicts.py
Normal file
104
examples/ltx2/scripts/split_model_statedicts.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
from safetensors.torch import save_file
|
||||||
|
from diffsynth import hash_state_dict_keys
|
||||||
|
from diffsynth.core import load_state_dict
|
||||||
|
from diffsynth.models.model_loader import ModelPool
|
||||||
|
|
||||||
|
model_pool = ModelPool()
|
||||||
|
state_dict = load_state_dict("models/Lightricks/LTX-2/ltx-2-19b-dev.safetensors")
|
||||||
|
|
||||||
|
dit_state_dict = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name.startswith("model.diffusion_model."):
|
||||||
|
new_name = name.replace("model.diffusion_model.", "")
|
||||||
|
if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."):
|
||||||
|
continue
|
||||||
|
dit_state_dict[name] = state_dict[name]
|
||||||
|
|
||||||
|
print(f"dit_state_dict keys hash: {hash_state_dict_keys(dit_state_dict)}")
|
||||||
|
save_file(dit_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/transformer.safetensors")
|
||||||
|
model_pool.auto_load_model(
|
||||||
|
"models/DiffSynth-Studio/LTX-2-Repackage/transformer.safetensors",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
video_vae_encoder_state_dict = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name.startswith("vae.encoder."):
|
||||||
|
video_vae_encoder_state_dict[name] = state_dict[name]
|
||||||
|
elif name.startswith("vae.per_channel_statistics."):
|
||||||
|
video_vae_encoder_state_dict[name] = state_dict[name]
|
||||||
|
|
||||||
|
save_file(video_vae_encoder_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/video_vae_encoder.safetensors")
|
||||||
|
print(f"video_vae_encoder keys hash: {hash_state_dict_keys(video_vae_encoder_state_dict)}")
|
||||||
|
model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2-Repackage/video_vae_encoder.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
video_vae_decoder_state_dict = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name.startswith("vae.decoder."):
|
||||||
|
video_vae_decoder_state_dict[name] = state_dict[name]
|
||||||
|
elif name.startswith("vae.per_channel_statistics."):
|
||||||
|
video_vae_decoder_state_dict[name] = state_dict[name]
|
||||||
|
save_file(video_vae_decoder_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/video_vae_decoder.safetensors")
|
||||||
|
print(f"video_vae_decoder keys hash: {hash_state_dict_keys(video_vae_decoder_state_dict)}")
|
||||||
|
model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2-Repackage/video_vae_decoder.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
audio_vae_decoder_state_dict = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name.startswith("audio_vae.decoder."):
|
||||||
|
audio_vae_decoder_state_dict[name] = state_dict[name]
|
||||||
|
elif name.startswith("audio_vae.per_channel_statistics."):
|
||||||
|
audio_vae_decoder_state_dict[name] = state_dict[name]
|
||||||
|
save_file(audio_vae_decoder_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/audio_vae_decoder.safetensors")
|
||||||
|
print(f"audio_vae_decoder keys hash: {hash_state_dict_keys(audio_vae_decoder_state_dict)}")
|
||||||
|
model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2-Repackage/audio_vae_decoder.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
audio_vae_encoder_state_dict = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name.startswith("audio_vae.encoder."):
|
||||||
|
audio_vae_encoder_state_dict[name] = state_dict[name]
|
||||||
|
elif name.startswith("audio_vae.per_channel_statistics."):
|
||||||
|
audio_vae_encoder_state_dict[name] = state_dict[name]
|
||||||
|
save_file(audio_vae_encoder_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/audio_vae_encoder.safetensors")
|
||||||
|
print(f"audio_vae_encoder keys hash: {hash_state_dict_keys(audio_vae_encoder_state_dict)}")
|
||||||
|
model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2-Repackage/audio_vae_encoder.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
audio_vocoder_state_dict = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name.startswith("vocoder."):
|
||||||
|
audio_vocoder_state_dict[name] = state_dict[name]
|
||||||
|
save_file(audio_vocoder_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/audio_vocoder.safetensors")
|
||||||
|
print(f"audio_vocoder keys hash: {hash_state_dict_keys(audio_vocoder_state_dict)}")
|
||||||
|
model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2-Repackage/audio_vocoder.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
text_encoder_post_modules_state_dict = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name.startswith("text_embedding_projection."):
|
||||||
|
text_encoder_post_modules_state_dict[name] = state_dict[name]
|
||||||
|
elif name.startswith("model.diffusion_model.video_embeddings_connector."):
|
||||||
|
text_encoder_post_modules_state_dict[name] = state_dict[name]
|
||||||
|
elif name.startswith("model.diffusion_model.audio_embeddings_connector."):
|
||||||
|
text_encoder_post_modules_state_dict[name] = state_dict[name]
|
||||||
|
save_file(text_encoder_post_modules_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/text_encoder_post_modules.safetensors")
|
||||||
|
print(f"text_encoder_post_modules keys hash: {hash_state_dict_keys(text_encoder_post_modules_state_dict)}")
|
||||||
|
model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2-Repackage/text_encoder_post_modules.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
state_dict = load_state_dict("models/Lightricks/LTX-2/ltx-2-19b-distilled.safetensors")
|
||||||
|
dit_state_dict = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name.startswith("model.diffusion_model."):
|
||||||
|
new_name = name.replace("model.diffusion_model.", "")
|
||||||
|
if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."):
|
||||||
|
continue
|
||||||
|
dit_state_dict[name] = state_dict[name]
|
||||||
|
|
||||||
|
print(f"dit_state_dict keys hash: {hash_state_dict_keys(dit_state_dict)}")
|
||||||
|
save_file(dit_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/transformer_distilled.safetensors")
|
||||||
|
model_pool.auto_load_model(
|
||||||
|
"models/DiffSynth-Studio/LTX-2-Repackage/transformer_distilled.safetensors",
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user