mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 13:58:15 +00:00
* ltx2.3 ic lora inference&train * temp commit * fix first frame train-inference consistency * minor fix
73 lines
3.9 KiB
Python
73 lines
3.9 KiB
Python
import torch
|
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
|
from modelscope import dataset_snapshot_download
|
|
from diffsynth.utils.data import VideoData
|
|
|
|
vram_config = {
|
|
"offload_dtype": torch.float8_e5m2,
|
|
"offload_device": "cpu",
|
|
"onload_dtype": torch.float8_e5m2,
|
|
"onload_device": "cpu",
|
|
"preparing_dtype": torch.float8_e5m2,
|
|
"preparing_device": "cuda",
|
|
"computation_dtype": torch.bfloat16,
|
|
"computation_device": "cuda",
|
|
}
|
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
|
torch_dtype=torch.bfloat16,
|
|
device="cuda",
|
|
model_configs=[
|
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
|
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-dev.safetensors", **vram_config),
|
|
ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
|
],
|
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
|
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2.3", origin_file_pattern="ltx-2.3-22b-distilled-lora-384.safetensors"),
|
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
|
)
|
|
pipe.load_lora(pipe.dit, ModelConfig(model_id="Lightricks/LTX-2.3-22b-IC-LoRA-Motion-Track-Control", origin_file_pattern="ltx-2.3-22b-ic-lora-motion-track-control-ref0.5.safetensors"))
|
|
dataset_snapshot_download("DiffSynth-Studio/example_video_dataset", allow_file_pattern="ltx2/*", local_dir="data/example_video_dataset")
|
|
prompt = "[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing"
|
|
negative_prompt = (
|
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
|
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
|
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
|
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
|
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
|
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
|
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
|
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
|
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
|
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
|
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
|
)
|
|
height, width, num_frames = 512 * 2, 768 * 2, 121
|
|
ref_scale_factor = 2
|
|
frame_rate = 24
|
|
input_image = VideoData("data/example_video_dataset/ltx2/video1.mp4", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2)[0]
|
|
input_video = VideoData("data/example_video_dataset/ltx2/spatial_tracker_v2.mp4", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2).raw_data()
|
|
video, audio = pipe(
|
|
prompt=prompt,
|
|
negative_prompt=negative_prompt,
|
|
seed=43,
|
|
height=height,
|
|
width=width,
|
|
num_frames=num_frames,
|
|
frame_rate=frame_rate,
|
|
in_context_videos=[input_video],
|
|
in_context_downsample_factor=ref_scale_factor,
|
|
input_images=[input_image],
|
|
input_images_indexes=[0],
|
|
tiled=True,
|
|
use_two_stage_pipeline=True,
|
|
clear_lora_before_state_two=True,
|
|
)
|
|
write_video_audio_ltx2(
|
|
video=video,
|
|
audio=audio,
|
|
output_path='ltx2.3_ic_lora.mp4',
|
|
fps=frame_rate,
|
|
audio_sample_rate=pipe.audio_vocoder.output_sampling_rate,
|
|
)
|