Ltx2.3 i2v training and sample frames with fixed fps (#1339)

* add 2.3 i2v training scripts

* add frame resampling by fixed fps

* LoadVideo: add compatibility for not fix_frame_rate

* refactor frame resampler

* minor fix
This commit is contained in:
Hong Zhang
2026-03-09 20:32:02 +08:00
committed by GitHub
parent 7bc5611fb8
commit b272253956
12 changed files with 256 additions and 26 deletions

View File

@@ -60,7 +60,12 @@ class LTX2TrainingModule(DiffusionTrainingModule):
def parse_extra_inputs(self, data, extra_inputs, inputs_shared):
for extra_input in extra_inputs:
inputs_shared[extra_input] = data[extra_input]
if extra_input == "input_image":
inputs_shared["input_images"] = [data["video"][0]]
inputs_shared["input_images_indexes"] = [0]
inputs_shared["input_images_strength"] = 1.0
else:
inputs_shared[extra_input] = data[extra_input]
return inputs_shared
def get_pipeline_inputs(self, data):
@@ -123,6 +128,8 @@ if __name__ == "__main__":
num_frames=args.num_frames,
time_division_factor=8,
time_division_remainder=1,
frame_rate=args.frame_rate,
fix_frame_rate=True,
)
dataset = UnifiedDataset(
base_path=args.dataset_base_path,
@@ -131,7 +138,7 @@ if __name__ == "__main__":
data_file_keys=args.data_file_keys.split(","),
main_data_operator=video_processor,
special_operator_map={
"input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudioWithTorchaudio(duration=float(args.num_frames) / float(args.frame_rate)),
"input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudioWithTorchaudio(num_frames=args.num_frames, time_division_factor=8, time_division_remainder=1, frame_rate=args.frame_rate),
"in_context_videos": RouteByType(operator_map=[
(str, video_processor),
(list, SequencialProcess(video_processor)),