From 5996c2b0689cdd7cc634fe6ae04b604ac7468784 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Fri, 27 Feb 2026 16:48:16 +0800 Subject: [PATCH] support inference --- diffsynth/pipelines/ltx2_audio_video.py | 94 +++++++++++++++++-- .../model_inference/LTX-2-I2AV-TwoStage.py | 1 - .../LTX-2-T2AV-IC-LoRA-Detailer.py | 77 +++++++++++++++ .../LTX-2-T2AV-IC-LoRA-Union-Control.py | 77 +++++++++++++++ 4 files changed, 238 insertions(+), 11 deletions(-) create mode 100644 examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py create mode 100644 examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index fc0b969..c662016 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -61,6 +61,7 @@ class LTX2AudioVideoPipeline(BasePipeline): LTX2AudioVideoUnit_InputAudioEmbedder(), LTX2AudioVideoUnit_InputVideoEmbedder(), LTX2AudioVideoUnit_InputImagesEmbedder(), + LTX2AudioVideoUnit_InContextVideoEmbedder(), ] self.model_fn = model_fn_ltx2 @@ -105,6 +106,8 @@ class LTX2AudioVideoPipeline(BasePipeline): def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm): if inputs_shared["use_two_stage_pipeline"]: + if inputs_shared.get("clear_lora_before_state_two", False): + self.clear_lora() latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"]) self.load_models_to_device('upsampler',) latent = self.upsampler(latent) @@ -112,11 +115,17 @@ class LTX2AudioVideoPipeline(BasePipeline): self.scheduler.set_timesteps(special_case="stage2") inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")}) denoise_mask_video = 1.0 + # input image if inputs_shared.get("input_images", None) is not None: latent, denoise_mask_video, initial_latents = self.apply_input_images_to_latents( latent, inputs_shared.pop("input_latents"), inputs_shared["input_images_indexes"], inputs_shared["input_images_strength"], latent.clone()) inputs_shared.update({"input_latents_video": initial_latents, "denoise_mask_video": denoise_mask_video}) + # remove in-context video control in stage 2 + inputs_shared.pop("in_context_video_latents") + inputs_shared.pop("in_context_video_positions") + + # initialize latents for stage 2 inputs_shared["video_latents"] = self.scheduler.sigmas[0] * denoise_mask_video * inputs_shared[ "video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latent inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + ( @@ -145,11 +154,14 @@ class LTX2AudioVideoPipeline(BasePipeline): # Prompt prompt: str, negative_prompt: Optional[str] = "", - # Image-to-video denoising_strength: float = 1.0, + # Image-to-video input_images: Optional[list[Image.Image]] = None, input_images_indexes: Optional[list[int]] = None, input_images_strength: Optional[float] = 1.0, + # In-Context Video Control + in_context_videos: Optional[list[list[Image.Image]]] = None, + in_context_downsample_factor: Optional[int] = 2, # Randomness seed: Optional[int] = None, rand_device: Optional[str] = "cpu", @@ -157,6 +169,7 @@ class LTX2AudioVideoPipeline(BasePipeline): height: Optional[int] = 512, width: Optional[int] = 768, num_frames=121, + frame_rate=24, # Classifier-free guidance cfg_scale: Optional[float] = 3.0, # Scheduler @@ -169,6 +182,7 @@ class LTX2AudioVideoPipeline(BasePipeline): tile_overlap_in_frames: Optional[int] = 24, # Special Pipelines use_two_stage_pipeline: Optional[bool] = False, + clear_lora_before_state_two: Optional[bool] = False, use_distilled_pipeline: Optional[bool] = False, # progress_bar progress_bar_cmd=tqdm, @@ -185,12 +199,13 @@ class LTX2AudioVideoPipeline(BasePipeline): } inputs_shared = { "input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength, + "in_context_videos": in_context_videos, "in_context_downsample_factor": in_context_downsample_factor, "seed": seed, "rand_device": rand_device, - "height": height, "width": width, "num_frames": num_frames, + "height": height, "width": width, "num_frames": num_frames, "frame_rate": frame_rate, "cfg_scale": cfg_scale, "tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels, "tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames, - "use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, + "use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, "clear_lora_before_state_two": clear_lora_before_state_two, "video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier, } for unit in self.units: @@ -417,8 +432,8 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit): class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit): def __init__(self): super().__init__( - input_params=("height", "width", "num_frames", "seed", "rand_device", "use_two_stage_pipeline"), - output_params=("video_noise", "audio_noise",), + input_params=("height", "width", "num_frames", "seed", "rand_device", "frame_rate", "use_two_stage_pipeline"), + output_params=("video_noise", "audio_noise", "video_positions", "audio_positions", "video_latent_shape", "audio_latent_shape") ) def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0): @@ -471,7 +486,6 @@ class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit): if pipe.scheduler.training: return {"video_latents": input_latents, "input_latents": input_latents} else: - # TODO: implement video-to-video raise NotImplementedError("Video-to-video not implemented yet.") class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit): @@ -495,14 +509,13 @@ class LTX2AudioVideoUnit_InputAudioEmbedder(PipelineUnit): if pipe.scheduler.training: return {"audio_latents": audio_input_latents, "audio_input_latents": audio_input_latents, "audio_positions": audio_positions, "audio_latent_shape": audio_latent_shape} else: - # TODO: implement video-to-video - raise NotImplementedError("Video-to-video not implemented yet.") + raise NotImplementedError("Audio-to-video not supported.") class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit): def __init__(self): super().__init__( input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "num_frames", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"), - output_params=("video_latents"), + output_params=("video_latents", "denoise_mask_video", "input_latents_video", "stage2_input_latents"), onload_model_names=("video_vae_encoder") ) @@ -537,6 +550,54 @@ class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit): return output_dicts +class LTX2AudioVideoUnit_InContextVideoEmbedder(PipelineUnit): + def __init__(self): + super().__init__( + input_params=("in_context_videos", "height", "width", "num_frames", "frame_rate", "in_context_downsample_factor", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"), + output_params=("in_context_video_latents", "in_context_video_positions"), + onload_model_names=("video_vae_encoder") + ) + + def check_in_context_video(self, pipe, in_context_video, height, width, num_frames, in_context_downsample_factor, use_two_stage_pipeline=True): + if in_context_video is None or len(in_context_video) == 0: + raise ValueError("In-context video is None or empty.") + in_context_video = in_context_video[:num_frames] + expected_height = height // in_context_downsample_factor // 2 if use_two_stage_pipeline else height // in_context_downsample_factor + expected_width = width // in_context_downsample_factor // 2 if use_two_stage_pipeline else width // in_context_downsample_factor + current_h, current_w, current_f = in_context_video[0].size[1], in_context_video[0].size[0], len(in_context_video) + h, w, f = pipe.check_resize_height_width(expected_height, expected_width, current_f) + if current_h != h or current_w != w: + in_context_video = [img.resize((w, h)) for img in in_context_video] + if current_f != f: + # pad black frames at the end + in_context_video = in_context_video + [Image.new("RGB", (w, h), (0, 0, 0))] * (f - current_f) + return in_context_video + + def process(self, pipe: LTX2AudioVideoPipeline, in_context_videos, height, width, num_frames, frame_rate, in_context_downsample_factor, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=True): + if in_context_videos is None or len(in_context_videos) == 0: + return {} + else: + pipe.load_models_to_device(self.onload_model_names) + latents, positions = [], [] + for in_context_video in in_context_videos: + in_context_video = self.check_in_context_video(pipe, in_context_video, height, width, num_frames, in_context_downsample_factor, use_two_stage_pipeline) + in_context_video = pipe.preprocess_video(in_context_video) + in_context_latents = pipe.video_vae_encoder.encode(in_context_video, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(dtype=pipe.torch_dtype, device=pipe.device) + + latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=VideoLatentShape.from_torch_shape(in_context_latents.shape), device=pipe.device) + video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float() + video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate + video_positions[:, 1, ...] *= in_context_downsample_factor # height axis + video_positions[:, 2, ...] *= in_context_downsample_factor # width axis + video_positions = video_positions.to(pipe.torch_dtype) + + latents.append(in_context_latents) + positions.append(video_positions) + latents = torch.cat(latents, dim=1) + positions = torch.cat(positions, dim=1) + return {"in_context_video_latents": latents, "in_context_video_positions": positions} + + def model_fn_ltx2( dit: LTXModel, video_latents=None, @@ -549,6 +610,8 @@ def model_fn_ltx2( audio_patchifier=None, timestep=None, denoise_mask_video=None, + in_context_video_latents=None, + in_context_video_positions=None, use_gradient_checkpointing=False, use_gradient_checkpointing_offload=False, **kwargs, @@ -558,16 +621,25 @@ def model_fn_ltx2( # patchify b, c_v, f, h, w = video_latents.shape video_latents = video_patchifier.patchify(video_latents) + seq_len_video = video_latents.shape[1] video_timesteps = timestep.repeat(1, video_latents.shape[1], 1) if denoise_mask_video is not None: video_timesteps = video_patchifier.patchify(denoise_mask_video) * video_timesteps + + if in_context_video_latents is not None: + in_context_video_latents = video_patchifier.patchify(in_context_video_latents) + in_context_video_timesteps = timestep.repeat(1, in_context_video_latents.shape[1], 1) * 0. + video_latents = torch.cat([video_latents, in_context_video_latents], dim=1) + video_positions = torch.cat([video_positions, in_context_video_positions], dim=2) + video_timesteps = torch.cat([video_timesteps, in_context_video_timesteps], dim=1) + if audio_latents is not None: _, c_a, _, mel_bins = audio_latents.shape audio_latents = audio_patchifier.patchify(audio_latents) audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1) else: audio_timesteps = None - #TODO: support gradient checkpointing in training + vx, ax = dit( video_latents=video_latents, video_positions=video_positions, @@ -580,6 +652,8 @@ def model_fn_ltx2( use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, ) + + vx = vx[:, :seq_len_video, ...] # unpatchify vx = video_patchifier.unpatchify_video(vx, f, h, w) ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins) if ax is not None else None diff --git a/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py b/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py index bd86b34..0465803 100644 --- a/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py +++ b/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py @@ -46,7 +46,6 @@ negative_prompt = ( "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." ) height, width, num_frames = 512 * 2, 768 * 2, 121 -height, width, num_frames = 512 * 2, 768 * 2, 121 dataset_snapshot_download( dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py b/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py new file mode 100644 index 0000000..687e216 --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Detailer.py @@ -0,0 +1,77 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 +from diffsynth.utils.data import VideoData +from modelscope import dataset_snapshot_download + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) +pipe.load_lora(pipe.dit, ModelConfig(model_id="Lightricks/LTX-2-19b-IC-LoRA-Detailer", origin_file_pattern="ltx-2-19b-ic-lora-detailer.safetensors")) +dataset_snapshot_download("DiffSynth-Studio/example_video_dataset", allow_file_pattern="ltx2/*", local_dir="data/example_video_dataset") + +prompt = "[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing" +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +ref_scale_factor = 1 +frame_rate = 24 +# the frame rate of the video should better be the same with the reference video +# the spatial resolution of the first frame should be the resolution of stage 1 video generation divided by ref_scale_factor +input_video = VideoData("data/example_video_dataset/ltx2/video1.mp4", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2) +input_video = input_video.raw_data() +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + in_context_videos=[input_video], + in_context_downsample_factor=ref_scale_factor, + tiled=True, + use_two_stage_pipeline=True, + clear_lora_before_state_two=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_twostage_iclora.mp4', + fps=frame_rate, + audio_sample_rate=24000, +) diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py b/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py new file mode 100644 index 0000000..3021306 --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-T2AV-IC-LoRA-Union-Control.py @@ -0,0 +1,77 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2 +from diffsynth.utils.data import VideoData +from modelscope import dataset_snapshot_download + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="transformer.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="text_encoder_post_modules.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vae_decoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="audio_vocoder.safetensors", **vram_config), + ModelConfig(model_id="DiffSynth-Studio/LTX-2-Repackage", origin_file_pattern="video_vae_encoder.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) +pipe.load_lora(pipe.dit, ModelConfig(model_id="Lightricks/LTX-2-19b-IC-LoRA-Union-Control", origin_file_pattern="ltx-2-19b-ic-lora-union-control-ref0.5.safetensors")) +dataset_snapshot_download("DiffSynth-Studio/example_video_dataset", allow_file_pattern="ltx2/*", local_dir="data/example_video_dataset") + +prompt = "[VISUAL]:Two cute orange cats, wearing boxing gloves, stand on a boxing ring and fight each other. [SOUNDS]:the sound of two cats boxing" +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +ref_scale_factor = 2 +frame_rate = 24 +# the frame rate of the video should better be the same with the reference video +# the spatial resolution of the first frame should be the resolution of stage 1 video generation divided by ref_scale_factor +input_video = VideoData("data/example_video_dataset/ltx2/depth_video.mp4", height=height // ref_scale_factor // 2, width=width // ref_scale_factor // 2) +input_video = input_video.raw_data() +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=43, + height=height, + width=width, + num_frames=num_frames, + frame_rate=frame_rate, + in_context_videos=[input_video], + in_context_downsample_factor=ref_scale_factor, + tiled=True, + use_two_stage_pipeline=True, + clear_lora_before_state_two=True, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_twostage_iclora.mp4', + fps=frame_rate, + audio_sample_rate=24000, +)