support ltx2 train -2

This commit is contained in:
mi804
2026-02-25 18:06:02 +08:00
parent 586ac9d8a6
commit 8e15dcd289
32 changed files with 175 additions and 39 deletions

View File

@@ -17,7 +17,7 @@ accelerate launch examples/ltx2/model_training/train.py \
--use_gradient_checkpointing \
--task "sft:data_process"
accelerate launch examples/ltx2/model_training/train.py \
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/ltx2/model_training/train.py \
--dataset_base_path ./models/train/LTX2-T2AV-full-splited-cache \
--data_file_keys "video,input_audio" \
--extra_inputs "input_audio" \

View File

@@ -22,8 +22,8 @@
accelerate launch examples/ltx2/model_training/train.py \
--dataset_base_path data/example_video_dataset/ltx2 \
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
--height 256 \
--width 384 \
--height 512 \
--width 768 \
--num_frames 49\
--dataset_repeat 1 \
--model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \
@@ -40,8 +40,8 @@ accelerate launch examples/ltx2/model_training/train.py \
accelerate launch examples/ltx2/model_training/train.py \
--dataset_base_path ./models/train/LTX2-T2AV-noaudio_lora-splited-cache \
--height 256 \
--width 384 \
--height 512 \
--width 768 \
--num_frames 49\
--dataset_repeat 100 \
--model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors" \

View File

@@ -1,3 +1,24 @@
# Single Stage Training not recommended for T2AV due to the large memory consumption. Please use the Splited Training instead.
# accelerate launch examples/ltx2/model_training/train.py \
# --dataset_base_path data/example_video_dataset/ltx2 \
# --dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
# --data_file_keys "video,input_audio" \
# --extra_inputs "input_audio" \
# --height 256 \
# --width 384 \
# --num_frames 25\
# --dataset_repeat 100 \
# --model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors,DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \
# --learning_rate 1e-4 \
# --num_epochs 5 \
# --remove_prefix_in_ckpt "pipe.dit." \
# --output_path "./models/train/LTX2-T2AV_lora" \
# --lora_base_model "dit" \
# --lora_target_modules "to_k,to_q,to_v,to_out.0" \
# --lora_rank 32 \
# --use_gradient_checkpointing \
# --find_unused_parameters
# Splited Training
accelerate launch examples/ltx2/model_training/train.py \
--dataset_base_path data/example_video_dataset/ltx2 \
@@ -19,7 +40,6 @@ accelerate launch examples/ltx2/model_training/train.py \
--use_gradient_checkpointing \
--task "sft:data_process"
accelerate launch examples/ltx2/model_training/train.py \
--dataset_base_path ./models/train/LTX2-T2AV_lora-splited-cache \
--data_file_keys "video,input_audio" \

View File

@@ -1,19 +0,0 @@
accelerate launch examples/ltx2/model_training/train.py \
--dataset_base_path data/example_video_dataset/ltx2 \
--dataset_metadata_path data/example_video_dataset/ltx2_t2av.csv \
--data_file_keys "video,input_audio" \
--extra_inputs "input_audio" \
--height 256 \
--width 384 \
--num_frames 25\
--dataset_repeat 100 \
--model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors,DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/LTX2-T2AV_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_k,to_q,to_v,to_out.0" \
--lora_rank 32 \
--use_gradient_checkpointing \
--find_unused_parameters

View File

