mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
89
README.md
89
README.md
@@ -532,6 +532,95 @@ Example code for FLUX.1 is available at: [/examples/flux/](/examples/flux/)
|
||||
|
||||
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||
|
||||
#### LTX-2: [/docs/en/Model_Details/LTX-2.md](/docs/en/Model_Details/LTX-2.md)
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Quick Start</summary>
|
||||
|
||||
Running the following code will quickly load the [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8GB of VRAM.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e5m2,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e5m2,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e5m2,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
use_two_stage_pipeline=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_twostage.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Examples</summary>
|
||||
|
||||
Example code for LTX-2 is available at: [/examples/ltx2/](/examples/ltx2/)
|
||||
|
||||
| Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||
|-|-|-|-|-|-|-|-|
|
||||
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|
||||
|
||||
</details>
|
||||
|
||||
#### Wan: [/docs/en/Model_Details/Wan.md](/docs/en/Model_Details/Wan.md)
|
||||
|
||||
<details>
|
||||
|
||||
89
README_zh.md
89
README_zh.md
@@ -532,6 +532,95 @@ FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/)
|
||||
|
||||
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||
|
||||
#### LTX-2: [/docs/zh/Model_Details/LTX-2.md](/docs/zh/Model_Details/LTX-2.md)
|
||||
|
||||
<details>
|
||||
|
||||
<summary>快速开始</summary>
|
||||
|
||||
运行以下代码可以快速加载 [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e5m2,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e5m2,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e5m2,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
use_two_stage_pipeline=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_twostage.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>示例代码</summary>
|
||||
|
||||
LTX-2 的示例代码位于:[/examples/ltx2/](/examples/ltx2/)
|
||||
|
||||
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|-|
|
||||
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|
||||
|
||||
</details>
|
||||
|
||||
#### Wan: [/docs/zh/Model_Details/Wan.md](/docs/zh/Model_Details/Wan.md)
|
||||
|
||||
<details>
|
||||
|
||||
@@ -599,4 +599,68 @@ z_image_series = [
|
||||
},
|
||||
]
|
||||
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series
|
||||
ltx2_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_dit",
|
||||
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_video_vae_encoder",
|
||||
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_video_vae_decoder",
|
||||
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_audio_vae_decoder",
|
||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_audio_vocoder",
|
||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
||||
},
|
||||
# { # not used currently
|
||||
# # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
# "model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
# "model_name": "ltx2_audio_vae_encoder",
|
||||
# "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
||||
# "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
||||
# },
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_text_encoder_post_modules",
|
||||
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors")
|
||||
"model_hash": "33917f31c4a79196171154cca39f165e",
|
||||
"model_name": "ltx2_text_encoder",
|
||||
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "c79c458c6e99e0e14d47e676761732d2",
|
||||
"model_name": "ltx2_latent_upsampler",
|
||||
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
|
||||
},
|
||||
]
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series
|
||||
|
||||
@@ -210,4 +210,37 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
},
|
||||
"diffsynth.models.ltx2_dit.LTXModel": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler": {
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_video_vae.LTX2VideoEncoder": {
|
||||
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_video_vae.LTX2VideoDecoder": {
|
||||
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder": {
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_audio_vae.LTX2Vocoder": {
|
||||
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.ltx2_text_encoder.Embeddings1DConnector": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoder": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"transformers.models.gemma3.modeling_gemma3.Gemma3MultiModalProjector": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="
|
||||
if k_pattern != required_in_pattern:
|
||||
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
|
||||
if v_pattern != required_in_pattern:
|
||||
v = rearrange(v, f"{q_pattern} -> {required_in_pattern}", **dims)
|
||||
v = rearrange(v, f"{v_pattern} -> {required_in_pattern}", **dims)
|
||||
return q, k, v
|
||||
|
||||
|
||||
|
||||
@@ -318,7 +318,14 @@ class BasePipeline(torch.nn.Module):
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
if isinstance(noise_pred_posi, tuple):
|
||||
# Separately handling different output types of latents, eg. video and audio latents.
|
||||
noise_pred = tuple(
|
||||
n_nega + cfg_scale * (n_posi - n_nega)
|
||||
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
|
||||
)
|
||||
else:
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
return noise_pred
|
||||
|
||||
@@ -4,13 +4,14 @@ from typing_extensions import Literal
|
||||
|
||||
class FlowMatchScheduler():
|
||||
|
||||
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"):
|
||||
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"):
|
||||
self.set_timesteps_fn = {
|
||||
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
|
||||
"Wan": FlowMatchScheduler.set_timesteps_wan,
|
||||
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
|
||||
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
|
||||
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
||||
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
|
||||
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
|
||||
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
||||
self.num_train_timesteps = 1000
|
||||
@@ -144,7 +145,35 @@ class FlowMatchScheduler():
|
||||
timestep_id = torch.argmin((timesteps - timestep).abs())
|
||||
timesteps[timestep_id] = timestep
|
||||
return sigmas, timesteps
|
||||
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
|
||||
num_train_timesteps = 1000
|
||||
if special_case == "stage2":
|
||||
sigmas = torch.Tensor([0.909375, 0.725, 0.421875])
|
||||
elif special_case == "ditilled_stage1":
|
||||
sigmas = torch.Tensor([1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875])
|
||||
else:
|
||||
dynamic_shift_len = dynamic_shift_len or 4096
|
||||
sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image(
|
||||
image_seq_len=dynamic_shift_len,
|
||||
base_seq_len=1024,
|
||||
max_seq_len=4096,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
)
|
||||
sigma_min = 0.0
|
||||
sigma_max = 1.0
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||
sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1))
|
||||
# Shift terminal
|
||||
one_minus_z = 1.0 - sigmas
|
||||
scale_factor = one_minus_z[-1] / (1 - terminal)
|
||||
sigmas = 1.0 - (one_minus_z / scale_factor)
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
def set_training_weight(self):
|
||||
steps = 1000
|
||||
x = self.timesteps
|
||||
|
||||
1351
diffsynth/models/ltx2_audio_vae.py
Normal file
1351
diffsynth/models/ltx2_audio_vae.py
Normal file
File diff suppressed because it is too large
Load Diff
371
diffsynth/models/ltx2_common.py
Normal file
371
diffsynth/models/ltx2_common.py
Normal file
@@ -0,0 +1,371 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import NamedTuple, Protocol, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class VideoPixelShape(NamedTuple):
|
||||
"""
|
||||
Shape of the tensor representing the video pixel array. Assumes BGR channel format.
|
||||
"""
|
||||
|
||||
batch: int
|
||||
frames: int
|
||||
height: int
|
||||
width: int
|
||||
fps: float
|
||||
|
||||
|
||||
class SpatioTemporalScaleFactors(NamedTuple):
|
||||
"""
|
||||
Describes the spatiotemporal downscaling between decoded video space and
|
||||
the corresponding VAE latent grid.
|
||||
"""
|
||||
|
||||
time: int
|
||||
width: int
|
||||
height: int
|
||||
|
||||
@classmethod
|
||||
def default(cls) -> "SpatioTemporalScaleFactors":
|
||||
return cls(time=8, width=32, height=32)
|
||||
|
||||
|
||||
VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
|
||||
|
||||
|
||||
class VideoLatentShape(NamedTuple):
|
||||
"""
|
||||
Shape of the tensor representing video in VAE latent space.
|
||||
The latent representation is a 5D tensor with dimensions ordered as
|
||||
(batch, channels, frames, height, width). Spatial and temporal dimensions
|
||||
are downscaled relative to pixel space according to the VAE's scale factors.
|
||||
"""
|
||||
|
||||
batch: int
|
||||
channels: int
|
||||
frames: int
|
||||
height: int
|
||||
width: int
|
||||
|
||||
def to_torch_shape(self) -> torch.Size:
|
||||
return torch.Size([self.batch, self.channels, self.frames, self.height, self.width])
|
||||
|
||||
@staticmethod
|
||||
def from_torch_shape(shape: torch.Size) -> "VideoLatentShape":
|
||||
return VideoLatentShape(
|
||||
batch=shape[0],
|
||||
channels=shape[1],
|
||||
frames=shape[2],
|
||||
height=shape[3],
|
||||
width=shape[4],
|
||||
)
|
||||
|
||||
def mask_shape(self) -> "VideoLatentShape":
|
||||
return self._replace(channels=1)
|
||||
|
||||
@staticmethod
|
||||
def from_pixel_shape(
|
||||
shape: VideoPixelShape,
|
||||
latent_channels: int = 128,
|
||||
scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS,
|
||||
) -> "VideoLatentShape":
|
||||
frames = (shape.frames - 1) // scale_factors[0] + 1
|
||||
height = shape.height // scale_factors[1]
|
||||
width = shape.width // scale_factors[2]
|
||||
|
||||
return VideoLatentShape(
|
||||
batch=shape.batch,
|
||||
channels=latent_channels,
|
||||
frames=frames,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
|
||||
def upscale(self, scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS) -> "VideoLatentShape":
|
||||
return self._replace(
|
||||
channels=3,
|
||||
frames=(self.frames - 1) * scale_factors.time + 1,
|
||||
height=self.height * scale_factors.height,
|
||||
width=self.width * scale_factors.width,
|
||||
)
|
||||
|
||||
|
||||
class AudioLatentShape(NamedTuple):
|
||||
"""
|
||||
Shape of audio in VAE latent space: (batch, channels, frames, mel_bins).
|
||||
mel_bins is the number of frequency bins from the mel-spectrogram encoding.
|
||||
"""
|
||||
|
||||
batch: int
|
||||
channels: int
|
||||
frames: int
|
||||
mel_bins: int
|
||||
|
||||
def to_torch_shape(self) -> torch.Size:
|
||||
return torch.Size([self.batch, self.channels, self.frames, self.mel_bins])
|
||||
|
||||
def mask_shape(self) -> "AudioLatentShape":
|
||||
return self._replace(channels=1, mel_bins=1)
|
||||
|
||||
@staticmethod
|
||||
def from_torch_shape(shape: torch.Size) -> "AudioLatentShape":
|
||||
return AudioLatentShape(
|
||||
batch=shape[0],
|
||||
channels=shape[1],
|
||||
frames=shape[2],
|
||||
mel_bins=shape[3],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_duration(
|
||||
batch: int,
|
||||
duration: float,
|
||||
channels: int = 8,
|
||||
mel_bins: int = 16,
|
||||
sample_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
audio_latent_downsample_factor: int = 4,
|
||||
) -> "AudioLatentShape":
|
||||
latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor)
|
||||
|
||||
return AudioLatentShape(
|
||||
batch=batch,
|
||||
channels=channels,
|
||||
frames=round(duration * latents_per_second),
|
||||
mel_bins=mel_bins,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_video_pixel_shape(
|
||||
shape: VideoPixelShape,
|
||||
channels: int = 8,
|
||||
mel_bins: int = 16,
|
||||
sample_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
audio_latent_downsample_factor: int = 4,
|
||||
) -> "AudioLatentShape":
|
||||
return AudioLatentShape.from_duration(
|
||||
batch=shape.batch,
|
||||
duration=float(shape.frames) / float(shape.fps),
|
||||
channels=channels,
|
||||
mel_bins=mel_bins,
|
||||
sample_rate=sample_rate,
|
||||
hop_length=hop_length,
|
||||
audio_latent_downsample_factor=audio_latent_downsample_factor,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LatentState:
|
||||
"""
|
||||
State of latents during the diffusion denoising process.
|
||||
Attributes:
|
||||
latent: The current noisy latent tensor being denoised.
|
||||
denoise_mask: Mask encoding the denoising strength for each token (1 = full denoising, 0 = no denoising).
|
||||
positions: Positional indices for each latent element, used for positional embeddings.
|
||||
clean_latent: Initial state of the latent before denoising, may include conditioning latents.
|
||||
"""
|
||||
|
||||
latent: torch.Tensor
|
||||
denoise_mask: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
clean_latent: torch.Tensor
|
||||
|
||||
def clone(self) -> "LatentState":
|
||||
return LatentState(
|
||||
latent=self.latent.clone(),
|
||||
denoise_mask=self.denoise_mask.clone(),
|
||||
positions=self.positions.clone(),
|
||||
clean_latent=self.clean_latent.clone(),
|
||||
)
|
||||
|
||||
|
||||
class NormType(Enum):
|
||||
"""Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
|
||||
|
||||
GROUP = "group"
|
||||
PIXEL = "pixel"
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
"""
|
||||
Per-pixel (per-location) RMS normalization layer.
|
||||
For each element along the chosen dimension, this layer normalizes the tensor
|
||||
by the root-mean-square of its values across that dimension:
|
||||
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
|
||||
"""
|
||||
Args:
|
||||
dim: Dimension along which to compute the RMS (typically channels).
|
||||
eps: Small constant added for numerical stability.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply RMS normalization along the configured dimension.
|
||||
"""
|
||||
# Compute mean of squared values along `dim`, keep dimensions for broadcasting.
|
||||
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
|
||||
# Normalize by the root-mean-square (RMS).
|
||||
rms = torch.sqrt(mean_sq + self.eps)
|
||||
return x / rms
|
||||
|
||||
|
||||
def build_normalization_layer(
|
||||
in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Create a normalization layer based on the normalization type.
|
||||
Args:
|
||||
in_channels: Number of input channels
|
||||
num_groups: Number of groups for group normalization
|
||||
normtype: Type of normalization: "group" or "pixel"
|
||||
Returns:
|
||||
A normalization layer
|
||||
"""
|
||||
if normtype == NormType.GROUP:
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
if normtype == NormType.PIXEL:
|
||||
return PixelNorm(dim=1, eps=1e-6)
|
||||
raise ValueError(f"Invalid normalization type: {normtype}")
|
||||
|
||||
|
||||
def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor:
|
||||
"""Root-mean-square (RMS) normalize `x` over its last dimension.
|
||||
Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized
|
||||
shape and forwards `weight` and `eps`.
|
||||
"""
|
||||
return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Modality:
|
||||
"""
|
||||
Input data for a single modality (video or audio) in the transformer.
|
||||
Bundles the latent tokens, timestep embeddings, positional information,
|
||||
and text conditioning context for processing by the diffusion transformer.
|
||||
"""
|
||||
|
||||
latent: (
|
||||
torch.Tensor
|
||||
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
|
||||
timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
|
||||
positions: (
|
||||
torch.Tensor
|
||||
) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens
|
||||
context: torch.Tensor
|
||||
enabled: bool = True
|
||||
context_mask: torch.Tensor | None = None
|
||||
|
||||
|
||||
def to_denoised(
|
||||
sample: torch.Tensor,
|
||||
velocity: torch.Tensor,
|
||||
sigma: float | torch.Tensor,
|
||||
calc_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert the sample and its denoising velocity to denoised sample.
|
||||
Returns:
|
||||
Denoised sample
|
||||
"""
|
||||
if isinstance(sigma, torch.Tensor):
|
||||
sigma = sigma.to(calc_dtype)
|
||||
return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype)
|
||||
|
||||
|
||||
|
||||
class Patchifier(Protocol):
|
||||
"""
|
||||
Protocol for patchifiers that convert latent tensors into patches and assemble them back.
|
||||
"""
|
||||
|
||||
def patchify(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
"""
|
||||
Convert latent tensors into flattened patch tokens.
|
||||
Args:
|
||||
latents: Latent tensor to patchify.
|
||||
Returns:
|
||||
Flattened patch tokens tensor.
|
||||
"""
|
||||
|
||||
def unpatchify(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
output_shape: AudioLatentShape | VideoLatentShape,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Converts latent tensors between spatio-temporal formats and flattened sequence representations.
|
||||
Args:
|
||||
latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
|
||||
output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
|
||||
VideoLatentShape.
|
||||
Returns:
|
||||
Dense latent tensor restored from the flattened representation.
|
||||
"""
|
||||
|
||||
@property
|
||||
def patch_size(self) -> Tuple[int, int, int]:
|
||||
...
|
||||
"""
|
||||
Returns the patch size as a tuple of (temporal, height, width) dimensions
|
||||
"""
|
||||
|
||||
def get_patch_grid_bounds(
|
||||
self,
|
||||
output_shape: AudioLatentShape | VideoLatentShape,
|
||||
device: torch.device | None = None,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
"""
|
||||
Compute metadata describing where each latent patch resides within the
|
||||
grid specified by `output_shape`.
|
||||
Args:
|
||||
output_shape: Target grid layout for the patches.
|
||||
device: Target device for the returned tensor.
|
||||
Returns:
|
||||
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
|
||||
"""
|
||||
|
||||
|
||||
def get_pixel_coords(
|
||||
latent_coords: torch.Tensor,
|
||||
scale_factors: SpatioTemporalScaleFactors,
|
||||
causal_fix: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
|
||||
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
|
||||
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
|
||||
Args:
|
||||
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
|
||||
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
|
||||
per axis.
|
||||
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
|
||||
that treat frame zero differently still yield non-negative timestamps.
|
||||
"""
|
||||
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
|
||||
broadcast_shape = [1] * latent_coords.ndim
|
||||
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
|
||||
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
|
||||
|
||||
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
|
||||
pixel_coords = latent_coords * scale_tensor
|
||||
|
||||
if causal_fix:
|
||||
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
|
||||
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
|
||||
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
||||
|
||||
return pixel_coords
|
||||
1451
diffsynth/models/ltx2_dit.py
Normal file
1451
diffsynth/models/ltx2_dit.py
Normal file
File diff suppressed because it is too large
Load Diff
366
diffsynth/models/ltx2_text_encoder.py
Normal file
366
diffsynth/models/ltx2_text_encoder.py
Normal file
@@ -0,0 +1,366 @@
|
||||
import torch
|
||||
from transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer
|
||||
from .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention,
|
||||
FeedForward)
|
||||
from .ltx2_common import rms_norm
|
||||
|
||||
|
||||
class LTX2TextEncoder(Gemma3ForConditionalGeneration):
|
||||
def __init__(self):
|
||||
config = Gemma3Config(
|
||||
**{
|
||||
"architectures": ["Gemma3ForConditionalGeneration"],
|
||||
"boi_token_index": 255999,
|
||||
"dtype": "bfloat16",
|
||||
"eoi_token_index": 256000,
|
||||
"eos_token_id": [1, 106],
|
||||
"image_token_index": 262144,
|
||||
"initializer_range": 0.02,
|
||||
"mm_tokens_per_image": 256,
|
||||
"model_type": "gemma3",
|
||||
"text_config": {
|
||||
"_sliding_window_pattern": 6,
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"attn_logit_softcapping": None,
|
||||
"cache_implementation": "hybrid",
|
||||
"dtype": "bfloat16",
|
||||
"final_logit_softcapping": None,
|
||||
"head_dim": 256,
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"hidden_size": 3840,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 15360,
|
||||
"layer_types": [
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention"
|
||||
],
|
||||
"max_position_embeddings": 131072,
|
||||
"model_type": "gemma3_text",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 48,
|
||||
"num_key_value_heads": 8,
|
||||
"query_pre_attn_scalar": 256,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_local_base_freq": 10000,
|
||||
"rope_scaling": {
|
||||
"factor": 8.0,
|
||||
"rope_type": "linear"
|
||||
},
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": 1024,
|
||||
"sliding_window_pattern": 6,
|
||||
"use_bidirectional_attention": False,
|
||||
"use_cache": True,
|
||||
"vocab_size": 262208
|
||||
},
|
||||
"transformers_version": "4.57.3",
|
||||
"vision_config": {
|
||||
"attention_dropout": 0.0,
|
||||
"dtype": "bfloat16",
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": 896,
|
||||
"intermediate_size": 4304,
|
||||
"layer_norm_eps": 1e-06,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 14,
|
||||
"vision_use_head": False
|
||||
}
|
||||
})
|
||||
super().__init__(config)
|
||||
|
||||
|
||||
class LTXVGemmaTokenizer:
|
||||
"""
|
||||
Tokenizer wrapper for Gemma models compatible with LTXV processes.
|
||||
This class wraps HuggingFace's `AutoTokenizer` for use with Gemma text encoders,
|
||||
ensuring correct settings and output formatting for downstream consumption.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer_path: str, max_length: int = 1024):
|
||||
"""
|
||||
Initialize the tokenizer.
|
||||
Args:
|
||||
tokenizer_path (str): Path to the pretrained tokenizer files or model directory.
|
||||
max_length (int, optional): Max sequence length for encoding. Defaults to 256.
|
||||
"""
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_path, local_files_only=True, model_max_length=max_length
|
||||
)
|
||||
# Gemma expects left padding for chat-style prompts; for plain text it doesn't matter much.
|
||||
self.tokenizer.padding_side = "left"
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
self.max_length = max_length
|
||||
|
||||
def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict[str, list[tuple[int, int]]]:
|
||||
"""
|
||||
Tokenize the given text and return token IDs and attention weights.
|
||||
Args:
|
||||
text (str): The input string to tokenize.
|
||||
return_word_ids (bool, optional): If True, includes the token's position (index) in the output tuples.
|
||||
If False (default), omits the indices.
|
||||
Returns:
|
||||
dict[str, list[tuple[int, int]]] OR dict[str, list[tuple[int, int, int]]]:
|
||||
A dictionary with a "gemma" key mapping to:
|
||||
- a list of (token_id, attention_mask) tuples if return_word_ids is False;
|
||||
- a list of (token_id, attention_mask, index) tuples if return_word_ids is True.
|
||||
Example:
|
||||
>>> tokenizer = LTXVGemmaTokenizer("path/to/tokenizer", max_length=8)
|
||||
>>> tokenizer.tokenize_with_weights("hello world")
|
||||
{'gemma': [(1234, 1), (5678, 1), (2, 0), ...]}
|
||||
"""
|
||||
text = text.strip()
|
||||
encoded = self.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
max_length=self.max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = encoded.input_ids
|
||||
attention_mask = encoded.attention_mask
|
||||
tuples = [
|
||||
(token_id, attn, i) for i, (token_id, attn) in enumerate(zip(input_ids[0], attention_mask[0], strict=True))
|
||||
]
|
||||
out = {"gemma": tuples}
|
||||
|
||||
if not return_word_ids:
|
||||
# Return only (token_id, attention_mask) pairs, omitting token position
|
||||
out = {k: [(t, w) for t, w, _ in v] for k, v in out.items()}
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class GemmaFeaturesExtractorProjLinear(torch.nn.Module):
|
||||
"""
|
||||
Feature extractor module for Gemma models.
|
||||
This module applies a single linear projection to the input tensor.
|
||||
It expects a flattened feature tensor of shape (batch_size, 3840*49).
|
||||
The linear layer maps this to a (batch_size, 3840) embedding.
|
||||
Attributes:
|
||||
aggregate_embed (torch.nn.Linear): Linear projection layer.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize the GemmaFeaturesExtractorProjLinear module.
|
||||
The input dimension is expected to be 3840 * 49, and the output is 3840.
|
||||
"""
|
||||
super().__init__()
|
||||
self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the feature extractor.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49).
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of shape (batch_size, 3840).
|
||||
"""
|
||||
return self.aggregate_embed(x)
|
||||
|
||||
|
||||
class _BasicTransformerBlock1D(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
heads: int,
|
||||
dim_head: int,
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
rope_type=rope_type,
|
||||
)
|
||||
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dim_out=dim,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
pe: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
|
||||
# 1. Normalization Before Self-Attention
|
||||
norm_hidden_states = rms_norm(hidden_states)
|
||||
|
||||
norm_hidden_states = norm_hidden_states.squeeze(1)
|
||||
|
||||
# 2. Self-Attention
|
||||
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 3. Normalization before Feed-Forward
|
||||
norm_hidden_states = rms_norm(hidden_states)
|
||||
|
||||
# 4. Feed-forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Embeddings1DConnector(torch.nn.Module):
|
||||
"""
|
||||
Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or
|
||||
other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can
|
||||
substitute padded positions with learnable registers. The module is highly configurable for head size, number of
|
||||
layers, and register usage.
|
||||
Args:
|
||||
attention_head_dim (int): Dimension of each attention head (default=128).
|
||||
num_attention_heads (int): Number of attention heads (default=30).
|
||||
num_layers (int): Number of transformer layers (default=2).
|
||||
positional_embedding_theta (float): Scaling factor for position embedding (default=10000.0).
|
||||
positional_embedding_max_pos (list[int] | None): Max positions for positional embeddings (default=[1]).
|
||||
causal_temporal_positioning (bool): If True, uses causal attention (default=False).
|
||||
num_learnable_registers (int | None): Number of learnable registers to replace padded tokens. If None, disables
|
||||
register replacement. (default=128)
|
||||
rope_type (LTXRopeType): The RoPE variant to use (default=DEFAULT_ROPE_TYPE).
|
||||
double_precision_rope (bool): Use double precision rope calculation (default=False).
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 30,
|
||||
num_layers: int = 2,
|
||||
positional_embedding_theta: float = 10000.0,
|
||||
positional_embedding_max_pos: list[int] | None = [4096],
|
||||
causal_temporal_positioning: bool = False,
|
||||
num_learnable_registers: int | None = 128,
|
||||
rope_type: LTXRopeType = LTXRopeType.SPLIT,
|
||||
double_precision_rope: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
self.positional_embedding_max_pos = (
|
||||
positional_embedding_max_pos if positional_embedding_max_pos is not None else [1]
|
||||
)
|
||||
self.rope_type = rope_type
|
||||
self.double_precision_rope = double_precision_rope
|
||||
self.transformer_1d_blocks = torch.nn.ModuleList(
|
||||
[
|
||||
_BasicTransformerBlock1D(
|
||||
dim=self.inner_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
rope_type=rope_type,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
if self.num_learnable_registers:
|
||||
self.learnable_registers = torch.nn.Parameter(
|
||||
torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0
|
||||
)
|
||||
|
||||
def _replace_padded_with_learnable_registers(
|
||||
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.shape[1] % self.num_learnable_registers == 0, (
|
||||
f"Hidden states sequence length {hidden_states.shape[1]} must be divisible by num_learnable_registers "
|
||||
f"{self.num_learnable_registers}."
|
||||
)
|
||||
|
||||
num_registers_duplications = hidden_states.shape[1] // self.num_learnable_registers
|
||||
learnable_registers = torch.tile(self.learnable_registers, (num_registers_duplications, 1))
|
||||
attention_mask_binary = (attention_mask.squeeze(1).squeeze(1).unsqueeze(-1) >= -9000.0).int()
|
||||
|
||||
non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :]
|
||||
non_zero_nums = non_zero_hidden_states.shape[1]
|
||||
pad_length = hidden_states.shape[1] - non_zero_nums
|
||||
adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
|
||||
flipped_mask = torch.flip(attention_mask_binary, dims=[1])
|
||||
hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers
|
||||
|
||||
attention_mask = torch.full_like(
|
||||
attention_mask,
|
||||
0.0,
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
|
||||
return hidden_states, attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass of Embeddings1DConnector.
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Input tensor of embeddings (shape [batch, seq_len, feature_dim]).
|
||||
attention_mask (torch.Tensor|None): Optional mask for valid tokens (shape compatible with hidden_states).
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: Processed features and the corresponding (possibly modified) mask.
|
||||
"""
|
||||
if self.num_learnable_registers:
|
||||
hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask)
|
||||
|
||||
indices_grid = torch.arange(hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device)
|
||||
indices_grid = indices_grid[None, None, :]
|
||||
freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch
|
||||
freqs_cis = precompute_freqs_cis(
|
||||
indices_grid=indices_grid,
|
||||
dim=self.inner_dim,
|
||||
out_dtype=hidden_states.dtype,
|
||||
theta=self.positional_embedding_theta,
|
||||
max_pos=self.positional_embedding_max_pos,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
rope_type=self.rope_type,
|
||||
freq_grid_generator=freq_grid_generator,
|
||||
)
|
||||
|
||||
for block in self.transformer_1d_blocks:
|
||||
hidden_states = block(hidden_states, attention_mask=attention_mask, pe=freqs_cis)
|
||||
|
||||
hidden_states = rms_norm(hidden_states)
|
||||
|
||||
return hidden_states, attention_mask
|
||||
|
||||
|
||||
class LTX2TextEncoderPostModules(torch.nn.Module):
|
||||
def __init__(self,):
|
||||
super().__init__()
|
||||
self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()
|
||||
self.embeddings_connector = Embeddings1DConnector()
|
||||
self.audio_embeddings_connector = Embeddings1DConnector()
|
||||
313
diffsynth/models/ltx2_upsampler.py
Normal file
313
diffsynth/models/ltx2_upsampler.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
import torch
|
||||
from einops import rearrange
|
||||
import torch.nn.functional as F
|
||||
from .ltx2_video_vae import LTX2VideoEncoder
|
||||
|
||||
class PixelShuffleND(torch.nn.Module):
|
||||
"""
|
||||
N-dimensional pixel shuffle operation for upsampling tensors.
|
||||
Args:
|
||||
dims (int): Number of dimensions to apply pixel shuffle to.
|
||||
- 1: Temporal (e.g., frames)
|
||||
- 2: Spatial (e.g., height and width)
|
||||
- 3: Spatiotemporal (e.g., depth, height, width)
|
||||
upscale_factors (tuple[int, int, int], optional): Upscaling factors for each dimension.
|
||||
For dims=1, only the first value is used.
|
||||
For dims=2, the first two values are used.
|
||||
For dims=3, all three values are used.
|
||||
The input tensor is rearranged so that the channel dimension is split into
|
||||
smaller channels and upscaling factors, and the upscaling factors are moved
|
||||
into the corresponding spatial/temporal dimensions.
|
||||
Note:
|
||||
This operation is equivalent to the patchifier operation in for the models. Consider
|
||||
using this class instead.
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, upscale_factors: tuple[int, int, int] = (2, 2, 2)):
|
||||
super().__init__()
|
||||
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
|
||||
self.dims = dims
|
||||
self.upscale_factors = upscale_factors
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.dims == 3:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.upscale_factors[0],
|
||||
p2=self.upscale_factors[1],
|
||||
p3=self.upscale_factors[2],
|
||||
)
|
||||
elif self.dims == 2:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
||||
p1=self.upscale_factors[0],
|
||||
p2=self.upscale_factors[1],
|
||||
)
|
||||
elif self.dims == 1:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1) f h w -> b c (f p1) h w",
|
||||
p1=self.upscale_factors[0],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dims: {self.dims}")
|
||||
|
||||
|
||||
class ResBlock(torch.nn.Module):
|
||||
"""
|
||||
Residual block with two convolutional layers, group normalization, and SiLU activation.
|
||||
Args:
|
||||
channels (int): Number of input and output channels.
|
||||
mid_channels (Optional[int]): Number of channels in the intermediate convolution layer. Defaults to `channels`
|
||||
if not specified.
|
||||
dims (int): Dimensionality of the convolution (2 for Conv2d, 3 for Conv3d). Defaults to 3.
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
|
||||
super().__init__()
|
||||
if mid_channels is None:
|
||||
mid_channels = channels
|
||||
|
||||
conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||
|
||||
self.conv1 = conv(channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.norm1 = torch.nn.GroupNorm(32, mid_channels)
|
||||
self.conv2 = conv(mid_channels, channels, kernel_size=3, padding=1)
|
||||
self.norm2 = torch.nn.GroupNorm(32, channels)
|
||||
self.activation = torch.nn.SiLU()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
residual = x
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.activation(x)
|
||||
x = self.conv2(x)
|
||||
x = self.norm2(x)
|
||||
x = self.activation(x + residual)
|
||||
return x
|
||||
|
||||
|
||||
class BlurDownsample(torch.nn.Module):
|
||||
"""
|
||||
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
|
||||
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:
|
||||
super().__init__()
|
||||
assert dims in (2, 3)
|
||||
assert isinstance(stride, int)
|
||||
assert stride >= 1
|
||||
assert kernel_size >= 3
|
||||
assert kernel_size % 2 == 1
|
||||
self.dims = dims
|
||||
self.stride = stride
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
# 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from
|
||||
# the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and
|
||||
# provides a smooth approximation of a Gaussian filter (often called a "binomial filter").
|
||||
# The 2D kernel is constructed as the outer product and normalized.
|
||||
k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)])
|
||||
k2d = k[:, None] @ k[None, :]
|
||||
k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size)
|
||||
self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.stride == 1:
|
||||
return x
|
||||
|
||||
if self.dims == 2:
|
||||
return self._apply_2d(x)
|
||||
else:
|
||||
# dims == 3: apply per-frame on H,W
|
||||
b, _, f, _, _ = x.shape
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = self._apply_2d(x)
|
||||
h2, w2 = x.shape[-2:]
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
|
||||
return x
|
||||
|
||||
def _apply_2d(self, x2d: torch.Tensor) -> torch.Tensor:
|
||||
c = x2d.shape[1]
|
||||
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
|
||||
x2d = F.conv2d(x2d, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
|
||||
return x2d
|
||||
|
||||
|
||||
def _rational_for_scale(scale: float) -> Tuple[int, int]:
|
||||
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
|
||||
if float(scale) not in mapping:
|
||||
raise ValueError(f"Unsupported scale {scale}. Choose from {list(mapping.keys())}")
|
||||
return mapping[float(scale)]
|
||||
|
||||
|
||||
class SpatialRationalResampler(torch.nn.Module):
|
||||
"""
|
||||
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
|
||||
downsample by 'den' using fixed blur + stride. Operates on H,W only.
|
||||
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
|
||||
Args:
|
||||
mid_channels (`int`): Number of intermediate channels for the convolution layer
|
||||
scale (`float`): Spatial scaling factor. Supported values are:
|
||||
- 0.75: Downsample by 3/4 (reduce spatial size)
|
||||
- 1.5: Upsample by 3/2 (increase spatial size)
|
||||
- 2.0: Upsample by 2x (double spatial size)
|
||||
- 4.0: Upsample by 4x (quadruple spatial size)
|
||||
Any other value will raise a ValueError.
|
||||
"""
|
||||
|
||||
def __init__(self, mid_channels: int, scale: float):
|
||||
super().__init__()
|
||||
self.scale = float(scale)
|
||||
self.num, self.den = _rational_for_scale(self.scale)
|
||||
self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)
|
||||
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
|
||||
self.blur_down = BlurDownsample(dims=2, stride=self.den)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, _, f, _, _ = x.shape
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = self.conv(x)
|
||||
x = self.pixel_shuffle(x)
|
||||
x = self.blur_down(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
return x
|
||||
|
||||
|
||||
class LTX2LatentUpsampler(torch.nn.Module):
|
||||
"""
|
||||
Model to upsample VAE latents spatially and/or temporally.
|
||||
Args:
|
||||
in_channels (`int`): Number of channels in the input latent
|
||||
mid_channels (`int`): Number of channels in the middle layers
|
||||
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
|
||||
dims (`int`): Number of dimensions for convolutions (2 or 3)
|
||||
spatial_upsample (`bool`): Whether to spatially upsample the latent
|
||||
temporal_upsample (`bool`): Whether to temporally upsample the latent
|
||||
spatial_scale (`float`): Scale factor for spatial upsampling
|
||||
rational_resampler (`bool`): Whether to use a rational resampler for spatial upsampling
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
mid_channels: int = 1024,
|
||||
num_blocks_per_stage: int = 4,
|
||||
dims: int = 3,
|
||||
spatial_upsample: bool = True,
|
||||
temporal_upsample: bool = False,
|
||||
spatial_scale: float = 2.0,
|
||||
rational_resampler: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.num_blocks_per_stage = num_blocks_per_stage
|
||||
self.dims = dims
|
||||
self.spatial_upsample = spatial_upsample
|
||||
self.temporal_upsample = temporal_upsample
|
||||
self.spatial_scale = float(spatial_scale)
|
||||
self.rational_resampler = rational_resampler
|
||||
|
||||
conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||
|
||||
self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
|
||||
self.initial_activation = torch.nn.SiLU()
|
||||
|
||||
self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
|
||||
|
||||
if spatial_upsample and temporal_upsample:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(3),
|
||||
)
|
||||
elif spatial_upsample:
|
||||
if rational_resampler:
|
||||
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale)
|
||||
else:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(2),
|
||||
)
|
||||
elif temporal_upsample:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(1),
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either spatial_upsample or temporal_upsample must be True")
|
||||
|
||||
self.post_upsample_res_blocks = torch.nn.ModuleList(
|
||||
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||
)
|
||||
|
||||
self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, latent: torch.Tensor) -> torch.Tensor:
|
||||
b, _, f, _, _ = latent.shape
|
||||
|
||||
if self.dims == 2:
|
||||
x = rearrange(latent, "b c f h w -> (b f) c h w")
|
||||
x = self.initial_conv(x)
|
||||
x = self.initial_norm(x)
|
||||
x = self.initial_activation(x)
|
||||
|
||||
for block in self.res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.upsampler(x)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
else:
|
||||
x = self.initial_conv(latent)
|
||||
x = self.initial_norm(x)
|
||||
x = self.initial_activation(x)
|
||||
|
||||
for block in self.res_blocks:
|
||||
x = block(x)
|
||||
|
||||
if self.temporal_upsample:
|
||||
x = self.upsampler(x)
|
||||
# remove the first frame after upsampling.
|
||||
# This is done because the first frame encodes one pixel frame.
|
||||
x = x[:, :, 1:, :, :]
|
||||
elif isinstance(self.upsampler, SpatialRationalResampler):
|
||||
x = self.upsampler(x)
|
||||
else:
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = self.upsampler(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def upsample_video(latent: torch.Tensor, video_encoder: LTX2VideoEncoder, upsampler: "LTX2LatentUpsampler") -> torch.Tensor:
|
||||
"""
|
||||
Apply upsampling to the latent representation using the provided upsampler,
|
||||
with normalization and un-normalization based on the video encoder's per-channel statistics.
|
||||
Args:
|
||||
latent: Input latent tensor of shape [B, C, F, H, W].
|
||||
video_encoder: VideoEncoder with per_channel_statistics for normalization.
|
||||
upsampler: LTX2LatentUpsampler module to perform upsampling.
|
||||
Returns:
|
||||
torch.Tensor: Upsampled and re-normalized latent tensor.
|
||||
"""
|
||||
latent = video_encoder.per_channel_statistics.un_normalize(latent)
|
||||
latent = upsampler(latent)
|
||||
latent = video_encoder.per_channel_statistics.normalize(latent)
|
||||
return latent
|
||||
2317
diffsynth/models/ltx2_video_vae.py
Normal file
2317
diffsynth/models/ltx2_video_vae.py
Normal file
File diff suppressed because it is too large
Load Diff
550
diffsynth/pipelines/ltx2_audio_video.py
Normal file
550
diffsynth/pipelines/ltx2_audio_video.py
Normal file
@@ -0,0 +1,550 @@
|
||||
import torch, types
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from einops import repeat
|
||||
from typing import Optional, Union
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
from transformers import AutoImageProcessor, Gemma3Processor
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
|
||||
from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer
|
||||
from ..models.ltx2_dit import LTXModel
|
||||
from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier
|
||||
from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier
|
||||
from ..models.ltx2_upsampler import LTX2LatentUpsampler
|
||||
from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS
|
||||
from ..utils.data.media_io_ltx2 import ltx2_preprocess
|
||||
|
||||
|
||||
class LTX2AudioVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device,
|
||||
torch_dtype=torch_dtype,
|
||||
height_division_factor=32,
|
||||
width_division_factor=32,
|
||||
time_division_factor=8,
|
||||
time_division_remainder=1,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler("LTX-2")
|
||||
self.text_encoder: LTX2TextEncoder = None
|
||||
self.tokenizer: LTXVGemmaTokenizer = None
|
||||
self.processor: Gemma3Processor = None
|
||||
self.text_encoder_post_modules: LTX2TextEncoderPostModules = None
|
||||
self.dit: LTXModel = None
|
||||
self.video_vae_encoder: LTX2VideoEncoder = None
|
||||
self.video_vae_decoder: LTX2VideoDecoder = None
|
||||
self.audio_vae_encoder: LTX2AudioEncoder = None
|
||||
self.audio_vae_decoder: LTX2AudioDecoder = None
|
||||
self.audio_vocoder: LTX2Vocoder = None
|
||||
self.upsampler: LTX2LatentUpsampler = None
|
||||
|
||||
self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1)
|
||||
self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1)
|
||||
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
LTX2AudioVideoUnit_PipelineChecker(),
|
||||
LTX2AudioVideoUnit_ShapeChecker(),
|
||||
LTX2AudioVideoUnit_PromptEmbedder(),
|
||||
LTX2AudioVideoUnit_NoiseInitializer(),
|
||||
LTX2AudioVideoUnit_InputVideoEmbedder(),
|
||||
LTX2AudioVideoUnit_InputImagesEmbedder(),
|
||||
]
|
||||
self.model_fn = model_fn_ltx2
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
stage2_lora_config: Optional[ModelConfig] = None,
|
||||
vram_limit: float = None,
|
||||
):
|
||||
# Initialize pipeline
|
||||
pipe = LTX2AudioVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("ltx2_text_encoder")
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = LTXVGemmaTokenizer(tokenizer_path=tokenizer_config.path)
|
||||
image_processor = AutoImageProcessor.from_pretrained(tokenizer_config.path, local_files_only=True)
|
||||
pipe.processor = Gemma3Processor(image_processor=image_processor, tokenizer=pipe.tokenizer.tokenizer)
|
||||
|
||||
pipe.text_encoder_post_modules = model_pool.fetch_model("ltx2_text_encoder_post_modules")
|
||||
pipe.dit = model_pool.fetch_model("ltx2_dit")
|
||||
pipe.video_vae_encoder = model_pool.fetch_model("ltx2_video_vae_encoder")
|
||||
pipe.video_vae_decoder = model_pool.fetch_model("ltx2_video_vae_decoder")
|
||||
pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder")
|
||||
pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder")
|
||||
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
|
||||
|
||||
# Stage 2
|
||||
if stage2_lora_config is not None:
|
||||
stage2_lora_config.download_if_necessary()
|
||||
pipe.stage2_lora_path = stage2_lora_config.path
|
||||
# Optional, currently not used
|
||||
# pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm):
|
||||
if inputs_shared["use_two_stage_pipeline"]:
|
||||
latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
|
||||
self.load_models_to_device('upsampler',)
|
||||
latent = self.upsampler(latent)
|
||||
latent = self.video_vae_encoder.per_channel_statistics.normalize(latent)
|
||||
self.scheduler.set_timesteps(special_case="stage2")
|
||||
inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")})
|
||||
denoise_mask_video = 1.0
|
||||
if inputs_shared.get("input_images", None) is not None:
|
||||
latent, denoise_mask_video, initial_latents = self.apply_input_images_to_latents(
|
||||
latent, inputs_shared.pop("input_latents"), inputs_shared["input_images_indexes"],
|
||||
inputs_shared["input_images_strength"], latent.clone())
|
||||
inputs_shared.update({"input_latents_video": initial_latents, "denoise_mask_video": denoise_mask_video})
|
||||
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * denoise_mask_video * inputs_shared[
|
||||
"video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latent
|
||||
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (
|
||||
1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"]
|
||||
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
if not inputs_shared["use_distilled_pipeline"]:
|
||||
self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
|
||||
self.model_fn, 1.0, inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id,
|
||||
noise_pred=noise_pred_video, inpaint_mask=inputs_shared.get("denoise_mask_video", None),
|
||||
input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
|
||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
|
||||
noise_pred=noise_pred_audio, **inputs_shared)
|
||||
return inputs_shared
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = "",
|
||||
# Image-to-video
|
||||
denoising_strength: float = 1.0,
|
||||
input_images: Optional[list[Image.Image]] = None,
|
||||
input_images_indexes: Optional[list[int]] = None,
|
||||
input_images_strength: Optional[float] = 1.0,
|
||||
# Randomness
|
||||
seed: Optional[int] = None,
|
||||
rand_device: Optional[str] = "cpu",
|
||||
# Shape
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 768,
|
||||
num_frames=121,
|
||||
# Classifier-free guidance
|
||||
cfg_scale: Optional[float] = 3.0,
|
||||
cfg_merge: Optional[bool] = False,
|
||||
# Scheduler
|
||||
num_inference_steps: Optional[int] = 40,
|
||||
# VAE tiling
|
||||
tiled: Optional[bool] = True,
|
||||
tile_size_in_pixels: Optional[int] = 512,
|
||||
tile_overlap_in_pixels: Optional[int] = 128,
|
||||
tile_size_in_frames: Optional[int] = 128,
|
||||
tile_overlap_in_frames: Optional[int] = 24,
|
||||
# Special Pipelines
|
||||
use_two_stage_pipeline: Optional[bool] = False,
|
||||
use_distilled_pipeline: Optional[bool] = False,
|
||||
# progress_bar
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength,
|
||||
special_case="ditilled_stage1" if use_distilled_pipeline else None)
|
||||
# Inputs
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
}
|
||||
inputs_nega = {
|
||||
"negative_prompt": negative_prompt,
|
||||
}
|
||||
inputs_shared = {
|
||||
"input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"height": height, "width": width, "num_frames": num_frames,
|
||||
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
|
||||
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
|
||||
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
|
||||
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline,
|
||||
"video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# Denoise Stage 1
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video,
|
||||
inpaint_mask=inputs_shared.get("denoise_mask_video", None), input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
|
||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
|
||||
noise_pred=noise_pred_audio, **inputs_shared)
|
||||
|
||||
# Denoise Stage 2
|
||||
inputs_shared = self.stage2_denoise(inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['video_vae_decoder'])
|
||||
video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels,
|
||||
tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames)
|
||||
video = self.vae_output_to_video(video)
|
||||
self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder'])
|
||||
decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"])
|
||||
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
|
||||
return video, decoded_audio
|
||||
|
||||
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength, initial_latents=None, num_frames=121):
|
||||
b, _, f, h, w = latents.shape
|
||||
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device)
|
||||
initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents
|
||||
for idx, input_latent in zip(input_indexes, input_latents):
|
||||
idx = min(max(1 + (idx-1) // 8, 0), f - 1)
|
||||
input_latent = input_latent.to(dtype=latents.dtype, device=latents.device)
|
||||
initial_latents[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent
|
||||
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
|
||||
latents = latents * denoise_mask + initial_latents * (1.0 - denoise_mask)
|
||||
return latents, denoise_mask, initial_latents
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
input_params=("use_distilled_pipeline", "use_two_stage_pipeline"),
|
||||
output_params=("use_two_stage_pipeline", "cfg_scale")
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if inputs_shared.get("use_distilled_pipeline", False):
|
||||
inputs_shared["use_two_stage_pipeline"] = True
|
||||
inputs_shared["cfg_scale"] = 1.0
|
||||
print(f"Distilled pipeline requested, setting use_two_stage_pipeline to True, disable CFG by setting cfg_scale to 1.0.")
|
||||
if inputs_shared.get("use_two_stage_pipeline", False):
|
||||
# distill pipeline also uses two-stage, but it does not needs lora
|
||||
if not inputs_shared.get("use_distilled_pipeline", False):
|
||||
if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None):
|
||||
raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.")
|
||||
if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None):
|
||||
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit):
|
||||
"""
|
||||
For two-stage pipelines, the resolution must be divisible by 64.
|
||||
For one-stage pipelines, the resolution must be divisible by 32.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "num_frames"),
|
||||
output_params=("height", "width", "num_frames"),
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False):
|
||||
if use_two_stage_pipeline:
|
||||
self.width_division_factor = 64
|
||||
self.height_division_factor = 64
|
||||
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
|
||||
if use_two_stage_pipeline:
|
||||
self.width_division_factor = 32
|
||||
self.height_division_factor = 32
|
||||
return {"height": height, "width": width, "num_frames": num_frames}
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("video_context", "audio_context"),
|
||||
onload_model_names=("text_encoder", "text_encoder_post_modules"),
|
||||
)
|
||||
|
||||
def _convert_to_additive_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
return (attention_mask - 1).to(dtype).reshape(
|
||||
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max
|
||||
|
||||
def _run_connectors(self, pipe, encoded_input: torch.Tensor,
|
||||
attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
connector_attention_mask = self._convert_to_additive_mask(attention_mask, encoded_input.dtype)
|
||||
|
||||
encoded, encoded_connector_attention_mask = pipe.text_encoder_post_modules.embeddings_connector(
|
||||
encoded_input,
|
||||
connector_attention_mask,
|
||||
)
|
||||
|
||||
# restore the mask values to int64
|
||||
attention_mask = (encoded_connector_attention_mask < 0.000001).to(torch.int64)
|
||||
attention_mask = attention_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
|
||||
encoded = encoded * attention_mask
|
||||
|
||||
encoded_for_audio, _ = pipe.text_encoder_post_modules.audio_embeddings_connector(
|
||||
encoded_input, connector_attention_mask)
|
||||
|
||||
return encoded, encoded_for_audio, attention_mask.squeeze(-1)
|
||||
|
||||
def _norm_and_concat_padded_batch(
|
||||
self,
|
||||
encoded_text: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
padding_side: str = "right",
|
||||
) -> torch.Tensor:
|
||||
"""Normalize and flatten multi-layer hidden states, respecting padding.
|
||||
Performs per-batch, per-layer normalization using masked mean and range,
|
||||
then concatenates across the layer dimension.
|
||||
Args:
|
||||
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
|
||||
sequence_lengths: Number of valid (non-padded) tokens per batch item.
|
||||
padding_side: Whether padding is on "left" or "right".
|
||||
Returns:
|
||||
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
|
||||
with padded positions zeroed out.
|
||||
"""
|
||||
b, t, d, l = encoded_text.shape # noqa: E741
|
||||
device = encoded_text.device
|
||||
# Build mask: [B, T, 1, 1]
|
||||
token_indices = torch.arange(t, device=device)[None, :] # [1, T]
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [B, T]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = t - sequence_lengths[:, None] # [B, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = rearrange(mask, "b t -> b t 1 1")
|
||||
eps = 1e-6
|
||||
# Compute masked mean: [B, 1, 1, L]
|
||||
masked = encoded_text.masked_fill(~mask, 0.0)
|
||||
denom = (sequence_lengths * d).view(b, 1, 1, 1)
|
||||
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
|
||||
# Compute masked min/max: [B, 1, 1, L]
|
||||
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
range_ = x_max - x_min
|
||||
# Normalize only the valid tokens
|
||||
normed = 8 * (encoded_text - mean) / (range_ + eps)
|
||||
# concat to be [Batch, T, D * L] - this preserves the original structure
|
||||
normed = normed.reshape(b, t, -1) # [B, T, D * L]
|
||||
# Apply mask to preserve original padding (set padded positions to 0)
|
||||
mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l)
|
||||
normed = normed.masked_fill(~mask_flattened, 0.0)
|
||||
|
||||
return normed
|
||||
|
||||
def _run_feature_extractor(self,
|
||||
pipe,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
padding_side: str = "right") -> torch.Tensor:
|
||||
encoded_text_features = torch.stack(hidden_states, dim=-1)
|
||||
encoded_text_features_dtype = encoded_text_features.dtype
|
||||
sequence_lengths = attention_mask.sum(dim=-1)
|
||||
normed_concated_encoded_text_features = self._norm_and_concat_padded_batch(encoded_text_features,
|
||||
sequence_lengths,
|
||||
padding_side=padding_side)
|
||||
|
||||
return pipe.text_encoder_post_modules.feature_extractor_linear(
|
||||
normed_concated_encoded_text_features.to(encoded_text_features_dtype))
|
||||
|
||||
def _preprocess_text(
|
||||
self,
|
||||
pipe,
|
||||
text: str,
|
||||
padding_side: str = "left",
|
||||
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Encode a given string into feature tensors suitable for downstream tasks.
|
||||
Args:
|
||||
text (str): Input string to encode.
|
||||
Returns:
|
||||
tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask.
|
||||
"""
|
||||
token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"]
|
||||
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.device)
|
||||
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.device)
|
||||
outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
||||
projected = self._run_feature_extractor(pipe,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
padding_side=padding_side)
|
||||
return projected, attention_mask
|
||||
|
||||
def encode_prompt(self, pipe, text, padding_side="left"):
|
||||
encoded_inputs, attention_mask = self._preprocess_text(pipe, text, padding_side)
|
||||
video_encoding, audio_encoding, attention_mask = self._run_connectors(pipe, encoded_inputs, attention_mask)
|
||||
return video_encoding, audio_encoding, attention_mask
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, prompt: str):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
video_context, audio_context, _ = self.encode_prompt(pipe, prompt)
|
||||
return {"video_context": video_context, "audio_context": audio_context}
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "num_frames", "seed", "rand_device", "use_two_stage_pipeline"),
|
||||
output_params=("video_noise", "audio_noise",),
|
||||
)
|
||||
|
||||
def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
|
||||
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
|
||||
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels)
|
||||
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
||||
|
||||
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)
|
||||
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()
|
||||
video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate
|
||||
video_positions = video_positions.to(pipe.torch_dtype)
|
||||
|
||||
audio_latent_shape = AudioLatentShape.from_video_pixel_shape(video_pixel_shape)
|
||||
audio_noise = pipe.generate_noise(audio_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
||||
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
|
||||
return {
|
||||
"video_noise": video_noise,
|
||||
"audio_noise": audio_noise,
|
||||
"video_positions": video_positions,
|
||||
"audio_positions": audio_positions,
|
||||
"video_latent_shape": video_latent_shape,
|
||||
"audio_latent_shape": audio_latent_shape
|
||||
}
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0, use_two_stage_pipeline=False):
|
||||
if use_two_stage_pipeline:
|
||||
stage1_dict = self.process_stage(pipe, height // 2, width // 2, num_frames, seed, rand_device, frame_rate)
|
||||
stage2_dict = self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
|
||||
initial_dict = stage1_dict
|
||||
initial_dict.update({"stage2_" + k: v for k, v in stage2_dict.items()})
|
||||
return initial_dict
|
||||
else:
|
||||
return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
|
||||
|
||||
class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_video", "video_noise", "audio_noise", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("video_latents", "audio_latents"),
|
||||
onload_model_names=("video_vae_encoder")
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, audio_noise, tiled, tile_size, tile_stride):
|
||||
if input_video is None:
|
||||
return {"video_latents": video_noise, "audio_latents": audio_noise}
|
||||
else:
|
||||
# TODO: implement video-to-video
|
||||
raise NotImplementedError("Video-to-video not implemented yet.")
|
||||
|
||||
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "num_frames", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"),
|
||||
output_params=("video_latents"),
|
||||
onload_model_names=("video_vae_encoder")
|
||||
)
|
||||
|
||||
def get_image_latent(self, pipe, input_image, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
|
||||
image = ltx2_preprocess(np.array(input_image.resize((width, height))))
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32)).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
image = image / 127.5 - 1.0
|
||||
image = repeat(image, f"H W C -> B C F H W", B=1, F=1)
|
||||
latent = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)
|
||||
return latent
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_indexes, input_images_strength, video_latents, height, width, num_frames, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=False):
|
||||
if input_images is None or len(input_images) == 0:
|
||||
return {"video_latents": video_latents}
|
||||
else:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
output_dicts = {}
|
||||
stage1_height = height // 2 if use_two_stage_pipeline else height
|
||||
stage1_width = width // 2 if use_two_stage_pipeline else width
|
||||
stage1_latents = [
|
||||
self.get_image_latent(pipe, img, stage1_height, stage1_width, tiled, tile_size_in_pixels,
|
||||
tile_overlap_in_pixels) for img in input_images
|
||||
]
|
||||
video_latents, denoise_mask_video, initial_latents = pipe.apply_input_images_to_latents(video_latents, stage1_latents, input_images_indexes, input_images_strength, num_frames=num_frames)
|
||||
output_dicts.update({"video_latents": video_latents, "denoise_mask_video": denoise_mask_video, "input_latents_video": initial_latents})
|
||||
if use_two_stage_pipeline:
|
||||
stage2_latents = [
|
||||
self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels,
|
||||
tile_overlap_in_pixels) for img in input_images
|
||||
]
|
||||
output_dicts.update({"stage2_input_latents": stage2_latents})
|
||||
return output_dicts
|
||||
|
||||
|
||||
def model_fn_ltx2(
|
||||
dit: LTXModel,
|
||||
video_latents=None,
|
||||
video_context=None,
|
||||
video_positions=None,
|
||||
video_patchifier=None,
|
||||
audio_latents=None,
|
||||
audio_context=None,
|
||||
audio_positions=None,
|
||||
audio_patchifier=None,
|
||||
timestep=None,
|
||||
denoise_mask_video=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
timestep = timestep.float() / 1000.
|
||||
|
||||
# patchify
|
||||
b, c_v, f, h, w = video_latents.shape
|
||||
video_latents = video_patchifier.patchify(video_latents)
|
||||
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
|
||||
if denoise_mask_video is not None:
|
||||
video_timesteps = video_patchifier.patchify(denoise_mask_video) * video_timesteps
|
||||
_, c_a, _, mel_bins = audio_latents.shape
|
||||
audio_latents = audio_patchifier.patchify(audio_latents)
|
||||
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
|
||||
#TODO: support gradient checkpointing in training
|
||||
vx, ax = dit(
|
||||
video_latents=video_latents,
|
||||
video_positions=video_positions,
|
||||
video_context=video_context,
|
||||
video_timesteps=video_timesteps,
|
||||
audio_latents=audio_latents,
|
||||
audio_positions=audio_positions,
|
||||
audio_context=audio_context,
|
||||
audio_timesteps=audio_timesteps,
|
||||
)
|
||||
# unpatchify
|
||||
vx = video_patchifier.unpatchify_video(vx, f, h, w)
|
||||
ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins)
|
||||
return vx, ax
|
||||
149
diffsynth/utils/data/media_io_ltx2.py
Normal file
149
diffsynth/utils/data/media_io_ltx2.py
Normal file
@@ -0,0 +1,149 @@
|
||||
|
||||
from fractions import Fraction
|
||||
import torch
|
||||
import av
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from io import BytesIO
|
||||
from collections.abc import Generator, Iterator
|
||||
|
||||
|
||||
def _resample_audio(
|
||||
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
|
||||
) -> None:
|
||||
cc = audio_stream.codec_context
|
||||
|
||||
# Use the encoder's format/layout/rate as the *target*
|
||||
target_format = cc.format or "fltp" # AAC → usually fltp
|
||||
target_layout = cc.layout or "stereo"
|
||||
target_rate = cc.sample_rate or frame_in.sample_rate
|
||||
|
||||
audio_resampler = av.audio.resampler.AudioResampler(
|
||||
format=target_format,
|
||||
layout=target_layout,
|
||||
rate=target_rate,
|
||||
)
|
||||
|
||||
audio_next_pts = 0
|
||||
for rframe in audio_resampler.resample(frame_in):
|
||||
if rframe.pts is None:
|
||||
rframe.pts = audio_next_pts
|
||||
audio_next_pts += rframe.samples
|
||||
rframe.sample_rate = frame_in.sample_rate
|
||||
container.mux(audio_stream.encode(rframe))
|
||||
|
||||
# flush audio encoder
|
||||
for packet in audio_stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
|
||||
def _write_audio(
|
||||
container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
|
||||
) -> None:
|
||||
if samples.ndim == 1:
|
||||
samples = samples[:, None]
|
||||
|
||||
if samples.shape[1] != 2 and samples.shape[0] == 2:
|
||||
samples = samples.T
|
||||
|
||||
if samples.shape[1] != 2:
|
||||
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
|
||||
|
||||
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
|
||||
if samples.dtype != torch.int16:
|
||||
samples = torch.clip(samples, -1.0, 1.0)
|
||||
samples = (samples * 32767.0).to(torch.int16)
|
||||
|
||||
frame_in = av.AudioFrame.from_ndarray(
|
||||
samples.contiguous().reshape(1, -1).cpu().numpy(),
|
||||
format="s16",
|
||||
layout="stereo",
|
||||
)
|
||||
frame_in.sample_rate = audio_sample_rate
|
||||
|
||||
_resample_audio(container, audio_stream, frame_in)
|
||||
|
||||
|
||||
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
|
||||
"""
|
||||
Prepare the audio stream for writing.
|
||||
"""
|
||||
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
|
||||
audio_stream.codec_context.sample_rate = audio_sample_rate
|
||||
audio_stream.codec_context.layout = "stereo"
|
||||
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
|
||||
return audio_stream
|
||||
|
||||
def write_video_audio_ltx2(
|
||||
video: list[Image.Image],
|
||||
audio: torch.Tensor | None,
|
||||
output_path: str,
|
||||
fps: int = 24,
|
||||
audio_sample_rate: int | None = 24000,
|
||||
) -> None:
|
||||
|
||||
width, height = video[0].size
|
||||
container = av.open(output_path, mode="w")
|
||||
stream = container.add_stream("libx264", rate=int(fps))
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
if audio is not None:
|
||||
if audio_sample_rate is None:
|
||||
raise ValueError("audio_sample_rate is required when audio is provided")
|
||||
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
|
||||
|
||||
for frame in tqdm(video, total=len(video)):
|
||||
frame = av.VideoFrame.from_image(frame)
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
|
||||
# Flush encoder
|
||||
for packet in stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
if audio is not None:
|
||||
_write_audio(container, audio_stream, audio, audio_sample_rate)
|
||||
|
||||
container.close()
|
||||
|
||||
|
||||
def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None:
|
||||
container = av.open(output_file, "w", format="mp4")
|
||||
try:
|
||||
stream = container.add_stream("libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"})
|
||||
# Round to nearest multiple of 2 for compatibility with video codecs
|
||||
height = image_array.shape[0] // 2 * 2
|
||||
width = image_array.shape[1] // 2 * 2
|
||||
image_array = image_array[:height, :width]
|
||||
stream.height = height
|
||||
stream.width = width
|
||||
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(format="yuv420p")
|
||||
container.mux(stream.encode(av_frame))
|
||||
container.mux(stream.encode())
|
||||
finally:
|
||||
container.close()
|
||||
|
||||
|
||||
def decode_single_frame(video_file: str) -> np.array:
|
||||
container = av.open(video_file)
|
||||
try:
|
||||
stream = next(s for s in container.streams if s.type == "video")
|
||||
frame = next(container.decode(stream))
|
||||
finally:
|
||||
container.close()
|
||||
return frame.to_ndarray(format="rgb24")
|
||||
|
||||
|
||||
def ltx2_preprocess(image: np.array, crf: float = 33) -> np.array:
|
||||
if crf == 0:
|
||||
return image
|
||||
|
||||
with BytesIO() as output_file:
|
||||
encode_single_frame(output_file, image, crf)
|
||||
video_bytes = output_file.getvalue()
|
||||
with BytesIO(video_bytes) as video_file:
|
||||
image_array = decode_single_frame(video_file)
|
||||
return image_array
|
||||
32
diffsynth/utils/state_dict_converters/ltx2_audio_vae.py
Normal file
32
diffsynth/utils/state_dict_converters/ltx2_audio_vae.py
Normal file
@@ -0,0 +1,32 @@
|
||||
def LTX2AudioEncoderStateDictConverter(state_dict):
|
||||
# Not used
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("audio_vae.encoder."):
|
||||
new_name = name.replace("audio_vae.encoder.", "")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
elif name.startswith("audio_vae.per_channel_statistics."):
|
||||
new_name = name.replace("audio_vae.per_channel_statistics.", "per_channel_statistics.")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
|
||||
def LTX2AudioDecoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("audio_vae.decoder."):
|
||||
new_name = name.replace("audio_vae.decoder.", "")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
elif name.startswith("audio_vae.per_channel_statistics."):
|
||||
new_name = name.replace("audio_vae.per_channel_statistics.", "per_channel_statistics.")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
|
||||
def LTX2VocoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("vocoder."):
|
||||
new_name = name.replace("vocoder.", "")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
9
diffsynth/utils/state_dict_converters/ltx2_dit.py
Normal file
9
diffsynth/utils/state_dict_converters/ltx2_dit.py
Normal file
@@ -0,0 +1,9 @@
|
||||
def LTXModelStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("model.diffusion_model."):
|
||||
new_name = name.replace("model.diffusion_model.", "")
|
||||
if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."):
|
||||
continue
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
31
diffsynth/utils/state_dict_converters/ltx2_text_encoder.py
Normal file
31
diffsynth/utils/state_dict_converters/ltx2_text_encoder.py
Normal file
@@ -0,0 +1,31 @@
|
||||
def LTX2TextEncoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if key.startswith("language_model.model."):
|
||||
new_key = key.replace("language_model.model.", "model.language_model.")
|
||||
elif key.startswith("vision_tower."):
|
||||
new_key = key.replace("vision_tower.", "model.vision_tower.")
|
||||
elif key.startswith("multi_modal_projector."):
|
||||
new_key = key.replace("multi_modal_projector.", "model.multi_modal_projector.")
|
||||
elif key.startswith("language_model.lm_head."):
|
||||
new_key = key.replace("language_model.lm_head.", "lm_head.")
|
||||
else:
|
||||
continue
|
||||
state_dict_[new_key] = state_dict[key]
|
||||
state_dict_["lm_head.weight"] = state_dict_.get("model.language_model.embed_tokens.weight")
|
||||
return state_dict_
|
||||
|
||||
|
||||
def LTX2TextEncoderPostModulesStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if key.startswith("text_embedding_projection."):
|
||||
new_key = key.replace("text_embedding_projection.", "feature_extractor_linear.")
|
||||
elif key.startswith("model.diffusion_model.video_embeddings_connector."):
|
||||
new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "embeddings_connector.")
|
||||
elif key.startswith("model.diffusion_model.audio_embeddings_connector."):
|
||||
new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "audio_embeddings_connector.")
|
||||
else:
|
||||
continue
|
||||
state_dict_[new_key] = state_dict[key]
|
||||
return state_dict_
|
||||
22
diffsynth/utils/state_dict_converters/ltx2_video_vae.py
Normal file
22
diffsynth/utils/state_dict_converters/ltx2_video_vae.py
Normal file
@@ -0,0 +1,22 @@
|
||||
def LTX2VideoEncoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("vae.encoder."):
|
||||
new_name = name.replace("vae.encoder.", "")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
elif name.startswith("vae.per_channel_statistics."):
|
||||
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
|
||||
def LTX2VideoDecoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("vae.decoder."):
|
||||
new_name = name.replace("vae.decoder.", "")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
elif name.startswith("vae.per_channel_statistics."):
|
||||
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
109
docs/en/Model_Details/LTX-2.md
Normal file
109
docs/en/Model_Details/LTX-2.md
Normal file
@@ -0,0 +1,109 @@
|
||||
# LTX-2
|
||||
|
||||
LTX-2 is a series of audio-video generation models developed by Lightricks.
|
||||
|
||||
## Installation
|
||||
|
||||
Before using this project for model inference and training, please install DiffSynth-Studio first.
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
For more information about installation, please refer to [Installation Dependencies](/docs/en/Pipeline_Usage/Setup.md).
|
||||
|
||||
## Quick Start
|
||||
|
||||
Run the following code to quickly load the [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) model and perform inference. VRAM management has been enabled, and the framework will automatically control model parameter loading based on remaining VRAM. It can run with a minimum of 8GB VRAM.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e5m2,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e5m2,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e5m2,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
||||
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
height, width, num_frames = 512, 768, 121
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_onestage.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
```
|
||||
|
||||
## Model Overview
|
||||
|Model ID|Additional Parameters|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|
||||
|-|-|-|-|-|-|-|-|
|
||||
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|
||||
|
||||
## Model Inference
|
||||
|
||||
Models are loaded through `LTX2AudioVideoPipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models) for details.
|
||||
|
||||
Input parameters for `LTX2AudioVideoPipeline` inference include:
|
||||
|
||||
* `prompt`: Prompt describing the content appearing in the video.
|
||||
* `negative_prompt`: Negative prompt describing content that should not appear in the video, default value is `""`.
|
||||
* `cfg_scale`: Classifier-free guidance parameter, default value is 3.0.
|
||||
* `input_images`: List of input images for image-to-video generation.
|
||||
* `input_images_indexes`: Frame index list of input images in the video.
|
||||
* `input_images_strength`: Strength of input images, default value is 1.0.
|
||||
* `denoising_strength`: Denoising strength, range is 0~1, default value is 1.0.
|
||||
* `seed`: Random seed. Default is `None`, which means completely random.
|
||||
* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different results will be generated on different GPUs.
|
||||
* `height`: Video height, must be a multiple of 32 (single-stage) or 64 (two-stage).
|
||||
* `width`: Video width, must be a multiple of 32 (single-stage) or 64 (two-stage).
|
||||
* `num_frames`: Number of video frames, default value is 121, must be a multiple of 8 + 1.
|
||||
* `num_inference_steps`: Number of inference steps, default value is 40.
|
||||
* `tiled`: Whether to enable VAE tiling inference, default is `True`. When set to `True`, it can significantly reduce VRAM usage during VAE encoding/decoding stages, with slight errors and minor inference time extension.
|
||||
* `tile_size_in_pixels`: Pixel tiling size during VAE encoding/decoding stages, default is 512.
|
||||
* `tile_overlap_in_pixels`: Pixel tiling overlap size during VAE encoding/decoding stages, default is 128.
|
||||
* `tile_size_in_frames`: Frame tiling size during VAE encoding/decoding stages, default is 128.
|
||||
* `tile_overlap_in_frames`: Frame tiling overlap size during VAE encoding/decoding stages, default is 24.
|
||||
* `use_two_stage_pipeline`: Whether to use two-stage pipeline, default is `False`.
|
||||
* `use_distilled_pipeline`: Whether to use distilled pipeline, default is `False`.
|
||||
* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be set to `lambda x:x` to hide the progress bar.
|
||||
|
||||
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the previous "Supported Inference Scripts" section.
|
||||
|
||||
## Model Training
|
||||
|
||||
The LTX-2 series models currently do not support training functionality. We will add related support as soon as possible.
|
||||
109
docs/zh/Model_Details/LTX-2.md
Normal file
109
docs/zh/Model_Details/LTX-2.md
Normal file
@@ -0,0 +1,109 @@
|
||||
# LTX-2
|
||||
|
||||
LTX-2 是由 Lightricks 开发的音视频生成模型系列。
|
||||
|
||||
## 安装
|
||||
|
||||
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
更多关于安装的信息,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)。
|
||||
|
||||
## 快速开始
|
||||
|
||||
运行以下代码可以快速加载 [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e5m2,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e5m2,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e5m2,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
height, width, num_frames = 512, 768, 121
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_onestage.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
```
|
||||
|
||||
## 模型总览
|
||||
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|-|
|
||||
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|
||||
|
||||
## 模型推理
|
||||
|
||||
模型通过 `LTX2AudioVideoPipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。
|
||||
|
||||
`LTX2AudioVideoPipeline` 推理的输入参数包括:
|
||||
|
||||
* `prompt`: 提示词,描述视频中出现的内容。
|
||||
* `negative_prompt`: 负向提示词,描述视频中不应该出现的内容,默认值为 `""`。
|
||||
* `cfg_scale`: Classifier-free guidance 的参数,默认值为 3.0。
|
||||
* `input_images`: 输入图像列表,用于图生视频。
|
||||
* `input_images_indexes`: 输入图像在视频中的帧索引列表。
|
||||
* `input_images_strength`: 输入图像的强度,默认值为 1.0。
|
||||
* `denoising_strength`: 去噪强度,范围是 0~1,默认值为 1.0。
|
||||
* `seed`: 随机种子。默认为 `None`,即完全随机。
|
||||
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
|
||||
* `height`: 视频高度,需保证高度为 32 的倍数(单阶段)或 64 的倍数(两阶段)。
|
||||
* `width`: 视频宽度,需保证宽度为 32 的倍数(单阶段)或 64 的倍数(两阶段)。
|
||||
* `num_frames`: 视频帧数,默认值为 121,需保证为 8 的倍数 + 1。
|
||||
* `num_inference_steps`: 推理次数,默认值为 40。
|
||||
* `tiled`: 是否启用 VAE 分块推理,默认为 `True`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用,会产生少许误差,以及少量推理时间延长。
|
||||
* `tile_size_in_pixels`: VAE 编解码阶段的像素分块大小,默认为 512。
|
||||
* `tile_overlap_in_pixels`: VAE 编解码阶段的像素分块重叠大小,默认为 128。
|
||||
* `tile_size_in_frames`: VAE 编解码阶段的帧分块大小,默认为 128。
|
||||
* `tile_overlap_in_frames`: VAE 编解码阶段的帧分块重叠大小,默认为 24。
|
||||
* `use_two_stage_pipeline`: 是否使用两阶段管道,默认为 `False`。
|
||||
* `use_distilled_pipeline`: 是否使用蒸馏管道,默认为 `False`。
|
||||
* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。
|
||||
|
||||
如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"支持的推理脚本"中的表格。
|
||||
|
||||
## 模型训练
|
||||
|
||||
LTX-2 系列模型目前暂不支持训练功能。我们将尽快添加相关支持。
|
||||
@@ -0,0 +1,69 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
)
|
||||
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"]
|
||||
)
|
||||
image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height))
|
||||
# first frame
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
use_distilled_pipeline=True,
|
||||
input_images=[image],
|
||||
input_images_indexes=[0],
|
||||
input_images_strength=1.0,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_distilled_i2av_first.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
55
examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py
Normal file
55
examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
)
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"]
|
||||
)
|
||||
image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height))
|
||||
# first frame
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=False,
|
||||
input_images=[image],
|
||||
input_images_indexes=[0],
|
||||
input_images_strength=1.0,
|
||||
num_inference_steps=40,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_onestage_i2av_first.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
72
examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py
Normal file
72
examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||
)
|
||||
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"]
|
||||
)
|
||||
image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height))
|
||||
# first frame
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=42,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
use_two_stage_pipeline=True,
|
||||
num_inference_steps=40,
|
||||
input_images=[image],
|
||||
input_images_indexes=[0],
|
||||
input_images_strength=1.0,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_twostage_i2av_first.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
)
|
||||
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
use_distilled_pipeline=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_distilled.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
42
examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py
Normal file
42
examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
)
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
height, width, num_frames = 512, 768, 121
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_onestage.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
58
examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py
Normal file
58
examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||
)
|
||||
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
use_two_stage_pipeline=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_twostage.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
@@ -0,0 +1,70 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e5m2,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e5m2,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e5m2,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"]
|
||||
)
|
||||
image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height))
|
||||
# first frame
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
use_distilled_pipeline=True,
|
||||
input_images=[image],
|
||||
input_images_indexes=[0],
|
||||
input_images_strength=1.0,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_distilled_i2av_first.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
@@ -0,0 +1,56 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e5m2,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e5m2,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e5m2,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"]
|
||||
)
|
||||
image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height))
|
||||
# first frame
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=False,
|
||||
input_images=[image],
|
||||
input_images_indexes=[0],
|
||||
input_images_strength=1.0,
|
||||
num_inference_steps=40,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_onestage_i2av_first.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
@@ -0,0 +1,72 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
from PIL import Image
|
||||
from modelscope import dataset_snapshot_download
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e5m2,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e5m2,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e5m2,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
dataset_snapshot_download(
|
||||
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||
local_dir="./",
|
||||
allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"]
|
||||
)
|
||||
image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height))
|
||||
# first frame
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=42,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
use_two_stage_pipeline=True,
|
||||
num_inference_steps=40,
|
||||
input_images=[image],
|
||||
input_images_indexes=[0],
|
||||
input_images_strength=1.0,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_twostage_i2av_first.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
@@ -0,0 +1,58 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e5m2,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e5m2,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e5m2,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
use_distilled_pipeline=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_distilled.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e5m2,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e5m2,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e5m2,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
height, width, num_frames = 512, 768, 121
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_onestage.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e5m2,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e5m2,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e5m2,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
use_two_stage_pipeline=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_twostage.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
Reference in New Issue
Block a user