From b272253956d1ff9fea7156b6b38e23dff8ef6fe6 Mon Sep 17 00:00:00 2001 From: Hong Zhang <41229682+mi804@users.noreply.github.com> Date: Mon, 9 Mar 2026 20:32:02 +0800 Subject: [PATCH] Ltx2.3 i2v training and sample frames with fixed fps (#1339) * add 2.3 i2v training scripts * add frame resampling by fixed fps * LoadVideo: add compatibility for not fix_frame_rate * refactor frame resampler * minor fix --- README.md | 4 +- README_zh.md | 4 +- diffsynth/core/data/operators.py | 67 ++++++++++++++----- diffsynth/core/data/unified_dataset.py | 2 + diffsynth/pipelines/ltx2_audio_video.py | 2 +- docs/en/Model_Details/LTX-2.md | 4 +- docs/zh/Model_Details/LTX-2.md | 4 +- .../full/LTX-2.3-I2AV-splited.sh | 35 ++++++++++ .../lora/LTX-2.3-I2AV-splited.sh | 39 +++++++++++ examples/ltx2/model_training/train.py | 11 ++- .../validate_full/LTX-2.3-I2AV.py | 54 +++++++++++++++ .../validate_lora/LTX-2.3-I2AV.py | 56 ++++++++++++++++ 12 files changed, 256 insertions(+), 26 deletions(-) create mode 100644 examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh create mode 100644 examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh create mode 100644 examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py create mode 100644 examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py diff --git a/README.md b/README.md index ed1ba58..e28ea4b 100644 --- a/README.md +++ b/README.md @@ -705,10 +705,10 @@ Example code for LTX-2 is available at: [/examples/ltx2/](/examples/ltx2/) | Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation | |-|-|-|-|-|-|-|-| -|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)| +|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py)| |[Lightricks/LTX-2.3: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2.3: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-DistilledPipeline.py)|-|-|-|-| -|[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|-|-|-|-| +|[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)| |[Lightricks/LTX-2.3: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2.3: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-DistilledPipeline.py)|-|-|-|-| |[Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)| diff --git a/README_zh.md b/README_zh.md index 7ac0919..f1bf7da 100644 --- a/README_zh.md +++ b/README_zh.md @@ -705,10 +705,10 @@ LTX-2 的示例代码位于:[/examples/ltx2/](/examples/ltx2/) |模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-|-| -|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)| +|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py)| |[Lightricks/LTX-2.3: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2.3: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2.3-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-DistilledPipeline.py)|-|-|-|-| -|[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|-|-|-|-| +|[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|[code](/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)| |[Lightricks/LTX-2.3: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2.3: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-DistilledPipeline.py)|-|-|-|-| |[Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)| diff --git a/diffsynth/core/data/operators.py b/diffsynth/core/data/operators.py index 36756d3..5e1cfa0 100644 --- a/diffsynth/core/data/operators.py +++ b/diffsynth/core/data/operators.py @@ -1,6 +1,8 @@ +import math import torch, torchvision, imageio, os import imageio.v3 as iio from PIL import Image +import torchaudio class DataProcessingPipeline: @@ -105,27 +107,59 @@ class ToList(DataProcessingOperator): return [data] -class LoadVideo(DataProcessingOperator): - def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x): +class FrameSamplerByRateMixin: + def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_rate=24, fix_frame_rate=False): self.num_frames = num_frames self.time_division_factor = time_division_factor self.time_division_remainder = time_division_remainder - # frame_processor is build in the video loader for high efficiency. - self.frame_processor = frame_processor - + self.frame_rate = frame_rate + self.fix_frame_rate = fix_frame_rate + + def get_reader(self, data: str): + return imageio.get_reader(data) + + def get_available_num_frames(self, reader): + if not self.fix_frame_rate: + return reader.count_frames() + meta_data = reader.get_meta_data() + total_original_frames = int(reader.count_frames()) + duration = meta_data["duration"] if "duration" in meta_data else total_original_frames / meta_data['fps'] + total_available_frames = math.floor(duration * self.frame_rate) + return int(total_available_frames) + def get_num_frames(self, reader): num_frames = self.num_frames - if int(reader.count_frames()) < num_frames: - num_frames = int(reader.count_frames()) + total_frames = self.get_available_num_frames(reader) + if int(total_frames) < num_frames: + num_frames = total_frames while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder: num_frames -= 1 return num_frames - + + def map_single_frame_id(self, new_sequence_id: int, raw_frame_rate: float, total_raw_frames: int) -> int: + if not self.fix_frame_rate: + return new_sequence_id + target_time_in_seconds = new_sequence_id / self.frame_rate + raw_frame_index_float = target_time_in_seconds * raw_frame_rate + frame_id = int(round(raw_frame_index_float)) + frame_id = min(frame_id, total_raw_frames - 1) + return frame_id + + +class LoadVideo(DataProcessingOperator, FrameSamplerByRateMixin): + def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x, frame_rate=24, fix_frame_rate=False): + FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate) + # frame_processor is build in the video loader for high efficiency. + self.frame_processor = frame_processor + def __call__(self, data: str): - reader = imageio.get_reader(data) + reader = self.get_reader(data) + raw_frame_rate = reader.get_meta_data()['fps'] num_frames = self.get_num_frames(reader) + total_raw_frames = reader.count_frames() frames = [] for frame_id in range(num_frames): + frame_id = self.map_single_frame_id(frame_id, raw_frame_rate, total_raw_frames) frame = reader.get_data(frame_id) frame = Image.fromarray(frame) frame = self.frame_processor(frame) @@ -149,7 +183,7 @@ class LoadGIF(DataProcessingOperator): self.time_division_remainder = time_division_remainder # frame_processor is build in the video loader for high efficiency. self.frame_processor = frame_processor - + def get_num_frames(self, path): num_frames = self.num_frames images = iio.imread(path, mode="RGB") @@ -220,14 +254,17 @@ class LoadAudio(DataProcessingOperator): return input_audio -class LoadAudioWithTorchaudio(DataProcessingOperator): - def __init__(self, duration=5): - self.duration = duration +class LoadAudioWithTorchaudio(DataProcessingOperator, FrameSamplerByRateMixin): + + def __init__(self, num_frames=121, time_division_factor=8, time_division_remainder=1, frame_rate=24, fix_frame_rate=True): + FrameSamplerByRateMixin.__init__(self, num_frames, time_division_factor, time_division_remainder, frame_rate, fix_frame_rate) def __call__(self, data: str): - import torchaudio + reader = self.get_reader(data) + num_frames = self.get_num_frames(reader) + duration = num_frames / self.frame_rate waveform, sample_rate = torchaudio.load(data) - target_samples = int(self.duration * sample_rate) + target_samples = int(duration * sample_rate) current_samples = waveform.shape[-1] if current_samples > target_samples: waveform = waveform[..., :target_samples] diff --git a/diffsynth/core/data/unified_dataset.py b/diffsynth/core/data/unified_dataset.py index 46fecd7..9dd9c51 100644 --- a/diffsynth/core/data/unified_dataset.py +++ b/diffsynth/core/data/unified_dataset.py @@ -42,6 +42,7 @@ class UnifiedDataset(torch.utils.data.Dataset): max_pixels=1920*1080, height=None, width=None, height_division_factor=16, width_division_factor=16, num_frames=81, time_division_factor=4, time_division_remainder=1, + frame_rate=24, fix_frame_rate=False, ): return RouteByType(operator_map=[ (str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[ @@ -53,6 +54,7 @@ class UnifiedDataset(torch.utils.data.Dataset): (("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo( num_frames, time_division_factor, time_division_remainder, frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor), + frame_rate=frame_rate, fix_frame_rate=fix_frame_rate, )), ])), ]) diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index 56bc923..7da54ec 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -436,7 +436,7 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit): frame_conditions.update({"input_latents_video": input_latents_video, "denoise_mask_video": denoise_mask_video}) return frame_conditions - def process(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_indexes, input_images_strength, video_latents, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=False): + def process(self, pipe: LTX2AudioVideoPipeline, input_images, video_latents, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels, input_images_indexes=[0], input_images_strength=1.0, use_two_stage_pipeline=False): if input_images is None or len(input_images) == 0: return {} else: diff --git a/docs/en/Model_Details/LTX-2.md b/docs/en/Model_Details/LTX-2.md index 5c5ef97..db9937a 100644 --- a/docs/en/Model_Details/LTX-2.md +++ b/docs/en/Model_Details/LTX-2.md @@ -111,10 +111,10 @@ write_video_audio_ltx2( ## Model Overview |Model ID|Additional Parameters|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training| |-|-|-|-|-|-|-|-| -|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)| +|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py)| |[Lightricks/LTX-2.3: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2.3: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-DistilledPipeline.py)|-|-|-|-| -|[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|-|-|-|-| +|[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)| |[Lightricks/LTX-2.3: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2.3: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-DistilledPipeline.py)|-|-|-|-| |[Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)| diff --git a/docs/zh/Model_Details/LTX-2.md b/docs/zh/Model_Details/LTX-2.md index fd3de32..d1c5496 100644 --- a/docs/zh/Model_Details/LTX-2.md +++ b/docs/zh/Model_Details/LTX-2.md @@ -111,10 +111,10 @@ write_video_audio_ltx2( ## 模型总览 |模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-|-|-| -|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)| +|[Lightricks/LTX-2.3: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py)| |[Lightricks/LTX-2.3: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2.3: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)|`input_images`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-I2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-I2AV-DistilledPipeline.py)|-|-|-|-| -|[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|-|-|-|-| +|[Lightricks/LTX-2.3: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-OneStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/full/LTX-2.3-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_full/LTX-2.3-T2AV.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-T2AV-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV.py)| |[Lightricks/LTX-2.3: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-TwoStage.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-TwoStage.py)|-|-|-|-| |[Lightricks/LTX-2.3: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2.3)||[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-DistilledPipeline.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-DistilledPipeline.py)|-|-|-|-| |[Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control](https://www.modelscope.cn/models/Lightricks/LTX-2.3-22b-IC-LoRA-Union-Control)|`in_context_videos`,`in_context_downsample_factor`|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_inference_low_vram/LTX-2.3-T2AV-IC-LoRA-Union-Control.py)|-|-|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/lora/LTX-2.3-T2AV-IC-LoRA-splited.sh)|[code](https://github.com/modelscope/DiffSynth-Studio/blob/main/examples/ltx2/model_training/validate_lora/LTX-2.3-T2AV-IC-LoRA.py)| diff --git a/examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh b/examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh new file mode 100644 index 0000000..7c83a01 --- /dev/null +++ b/examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh @@ -0,0 +1,35 @@ +# Splited Training +accelerate launch examples/ltx2/model_training/train.py \ + --dataset_base_path data/example_video_dataset/ltx2 \ + --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ + --data_file_keys "video,input_audio" \ + --extra_inputs "input_audio,input_image" \ + --height 512 \ + --width 768 \ + --num_frames 121 \ + --dataset_repeat 1 \ + --model_id_with_origin_paths "DiffSynth-Studio/LTX-2.3-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LTX2.3-I2AV-full-splited-cache" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --task "sft:data_process" + +accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/ltx2/model_training/train.py \ + --dataset_base_path ./models/train/LTX2.3-I2AV-full-splited-cache \ + --data_file_keys "video,input_audio" \ + --extra_inputs "input_audio,input_image" \ + --height 512 \ + --width 768 \ + --num_frames 121 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "DiffSynth-Studio/LTX-2.3-Repackage:transformer.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LTX2.3-I2AV-full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing \ + --task "sft:train" diff --git a/examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh b/examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh new file mode 100644 index 0000000..939eff8 --- /dev/null +++ b/examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh @@ -0,0 +1,39 @@ +# Splited Training +accelerate launch examples/ltx2/model_training/train.py \ + --dataset_base_path data/example_video_dataset/ltx2 \ + --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \ + --data_file_keys "video,input_audio" \ + --extra_inputs "input_audio,input_image" \ + --height 512 \ + --width 768 \ + --num_frames 121 \ + --dataset_repeat 1 \ + --model_id_with_origin_paths "DiffSynth-Studio/LTX-2.3-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LTX2.3-I2AV_lora-splited-cache" \ + --lora_base_model "dit" \ + --lora_target_modules "to_k,to_q,to_v,to_out.0" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --task "sft:data_process" + +accelerate launch examples/ltx2/model_training/train.py \ + --dataset_base_path ./models/train/LTX2.3-I2AV_lora-splited-cache \ + --data_file_keys "video,input_audio" \ + --extra_inputs "input_audio,input_image" \ + --height 512 \ + --width 768 \ + --num_frames 121 \ + --dataset_repeat 100 \ + --model_id_with_origin_paths "DiffSynth-Studio/LTX-2.3-Repackage:transformer.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/LTX2.3-I2AV_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_k,to_q,to_v,to_out.0" \ + --lora_rank 32 \ + --use_gradient_checkpointing \ + --task "sft:train" diff --git a/examples/ltx2/model_training/train.py b/examples/ltx2/model_training/train.py index 980d903..46d43e8 100644 --- a/examples/ltx2/model_training/train.py +++ b/examples/ltx2/model_training/train.py @@ -60,7 +60,12 @@ class LTX2TrainingModule(DiffusionTrainingModule): def parse_extra_inputs(self, data, extra_inputs, inputs_shared): for extra_input in extra_inputs: - inputs_shared[extra_input] = data[extra_input] + if extra_input == "input_image": + inputs_shared["input_images"] = [data["video"][0]] + inputs_shared["input_images_indexes"] = [0] + inputs_shared["input_images_strength"] = 1.0 + else: + inputs_shared[extra_input] = data[extra_input] return inputs_shared def get_pipeline_inputs(self, data): @@ -123,6 +128,8 @@ if __name__ == "__main__": num_frames=args.num_frames, time_division_factor=8, time_division_remainder=1, + frame_rate=args.frame_rate, + fix_frame_rate=True, ) dataset = UnifiedDataset( base_path=args.dataset_base_path, @@ -131,7 +138,7 @@ if __name__ == "__main__": data_file_keys=args.data_file_keys.split(","), main_data_operator=video_processor, special_operator_map={ - "input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudioWithTorchaudio(duration=float(args.num_frames) / float(args.frame_rate)), + "input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudioWithTorchaudio(num_frames=args.num_frames, time_division_factor=8, time_division_remainder=1, frame_rate=args.frame_rate), "in_context_videos": RouteByType(operator_map=[ (str, video_processor), (list, SequencialProcess(video_processor)), diff --git a/examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py b/examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py new file mode 100644 index 0000000..f375ee6 --- /dev/null +++ b/examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py @@ -0,0 +1,54 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 +from diffsynth.utils.data import VideoData + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(path="./models/train/LTX2.3-I2AV-full/epoch-4.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +) + +prompt = "A beautiful sunset over the ocean." +negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +height, width, num_frames = 512, 768, 121 +image = VideoData("data/example_video_dataset/ltx2/video.mp4", height=height, width=width)[0] +# first frame +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=False, + input_images=[image], + input_images_indexes=[0], + input_images_strength=1.0, + num_inference_steps=40, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2.3_onestage_i2av_first.mp4', + fps=24, + audio_sample_rate=pipe.audio_vocoder.output_sampling_rate, +) diff --git a/examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py b/examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py new file mode 100644 index 0000000..dc40930 --- /dev/null +++ b/examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py @@ -0,0 +1,56 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 +from diffsynth.utils.data import VideoData + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), +) +pipe.load_lora(pipe.dit, "models/train/LTX2.3-I2AV_lora/epoch-4.safetensors") + +prompt = "A beautiful sunset over the ocean." +negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +height, width, num_frames = 512, 768, 121 +image = VideoData("data/example_video_dataset/ltx2/video.mp4", height=height, width=width)[0] +# first frame +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + tiled=False, + input_images=[image], + input_images_indexes=[0], + input_images_strength=1.0, + num_inference_steps=40, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2.3_onestage_i2av_first.mp4', + fps=24, + audio_sample_rate=pipe.audio_vocoder.output_sampling_rate, +)