diff --git a/diffsynth/configs/vram_management_module_maps.py b/diffsynth/configs/vram_management_module_maps.py index a1813fb..908361e 100644 --- a/diffsynth/configs/vram_management_module_maps.py +++ b/diffsynth/configs/vram_management_module_maps.py @@ -210,4 +210,37 @@ VRAM_MANAGEMENT_MODULE_MAPS = { "torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule", "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", }, + "diffsynth.models.ltx2_dit.LTXModel": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler": { + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder": { + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder": { + "torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder": { + "torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_audio_vae.LTX2Vocoder": { + "torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "diffsynth.models.ltx2_text_encoder.Embeddings1DConnector": "diffsynth.core.vram.layers.AutoWrappedModule", + }, + "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder": { + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", + "transformers.models.gemma3.modeling_gemma3.Gemma3MultiModalProjector": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule", + }, } diff --git a/diffsynth/diffusion/flow_match.py b/diffsynth/diffusion/flow_match.py index aed8a35..3f1ee8c 100644 --- a/diffsynth/diffusion/flow_match.py +++ b/diffsynth/diffusion/flow_match.py @@ -124,25 +124,30 @@ class FlowMatchScheduler(): return sigmas, timesteps @staticmethod - def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, stretch=True, terminal=0.1): - dynamic_shift_len = dynamic_shift_len or 4096 - sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image( - image_seq_len=dynamic_shift_len, - base_seq_len=1024, - max_seq_len=4096, - base_shift=0.95, - max_shift=2.05, - ) + def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None): num_train_timesteps = 1000 - sigma_min = 0.0 - sigma_max = 1.0 - sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength - sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] - sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1)) - # Shift terminal - one_minus_z = 1.0 - sigmas - scale_factor = one_minus_z[-1] / (1 - terminal) - sigmas = 1.0 - (one_minus_z / scale_factor) + if special_case == "stage2": + sigmas = torch.Tensor([0.909375, 0.725, 0.421875]) + elif special_case == "ditilled_stage1": + sigmas = torch.Tensor([0.95, 0.8, 0.5, 0.2]) + else: + dynamic_shift_len = dynamic_shift_len or 4096 + sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image( + image_seq_len=dynamic_shift_len, + base_seq_len=1024, + max_seq_len=4096, + base_shift=0.95, + max_shift=2.05, + ) + sigma_min = 0.0 + sigma_max = 1.0 + sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength + sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1] + sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1)) + # Shift terminal + one_minus_z = 1.0 - sigmas + scale_factor = one_minus_z[-1] / (1 - terminal) + sigmas = 1.0 - (one_minus_z / scale_factor) timesteps = sigmas * num_train_timesteps return sigmas, timesteps diff --git a/diffsynth/models/ltx2_text_encoder.py b/diffsynth/models/ltx2_text_encoder.py index 535d5e5..7fb94a7 100644 --- a/diffsynth/models/ltx2_text_encoder.py +++ b/diffsynth/models/ltx2_text_encoder.py @@ -363,4 +363,4 @@ class LTX2TextEncoderPostModules(torch.nn.Module): super().__init__() self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear() self.embeddings_connector = Embeddings1DConnector() - self.audio_embeddings_connector = Embeddings1DConnector() \ No newline at end of file + self.audio_embeddings_connector = Embeddings1DConnector() diff --git a/diffsynth/pipelines/ltx2_audio_video.py b/diffsynth/pipelines/ltx2_audio_video.py index 1081288..c921eac 100644 --- a/diffsynth/pipelines/ltx2_audio_video.py +++ b/diffsynth/pipelines/ltx2_audio_video.py @@ -60,10 +60,8 @@ class LTX2AudioVideoPipeline(BasePipeline): LTX2AudioVideoUnit_NoiseInitializer(), LTX2AudioVideoUnit_InputVideoEmbedder(), ] - self.post_units = [ - LTX2AudioVideoPostUnit_UnPatchifier(), - ] self.model_fn = model_fn_ltx2 + # self.lora_loader = LTX2LoRALoader @staticmethod def from_pretrained( @@ -71,6 +69,7 @@ class LTX2AudioVideoPipeline(BasePipeline): device: Union[str, torch.device] = get_device_type(), model_configs: list[ModelConfig] = [], tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config: Optional[ModelConfig] = None, vram_limit: float = None, ): # Initialize pipeline @@ -90,14 +89,67 @@ class LTX2AudioVideoPipeline(BasePipeline): pipe.video_vae_decoder = model_pool.fetch_model("ltx2_video_vae_decoder") pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder") pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder") - pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler") - # Optional + + # Stage 2 + if stage2_lora_config is not None: + stage2_lora_config.download_if_necessary() + pipe.stage2_lora_path = stage2_lora_config.path + pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler") + # Optional, currently not used # pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder") # VRAM Management pipe.vram_management_enabled = pipe.check_vram_management_state() return pipe + def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others): + if inputs_shared.get("positive_only_lora", None) is not None: + self.clear_lora(verbose=0) + self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0) + noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others) + if cfg_scale != 1.0: + if inputs_shared.get("positive_only_lora", None) is not None: + self.clear_lora(verbose=0) + noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others) + if isinstance(noise_pred_posi, tuple): + noise_pred = tuple( + n_nega + cfg_scale * (n_posi - n_nega) + for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega) + ) + else: + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + return noise_pred + + def stage2_denoise(self, cfg_scale, inputs_shared, inputs_posi, inputs_nega, use_two_stage_pipeline=True, progress_bar_cmd=tqdm): + if use_two_stage_pipeline: + latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"]) + self.load_models_to_device(self.in_iteration_models + ('upsampler',)) + latent = self.upsampler(latent) + latent = self.video_vae_encoder.per_channel_statistics.normalize(latent) + latent = self.video_patchifier.patchify(latent) + self.scheduler.set_timesteps(special_case="stage2") + inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")}) + inputs_shared["video_latents"] = self.scheduler.sigmas[0] * inputs_shared["video_noise"] + (1 - self.scheduler.sigmas[0]) * latent + inputs_shared["audio_latents"] = self.audio_patchifier.patchify(inputs_shared["audio_latents"]) + inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"] + self.load_models_to_device(self.in_iteration_models) + self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8) + models = {name: getattr(self, name) for name in self.in_iteration_models} + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device) + noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn( + self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, + **models, timestep=timestep, progress_id=progress_id + ) + inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, + noise_pred=noise_pred_video, **inputs_shared) + inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, + noise_pred=noise_pred_audio, **inputs_shared) + inputs_shared["video_latents"] = self.video_patchifier.unpatchify(inputs_shared["video_latents"], inputs_shared["video_latent_shape"]) + inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"]) + return inputs_shared @torch.no_grad() def __call__( @@ -126,14 +178,15 @@ class LTX2AudioVideoPipeline(BasePipeline): tile_overlap_in_pixels: Optional[int] = 128, tile_size_in_frames: Optional[int] = 128, tile_overlap_in_frames: Optional[int] = 24, - # Two-Stage Pipeline - use_two_stage: Optional[bool] = True, + # Special Pipelines + use_two_stage_pipeline: Optional[bool] = False, + use_distilled_pipeline: Optional[bool] = False, # progress_bar progress_bar_cmd=tqdm, ): # Scheduler self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength) - + # self.load_lora(self.dit, self.stage2_lora_path) # Inputs inputs_posi = { "prompt": prompt, @@ -148,14 +201,12 @@ class LTX2AudioVideoPipeline(BasePipeline): "cfg_scale": cfg_scale, "cfg_merge": cfg_merge, "tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels, "tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames, - "use_two_stage": True + "use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline, } for unit in self.units: inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) - # inputs_posi.update(torch.load("/mnt/nas1/zhanghong/project26/extern_codes/LTX-2/text_encodings.pt")) - # inputs_nega.update(torch.load("/mnt/nas1/zhanghong/project26/extern_codes/LTX-2/negative_text_encodings.pt")) - # Denoise + # Denoise Stage 1 self.load_models_to_device(self.in_iteration_models) models = {name: getattr(self, name) for name in self.in_iteration_models} for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): @@ -164,12 +215,16 @@ class LTX2AudioVideoPipeline(BasePipeline): self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **models, timestep=timestep, progress_id=progress_id ) - inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video, **inputs_shared) - inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, noise_pred=noise_pred_audio, **inputs_shared) + inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, + noise_pred=noise_pred_video, **inputs_shared) + inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id, + noise_pred=noise_pred_audio, **inputs_shared) + inputs_shared["video_latents"] = self.video_patchifier.unpatchify(inputs_shared["video_latents"], inputs_shared["video_latent_shape"]) + inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"]) + + # Denoise Stage 2 + inputs_shared = self.stage2_denoise(cfg_scale, inputs_shared, inputs_posi, inputs_nega, use_two_stage_pipeline, progress_bar_cmd) - # post-denoising, pre-decoding processing logic - for unit in self.post_units: - inputs_shared, _, _ = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega) # Decode self.load_models_to_device(['video_vae_decoder']) video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels, @@ -181,39 +236,21 @@ class LTX2AudioVideoPipeline(BasePipeline): return video, decoded_audio - def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others): - if inputs_shared.get("positive_only_lora", None) is not None: - self.clear_lora(verbose=0) - self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0) - noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others) - if cfg_scale != 1.0: - if inputs_shared.get("positive_only_lora", None) is not None: - self.clear_lora(verbose=0) - noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others) - if isinstance(noise_pred_posi, tuple): - noise_pred = tuple( - n_nega + cfg_scale * (n_posi - n_nega) - for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega) - ) - else: - noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) - else: - noise_pred = noise_pred_posi - return noise_pred - - class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit): def __init__(self): super().__init__(take_over=True) def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega): - pass + if inputs_shared.get("use_two_stage_pipeline", False): + if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None): + raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.") + if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None): + raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.") return inputs_shared, inputs_posi, inputs_nega class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit): """ - # TODO: Adjust with two stage pipeline For two-stage pipelines, the resolution must be divisible by 64. For one-stage pipelines, the resolution must be divisible by 32. """ @@ -223,8 +260,14 @@ class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit): output_params=("height", "width", "num_frames"), ) - def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames): + def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False): + if use_two_stage_pipeline: + self.width_division_factor = 64 + self.height_division_factor = 64 height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames) + if use_two_stage_pipeline: + self.width_division_factor = 32 + self.height_division_factor = 32 return {"height": height, "width": width, "num_frames": num_frames} @@ -327,10 +370,12 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit): return pipe.text_encoder_post_modules.feature_extractor_linear( normed_concated_encoded_text_features.to(encoded_text_features_dtype)) - def _preprocess_text(self, - pipe, - text: str, - padding_side: str = "left") -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + def _preprocess_text( + self, + pipe, + text: str, + padding_side: str = "left", + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: """ Encode a given string into feature tensors suitable for downstream tasks. Args: @@ -339,8 +384,8 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit): tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask. """ token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"] - input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.text_encoder.device) - attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.text_encoder.device) + input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.device) + attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.device) outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True) projected = self._run_feature_extractor(pipe, hidden_states=outputs.hidden_states, @@ -362,11 +407,11 @@ class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit): class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit): def __init__(self): super().__init__( - input_params=("height", "width", "num_frames", "seed", "rand_device",), + input_params=("height", "width", "num_frames", "seed", "rand_device", "use_two_stage_pipeline"), output_params=("video_noise", "audio_noise",), ) - def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0): + def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0): video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate) video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels) video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device) @@ -390,6 +435,15 @@ class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit): "audio_latent_shape": audio_latent_shape } + def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0, use_two_stage_pipeline=False): + if use_two_stage_pipeline: + stage1_dict = self.process_stage(pipe, height // 2, width // 2, num_frames, seed, rand_device, frame_rate) + stage2_dict = self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate) + initial_dict = stage1_dict + initial_dict.update({"stage2_" + k: v for k, v in stage2_dict.items()}) + return initial_dict + else: + return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate) class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit): def __init__(self): @@ -407,20 +461,6 @@ class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit): raise NotImplementedError("Video-to-video not implemented yet.") -class LTX2AudioVideoPostUnit_UnPatchifier(PipelineUnit): - - def __init__(self): - super().__init__( - input_params=("video_latent_shape", "audio_latent_shape", "video_latents", "audio_latents"), - output_params=("video_latents", "audio_latents"), - ) - - def process(self, pipe: LTX2AudioVideoPipeline, video_latent_shape, audio_latent_shape, video_latents, audio_latents): - video_latents = pipe.video_patchifier.unpatchify(video_latents, output_shape=video_latent_shape) - audio_latents = pipe.audio_patchifier.unpatchify(audio_latents, output_shape=audio_latent_shape) - return {"video_latents": video_latents, "audio_latents": audio_latents} - - def model_fn_ltx2( dit: LTXModel, video_latents=None, diff --git a/diffsynth/utils/data/media_io.py b/diffsynth/utils/data/media_io.py index a95fdbc..0450186 100644 --- a/diffsynth/utils/data/media_io.py +++ b/diffsynth/utils/data/media_io.py @@ -3,6 +3,7 @@ from fractions import Fraction import torch import av from tqdm import tqdm +from PIL import Image def _resample_audio( diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py b/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py index afbdd4b..a96efbf 100644 --- a/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py +++ b/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py @@ -31,16 +31,16 @@ height, width, num_frames = 512, 768, 121 video, audio = pipe( prompt=prompt, negative_prompt=negative_prompt, - seed=43, + seed=10, height=height, width=width, num_frames=num_frames, - tiled=False, + tiled=True, ) write_video_audio_ltx2( video=video, audio=audio, - output_path='ltx2_onestage.mp4', + output_path='ltx2_onestage_oven.mp4', fps=24, audio_sample_rate=24000, ) diff --git a/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py b/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py new file mode 100644 index 0000000..b966b0a --- /dev/null +++ b/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py @@ -0,0 +1,63 @@ +import torch +from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig +from diffsynth.utils.data.media_io import write_video_audio_ltx2 + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = LTX2AudioVideoPipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config), + ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"), + stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"), +) + +prompt = """ +INT. OVEN – DAY. Static camera from inside the oven, looking outward through the slightly fogged glass door. Warm golden light glows around freshly baked cookies. The baker’s face fills the frame, eyes wide with focus, his breath fogging the glass as he leans in. Subtle reflections move across the glass as steam rises. +Baker (whispering dramatically): “Today… I achieve perfection.” +He leans even closer, nose nearly touching the glass. +""" +negative_prompt = ( + "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, " + "grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, " + "deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, " + "wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of " + "field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent " + "lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny " + "valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, " + "mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, " + "off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward " + "pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, " + "inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts." +) +height, width, num_frames = 512 * 2, 768 * 2, 121 +video, audio = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + seed=0, + height=height, + width=width, + num_frames=num_frames, + tiled=True, + use_two_stage_pipeline=True, + num_inference_steps=40, +) +write_video_audio_ltx2( + video=video, + audio=audio, + output_path='ltx2_twostage_oven.mp4', + fps=24, + audio_sample_rate=24000, +)