mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Ltx2.3 i2v training and sample frames with fixed fps (#1339)
* add 2.3 i2v training scripts * add frame resampling by fixed fps * LoadVideo: add compatibility for not fix_frame_rate * refactor frame resampler * minor fix
This commit is contained in:
35
examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh
Normal file
35
examples/ltx2/model_training/full/LTX-2.3-I2AV-splited.sh
Normal file
@@ -0,0 +1,35 @@
|
||||
# Splited Training
|
||||
accelerate launch examples/ltx2/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset/ltx2 \
|
||||
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
|
||||
--data_file_keys "video,input_audio" \
|
||||
--extra_inputs "input_audio,input_image" \
|
||||
--height 512 \
|
||||
--width 768 \
|
||||
--num_frames 121 \
|
||||
--dataset_repeat 1 \
|
||||
--model_id_with_origin_paths "DiffSynth-Studio/LTX-2.3-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/LTX2.3-I2AV-full-splited-cache" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing \
|
||||
--task "sft:data_process"
|
||||
|
||||
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/ltx2/model_training/train.py \
|
||||
--dataset_base_path ./models/train/LTX2.3-I2AV-full-splited-cache \
|
||||
--data_file_keys "video,input_audio" \
|
||||
--extra_inputs "input_audio,input_image" \
|
||||
--height 512 \
|
||||
--width 768 \
|
||||
--num_frames 121 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "DiffSynth-Studio/LTX-2.3-Repackage:transformer.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/LTX2.3-I2AV-full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing \
|
||||
--task "sft:train"
|
||||
39
examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh
Normal file
39
examples/ltx2/model_training/lora/LTX-2.3-I2AV-splited.sh
Normal file
@@ -0,0 +1,39 @@
|
||||
# Splited Training
|
||||
accelerate launch examples/ltx2/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset/ltx2 \
|
||||
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
|
||||
--data_file_keys "video,input_audio" \
|
||||
--extra_inputs "input_audio,input_image" \
|
||||
--height 512 \
|
||||
--width 768 \
|
||||
--num_frames 121 \
|
||||
--dataset_repeat 1 \
|
||||
--model_id_with_origin_paths "DiffSynth-Studio/LTX-2.3-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2.3-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/LTX2.3-I2AV_lora-splited-cache" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_k,to_q,to_v,to_out.0" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--task "sft:data_process"
|
||||
|
||||
accelerate launch examples/ltx2/model_training/train.py \
|
||||
--dataset_base_path ./models/train/LTX2.3-I2AV_lora-splited-cache \
|
||||
--data_file_keys "video,input_audio" \
|
||||
--extra_inputs "input_audio,input_image" \
|
||||
--height 512 \
|
||||
--width 768 \
|
||||
--num_frames 121 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "DiffSynth-Studio/LTX-2.3-Repackage:transformer.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/LTX2.3-I2AV_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_k,to_q,to_v,to_out.0" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--task "sft:train"
|
||||
@@ -60,7 +60,12 @@ class LTX2TrainingModule(DiffusionTrainingModule):
|
||||
|
||||
def parse_extra_inputs(self, data, extra_inputs, inputs_shared):
|
||||
for extra_input in extra_inputs:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
if extra_input == "input_image":
|
||||
inputs_shared["input_images"] = [data["video"][0]]
|
||||
inputs_shared["input_images_indexes"] = [0]
|
||||
inputs_shared["input_images_strength"] = 1.0
|
||||
else:
|
||||
inputs_shared[extra_input] = data[extra_input]
|
||||
return inputs_shared
|
||||
|
||||
def get_pipeline_inputs(self, data):
|
||||
@@ -123,6 +128,8 @@ if __name__ == "__main__":
|
||||
num_frames=args.num_frames,
|
||||
time_division_factor=8,
|
||||
time_division_remainder=1,
|
||||
frame_rate=args.frame_rate,
|
||||
fix_frame_rate=True,
|
||||
)
|
||||
dataset = UnifiedDataset(
|
||||
base_path=args.dataset_base_path,
|
||||
@@ -131,7 +138,7 @@ if __name__ == "__main__":
|
||||
data_file_keys=args.data_file_keys.split(","),
|
||||
main_data_operator=video_processor,
|
||||
special_operator_map={
|
||||
"input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudioWithTorchaudio(duration=float(args.num_frames) / float(args.frame_rate)),
|
||||
"input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudioWithTorchaudio(num_frames=args.num_frames, time_division_factor=8, time_division_remainder=1, frame_rate=args.frame_rate),
|
||||
"in_context_videos": RouteByType(operator_map=[
|
||||
(str, video_processor),
|
||||
(list, SequencialProcess(video_processor)),
|
||||
|
||||
54
examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py
Normal file
54
examples/ltx2/model_training/validate_full/LTX-2.3-I2AV.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
from diffsynth.utils.data import VideoData
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||
ModelConfig(path="./models/train/LTX2.3-I2AV-full/epoch-4.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
)
|
||||
|
||||
prompt = "A beautiful sunset over the ocean."
|
||||
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
height, width, num_frames = 512, 768, 121
|
||||
image = VideoData("data/example_video_dataset/ltx2/video.mp4", height=height, width=width)[0]
|
||||
# first frame
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=False,
|
||||
input_images=[image],
|
||||
input_images_indexes=[0],
|
||||
input_images_strength=1.0,
|
||||
num_inference_steps=40,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2.3_onestage_i2av_first.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,
|
||||
)
|
||||
56
examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py
Normal file
56
examples/ltx2/model_training/validate_lora/LTX-2.3-I2AV.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
from diffsynth.utils.data import VideoData
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2.3-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "models/train/LTX2.3-I2AV_lora/epoch-4.safetensors")
|
||||
|
||||
prompt = "A beautiful sunset over the ocean."
|
||||
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
height, width, num_frames = 512, 768, 121
|
||||
image = VideoData("data/example_video_dataset/ltx2/video.mp4", height=height, width=width)[0]
|
||||
# first frame
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=False,
|
||||
input_images=[image],
|
||||
input_images_indexes=[0],
|
||||
input_images_strength=1.0,
|
||||
num_inference_steps=40,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2.3_onestage_i2av_first.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,
|
||||
)
|
||||
Reference in New Issue
Block a user