@@ -0,0 +1,104 @@
from safetensors.torch import save_file
from diffsynth import hash_state_dict_keys
from diffsynth.core import load_state_dict
from diffsynth.models.model_loader import ModelPool
model_pool = ModelPool()
state_dict = load_state_dict("models/Lightricks/LTX-2/ltx-2-19b-dev.safetensors")
dit_state_dict = {}
for name in state_dict:
if name.startswith("model.diffusion_model."):
new_name = name.replace("model.diffusion_model.", "")
if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."):
continue
dit_state_dict[name] = state_dict[name]
print(f"dit_state_dict keys hash: {hash_state_dict_keys(dit_state_dict)}")
save_file(dit_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/transformer.safetensors")
model_pool.auto_load_model(
"models/DiffSynth-Studio/LTX-2-Repackage/transformer.safetensors",
)
video_vae_encoder_state_dict = {}
for name in state_dict:
if name.startswith("vae.encoder."):
video_vae_encoder_state_dict[name] = state_dict[name]
elif name.startswith("vae.per_channel_statistics."):
video_vae_encoder_state_dict[name] = state_dict[name]
save_file(video_vae_encoder_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/video_vae_encoder.safetensors")
print(f"video_vae_encoder keys hash: {hash_state_dict_keys(video_vae_encoder_state_dict)}")
model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2-Repackage/video_vae_encoder.safetensors")
video_vae_decoder_state_dict = {}
for name in state_dict:
if name.startswith("vae.decoder."):
video_vae_decoder_state_dict[name] = state_dict[name]
elif name.startswith("vae.per_channel_statistics."):
video_vae_decoder_state_dict[name] = state_dict[name]
save_file(video_vae_decoder_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/video_vae_decoder.safetensors")
print(f"video_vae_decoder keys hash: {hash_state_dict_keys(video_vae_decoder_state_dict)}")
model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2-Repackage/video_vae_decoder.safetensors")
audio_vae_decoder_state_dict = {}
for name in state_dict:
if name.startswith("audio_vae.decoder."):
audio_vae_decoder_state_dict[name] = state_dict[name]
elif name.startswith("audio_vae.per_channel_statistics."):
audio_vae_decoder_state_dict[name] = state_dict[name]
save_file(audio_vae_decoder_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/audio_vae_decoder.safetensors")
print(f"audio_vae_decoder keys hash: {hash_state_dict_keys(audio_vae_decoder_state_dict)}")
model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2-Repackage/audio_vae_decoder.safetensors")
audio_vae_encoder_state_dict = {}
for name in state_dict:
if name.startswith("audio_vae.encoder."):
audio_vae_encoder_state_dict[name] = state_dict[name]
elif name.startswith("audio_vae.per_channel_statistics."):
audio_vae_encoder_state_dict[name] = state_dict[name]
save_file(audio_vae_encoder_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/audio_vae_encoder.safetensors")
print(f"audio_vae_encoder keys hash: {hash_state_dict_keys(audio_vae_encoder_state_dict)}")
model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2-Repackage/audio_vae_encoder.safetensors")
audio_vocoder_state_dict = {}
for name in state_dict:
if name.startswith("vocoder."):
audio_vocoder_state_dict[name] = state_dict[name]
save_file(audio_vocoder_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/audio_vocoder.safetensors")
print(f"audio_vocoder keys hash: {hash_state_dict_keys(audio_vocoder_state_dict)}")
model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2-Repackage/audio_vocoder.safetensors")
text_encoder_post_modules_state_dict = {}
for name in state_dict:
if name.startswith("text_embedding_projection."):
text_encoder_post_modules_state_dict[name] = state_dict[name]
elif name.startswith("model.diffusion_model.video_embeddings_connector."):
text_encoder_post_modules_state_dict[name] = state_dict[name]
elif name.startswith("model.diffusion_model.audio_embeddings_connector."):
text_encoder_post_modules_state_dict[name] = state_dict[name]
save_file(text_encoder_post_modules_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/text_encoder_post_modules.safetensors")
print(f"text_encoder_post_modules keys hash: {hash_state_dict_keys(text_encoder_post_modules_state_dict)}")
model_pool.auto_load_model("models/DiffSynth-Studio/LTX-2-Repackage/text_encoder_post_modules.safetensors")
state_dict = load_state_dict("models/Lightricks/LTX-2/ltx-2-19b-distilled.safetensors")
dit_state_dict = {}
for name in state_dict:
if name.startswith("model.diffusion_model."):
new_name = name.replace("model.diffusion_model.", "")
if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."):
continue
dit_state_dict[name] = state_dict[name]
print(f"dit_state_dict keys hash: {hash_state_dict_keys(dit_state_dict)}")
save_file(dit_state_dict, "models/DiffSynth-Studio/LTX-2-Repackage/transformer_distilled.safetensors")
model_pool.auto_load_model(
"models/DiffSynth-Studio/LTX-2-Repackage/transformer_distilled.safetensors",
)

View File

@@ -96,7 +96,7 @@ def ltx2_parser():
parser = add_general_config(parser)
parser = add_video_size_config(parser)
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
parser.add_argument("--frame_rate", type=float, default=24, help="frame rate of the training videos. If not specified, it will be determined by the dataset.")
parser.add_argument("--frame_rate", type=float, default=24, help="frame rate of the training videos.")
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
return parser

View File

@@ -27,7 +27,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
)
prompt = "A beautiful sunset over the ocean."
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
height, width, num_frames = 512, 768, 121
height, width, num_frames = 512, 768, 49
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,

View File

@@ -28,8 +28,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
pipe.load_lora(pipe.dit, "models/train/LTX2-T2AV_lora/epoch-4.safetensors")
prompt = "A beautiful sunset over the ocean."
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
height, width, num_frames = 512, 768, 121
height, width, num_frames = 256, 384, 25
height, width, num_frames = 512, 768, 49
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,

View File

@@ -28,8 +28,7 @@ pipe = LTX2AudioVideoPipeline.from_pretrained(
pipe.load_lora(pipe.dit, "models/train/LTX2-T2AV-noaudio_lora/epoch-4.safetensors")
prompt = "A beautiful sunset over the ocean."
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
height, width, num_frames = 512, 768, 121
height, width, num_frames = 256, 384, 25
height, width, num_frames = 512, 768, 49
video, audio = pipe(
prompt=prompt,
negative_prompt=negative_prompt,