mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
ltx iclora train
This commit is contained in:
@@ -94,20 +94,23 @@ class BasePipeline(torch.nn.Module):
|
||||
return self
|
||||
|
||||
|
||||
def check_resize_height_width(self, height, width, num_frames=None):
|
||||
def check_resize_height_width(self, height, width, num_frames=None, verbose=1):
|
||||
# Shape check
|
||||
if height % self.height_division_factor != 0:
|
||||
height = (height + self.height_division_factor - 1) // self.height_division_factor * self.height_division_factor
|
||||
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
|
||||
if verbose > 0:
|
||||
print(f"height % {self.height_division_factor} != 0. We round it up to {height}.")
|
||||
if width % self.width_division_factor != 0:
|
||||
width = (width + self.width_division_factor - 1) // self.width_division_factor * self.width_division_factor
|
||||
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
|
||||
if verbose > 0:
|
||||
print(f"width % {self.width_division_factor} != 0. We round it up to {width}.")
|
||||
if num_frames is None:
|
||||
return height, width
|
||||
else:
|
||||
if num_frames % self.time_division_factor != self.time_division_remainder:
|
||||
num_frames = (num_frames + self.time_division_factor - 1) // self.time_division_factor * self.time_division_factor + self.time_division_remainder
|
||||
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
|
||||
if verbose > 0:
|
||||
print(f"num_frames % {self.time_division_factor} != {self.time_division_remainder}. We round it up to {num_frames}.")
|
||||
return height, width, num_frames
|
||||
|
||||
|
||||
|
||||
@@ -565,7 +565,7 @@ class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit):
|
||||
expected_height = height // in_context_downsample_factor // 2 if use_two_stage_pipeline else height // in_context_downsample_factor
|
||||
expected_width = width // in_context_downsample_factor // 2 if use_two_stage_pipeline else width // in_context_downsample_factor
|
||||
current_h, current_w, current_f = in_context_video[0].size[1], in_context_video[0].size[0], len(in_context_video)
|
||||
h, w, f = pipe.check_resize_height_width(expected_height, expected_width, current_f)
|
||||
h, w, f = pipe.check_resize_height_width(expected_height, expected_width, current_f, verbose=0)
|
||||
if current_h != h or current_w != w:
|
||||
in_context_video = [img.resize((w, h)) for img in in_context_video]
|
||||
if current_f != f:
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
from diffsynth.utils.data import VideoData
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e5m2,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e5m2,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e5m2,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, ModelConfig(model_id="Lightricks/LTX-2-19b-IC-LoRA-Detailer", origin_file_pattern="ltx-2-19b-ic-lora-detailer.safetensors"))
|
||||
dataset_snapshot_download("DiffSynth-Studio/example_video_dataset", allow_file_pattern="ltx2/*", local_dir="data/example_video_dataset")
|
||||
|
||||
prompt = "[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing"
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
ref_scale_factor = 1
|
||||
frame_rate = 24
|
||||
# the frame rate of the video should better be the same with the reference video
|
||||
# the spatial resolution of the first frame should be the resolution of stage 1 video generation divided by ref_scale_factor
|
||||
input_video = VideoData("data/example_video_dataset/ltx2/video1.mp4", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2)
|
||||
input_video = input_video.raw_data()
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
frame_rate=frame_rate,
|
||||
in_context_videos=[input_video],
|
||||
in_context_downsample_factor=ref_scale_factor,
|
||||
tiled=True,
|
||||
use_two_stage_pipeline=True,
|
||||
clear_lora_before_state_two=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_twostage_iclora.mp4',
|
||||
fps=frame_rate,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
from diffsynth.utils.data import VideoData
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e5m2,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e5m2,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e5m2,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, ModelConfig(model_id="Lightricks/LTX-2-19b-IC-LoRA-Union-Control", origin_file_pattern="ltx-2-19b-ic-lora-union-control-ref0.5.safetensors"))
|
||||
dataset_snapshot_download("DiffSynth-Studio/example_video_dataset", allow_file_pattern="ltx2/*", local_dir="data/example_video_dataset")
|
||||
|
||||
prompt = "[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing"
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
ref_scale_factor = 2
|
||||
frame_rate = 24
|
||||
# the frame rate of the video should better be the same with the reference video
|
||||
# the spatial resolution of the first frame should be the resolution of stage 1 video generation divided by ref_scale_factor
|
||||
input_video = VideoData("data/example_video_dataset/ltx2/depth_video.mp4", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2)
|
||||
input_video = input_video.raw_data()
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
frame_rate=frame_rate,
|
||||
in_context_videos=[input_video],
|
||||
in_context_downsample_factor=ref_scale_factor,
|
||||
tiled=True,
|
||||
use_two_stage_pipeline=True,
|
||||
clear_lora_before_state_two=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_twostage_iclora.mp4',
|
||||
fps=frame_rate,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
@@ -0,0 +1,39 @@
|
||||
# Splited Training
|
||||
accelerate launch examples/ltx2/model_training/train.py \
|
||||
--dataset_base_path data/example_video_dataset/ltx2 \
|
||||
--dataset_metadata_path data/example_video_dataset/ltx2_t2av_iclora.json \
|
||||
--data_file_keys "video,input_audio,in_context_videos" \
|
||||
--extra_inputs "input_audio,in_context_videos,in_context_downsample_factor,frame_rate" \
|
||||
--height 512 \
|
||||
--width 768 \
|
||||
--num_frames 81 \
|
||||
--dataset_repeat 1 \
|
||||
--model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:text_encoder_post_modules.safetensors,DiffSynth-Studio/LTX-2-Repackage:video_vae_encoder.safetensors,DiffSynth-Studio/LTX-2-Repackage:audio_vae_encoder.safetensors,google/gemma-3-12b-it-qat-q4_0-unquantized:model-*.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/LTX2-T2AV-IC-LoRA-splited-cache" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_k,to_q,to_v,to_out.0" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--task "sft:data_process"
|
||||
|
||||
accelerate launch examples/ltx2/model_training/train.py \
|
||||
--dataset_base_path ./models/train/LTX2-T2AV-IC-LoRA-splited-cache \
|
||||
--data_file_keys "video,input_audio,in_context_videos" \
|
||||
--extra_inputs "input_audio,in_context_videos,in_context_downsample_factor,frame_rate" \
|
||||
--height 512 \
|
||||
--width 768 \
|
||||
--num_frames 81 \
|
||||
--dataset_repeat 100 \
|
||||
--model_id_with_origin_paths "DiffSynth-Studio/LTX-2-Repackage:transformer.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/LTX2-T2AV-IC-LoRA" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_k,to_q,to_v,to_out.0" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--task "sft:train"
|
||||
@@ -1,7 +1,6 @@
|
||||
import torch, os, argparse, accelerate, warnings
|
||||
from diffsynth.core import UnifiedDataset
|
||||
from diffsynth.core.data.operators import LoadAudioWithTorchaudio, ToAbsolutePath
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from diffsynth.core.data.operators import LoadAudioWithTorchaudio, ToAbsolutePath, RouteByType, SequencialProcess
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.diffusion import *
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@@ -69,6 +68,7 @@ class LTX2TrainingModule(DiffusionTrainingModule):
|
||||
"height": data["video"][0].size[1],
|
||||
"width": data["video"][0].size[0],
|
||||
"num_frames": len(data["video"]),
|
||||
"frame_rate": data.get("frame_rate", 24),
|
||||
# Please do not modify the following parameters
|
||||
# unless you clearly know what this will cause.
|
||||
"cfg_scale": 1,
|
||||
@@ -108,12 +108,7 @@ if __name__ == "__main__":
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
||||
)
|
||||
dataset = UnifiedDataset(
|
||||
base_path=args.dataset_base_path,
|
||||
metadata_path=args.dataset_metadata_path,
|
||||
repeat=args.dataset_repeat,
|
||||
data_file_keys=args.data_file_keys.split(","),
|
||||
main_data_operator=UnifiedDataset.default_video_operator(
|
||||
video_processor = UnifiedDataset.default_video_operator(
|
||||
base_path=args.dataset_base_path,
|
||||
max_pixels=args.max_pixels,
|
||||
height=args.height,
|
||||
@@ -123,9 +118,19 @@ if __name__ == "__main__":
|
||||
num_frames=args.num_frames,
|
||||
time_division_factor=8,
|
||||
time_division_remainder=1,
|
||||
),
|
||||
)
|
||||
dataset = UnifiedDataset(
|
||||
base_path=args.dataset_base_path,
|
||||
metadata_path=args.dataset_metadata_path,
|
||||
repeat=args.dataset_repeat,
|
||||
data_file_keys=args.data_file_keys.split(","),
|
||||
main_data_operator=video_processor,
|
||||
special_operator_map={
|
||||
"input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudioWithTorchaudio(duration=float(args.num_frames) / float(args.frame_rate)),
|
||||
"in_context_videos": RouteByType(operator_map=[
|
||||
(str, video_processor),
|
||||
(list, SequencialProcess(video_processor)),
|
||||
]),
|
||||
}
|
||||
)
|
||||
model = LTX2TrainingModule(
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
from diffsynth.utils.data import VideoData
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config),
|
||||
ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "./models/train/LTX2-T2AV-IC-LoRA/epoch-4.safetensors")
|
||||
prompt = "[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing"
|
||||
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
height, width, num_frames = 512, 768, 81
|
||||
ref_scale_factor = 2
|
||||
frame_rate = 24
|
||||
input_video = VideoData("data/examples/wan/depth_video.mp4", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2)
|
||||
input_video = input_video.raw_data()
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
frame_rate=frame_rate,
|
||||
tiled=True,
|
||||
in_context_videos=[input_video],
|
||||
in_context_downsample_factor=ref_scale_factor,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_onestage_ic.mp4',
|
||||
fps=frame_rate,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
Reference in New Issue
Block a user