mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
Compare commits
43 Commits
zero3
...
examples-u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53890bafa4 | ||
|
|
afd48cd706 | ||
|
|
24b68c2392 | ||
|
|
280ff7cca6 | ||
|
|
b4b62e2f7c | ||
|
|
6d1be405b9 | ||
|
|
25c3a3d3e2 | ||
|
|
49bc84f78e | ||
|
|
25a9e75030 | ||
|
|
2a7ac73eb5 | ||
|
|
f4f991d409 | ||
|
|
a781138413 | ||
|
|
91a5623976 | ||
|
|
28cd355aba | ||
|
|
005389fca7 | ||
|
|
a6282056eb | ||
|
|
21a6eb8e2f | ||
|
|
98ab238340 | ||
|
|
1c8a0f8317 | ||
|
|
9f07d65ebb | ||
|
|
5f1d5adfce | ||
|
|
4f23caa55f | ||
|
|
b4f6a4de6c | ||
|
|
53fe42af1b | ||
|
|
ee9a3b4405 | ||
|
|
b1a2782ad7 | ||
|
|
8d303b47e9 | ||
|
|
00da4b6c4f | ||
|
|
22695e9be0 | ||
|
|
98290190ec | ||
|
|
3f4de2cc7f | ||
|
|
8d0df403ca | ||
|
|
d12bf71bcc | ||
|
|
35e0776022 | ||
|
|
b3cc652dea | ||
|
|
d879d66c62 | ||
|
|
848bfd6993 | ||
|
|
269da09f6e | ||
|
|
e30514a00c | ||
|
|
c758769a02 | ||
|
|
a5935e973a | ||
|
|
9834d72e4d | ||
|
|
01234e59c0 |
101
README.md
101
README.md
@@ -33,6 +33,10 @@ We believe that a well-developed open-source code framework can lower the thresh
|
|||||||
|
|
||||||
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
||||||
|
|
||||||
|
- **February 2, 2026** The first document of the Research Tutorial series is now available, guiding you through training a small 0.1B text-to-image model from scratch. For details, see the [documentation](/docs/en/Research_Tutorial/train_from_scratch.md) and [model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel). We hope DiffSynth-Studio can evolve into a more powerful training framework for Diffusion models.
|
||||||
|
|
||||||
|
- **January 27, 2026**: [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) is released, and our [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) model is released concurrently. You can use it in [ModelScope Studios](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L). For details, see the [documentation](/docs/zh/Model_Details/Z-Image.md).
|
||||||
|
|
||||||
- **January 19, 2026**: Added support for [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) and [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/FLUX2.md) and [example code](/examples/flux2/) are now available.
|
- **January 19, 2026**: Added support for [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) and [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/FLUX2.md) and [example code](/examples/flux2/) are now available.
|
||||||
|
|
||||||
- **January 12, 2026**: We trained and open-sourced a text-guided image layer separation model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)). Given an input image and a textual description, the model isolates the image layer corresponding to the described content. For more details, please refer to our blog post ([Chinese version](https://modelscope.cn/learn/4938), [English version](https://huggingface.co/blog/kelseye/qwen-image-layered-control)).
|
- **January 12, 2026**: We trained and open-sourced a text-guided image layer separation model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)). Given an input image and a textual description, the model isolates the image layer corresponding to the described content. For more details, please refer to our blog post ([Chinese version](https://modelscope.cn/learn/4938), [English version](https://huggingface.co/blog/kelseye/qwen-image-layered-control)).
|
||||||
@@ -269,9 +273,14 @@ image.save("image.jpg")
|
|||||||
|
|
||||||
Example code for Z-Image is available at: [/examples/z_image/](/examples/z_image/)
|
Example code for Z-Image is available at: [/examples/z_image/](/examples/z_image/)
|
||||||
|
|
||||||
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
|Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|
||||||
|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|
||||||
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -410,6 +419,7 @@ Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/q
|
|||||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||||
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
||||||
|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|
||||||
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
@@ -522,6 +532,95 @@ Example code for FLUX.1 is available at: [/examples/flux/](/examples/flux/)
|
|||||||
|
|
||||||
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||||
|
|
||||||
|
#### LTX-2: [/docs/en/Model_Details/LTX-2.md](/docs/en/Model_Details/LTX-2.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Quick Start</summary>
|
||||||
|
|
||||||
|
Running the following code will quickly load the [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8GB of VRAM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float8_e5m2,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float8_e5m2,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e5m2,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
||||||
|
negative_prompt = (
|
||||||
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||||
|
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||||
|
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||||
|
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||||
|
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||||
|
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||||
|
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||||
|
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||||
|
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||||
|
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||||
|
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
)
|
||||||
|
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=43,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=True,
|
||||||
|
use_two_stage_pipeline=True,
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_twostage.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>Examples</summary>
|
||||||
|
|
||||||
|
Example code for LTX-2 is available at: [/examples/ltx2/](/examples/ltx2/)
|
||||||
|
|
||||||
|
| Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||||
|
|-|-|-|-|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
#### Wan: [/docs/en/Model_Details/Wan.md](/docs/en/Model_Details/Wan.md)
|
#### Wan: [/docs/en/Model_Details/Wan.md](/docs/en/Model_Details/Wan.md)
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
|
|||||||
99
README_zh.md
99
README_zh.md
@@ -33,6 +33,10 @@ DiffSynth 目前包括两个开源项目:
|
|||||||
|
|
||||||
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
|
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
|
||||||
|
|
||||||
|
- **2026年2月2日** Research Tutorial 的第一篇文档上线,带你从零开始训练一个 0.1B 的小型文生图模型,详见[文档](/docs/zh/Research_Tutorial/train_from_scratch.md)、[模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel),我们希望 DiffSynth-Studio 能够成为一个更强大的 Diffusion 模型训练框架。
|
||||||
|
|
||||||
|
- **2026年1月27日** [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) 发布,我们的 [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) 模型同步发布,在[魔搭创空间](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L)可直接体验,详见[文档](/docs/zh/Model_Details/Z-Image.md)。
|
||||||
|
|
||||||
- **2026年1月19日** 新增对 [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 和 [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/FLUX2.md)和[示例代码](/examples/flux2/)现已可用。
|
- **2026年1月19日** 新增对 [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 和 [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/FLUX2.md)和[示例代码](/examples/flux2/)现已可用。
|
||||||
|
|
||||||
- **2026年1月12日** 我们训练并开源了一个文本引导的图层拆分模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)),这一模型输入一张图与一段文本描述,模型会将图像中与文本描述相关的图层拆分出来。更多细节请阅读我们的 blog([中文版](https://modelscope.cn/learn/4938)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-layered-control))。
|
- **2026年1月12日** 我们训练并开源了一个文本引导的图层拆分模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)),这一模型输入一张图与一段文本描述,模型会将图像中与文本描述相关的图层拆分出来。更多细节请阅读我们的 blog([中文版](https://modelscope.cn/learn/4938)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-layered-control))。
|
||||||
@@ -271,7 +275,12 @@ Z-Image 的示例代码位于:[/examples/z_image/](/examples/z_image/)
|
|||||||
|
|
||||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|
||||||
|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|
||||||
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
@@ -410,6 +419,7 @@ Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/
|
|||||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||||
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
||||||
|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|
||||||
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
@@ -522,6 +532,95 @@ FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/)
|
|||||||
|
|
||||||
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||||
|
|
||||||
|
#### LTX-2: [/docs/zh/Model_Details/LTX-2.md](/docs/zh/Model_Details/LTX-2.md)
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>快速开始</summary>
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float8_e5m2,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float8_e5m2,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e5m2,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
||||||
|
negative_prompt = (
|
||||||
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||||
|
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||||
|
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||||
|
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||||
|
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||||
|
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||||
|
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||||
|
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||||
|
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||||
|
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||||
|
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
)
|
||||||
|
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=43,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=True,
|
||||||
|
use_two_stage_pipeline=True,
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_twostage.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>示例代码</summary>
|
||||||
|
|
||||||
|
LTX-2 的示例代码位于:[/examples/ltx2/](/examples/ltx2/)
|
||||||
|
|
||||||
|
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
#### Wan: [/docs/zh/Model_Details/Wan.md](/docs/zh/Model_Details/Wan.md)
|
#### Wan: [/docs/zh/Model_Details/Wan.md](/docs/zh/Model_Details/Wan.md)
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
|
|||||||
@@ -589,6 +589,78 @@ z_image_series = [
|
|||||||
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
|
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
|
||||||
"extra_kwargs": {"compress_dim": 128},
|
"extra_kwargs": {"compress_dim": 128},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors")
|
||||||
|
"model_hash": "1392adecee344136041e70553f875f31",
|
||||||
|
"model_name": "z_image_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||||
|
"extra_kwargs": {"model_size": "0.6B"},
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series
|
ltx2_series = [
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_dit",
|
||||||
|
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_video_vae_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_video_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_audio_vae_decoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_audio_vocoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
||||||
|
},
|
||||||
|
# { # not used currently
|
||||||
|
# # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
# "model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
# "model_name": "ltx2_audio_vae_encoder",
|
||||||
|
# "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
||||||
|
# "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
||||||
|
# },
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||||
|
"model_name": "ltx2_text_encoder_post_modules",
|
||||||
|
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors")
|
||||||
|
"model_hash": "33917f31c4a79196171154cca39f165e",
|
||||||
|
"model_name": "ltx2_text_encoder",
|
||||||
|
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder",
|
||||||
|
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderStateDictConverter",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||||
|
"model_hash": "c79c458c6e99e0e14d47e676761732d2",
|
||||||
|
"model_name": "ltx2_latent_upsampler",
|
||||||
|
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series
|
||||||
|
|||||||
@@ -210,4 +210,37 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
|||||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
},
|
},
|
||||||
|
"diffsynth.models.ltx2_dit.LTXModel": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler": {
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_video_vae.LTX2VideoEncoder": {
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_video_vae.LTX2VideoDecoder": {
|
||||||
|
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder": {
|
||||||
|
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_audio_vae.LTX2Vocoder": {
|
||||||
|
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"diffsynth.models.ltx2_text_encoder.Embeddings1DConnector": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
|
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoder": {
|
||||||
|
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||||
|
"transformers.models.gemma3.modeling_gemma3.Gemma3MultiModalProjector": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
"transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="
|
|||||||
if k_pattern != required_in_pattern:
|
if k_pattern != required_in_pattern:
|
||||||
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
|
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
|
||||||
if v_pattern != required_in_pattern:
|
if v_pattern != required_in_pattern:
|
||||||
v = rearrange(v, f"{q_pattern} -> {required_in_pattern}", **dims)
|
v = rearrange(v, f"{v_pattern} -> {required_in_pattern}", **dims)
|
||||||
return q, k, v
|
return q, k, v
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import torch, glob, os
|
import torch, glob, os
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union, Dict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
from huggingface_hub import snapshot_download as hf_snapshot_download
|
from huggingface_hub import snapshot_download as hf_snapshot_download
|
||||||
@@ -23,13 +23,14 @@ class ModelConfig:
|
|||||||
computation_device: Optional[Union[str, torch.device]] = None
|
computation_device: Optional[Union[str, torch.device]] = None
|
||||||
computation_dtype: Optional[torch.dtype] = None
|
computation_dtype: Optional[torch.dtype] = None
|
||||||
clear_parameters: bool = False
|
clear_parameters: bool = False
|
||||||
|
state_dict: Dict[str, torch.Tensor] = None
|
||||||
|
|
||||||
def check_input(self):
|
def check_input(self):
|
||||||
if self.path is None and self.model_id is None:
|
if self.path is None and self.model_id is None:
|
||||||
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
|
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
|
||||||
|
|
||||||
def parse_original_file_pattern(self):
|
def parse_original_file_pattern(self):
|
||||||
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
if self.origin_file_pattern in [None, "", "./"]:
|
||||||
return "*"
|
return "*"
|
||||||
elif self.origin_file_pattern.endswith("/"):
|
elif self.origin_file_pattern.endswith("/"):
|
||||||
return self.origin_file_pattern + "*"
|
return self.origin_file_pattern + "*"
|
||||||
@@ -98,7 +99,7 @@ class ModelConfig:
|
|||||||
if self.require_downloading():
|
if self.require_downloading():
|
||||||
self.download()
|
self.download()
|
||||||
if self.path is None:
|
if self.path is None:
|
||||||
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
if self.origin_file_pattern in [None, "", "./"]:
|
||||||
self.path = os.path.join(self.local_model_path, self.model_id)
|
self.path = os.path.join(self.local_model_path, self.model_id)
|
||||||
else:
|
else:
|
||||||
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
|
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
|
||||||
|
|||||||
@@ -2,16 +2,25 @@ from safetensors import safe_open
|
|||||||
import torch, hashlib
|
import torch, hashlib
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
|
def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
|
||||||
if isinstance(file_path, list):
|
if isinstance(file_path, list):
|
||||||
state_dict = {}
|
state_dict = {}
|
||||||
for file_path_ in file_path:
|
for file_path_ in file_path:
|
||||||
state_dict.update(load_state_dict(file_path_, torch_dtype, device))
|
state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
|
||||||
return state_dict
|
|
||||||
if file_path.endswith(".safetensors"):
|
|
||||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
|
||||||
else:
|
else:
|
||||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
if verbose >= 1:
|
||||||
|
print(f"Loading file [started]: {file_path}")
|
||||||
|
if file_path.endswith(".safetensors"):
|
||||||
|
state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||||
|
else:
|
||||||
|
state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||||
|
# If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
|
||||||
|
if pin_memory:
|
||||||
|
for i in state_dict:
|
||||||
|
state_dict[i] = state_dict[i].pin_memory()
|
||||||
|
if verbose >= 1:
|
||||||
|
print(f"Loading file [done]: {file_path}")
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from .file import load_state_dict
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
|
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
|
||||||
config = {} if config is None else config
|
config = {} if config is None else config
|
||||||
# Why do we use `skip_model_initialization`?
|
# Why do we use `skip_model_initialization`?
|
||||||
# It skips the random initialization of model parameters,
|
# It skips the random initialization of model parameters,
|
||||||
@@ -20,7 +20,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
|||||||
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
|
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
|
||||||
dtype = [d for d in dtypes if d != "disk"][0]
|
dtype = [d for d in dtypes if d != "disk"][0]
|
||||||
if vram_config["offload_device"] != "disk":
|
if vram_config["offload_device"] != "disk":
|
||||||
state_dict = DiskMap(path, device, torch_dtype=dtype)
|
if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
|
||||||
if state_dict_converter is not None:
|
if state_dict_converter is not None:
|
||||||
state_dict = state_dict_converter(state_dict)
|
state_dict = state_dict_converter(state_dict)
|
||||||
else:
|
else:
|
||||||
@@ -35,7 +35,9 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
|||||||
# Sometimes a model file contains multiple models,
|
# Sometimes a model file contains multiple models,
|
||||||
# and DiskMap can load only the parameters of a single model,
|
# and DiskMap can load only the parameters of a single model,
|
||||||
# avoiding the need to load all parameters in the file.
|
# avoiding the need to load all parameters in the file.
|
||||||
if use_disk_map:
|
if state_dict is not None:
|
||||||
|
pass
|
||||||
|
elif use_disk_map:
|
||||||
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
|
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
|
||||||
else:
|
else:
|
||||||
state_dict = load_state_dict(path, torch_dtype, device)
|
state_dict = load_state_dict(path, torch_dtype, device)
|
||||||
|
|||||||
@@ -296,6 +296,7 @@ class BasePipeline(torch.nn.Module):
|
|||||||
vram_config=vram_config,
|
vram_config=vram_config,
|
||||||
vram_limit=vram_limit,
|
vram_limit=vram_limit,
|
||||||
clear_parameters=model_config.clear_parameters,
|
clear_parameters=model_config.clear_parameters,
|
||||||
|
state_dict=model_config.state_dict,
|
||||||
)
|
)
|
||||||
return model_pool
|
return model_pool
|
||||||
|
|
||||||
@@ -317,6 +318,13 @@ class BasePipeline(torch.nn.Module):
|
|||||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||||
self.clear_lora(verbose=0)
|
self.clear_lora(verbose=0)
|
||||||
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
||||||
|
if isinstance(noise_pred_posi, tuple):
|
||||||
|
# Separately handling different output types of latents, eg. video and audio latents.
|
||||||
|
noise_pred = tuple(
|
||||||
|
n_nega + cfg_scale * (n_posi - n_nega)
|
||||||
|
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
|
||||||
|
)
|
||||||
|
else:
|
||||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||||
else:
|
else:
|
||||||
noise_pred = noise_pred_posi
|
noise_pred = noise_pred_posi
|
||||||
|
|||||||
@@ -4,13 +4,15 @@ from typing_extensions import Literal
|
|||||||
|
|
||||||
class FlowMatchScheduler():
|
class FlowMatchScheduler():
|
||||||
|
|
||||||
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"):
|
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"):
|
||||||
self.set_timesteps_fn = {
|
self.set_timesteps_fn = {
|
||||||
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
|
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
|
||||||
"Wan": FlowMatchScheduler.set_timesteps_wan,
|
"Wan": FlowMatchScheduler.set_timesteps_wan,
|
||||||
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
|
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
|
||||||
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
|
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
|
||||||
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
||||||
|
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
|
||||||
|
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
|
||||||
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
||||||
self.num_train_timesteps = 1000
|
self.num_train_timesteps = 1000
|
||||||
|
|
||||||
@@ -70,6 +72,28 @@ class FlowMatchScheduler():
|
|||||||
timesteps = sigmas * num_train_timesteps
|
timesteps = sigmas * num_train_timesteps
|
||||||
return sigmas, timesteps
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_qwen_image_lightning(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
base_shift = math.log(3)
|
||||||
|
max_shift = math.log(3)
|
||||||
|
# Sigmas
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
# Mu
|
||||||
|
if exponential_shift_mu is not None:
|
||||||
|
mu = exponential_shift_mu
|
||||||
|
elif dynamic_shift_len is not None:
|
||||||
|
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len, base_shift=base_shift, max_shift=max_shift)
|
||||||
|
else:
|
||||||
|
mu = 0.8
|
||||||
|
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
||||||
|
# Timesteps
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def compute_empirical_mu(image_seq_len, num_steps):
|
def compute_empirical_mu(image_seq_len, num_steps):
|
||||||
a1, b1 = 8.73809524e-05, 1.89833333
|
a1, b1 = 8.73809524e-05, 1.89833333
|
||||||
@@ -122,6 +146,34 @@ class FlowMatchScheduler():
|
|||||||
timesteps[timestep_id] = timestep
|
timesteps[timestep_id] = timestep
|
||||||
return sigmas, timesteps
|
return sigmas, timesteps
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
|
||||||
|
num_train_timesteps = 1000
|
||||||
|
if special_case == "stage2":
|
||||||
|
sigmas = torch.Tensor([0.909375, 0.725, 0.421875])
|
||||||
|
elif special_case == "ditilled_stage1":
|
||||||
|
sigmas = torch.Tensor([1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875])
|
||||||
|
else:
|
||||||
|
dynamic_shift_len = dynamic_shift_len or 4096
|
||||||
|
sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image(
|
||||||
|
image_seq_len=dynamic_shift_len,
|
||||||
|
base_seq_len=1024,
|
||||||
|
max_seq_len=4096,
|
||||||
|
base_shift=0.95,
|
||||||
|
max_shift=2.05,
|
||||||
|
)
|
||||||
|
sigma_min = 0.0
|
||||||
|
sigma_max = 1.0
|
||||||
|
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||||
|
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||||
|
sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1))
|
||||||
|
# Shift terminal
|
||||||
|
one_minus_z = 1.0 - sigmas
|
||||||
|
scale_factor = one_minus_z[-1] / (1 - terminal)
|
||||||
|
sigmas = 1.0 - (one_minus_z / scale_factor)
|
||||||
|
timesteps = sigmas * num_train_timesteps
|
||||||
|
return sigmas, timesteps
|
||||||
|
|
||||||
def set_training_weight(self):
|
def set_training_weight(self):
|
||||||
steps = 1000
|
steps = 1000
|
||||||
x = self.timesteps
|
x = self.timesteps
|
||||||
|
|||||||
@@ -13,9 +13,16 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
|||||||
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||||
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||||
|
|
||||||
|
if "first_frame_latents" in inputs:
|
||||||
|
inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"]
|
||||||
|
|
||||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||||
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
|
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
|
||||||
|
|
||||||
|
if "first_frame_latents" in inputs:
|
||||||
|
noise_pred = noise_pred[:, :, 1:]
|
||||||
|
training_target = training_target[:, :, 1:]
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||||
loss = loss * pipe.scheduler.training_weight(timestep)
|
loss = loss * pipe.scheduler.training_weight(timestep)
|
||||||
return loss
|
return loss
|
||||||
|
|||||||
1351
diffsynth/models/ltx2_audio_vae.py
Normal file
1351
diffsynth/models/ltx2_audio_vae.py
Normal file
File diff suppressed because it is too large
Load Diff
371
diffsynth/models/ltx2_common.py
Normal file
371
diffsynth/models/ltx2_common.py
Normal file
@@ -0,0 +1,371 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import NamedTuple, Protocol, Tuple
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class VideoPixelShape(NamedTuple):
|
||||||
|
"""
|
||||||
|
Shape of the tensor representing the video pixel array. Assumes BGR channel format.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch: int
|
||||||
|
frames: int
|
||||||
|
height: int
|
||||||
|
width: int
|
||||||
|
fps: float
|
||||||
|
|
||||||
|
|
||||||
|
class SpatioTemporalScaleFactors(NamedTuple):
|
||||||
|
"""
|
||||||
|
Describes the spatiotemporal downscaling between decoded video space and
|
||||||
|
the corresponding VAE latent grid.
|
||||||
|
"""
|
||||||
|
|
||||||
|
time: int
|
||||||
|
width: int
|
||||||
|
height: int
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default(cls) -> "SpatioTemporalScaleFactors":
|
||||||
|
return cls(time=8, width=32, height=32)
|
||||||
|
|
||||||
|
|
||||||
|
VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
|
||||||
|
|
||||||
|
|
||||||
|
class VideoLatentShape(NamedTuple):
|
||||||
|
"""
|
||||||
|
Shape of the tensor representing video in VAE latent space.
|
||||||
|
The latent representation is a 5D tensor with dimensions ordered as
|
||||||
|
(batch, channels, frames, height, width). Spatial and temporal dimensions
|
||||||
|
are downscaled relative to pixel space according to the VAE's scale factors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch: int
|
||||||
|
channels: int
|
||||||
|
frames: int
|
||||||
|
height: int
|
||||||
|
width: int
|
||||||
|
|
||||||
|
def to_torch_shape(self) -> torch.Size:
|
||||||
|
return torch.Size([self.batch, self.channels, self.frames, self.height, self.width])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_torch_shape(shape: torch.Size) -> "VideoLatentShape":
|
||||||
|
return VideoLatentShape(
|
||||||
|
batch=shape[0],
|
||||||
|
channels=shape[1],
|
||||||
|
frames=shape[2],
|
||||||
|
height=shape[3],
|
||||||
|
width=shape[4],
|
||||||
|
)
|
||||||
|
|
||||||
|
def mask_shape(self) -> "VideoLatentShape":
|
||||||
|
return self._replace(channels=1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pixel_shape(
|
||||||
|
shape: VideoPixelShape,
|
||||||
|
latent_channels: int = 128,
|
||||||
|
scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS,
|
||||||
|
) -> "VideoLatentShape":
|
||||||
|
frames = (shape.frames - 1) // scale_factors[0] + 1
|
||||||
|
height = shape.height // scale_factors[1]
|
||||||
|
width = shape.width // scale_factors[2]
|
||||||
|
|
||||||
|
return VideoLatentShape(
|
||||||
|
batch=shape.batch,
|
||||||
|
channels=latent_channels,
|
||||||
|
frames=frames,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
)
|
||||||
|
|
||||||
|
def upscale(self, scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS) -> "VideoLatentShape":
|
||||||
|
return self._replace(
|
||||||
|
channels=3,
|
||||||
|
frames=(self.frames - 1) * scale_factors.time + 1,
|
||||||
|
height=self.height * scale_factors.height,
|
||||||
|
width=self.width * scale_factors.width,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioLatentShape(NamedTuple):
|
||||||
|
"""
|
||||||
|
Shape of audio in VAE latent space: (batch, channels, frames, mel_bins).
|
||||||
|
mel_bins is the number of frequency bins from the mel-spectrogram encoding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch: int
|
||||||
|
channels: int
|
||||||
|
frames: int
|
||||||
|
mel_bins: int
|
||||||
|
|
||||||
|
def to_torch_shape(self) -> torch.Size:
|
||||||
|
return torch.Size([self.batch, self.channels, self.frames, self.mel_bins])
|
||||||
|
|
||||||
|
def mask_shape(self) -> "AudioLatentShape":
|
||||||
|
return self._replace(channels=1, mel_bins=1)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_torch_shape(shape: torch.Size) -> "AudioLatentShape":
|
||||||
|
return AudioLatentShape(
|
||||||
|
batch=shape[0],
|
||||||
|
channels=shape[1],
|
||||||
|
frames=shape[2],
|
||||||
|
mel_bins=shape[3],
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_duration(
|
||||||
|
batch: int,
|
||||||
|
duration: float,
|
||||||
|
channels: int = 8,
|
||||||
|
mel_bins: int = 16,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
hop_length: int = 160,
|
||||||
|
audio_latent_downsample_factor: int = 4,
|
||||||
|
) -> "AudioLatentShape":
|
||||||
|
latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor)
|
||||||
|
|
||||||
|
return AudioLatentShape(
|
||||||
|
batch=batch,
|
||||||
|
channels=channels,
|
||||||
|
frames=round(duration * latents_per_second),
|
||||||
|
mel_bins=mel_bins,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_video_pixel_shape(
|
||||||
|
shape: VideoPixelShape,
|
||||||
|
channels: int = 8,
|
||||||
|
mel_bins: int = 16,
|
||||||
|
sample_rate: int = 16000,
|
||||||
|
hop_length: int = 160,
|
||||||
|
audio_latent_downsample_factor: int = 4,
|
||||||
|
) -> "AudioLatentShape":
|
||||||
|
return AudioLatentShape.from_duration(
|
||||||
|
batch=shape.batch,
|
||||||
|
duration=float(shape.frames) / float(shape.fps),
|
||||||
|
channels=channels,
|
||||||
|
mel_bins=mel_bins,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
hop_length=hop_length,
|
||||||
|
audio_latent_downsample_factor=audio_latent_downsample_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LatentState:
|
||||||
|
"""
|
||||||
|
State of latents during the diffusion denoising process.
|
||||||
|
Attributes:
|
||||||
|
latent: The current noisy latent tensor being denoised.
|
||||||
|
denoise_mask: Mask encoding the denoising strength for each token (1 = full denoising, 0 = no denoising).
|
||||||
|
positions: Positional indices for each latent element, used for positional embeddings.
|
||||||
|
clean_latent: Initial state of the latent before denoising, may include conditioning latents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
latent: torch.Tensor
|
||||||
|
denoise_mask: torch.Tensor
|
||||||
|
positions: torch.Tensor
|
||||||
|
clean_latent: torch.Tensor
|
||||||
|
|
||||||
|
def clone(self) -> "LatentState":
|
||||||
|
return LatentState(
|
||||||
|
latent=self.latent.clone(),
|
||||||
|
denoise_mask=self.denoise_mask.clone(),
|
||||||
|
positions=self.positions.clone(),
|
||||||
|
clean_latent=self.clean_latent.clone(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NormType(Enum):
|
||||||
|
"""Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
|
||||||
|
|
||||||
|
GROUP = "group"
|
||||||
|
PIXEL = "pixel"
|
||||||
|
|
||||||
|
|
||||||
|
class PixelNorm(nn.Module):
|
||||||
|
"""
|
||||||
|
Per-pixel (per-location) RMS normalization layer.
|
||||||
|
For each element along the chosen dimension, this layer normalizes the tensor
|
||||||
|
by the root-mean-square of its values across that dimension:
|
||||||
|
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dim: Dimension along which to compute the RMS (typically channels).
|
||||||
|
eps: Small constant added for numerical stability.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply RMS normalization along the configured dimension.
|
||||||
|
"""
|
||||||
|
# Compute mean of squared values along `dim`, keep dimensions for broadcasting.
|
||||||
|
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
|
||||||
|
# Normalize by the root-mean-square (RMS).
|
||||||
|
rms = torch.sqrt(mean_sq + self.eps)
|
||||||
|
return x / rms
|
||||||
|
|
||||||
|
|
||||||
|
def build_normalization_layer(
|
||||||
|
in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Create a normalization layer based on the normalization type.
|
||||||
|
Args:
|
||||||
|
in_channels: Number of input channels
|
||||||
|
num_groups: Number of groups for group normalization
|
||||||
|
normtype: Type of normalization: "group" or "pixel"
|
||||||
|
Returns:
|
||||||
|
A normalization layer
|
||||||
|
"""
|
||||||
|
if normtype == NormType.GROUP:
|
||||||
|
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||||
|
if normtype == NormType.PIXEL:
|
||||||
|
return PixelNorm(dim=1, eps=1e-6)
|
||||||
|
raise ValueError(f"Invalid normalization type: {normtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor:
|
||||||
|
"""Root-mean-square (RMS) normalize `x` over its last dimension.
|
||||||
|
Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized
|
||||||
|
shape and forwards `weight` and `eps`.
|
||||||
|
"""
|
||||||
|
return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Modality:
|
||||||
|
"""
|
||||||
|
Input data for a single modality (video or audio) in the transformer.
|
||||||
|
Bundles the latent tokens, timestep embeddings, positional information,
|
||||||
|
and text conditioning context for processing by the diffusion transformer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
latent: (
|
||||||
|
torch.Tensor
|
||||||
|
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
|
||||||
|
timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
|
||||||
|
positions: (
|
||||||
|
torch.Tensor
|
||||||
|
) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens
|
||||||
|
context: torch.Tensor
|
||||||
|
enabled: bool = True
|
||||||
|
context_mask: torch.Tensor | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def to_denoised(
|
||||||
|
sample: torch.Tensor,
|
||||||
|
velocity: torch.Tensor,
|
||||||
|
sigma: float | torch.Tensor,
|
||||||
|
calc_dtype: torch.dtype = torch.float32,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert the sample and its denoising velocity to denoised sample.
|
||||||
|
Returns:
|
||||||
|
Denoised sample
|
||||||
|
"""
|
||||||
|
if isinstance(sigma, torch.Tensor):
|
||||||
|
sigma = sigma.to(calc_dtype)
|
||||||
|
return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Patchifier(Protocol):
|
||||||
|
"""
|
||||||
|
Protocol for patchifiers that convert latent tensors into patches and assemble them back.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def patchify(
|
||||||
|
self,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
Convert latent tensors into flattened patch tokens.
|
||||||
|
Args:
|
||||||
|
latents: Latent tensor to patchify.
|
||||||
|
Returns:
|
||||||
|
Flattened patch tokens tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def unpatchify(
|
||||||
|
self,
|
||||||
|
latents: torch.Tensor,
|
||||||
|
output_shape: AudioLatentShape | VideoLatentShape,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Converts latent tensors between spatio-temporal formats and flattened sequence representations.
|
||||||
|
Args:
|
||||||
|
latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
|
||||||
|
output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
|
||||||
|
VideoLatentShape.
|
||||||
|
Returns:
|
||||||
|
Dense latent tensor restored from the flattened representation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def patch_size(self) -> Tuple[int, int, int]:
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
Returns the patch size as a tuple of (temporal, height, width) dimensions
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_patch_grid_bounds(
|
||||||
|
self,
|
||||||
|
output_shape: AudioLatentShape | VideoLatentShape,
|
||||||
|
device: torch.device | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
Compute metadata describing where each latent patch resides within the
|
||||||
|
grid specified by `output_shape`.
|
||||||
|
Args:
|
||||||
|
output_shape: Target grid layout for the patches.
|
||||||
|
device: Target device for the returned tensor.
|
||||||
|
Returns:
|
||||||
|
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_pixel_coords(
|
||||||
|
latent_coords: torch.Tensor,
|
||||||
|
scale_factors: SpatioTemporalScaleFactors,
|
||||||
|
causal_fix: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
|
||||||
|
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
|
||||||
|
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
|
||||||
|
Args:
|
||||||
|
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
|
||||||
|
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
|
||||||
|
per axis.
|
||||||
|
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
|
||||||
|
that treat frame zero differently still yield non-negative timestamps.
|
||||||
|
"""
|
||||||
|
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
|
||||||
|
broadcast_shape = [1] * latent_coords.ndim
|
||||||
|
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
|
||||||
|
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
|
||||||
|
|
||||||
|
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
|
||||||
|
pixel_coords = latent_coords * scale_tensor
|
||||||
|
|
||||||
|
if causal_fix:
|
||||||
|
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
|
||||||
|
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
|
||||||
|
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
||||||
|
|
||||||
|
return pixel_coords
|
||||||
1451
diffsynth/models/ltx2_dit.py
Normal file
1451
diffsynth/models/ltx2_dit.py
Normal file
File diff suppressed because it is too large
Load Diff
366
diffsynth/models/ltx2_text_encoder.py
Normal file
366
diffsynth/models/ltx2_text_encoder.py
Normal file
@@ -0,0 +1,366 @@
|
|||||||
|
import torch
|
||||||
|
from transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer
|
||||||
|
from .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention,
|
||||||
|
FeedForward)
|
||||||
|
from .ltx2_common import rms_norm
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2TextEncoder(Gemma3ForConditionalGeneration):
|
||||||
|
def __init__(self):
|
||||||
|
config = Gemma3Config(
|
||||||
|
**{
|
||||||
|
"architectures": ["Gemma3ForConditionalGeneration"],
|
||||||
|
"boi_token_index": 255999,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"eoi_token_index": 256000,
|
||||||
|
"eos_token_id": [1, 106],
|
||||||
|
"image_token_index": 262144,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"mm_tokens_per_image": 256,
|
||||||
|
"model_type": "gemma3",
|
||||||
|
"text_config": {
|
||||||
|
"_sliding_window_pattern": 6,
|
||||||
|
"attention_bias": False,
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"attn_logit_softcapping": None,
|
||||||
|
"cache_implementation": "hybrid",
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"final_logit_softcapping": None,
|
||||||
|
"head_dim": 256,
|
||||||
|
"hidden_activation": "gelu_pytorch_tanh",
|
||||||
|
"hidden_size": 3840,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 15360,
|
||||||
|
"layer_types": [
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||||
|
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention"
|
||||||
|
],
|
||||||
|
"max_position_embeddings": 131072,
|
||||||
|
"model_type": "gemma3_text",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_hidden_layers": 48,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"query_pre_attn_scalar": 256,
|
||||||
|
"rms_norm_eps": 1e-06,
|
||||||
|
"rope_local_base_freq": 10000,
|
||||||
|
"rope_scaling": {
|
||||||
|
"factor": 8.0,
|
||||||
|
"rope_type": "linear"
|
||||||
|
},
|
||||||
|
"rope_theta": 1000000,
|
||||||
|
"sliding_window": 1024,
|
||||||
|
"sliding_window_pattern": 6,
|
||||||
|
"use_bidirectional_attention": False,
|
||||||
|
"use_cache": True,
|
||||||
|
"vocab_size": 262208
|
||||||
|
},
|
||||||
|
"transformers_version": "4.57.3",
|
||||||
|
"vision_config": {
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"dtype": "bfloat16",
|
||||||
|
"hidden_act": "gelu_pytorch_tanh",
|
||||||
|
"hidden_size": 1152,
|
||||||
|
"image_size": 896,
|
||||||
|
"intermediate_size": 4304,
|
||||||
|
"layer_norm_eps": 1e-06,
|
||||||
|
"model_type": "siglip_vision_model",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_channels": 3,
|
||||||
|
"num_hidden_layers": 27,
|
||||||
|
"patch_size": 14,
|
||||||
|
"vision_use_head": False
|
||||||
|
}
|
||||||
|
})
|
||||||
|
super().__init__(config)
|
||||||
|
|
||||||
|
|
||||||
|
class LTXVGemmaTokenizer:
|
||||||
|
"""
|
||||||
|
Tokenizer wrapper for Gemma models compatible with LTXV processes.
|
||||||
|
This class wraps HuggingFace's `AutoTokenizer` for use with Gemma text encoders,
|
||||||
|
ensuring correct settings and output formatting for downstream consumption.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tokenizer_path: str, max_length: int = 1024):
|
||||||
|
"""
|
||||||
|
Initialize the tokenizer.
|
||||||
|
Args:
|
||||||
|
tokenizer_path (str): Path to the pretrained tokenizer files or model directory.
|
||||||
|
max_length (int, optional): Max sequence length for encoding. Defaults to 256.
|
||||||
|
"""
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
tokenizer_path, local_files_only=True, model_max_length=max_length
|
||||||
|
)
|
||||||
|
# Gemma expects left padding for chat-style prompts; for plain text it doesn't matter much.
|
||||||
|
self.tokenizer.padding_side = "left"
|
||||||
|
if self.tokenizer.pad_token is None:
|
||||||
|
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||||
|
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict[str, list[tuple[int, int]]]:
|
||||||
|
"""
|
||||||
|
Tokenize the given text and return token IDs and attention weights.
|
||||||
|
Args:
|
||||||
|
text (str): The input string to tokenize.
|
||||||
|
return_word_ids (bool, optional): If True, includes the token's position (index) in the output tuples.
|
||||||
|
If False (default), omits the indices.
|
||||||
|
Returns:
|
||||||
|
dict[str, list[tuple[int, int]]] OR dict[str, list[tuple[int, int, int]]]:
|
||||||
|
A dictionary with a "gemma" key mapping to:
|
||||||
|
- a list of (token_id, attention_mask) tuples if return_word_ids is False;
|
||||||
|
- a list of (token_id, attention_mask, index) tuples if return_word_ids is True.
|
||||||
|
Example:
|
||||||
|
>>> tokenizer = LTXVGemmaTokenizer("path/to/tokenizer", max_length=8)
|
||||||
|
>>> tokenizer.tokenize_with_weights("hello world")
|
||||||
|
{'gemma': [(1234, 1), (5678, 1), (2, 0), ...]}
|
||||||
|
"""
|
||||||
|
text = text.strip()
|
||||||
|
encoded = self.tokenizer(
|
||||||
|
text,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=self.max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="pt",
|
||||||
|
)
|
||||||
|
input_ids = encoded.input_ids
|
||||||
|
attention_mask = encoded.attention_mask
|
||||||
|
tuples = [
|
||||||
|
(token_id, attn, i) for i, (token_id, attn) in enumerate(zip(input_ids[0], attention_mask[0], strict=True))
|
||||||
|
]
|
||||||
|
out = {"gemma": tuples}
|
||||||
|
|
||||||
|
if not return_word_ids:
|
||||||
|
# Return only (token_id, attention_mask) pairs, omitting token position
|
||||||
|
out = {k: [(t, w) for t, w, _ in v] for k, v in out.items()}
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class GemmaFeaturesExtractorProjLinear(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Feature extractor module for Gemma models.
|
||||||
|
This module applies a single linear projection to the input tensor.
|
||||||
|
It expects a flattened feature tensor of shape (batch_size, 3840*49).
|
||||||
|
The linear layer maps this to a (batch_size, 3840) embedding.
|
||||||
|
Attributes:
|
||||||
|
aggregate_embed (torch.nn.Linear): Linear projection layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""
|
||||||
|
Initialize the GemmaFeaturesExtractorProjLinear module.
|
||||||
|
The input dimension is expected to be 3840 * 49, and the output is 3840.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Forward pass for the feature extractor.
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49).
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Output tensor of shape (batch_size, 3840).
|
||||||
|
"""
|
||||||
|
return self.aggregate_embed(x)
|
||||||
|
|
||||||
|
|
||||||
|
class _BasicTransformerBlock1D(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim: int,
|
||||||
|
heads: int,
|
||||||
|
dim_head: int,
|
||||||
|
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attn1 = Attention(
|
||||||
|
query_dim=dim,
|
||||||
|
heads=heads,
|
||||||
|
dim_head=dim_head,
|
||||||
|
rope_type=rope_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ff = FeedForward(
|
||||||
|
dim,
|
||||||
|
dim_out=dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
pe: torch.Tensor | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||||
|
|
||||||
|
# 1. Normalization Before Self-Attention
|
||||||
|
norm_hidden_states = rms_norm(hidden_states)
|
||||||
|
|
||||||
|
norm_hidden_states = norm_hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
# 2. Self-Attention
|
||||||
|
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
|
||||||
|
|
||||||
|
hidden_states = attn_output + hidden_states
|
||||||
|
if hidden_states.ndim == 4:
|
||||||
|
hidden_states = hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
# 3. Normalization before Feed-Forward
|
||||||
|
norm_hidden_states = rms_norm(hidden_states)
|
||||||
|
|
||||||
|
# 4. Feed-forward
|
||||||
|
ff_output = self.ff(norm_hidden_states)
|
||||||
|
|
||||||
|
hidden_states = ff_output + hidden_states
|
||||||
|
if hidden_states.ndim == 4:
|
||||||
|
hidden_states = hidden_states.squeeze(1)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class Embeddings1DConnector(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or
|
||||||
|
other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can
|
||||||
|
substitute padded positions with learnable registers. The module is highly configurable for head size, number of
|
||||||
|
layers, and register usage.
|
||||||
|
Args:
|
||||||
|
attention_head_dim (int): Dimension of each attention head (default=128).
|
||||||
|
num_attention_heads (int): Number of attention heads (default=30).
|
||||||
|
num_layers (int): Number of transformer layers (default=2).
|
||||||
|
positional_embedding_theta (float): Scaling factor for position embedding (default=10000.0).
|
||||||
|
positional_embedding_max_pos (list[int] | None): Max positions for positional embeddings (default=[1]).
|
||||||
|
causal_temporal_positioning (bool): If True, uses causal attention (default=False).
|
||||||
|
num_learnable_registers (int | None): Number of learnable registers to replace padded tokens. If None, disables
|
||||||
|
register replacement. (default=128)
|
||||||
|
rope_type (LTXRopeType): The RoPE variant to use (default=DEFAULT_ROPE_TYPE).
|
||||||
|
double_precision_rope (bool): Use double precision rope calculation (default=False).
|
||||||
|
"""
|
||||||
|
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
attention_head_dim: int = 128,
|
||||||
|
num_attention_heads: int = 30,
|
||||||
|
num_layers: int = 2,
|
||||||
|
positional_embedding_theta: float = 10000.0,
|
||||||
|
positional_embedding_max_pos: list[int] | None = [4096],
|
||||||
|
causal_temporal_positioning: bool = False,
|
||||||
|
num_learnable_registers: int | None = 128,
|
||||||
|
rope_type: LTXRopeType = LTXRopeType.SPLIT,
|
||||||
|
double_precision_rope: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.inner_dim = num_attention_heads * attention_head_dim
|
||||||
|
self.causal_temporal_positioning = causal_temporal_positioning
|
||||||
|
self.positional_embedding_theta = positional_embedding_theta
|
||||||
|
self.positional_embedding_max_pos = (
|
||||||
|
positional_embedding_max_pos if positional_embedding_max_pos is not None else [1]
|
||||||
|
)
|
||||||
|
self.rope_type = rope_type
|
||||||
|
self.double_precision_rope = double_precision_rope
|
||||||
|
self.transformer_1d_blocks = torch.nn.ModuleList(
|
||||||
|
[
|
||||||
|
_BasicTransformerBlock1D(
|
||||||
|
dim=self.inner_dim,
|
||||||
|
heads=num_attention_heads,
|
||||||
|
dim_head=attention_head_dim,
|
||||||
|
rope_type=rope_type,
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_learnable_registers = num_learnable_registers
|
||||||
|
if self.num_learnable_registers:
|
||||||
|
self.learnable_registers = torch.nn.Parameter(
|
||||||
|
torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
def _replace_padded_with_learnable_registers(
|
||||||
|
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert hidden_states.shape[1] % self.num_learnable_registers == 0, (
|
||||||
|
f"Hidden states sequence length {hidden_states.shape[1]} must be divisible by num_learnable_registers "
|
||||||
|
f"{self.num_learnable_registers}."
|
||||||
|
)
|
||||||
|
|
||||||
|
num_registers_duplications = hidden_states.shape[1] // self.num_learnable_registers
|
||||||
|
learnable_registers = torch.tile(self.learnable_registers, (num_registers_duplications, 1))
|
||||||
|
attention_mask_binary = (attention_mask.squeeze(1).squeeze(1).unsqueeze(-1) >= -9000.0).int()
|
||||||
|
|
||||||
|
non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :]
|
||||||
|
non_zero_nums = non_zero_hidden_states.shape[1]
|
||||||
|
pad_length = hidden_states.shape[1] - non_zero_nums
|
||||||
|
adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
|
||||||
|
flipped_mask = torch.flip(attention_mask_binary, dims=[1])
|
||||||
|
hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers
|
||||||
|
|
||||||
|
attention_mask = torch.full_like(
|
||||||
|
attention_mask,
|
||||||
|
0.0,
|
||||||
|
dtype=attention_mask.dtype,
|
||||||
|
device=attention_mask.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
return hidden_states, attention_mask
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Forward pass of Embeddings1DConnector.
|
||||||
|
Args:
|
||||||
|
hidden_states (torch.Tensor): Input tensor of embeddings (shape [batch, seq_len, feature_dim]).
|
||||||
|
attention_mask (torch.Tensor|None): Optional mask for valid tokens (shape compatible with hidden_states).
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor, torch.Tensor]: Processed features and the corresponding (possibly modified) mask.
|
||||||
|
"""
|
||||||
|
if self.num_learnable_registers:
|
||||||
|
hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask)
|
||||||
|
|
||||||
|
indices_grid = torch.arange(hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device)
|
||||||
|
indices_grid = indices_grid[None, None, :]
|
||||||
|
freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch
|
||||||
|
freqs_cis = precompute_freqs_cis(
|
||||||
|
indices_grid=indices_grid,
|
||||||
|
dim=self.inner_dim,
|
||||||
|
out_dtype=hidden_states.dtype,
|
||||||
|
theta=self.positional_embedding_theta,
|
||||||
|
max_pos=self.positional_embedding_max_pos,
|
||||||
|
num_attention_heads=self.num_attention_heads,
|
||||||
|
rope_type=self.rope_type,
|
||||||
|
freq_grid_generator=freq_grid_generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
for block in self.transformer_1d_blocks:
|
||||||
|
hidden_states = block(hidden_states, attention_mask=attention_mask, pe=freqs_cis)
|
||||||
|
|
||||||
|
hidden_states = rms_norm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2TextEncoderPostModules(torch.nn.Module):
|
||||||
|
def __init__(self,):
|
||||||
|
super().__init__()
|
||||||
|
self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()
|
||||||
|
self.embeddings_connector = Embeddings1DConnector()
|
||||||
|
self.audio_embeddings_connector = Embeddings1DConnector()
|
||||||
313
diffsynth/models/ltx2_upsampler.py
Normal file
313
diffsynth/models/ltx2_upsampler.py
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
import math
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from .ltx2_video_vae import LTX2VideoEncoder
|
||||||
|
|
||||||
|
class PixelShuffleND(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
N-dimensional pixel shuffle operation for upsampling tensors.
|
||||||
|
Args:
|
||||||
|
dims (int): Number of dimensions to apply pixel shuffle to.
|
||||||
|
- 1: Temporal (e.g., frames)
|
||||||
|
- 2: Spatial (e.g., height and width)
|
||||||
|
- 3: Spatiotemporal (e.g., depth, height, width)
|
||||||
|
upscale_factors (tuple[int, int, int], optional): Upscaling factors for each dimension.
|
||||||
|
For dims=1, only the first value is used.
|
||||||
|
For dims=2, the first two values are used.
|
||||||
|
For dims=3, all three values are used.
|
||||||
|
The input tensor is rearranged so that the channel dimension is split into
|
||||||
|
smaller channels and upscaling factors, and the upscaling factors are moved
|
||||||
|
into the corresponding spatial/temporal dimensions.
|
||||||
|
Note:
|
||||||
|
This operation is equivalent to the patchifier operation in for the models. Consider
|
||||||
|
using this class instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dims: int, upscale_factors: tuple[int, int, int] = (2, 2, 2)):
|
||||||
|
super().__init__()
|
||||||
|
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
|
||||||
|
self.dims = dims
|
||||||
|
self.upscale_factors = upscale_factors
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.dims == 3:
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||||
|
p1=self.upscale_factors[0],
|
||||||
|
p2=self.upscale_factors[1],
|
||||||
|
p3=self.upscale_factors[2],
|
||||||
|
)
|
||||||
|
elif self.dims == 2:
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
||||||
|
p1=self.upscale_factors[0],
|
||||||
|
p2=self.upscale_factors[1],
|
||||||
|
)
|
||||||
|
elif self.dims == 1:
|
||||||
|
return rearrange(
|
||||||
|
x,
|
||||||
|
"b (c p1) f h w -> b c (f p1) h w",
|
||||||
|
p1=self.upscale_factors[0],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported dims: {self.dims}")
|
||||||
|
|
||||||
|
|
||||||
|
class ResBlock(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Residual block with two convolutional layers, group normalization, and SiLU activation.
|
||||||
|
Args:
|
||||||
|
channels (int): Number of input and output channels.
|
||||||
|
mid_channels (Optional[int]): Number of channels in the intermediate convolution layer. Defaults to `channels`
|
||||||
|
if not specified.
|
||||||
|
dims (int): Dimensionality of the convolution (2 for Conv2d, 3 for Conv3d). Defaults to 3.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
|
||||||
|
super().__init__()
|
||||||
|
if mid_channels is None:
|
||||||
|
mid_channels = channels
|
||||||
|
|
||||||
|
conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||||
|
|
||||||
|
self.conv1 = conv(channels, mid_channels, kernel_size=3, padding=1)
|
||||||
|
self.norm1 = torch.nn.GroupNorm(32, mid_channels)
|
||||||
|
self.conv2 = conv(mid_channels, channels, kernel_size=3, padding=1)
|
||||||
|
self.norm2 = torch.nn.GroupNorm(32, channels)
|
||||||
|
self.activation = torch.nn.SiLU()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
residual = x
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.activation(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.norm2(x)
|
||||||
|
x = self.activation(x + residual)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BlurDownsample(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
|
||||||
|
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:
|
||||||
|
super().__init__()
|
||||||
|
assert dims in (2, 3)
|
||||||
|
assert isinstance(stride, int)
|
||||||
|
assert stride >= 1
|
||||||
|
assert kernel_size >= 3
|
||||||
|
assert kernel_size % 2 == 1
|
||||||
|
self.dims = dims
|
||||||
|
self.stride = stride
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
|
||||||
|
# 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from
|
||||||
|
# the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and
|
||||||
|
# provides a smooth approximation of a Gaussian filter (often called a "binomial filter").
|
||||||
|
# The 2D kernel is constructed as the outer product and normalized.
|
||||||
|
k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)])
|
||||||
|
k2d = k[:, None] @ k[None, :]
|
||||||
|
k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size)
|
||||||
|
self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.stride == 1:
|
||||||
|
return x
|
||||||
|
|
||||||
|
if self.dims == 2:
|
||||||
|
return self._apply_2d(x)
|
||||||
|
else:
|
||||||
|
# dims == 3: apply per-frame on H,W
|
||||||
|
b, _, f, _, _ = x.shape
|
||||||
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||||
|
x = self._apply_2d(x)
|
||||||
|
h2, w2 = x.shape[-2:]
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _apply_2d(self, x2d: torch.Tensor) -> torch.Tensor:
|
||||||
|
c = x2d.shape[1]
|
||||||
|
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
|
||||||
|
x2d = F.conv2d(x2d, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
|
||||||
|
return x2d
|
||||||
|
|
||||||
|
|
||||||
|
def _rational_for_scale(scale: float) -> Tuple[int, int]:
|
||||||
|
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
|
||||||
|
if float(scale) not in mapping:
|
||||||
|
raise ValueError(f"Unsupported scale {scale}. Choose from {list(mapping.keys())}")
|
||||||
|
return mapping[float(scale)]
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialRationalResampler(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
|
||||||
|
downsample by 'den' using fixed blur + stride. Operates on H,W only.
|
||||||
|
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
|
||||||
|
Args:
|
||||||
|
mid_channels (`int`): Number of intermediate channels for the convolution layer
|
||||||
|
scale (`float`): Spatial scaling factor. Supported values are:
|
||||||
|
- 0.75: Downsample by 3/4 (reduce spatial size)
|
||||||
|
- 1.5: Upsample by 3/2 (increase spatial size)
|
||||||
|
- 2.0: Upsample by 2x (double spatial size)
|
||||||
|
- 4.0: Upsample by 4x (quadruple spatial size)
|
||||||
|
Any other value will raise a ValueError.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mid_channels: int, scale: float):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = float(scale)
|
||||||
|
self.num, self.den = _rational_for_scale(self.scale)
|
||||||
|
self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)
|
||||||
|
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
|
||||||
|
self.blur_down = BlurDownsample(dims=2, stride=self.den)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
b, _, f, _, _ = x.shape
|
||||||
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.pixel_shuffle(x)
|
||||||
|
x = self.blur_down(x)
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2LatentUpsampler(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Model to upsample VAE latents spatially and/or temporally.
|
||||||
|
Args:
|
||||||
|
in_channels (`int`): Number of channels in the input latent
|
||||||
|
mid_channels (`int`): Number of channels in the middle layers
|
||||||
|
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
|
||||||
|
dims (`int`): Number of dimensions for convolutions (2 or 3)
|
||||||
|
spatial_upsample (`bool`): Whether to spatially upsample the latent
|
||||||
|
temporal_upsample (`bool`): Whether to temporally upsample the latent
|
||||||
|
spatial_scale (`float`): Scale factor for spatial upsampling
|
||||||
|
rational_resampler (`bool`): Whether to use a rational resampler for spatial upsampling
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 128,
|
||||||
|
mid_channels: int = 1024,
|
||||||
|
num_blocks_per_stage: int = 4,
|
||||||
|
dims: int = 3,
|
||||||
|
spatial_upsample: bool = True,
|
||||||
|
temporal_upsample: bool = False,
|
||||||
|
spatial_scale: float = 2.0,
|
||||||
|
rational_resampler: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.mid_channels = mid_channels
|
||||||
|
self.num_blocks_per_stage = num_blocks_per_stage
|
||||||
|
self.dims = dims
|
||||||
|
self.spatial_upsample = spatial_upsample
|
||||||
|
self.temporal_upsample = temporal_upsample
|
||||||
|
self.spatial_scale = float(spatial_scale)
|
||||||
|
self.rational_resampler = rational_resampler
|
||||||
|
|
||||||
|
conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||||
|
|
||||||
|
self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||||
|
self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
|
||||||
|
self.initial_activation = torch.nn.SiLU()
|
||||||
|
|
||||||
|
self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
|
||||||
|
|
||||||
|
if spatial_upsample and temporal_upsample:
|
||||||
|
self.upsampler = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
|
||||||
|
PixelShuffleND(3),
|
||||||
|
)
|
||||||
|
elif spatial_upsample:
|
||||||
|
if rational_resampler:
|
||||||
|
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale)
|
||||||
|
else:
|
||||||
|
self.upsampler = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
|
||||||
|
PixelShuffleND(2),
|
||||||
|
)
|
||||||
|
elif temporal_upsample:
|
||||||
|
self.upsampler = torch.nn.Sequential(
|
||||||
|
torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
|
||||||
|
PixelShuffleND(1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Either spatial_upsample or temporal_upsample must be True")
|
||||||
|
|
||||||
|
self.post_upsample_res_blocks = torch.nn.ModuleList(
|
||||||
|
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
def forward(self, latent: torch.Tensor) -> torch.Tensor:
|
||||||
|
b, _, f, _, _ = latent.shape
|
||||||
|
|
||||||
|
if self.dims == 2:
|
||||||
|
x = rearrange(latent, "b c f h w -> (b f) c h w")
|
||||||
|
x = self.initial_conv(x)
|
||||||
|
x = self.initial_norm(x)
|
||||||
|
x = self.initial_activation(x)
|
||||||
|
|
||||||
|
for block in self.res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.upsampler(x)
|
||||||
|
|
||||||
|
for block in self.post_upsample_res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.final_conv(x)
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||||
|
else:
|
||||||
|
x = self.initial_conv(latent)
|
||||||
|
x = self.initial_norm(x)
|
||||||
|
x = self.initial_activation(x)
|
||||||
|
|
||||||
|
for block in self.res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
if self.temporal_upsample:
|
||||||
|
x = self.upsampler(x)
|
||||||
|
# remove the first frame after upsampling.
|
||||||
|
# This is done because the first frame encodes one pixel frame.
|
||||||
|
x = x[:, :, 1:, :, :]
|
||||||
|
elif isinstance(self.upsampler, SpatialRationalResampler):
|
||||||
|
x = self.upsampler(x)
|
||||||
|
else:
|
||||||
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||||
|
x = self.upsampler(x)
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||||
|
|
||||||
|
for block in self.post_upsample_res_blocks:
|
||||||
|
x = block(x)
|
||||||
|
|
||||||
|
x = self.final_conv(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def upsample_video(latent: torch.Tensor, video_encoder: LTX2VideoEncoder, upsampler: "LTX2LatentUpsampler") -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply upsampling to the latent representation using the provided upsampler,
|
||||||
|
with normalization and un-normalization based on the video encoder's per-channel statistics.
|
||||||
|
Args:
|
||||||
|
latent: Input latent tensor of shape [B, C, F, H, W].
|
||||||
|
video_encoder: VideoEncoder with per_channel_statistics for normalization.
|
||||||
|
upsampler: LTX2LatentUpsampler module to perform upsampling.
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Upsampled and re-normalized latent tensor.
|
||||||
|
"""
|
||||||
|
latent = video_encoder.per_channel_statistics.un_normalize(latent)
|
||||||
|
latent = upsampler(latent)
|
||||||
|
latent = video_encoder.per_channel_statistics.normalize(latent)
|
||||||
|
return latent
|
||||||
2317
diffsynth/models/ltx2_video_vae.py
Normal file
2317
diffsynth/models/ltx2_video_vae.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -29,7 +29,7 @@ class ModelPool:
|
|||||||
module_map = None
|
module_map = None
|
||||||
return module_map
|
return module_map
|
||||||
|
|
||||||
def load_model_file(self, config, path, vram_config, vram_limit=None):
|
def load_model_file(self, config, path, vram_config, vram_limit=None, state_dict=None):
|
||||||
model_class = self.import_model_class(config["model_class"])
|
model_class = self.import_model_class(config["model_class"])
|
||||||
model_config = config.get("extra_kwargs", {})
|
model_config = config.get("extra_kwargs", {})
|
||||||
if "state_dict_converter" in config:
|
if "state_dict_converter" in config:
|
||||||
@@ -43,6 +43,7 @@ class ModelPool:
|
|||||||
state_dict_converter,
|
state_dict_converter,
|
||||||
use_disk_map=True,
|
use_disk_map=True,
|
||||||
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
|
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
|
||||||
|
state_dict=state_dict,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@@ -59,7 +60,7 @@ class ModelPool:
|
|||||||
}
|
}
|
||||||
return vram_config
|
return vram_config
|
||||||
|
|
||||||
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False):
|
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False, state_dict=None):
|
||||||
print(f"Loading models from: {json.dumps(path, indent=4)}")
|
print(f"Loading models from: {json.dumps(path, indent=4)}")
|
||||||
if vram_config is None:
|
if vram_config is None:
|
||||||
vram_config = self.default_vram_config()
|
vram_config = self.default_vram_config()
|
||||||
@@ -67,7 +68,7 @@ class ModelPool:
|
|||||||
loaded = False
|
loaded = False
|
||||||
for config in MODEL_CONFIGS:
|
for config in MODEL_CONFIGS:
|
||||||
if config["model_hash"] == model_hash:
|
if config["model_hash"] == model_hash:
|
||||||
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit)
|
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit, state_dict=state_dict)
|
||||||
if clear_parameters: self.clear_parameters(model)
|
if clear_parameters: self.clear_parameters(model)
|
||||||
self.model.append(model)
|
self.model.append(model)
|
||||||
model_name = config["model_name"]
|
model_name = config["model_name"]
|
||||||
|
|||||||
@@ -93,7 +93,7 @@ def rope_apply(x, freqs, num_heads):
|
|||||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||||
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||||
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||||
freqs = freqs.to(torch.complex64) if freqs.device == "npu" else freqs
|
freqs = freqs.to(torch.complex64) if freqs.device.type == "npu" else freqs
|
||||||
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
||||||
return x_out.to(x.dtype)
|
return x_out.to(x.dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,36 @@ class ZImageTextEncoder(torch.nn.Module):
|
|||||||
def __init__(self, model_size="4B"):
|
def __init__(self, model_size="4B"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config_dict = {
|
config_dict = {
|
||||||
|
"0.6B": Qwen3Config(**{
|
||||||
|
"architectures": [
|
||||||
|
"Qwen3ForCausalLM"
|
||||||
|
],
|
||||||
|
"attention_bias": False,
|
||||||
|
"attention_dropout": 0.0,
|
||||||
|
"bos_token_id": 151643,
|
||||||
|
"eos_token_id": 151645,
|
||||||
|
"head_dim": 128,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"hidden_size": 1024,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 3072,
|
||||||
|
"max_position_embeddings": 40960,
|
||||||
|
"max_window_layers": 28,
|
||||||
|
"model_type": "qwen3",
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_hidden_layers": 28,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"rms_norm_eps": 1e-06,
|
||||||
|
"rope_scaling": None,
|
||||||
|
"rope_theta": 1000000,
|
||||||
|
"sliding_window": None,
|
||||||
|
"tie_word_embeddings": True,
|
||||||
|
"torch_dtype": "bfloat16",
|
||||||
|
"transformers_version": "4.51.0",
|
||||||
|
"use_cache": True,
|
||||||
|
"use_sliding_window": False,
|
||||||
|
"vocab_size": 151936
|
||||||
|
}),
|
||||||
"4B": Qwen3Config(**{
|
"4B": Qwen3Config(**{
|
||||||
"architectures": [
|
"architectures": [
|
||||||
"Qwen3ForCausalLM"
|
"Qwen3ForCausalLM"
|
||||||
|
|||||||
550
diffsynth/pipelines/ltx2_audio_video.py
Normal file
550
diffsynth/pipelines/ltx2_audio_video.py
Normal file
@@ -0,0 +1,550 @@
|
|||||||
|
import torch, types
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from einops import repeat
|
||||||
|
from typing import Optional, Union
|
||||||
|
from einops import rearrange
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
from typing import Optional
|
||||||
|
from transformers import AutoImageProcessor, Gemma3Processor
|
||||||
|
|
||||||
|
from ..core.device.npu_compatible_device import get_device_type
|
||||||
|
from ..diffusion import FlowMatchScheduler
|
||||||
|
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||||
|
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||||
|
|
||||||
|
from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer
|
||||||
|
from ..models.ltx2_dit import LTXModel
|
||||||
|
from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier
|
||||||
|
from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier
|
||||||
|
from ..models.ltx2_upsampler import LTX2LatentUpsampler
|
||||||
|
from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS
|
||||||
|
from ..utils.data.media_io_ltx2 import ltx2_preprocess
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2AudioVideoPipeline(BasePipeline):
|
||||||
|
|
||||||
|
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||||
|
super().__init__(
|
||||||
|
device=device,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
height_division_factor=32,
|
||||||
|
width_division_factor=32,
|
||||||
|
time_division_factor=8,
|
||||||
|
time_division_remainder=1,
|
||||||
|
)
|
||||||
|
self.scheduler = FlowMatchScheduler("LTX-2")
|
||||||
|
self.text_encoder: LTX2TextEncoder = None
|
||||||
|
self.tokenizer: LTXVGemmaTokenizer = None
|
||||||
|
self.processor: Gemma3Processor = None
|
||||||
|
self.text_encoder_post_modules: LTX2TextEncoderPostModules = None
|
||||||
|
self.dit: LTXModel = None
|
||||||
|
self.video_vae_encoder: LTX2VideoEncoder = None
|
||||||
|
self.video_vae_decoder: LTX2VideoDecoder = None
|
||||||
|
self.audio_vae_encoder: LTX2AudioEncoder = None
|
||||||
|
self.audio_vae_decoder: LTX2AudioDecoder = None
|
||||||
|
self.audio_vocoder: LTX2Vocoder = None
|
||||||
|
self.upsampler: LTX2LatentUpsampler = None
|
||||||
|
|
||||||
|
self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1)
|
||||||
|
self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1)
|
||||||
|
|
||||||
|
self.in_iteration_models = ("dit",)
|
||||||
|
self.units = [
|
||||||
|
LTX2AudioVideoUnit_PipelineChecker(),
|
||||||
|
LTX2AudioVideoUnit_ShapeChecker(),
|
||||||
|
LTX2AudioVideoUnit_PromptEmbedder(),
|
||||||
|
LTX2AudioVideoUnit_NoiseInitializer(),
|
||||||
|
LTX2AudioVideoUnit_InputVideoEmbedder(),
|
||||||
|
LTX2AudioVideoUnit_InputImagesEmbedder(),
|
||||||
|
]
|
||||||
|
self.model_fn = model_fn_ltx2
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(
|
||||||
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
device: Union[str, torch.device] = get_device_type(),
|
||||||
|
model_configs: list[ModelConfig] = [],
|
||||||
|
tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
stage2_lora_config: Optional[ModelConfig] = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
# Initialize pipeline
|
||||||
|
pipe = LTX2AudioVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||||
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||||
|
|
||||||
|
# Fetch models
|
||||||
|
pipe.text_encoder = model_pool.fetch_model("ltx2_text_encoder")
|
||||||
|
tokenizer_config.download_if_necessary()
|
||||||
|
pipe.tokenizer = LTXVGemmaTokenizer(tokenizer_path=tokenizer_config.path)
|
||||||
|
image_processor = AutoImageProcessor.from_pretrained(tokenizer_config.path, local_files_only=True)
|
||||||
|
pipe.processor = Gemma3Processor(image_processor=image_processor, tokenizer=pipe.tokenizer.tokenizer)
|
||||||
|
|
||||||
|
pipe.text_encoder_post_modules = model_pool.fetch_model("ltx2_text_encoder_post_modules")
|
||||||
|
pipe.dit = model_pool.fetch_model("ltx2_dit")
|
||||||
|
pipe.video_vae_encoder = model_pool.fetch_model("ltx2_video_vae_encoder")
|
||||||
|
pipe.video_vae_decoder = model_pool.fetch_model("ltx2_video_vae_decoder")
|
||||||
|
pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder")
|
||||||
|
pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder")
|
||||||
|
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
|
||||||
|
|
||||||
|
# Stage 2
|
||||||
|
if stage2_lora_config is not None:
|
||||||
|
stage2_lora_config.download_if_necessary()
|
||||||
|
pipe.stage2_lora_path = stage2_lora_config.path
|
||||||
|
# Optional, currently not used
|
||||||
|
# pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
|
||||||
|
|
||||||
|
# VRAM Management
|
||||||
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm):
|
||||||
|
if inputs_shared["use_two_stage_pipeline"]:
|
||||||
|
latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
|
||||||
|
self.load_models_to_device('upsampler',)
|
||||||
|
latent = self.upsampler(latent)
|
||||||
|
latent = self.video_vae_encoder.per_channel_statistics.normalize(latent)
|
||||||
|
self.scheduler.set_timesteps(special_case="stage2")
|
||||||
|
inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")})
|
||||||
|
denoise_mask_video = 1.0
|
||||||
|
if inputs_shared.get("input_images", None) is not None:
|
||||||
|
latent, denoise_mask_video, initial_latents = self.apply_input_images_to_latents(
|
||||||
|
latent, inputs_shared.pop("input_latents"), inputs_shared["input_images_indexes"],
|
||||||
|
inputs_shared["input_images_strength"], latent.clone())
|
||||||
|
inputs_shared.update({"input_latents_video": initial_latents, "denoise_mask_video": denoise_mask_video})
|
||||||
|
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * denoise_mask_video * inputs_shared[
|
||||||
|
"video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latent
|
||||||
|
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (
|
||||||
|
1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"]
|
||||||
|
|
||||||
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
|
if not inputs_shared["use_distilled_pipeline"]:
|
||||||
|
self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8)
|
||||||
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
|
||||||
|
self.model_fn, 1.0, inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id,
|
||||||
|
noise_pred=noise_pred_video, inpaint_mask=inputs_shared.get("denoise_mask_video", None),
|
||||||
|
input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
|
||||||
|
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
|
||||||
|
noise_pred=noise_pred_audio, **inputs_shared)
|
||||||
|
return inputs_shared
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
# Prompt
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: Optional[str] = "",
|
||||||
|
# Image-to-video
|
||||||
|
denoising_strength: float = 1.0,
|
||||||
|
input_images: Optional[list[Image.Image]] = None,
|
||||||
|
input_images_indexes: Optional[list[int]] = None,
|
||||||
|
input_images_strength: Optional[float] = 1.0,
|
||||||
|
# Randomness
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
rand_device: Optional[str] = "cpu",
|
||||||
|
# Shape
|
||||||
|
height: Optional[int] = 512,
|
||||||
|
width: Optional[int] = 768,
|
||||||
|
num_frames=121,
|
||||||
|
# Classifier-free guidance
|
||||||
|
cfg_scale: Optional[float] = 3.0,
|
||||||
|
cfg_merge: Optional[bool] = False,
|
||||||
|
# Scheduler
|
||||||
|
num_inference_steps: Optional[int] = 40,
|
||||||
|
# VAE tiling
|
||||||
|
tiled: Optional[bool] = True,
|
||||||
|
tile_size_in_pixels: Optional[int] = 512,
|
||||||
|
tile_overlap_in_pixels: Optional[int] = 128,
|
||||||
|
tile_size_in_frames: Optional[int] = 128,
|
||||||
|
tile_overlap_in_frames: Optional[int] = 24,
|
||||||
|
# Special Pipelines
|
||||||
|
use_two_stage_pipeline: Optional[bool] = False,
|
||||||
|
use_distilled_pipeline: Optional[bool] = False,
|
||||||
|
# progress_bar
|
||||||
|
progress_bar_cmd=tqdm,
|
||||||
|
):
|
||||||
|
# Scheduler
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength,
|
||||||
|
special_case="ditilled_stage1" if use_distilled_pipeline else None)
|
||||||
|
# Inputs
|
||||||
|
inputs_posi = {
|
||||||
|
"prompt": prompt,
|
||||||
|
}
|
||||||
|
inputs_nega = {
|
||||||
|
"negative_prompt": negative_prompt,
|
||||||
|
}
|
||||||
|
inputs_shared = {
|
||||||
|
"input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength,
|
||||||
|
"seed": seed, "rand_device": rand_device,
|
||||||
|
"height": height, "width": width, "num_frames": num_frames,
|
||||||
|
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
|
||||||
|
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
|
||||||
|
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
|
||||||
|
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline,
|
||||||
|
"video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier,
|
||||||
|
}
|
||||||
|
for unit in self.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
|
|
||||||
|
# Denoise Stage 1
|
||||||
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
|
||||||
|
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video,
|
||||||
|
inpaint_mask=inputs_shared.get("denoise_mask_video", None), input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
|
||||||
|
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
|
||||||
|
noise_pred=noise_pred_audio, **inputs_shared)
|
||||||
|
|
||||||
|
# Denoise Stage 2
|
||||||
|
inputs_shared = self.stage2_denoise(inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd)
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
self.load_models_to_device(['video_vae_decoder'])
|
||||||
|
video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels,
|
||||||
|
tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames)
|
||||||
|
video = self.vae_output_to_video(video)
|
||||||
|
self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder'])
|
||||||
|
decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"])
|
||||||
|
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
|
||||||
|
return video, decoded_audio
|
||||||
|
|
||||||
|
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength, initial_latents=None, num_frames=121):
|
||||||
|
b, _, f, h, w = latents.shape
|
||||||
|
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device)
|
||||||
|
initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents
|
||||||
|
for idx, input_latent in zip(input_indexes, input_latents):
|
||||||
|
idx = min(max(1 + (idx-1) // 8, 0), f - 1)
|
||||||
|
input_latent = input_latent.to(dtype=latents.dtype, device=latents.device)
|
||||||
|
initial_latents[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent
|
||||||
|
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
|
||||||
|
latents = latents * denoise_mask + initial_latents * (1.0 - denoise_mask)
|
||||||
|
return latents, denoise_mask, initial_latents
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
take_over=True,
|
||||||
|
input_params=("use_distilled_pipeline", "use_two_stage_pipeline"),
|
||||||
|
output_params=("use_two_stage_pipeline", "cfg_scale")
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||||
|
if inputs_shared.get("use_distilled_pipeline", False):
|
||||||
|
inputs_shared["use_two_stage_pipeline"] = True
|
||||||
|
inputs_shared["cfg_scale"] = 1.0
|
||||||
|
print(f"Distilled pipeline requested, setting use_two_stage_pipeline to True, disable CFG by setting cfg_scale to 1.0.")
|
||||||
|
if inputs_shared.get("use_two_stage_pipeline", False):
|
||||||
|
# distill pipeline also uses two-stage, but it does not needs lora
|
||||||
|
if not inputs_shared.get("use_distilled_pipeline", False):
|
||||||
|
if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None):
|
||||||
|
raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.")
|
||||||
|
if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None):
|
||||||
|
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
|
||||||
|
return inputs_shared, inputs_posi, inputs_nega
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit):
|
||||||
|
"""
|
||||||
|
For two-stage pipelines, the resolution must be divisible by 64.
|
||||||
|
For one-stage pipelines, the resolution must be divisible by 32.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width", "num_frames"),
|
||||||
|
output_params=("height", "width", "num_frames"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False):
|
||||||
|
if use_two_stage_pipeline:
|
||||||
|
self.width_division_factor = 64
|
||||||
|
self.height_division_factor = 64
|
||||||
|
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
|
||||||
|
if use_two_stage_pipeline:
|
||||||
|
self.width_division_factor = 32
|
||||||
|
self.height_division_factor = 32
|
||||||
|
return {"height": height, "width": width, "num_frames": num_frames}
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
seperate_cfg=True,
|
||||||
|
input_params_posi={"prompt": "prompt"},
|
||||||
|
input_params_nega={"prompt": "negative_prompt"},
|
||||||
|
output_params=("video_context", "audio_context"),
|
||||||
|
onload_model_names=("text_encoder", "text_encoder_post_modules"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _convert_to_additive_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||||
|
return (attention_mask - 1).to(dtype).reshape(
|
||||||
|
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max
|
||||||
|
|
||||||
|
def _run_connectors(self, pipe, encoded_input: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
connector_attention_mask = self._convert_to_additive_mask(attention_mask, encoded_input.dtype)
|
||||||
|
|
||||||
|
encoded, encoded_connector_attention_mask = pipe.text_encoder_post_modules.embeddings_connector(
|
||||||
|
encoded_input,
|
||||||
|
connector_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
# restore the mask values to int64
|
||||||
|
attention_mask = (encoded_connector_attention_mask < 0.000001).to(torch.int64)
|
||||||
|
attention_mask = attention_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
|
||||||
|
encoded = encoded * attention_mask
|
||||||
|
|
||||||
|
encoded_for_audio, _ = pipe.text_encoder_post_modules.audio_embeddings_connector(
|
||||||
|
encoded_input, connector_attention_mask)
|
||||||
|
|
||||||
|
return encoded, encoded_for_audio, attention_mask.squeeze(-1)
|
||||||
|
|
||||||
|
def _norm_and_concat_padded_batch(
|
||||||
|
self,
|
||||||
|
encoded_text: torch.Tensor,
|
||||||
|
sequence_lengths: torch.Tensor,
|
||||||
|
padding_side: str = "right",
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Normalize and flatten multi-layer hidden states, respecting padding.
|
||||||
|
Performs per-batch, per-layer normalization using masked mean and range,
|
||||||
|
then concatenates across the layer dimension.
|
||||||
|
Args:
|
||||||
|
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
|
||||||
|
sequence_lengths: Number of valid (non-padded) tokens per batch item.
|
||||||
|
padding_side: Whether padding is on "left" or "right".
|
||||||
|
Returns:
|
||||||
|
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
|
||||||
|
with padded positions zeroed out.
|
||||||
|
"""
|
||||||
|
b, t, d, l = encoded_text.shape # noqa: E741
|
||||||
|
device = encoded_text.device
|
||||||
|
# Build mask: [B, T, 1, 1]
|
||||||
|
token_indices = torch.arange(t, device=device)[None, :] # [1, T]
|
||||||
|
if padding_side == "right":
|
||||||
|
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||||
|
mask = token_indices < sequence_lengths[:, None] # [B, T]
|
||||||
|
elif padding_side == "left":
|
||||||
|
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||||
|
start_indices = t - sequence_lengths[:, None] # [B, 1]
|
||||||
|
mask = token_indices >= start_indices # [B, T]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||||
|
mask = rearrange(mask, "b t -> b t 1 1")
|
||||||
|
eps = 1e-6
|
||||||
|
# Compute masked mean: [B, 1, 1, L]
|
||||||
|
masked = encoded_text.masked_fill(~mask, 0.0)
|
||||||
|
denom = (sequence_lengths * d).view(b, 1, 1, 1)
|
||||||
|
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
|
||||||
|
# Compute masked min/max: [B, 1, 1, L]
|
||||||
|
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||||
|
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||||
|
range_ = x_max - x_min
|
||||||
|
# Normalize only the valid tokens
|
||||||
|
normed = 8 * (encoded_text - mean) / (range_ + eps)
|
||||||
|
# concat to be [Batch, T, D * L] - this preserves the original structure
|
||||||
|
normed = normed.reshape(b, t, -1) # [B, T, D * L]
|
||||||
|
# Apply mask to preserve original padding (set padded positions to 0)
|
||||||
|
mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l)
|
||||||
|
normed = normed.masked_fill(~mask_flattened, 0.0)
|
||||||
|
|
||||||
|
return normed
|
||||||
|
|
||||||
|
def _run_feature_extractor(self,
|
||||||
|
pipe,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
padding_side: str = "right") -> torch.Tensor:
|
||||||
|
encoded_text_features = torch.stack(hidden_states, dim=-1)
|
||||||
|
encoded_text_features_dtype = encoded_text_features.dtype
|
||||||
|
sequence_lengths = attention_mask.sum(dim=-1)
|
||||||
|
normed_concated_encoded_text_features = self._norm_and_concat_padded_batch(encoded_text_features,
|
||||||
|
sequence_lengths,
|
||||||
|
padding_side=padding_side)
|
||||||
|
|
||||||
|
return pipe.text_encoder_post_modules.feature_extractor_linear(
|
||||||
|
normed_concated_encoded_text_features.to(encoded_text_features_dtype))
|
||||||
|
|
||||||
|
def _preprocess_text(
|
||||||
|
self,
|
||||||
|
pipe,
|
||||||
|
text: str,
|
||||||
|
padding_side: str = "left",
|
||||||
|
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Encode a given string into feature tensors suitable for downstream tasks.
|
||||||
|
Args:
|
||||||
|
text (str): Input string to encode.
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask.
|
||||||
|
"""
|
||||||
|
token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"]
|
||||||
|
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.device)
|
||||||
|
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.device)
|
||||||
|
outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
||||||
|
projected = self._run_feature_extractor(pipe,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
padding_side=padding_side)
|
||||||
|
return projected, attention_mask
|
||||||
|
|
||||||
|
def encode_prompt(self, pipe, text, padding_side="left"):
|
||||||
|
encoded_inputs, attention_mask = self._preprocess_text(pipe, text, padding_side)
|
||||||
|
video_encoding, audio_encoding, attention_mask = self._run_connectors(pipe, encoded_inputs, attention_mask)
|
||||||
|
return video_encoding, audio_encoding, attention_mask
|
||||||
|
|
||||||
|
def process(self, pipe: LTX2AudioVideoPipeline, prompt: str):
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
video_context, audio_context, _ = self.encode_prompt(pipe, prompt)
|
||||||
|
return {"video_context": video_context, "audio_context": audio_context}
|
||||||
|
|
||||||
|
|
||||||
|
class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width", "num_frames", "seed", "rand_device", "use_two_stage_pipeline"),
|
||||||
|
output_params=("video_noise", "audio_noise",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
|
||||||
|
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
|
||||||
|
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels)
|
||||||
|
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
||||||
|
|
||||||
|
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)
|
||||||
|
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()
|
||||||
|
video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate
|
||||||
|
video_positions = video_positions.to(pipe.torch_dtype)
|
||||||
|
|
||||||
|
audio_latent_shape = AudioLatentShape.from_video_pixel_shape(video_pixel_shape)
|
||||||
|
audio_noise = pipe.generate_noise(audio_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
||||||
|
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
|
||||||
|
return {
|
||||||
|
"video_noise": video_noise,
|
||||||
|
"audio_noise": audio_noise,
|
||||||
|
"video_positions": video_positions,
|
||||||
|
"audio_positions": audio_positions,
|
||||||
|
"video_latent_shape": video_latent_shape,
|
||||||
|
"audio_latent_shape": audio_latent_shape
|
||||||
|
}
|
||||||
|
|
||||||
|
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0, use_two_stage_pipeline=False):
|
||||||
|
if use_two_stage_pipeline:
|
||||||
|
stage1_dict = self.process_stage(pipe, height // 2, width // 2, num_frames, seed, rand_device, frame_rate)
|
||||||
|
stage2_dict = self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
|
||||||
|
initial_dict = stage1_dict
|
||||||
|
initial_dict.update({"stage2_" + k: v for k, v in stage2_dict.items()})
|
||||||
|
return initial_dict
|
||||||
|
else:
|
||||||
|
return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
|
||||||
|
|
||||||
|
class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_video", "video_noise", "audio_noise", "tiled", "tile_size", "tile_stride"),
|
||||||
|
output_params=("video_latents", "audio_latents"),
|
||||||
|
onload_model_names=("video_vae_encoder")
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, audio_noise, tiled, tile_size, tile_stride):
|
||||||
|
if input_video is None:
|
||||||
|
return {"video_latents": video_noise, "audio_latents": audio_noise}
|
||||||
|
else:
|
||||||
|
# TODO: implement video-to-video
|
||||||
|
raise NotImplementedError("Video-to-video not implemented yet.")
|
||||||
|
|
||||||
|
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "num_frames", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"),
|
||||||
|
output_params=("video_latents"),
|
||||||
|
onload_model_names=("video_vae_encoder")
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_image_latent(self, pipe, input_image, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
|
||||||
|
image = ltx2_preprocess(np.array(input_image.resize((width, height))))
|
||||||
|
image = torch.Tensor(np.array(image, dtype=np.float32)).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||||
|
image = image / 127.5 - 1.0
|
||||||
|
image = repeat(image, f"H W C -> B C F H W", B=1, F=1)
|
||||||
|
latent = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)
|
||||||
|
return latent
|
||||||
|
|
||||||
|
def process(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_indexes, input_images_strength, video_latents, height, width, num_frames, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=False):
|
||||||
|
if input_images is None or len(input_images) == 0:
|
||||||
|
return {"video_latents": video_latents}
|
||||||
|
else:
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
output_dicts = {}
|
||||||
|
stage1_height = height // 2 if use_two_stage_pipeline else height
|
||||||
|
stage1_width = width // 2 if use_two_stage_pipeline else width
|
||||||
|
stage1_latents = [
|
||||||
|
self.get_image_latent(pipe, img, stage1_height, stage1_width, tiled, tile_size_in_pixels,
|
||||||
|
tile_overlap_in_pixels) for img in input_images
|
||||||
|
]
|
||||||
|
video_latents, denoise_mask_video, initial_latents = pipe.apply_input_images_to_latents(video_latents, stage1_latents, input_images_indexes, input_images_strength, num_frames=num_frames)
|
||||||
|
output_dicts.update({"video_latents": video_latents, "denoise_mask_video": denoise_mask_video, "input_latents_video": initial_latents})
|
||||||
|
if use_two_stage_pipeline:
|
||||||
|
stage2_latents = [
|
||||||
|
self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels,
|
||||||
|
tile_overlap_in_pixels) for img in input_images
|
||||||
|
]
|
||||||
|
output_dicts.update({"stage2_input_latents": stage2_latents})
|
||||||
|
return output_dicts
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_ltx2(
|
||||||
|
dit: LTXModel,
|
||||||
|
video_latents=None,
|
||||||
|
video_context=None,
|
||||||
|
video_positions=None,
|
||||||
|
video_patchifier=None,
|
||||||
|
audio_latents=None,
|
||||||
|
audio_context=None,
|
||||||
|
audio_positions=None,
|
||||||
|
audio_patchifier=None,
|
||||||
|
timestep=None,
|
||||||
|
denoise_mask_video=None,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
timestep = timestep.float() / 1000.
|
||||||
|
|
||||||
|
# patchify
|
||||||
|
b, c_v, f, h, w = video_latents.shape
|
||||||
|
video_latents = video_patchifier.patchify(video_latents)
|
||||||
|
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
|
||||||
|
if denoise_mask_video is not None:
|
||||||
|
video_timesteps = video_patchifier.patchify(denoise_mask_video) * video_timesteps
|
||||||
|
_, c_a, _, mel_bins = audio_latents.shape
|
||||||
|
audio_latents = audio_patchifier.patchify(audio_latents)
|
||||||
|
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
|
||||||
|
#TODO: support gradient checkpointing in training
|
||||||
|
vx, ax = dit(
|
||||||
|
video_latents=video_latents,
|
||||||
|
video_positions=video_positions,
|
||||||
|
video_context=video_context,
|
||||||
|
video_timesteps=video_timesteps,
|
||||||
|
audio_latents=audio_latents,
|
||||||
|
audio_positions=audio_positions,
|
||||||
|
audio_context=audio_context,
|
||||||
|
audio_timesteps=audio_timesteps,
|
||||||
|
)
|
||||||
|
# unpatchify
|
||||||
|
vx = video_patchifier.unpatchify_video(vx, f, h, w)
|
||||||
|
ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins)
|
||||||
|
return vx, ax
|
||||||
149
diffsynth/utils/data/media_io_ltx2.py
Normal file
149
diffsynth/utils/data/media_io_ltx2.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
|
||||||
|
from fractions import Fraction
|
||||||
|
import torch
|
||||||
|
import av
|
||||||
|
from tqdm import tqdm
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
from io import BytesIO
|
||||||
|
from collections.abc import Generator, Iterator
|
||||||
|
|
||||||
|
|
||||||
|
def _resample_audio(
|
||||||
|
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
|
||||||
|
) -> None:
|
||||||
|
cc = audio_stream.codec_context
|
||||||
|
|
||||||
|
# Use the encoder's format/layout/rate as the *target*
|
||||||
|
target_format = cc.format or "fltp" # AAC → usually fltp
|
||||||
|
target_layout = cc.layout or "stereo"
|
||||||
|
target_rate = cc.sample_rate or frame_in.sample_rate
|
||||||
|
|
||||||
|
audio_resampler = av.audio.resampler.AudioResampler(
|
||||||
|
format=target_format,
|
||||||
|
layout=target_layout,
|
||||||
|
rate=target_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_next_pts = 0
|
||||||
|
for rframe in audio_resampler.resample(frame_in):
|
||||||
|
if rframe.pts is None:
|
||||||
|
rframe.pts = audio_next_pts
|
||||||
|
audio_next_pts += rframe.samples
|
||||||
|
rframe.sample_rate = frame_in.sample_rate
|
||||||
|
container.mux(audio_stream.encode(rframe))
|
||||||
|
|
||||||
|
# flush audio encoder
|
||||||
|
for packet in audio_stream.encode():
|
||||||
|
container.mux(packet)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_audio(
|
||||||
|
container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
|
||||||
|
) -> None:
|
||||||
|
if samples.ndim == 1:
|
||||||
|
samples = samples[:, None]
|
||||||
|
|
||||||
|
if samples.shape[1] != 2 and samples.shape[0] == 2:
|
||||||
|
samples = samples.T
|
||||||
|
|
||||||
|
if samples.shape[1] != 2:
|
||||||
|
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
|
||||||
|
|
||||||
|
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
|
||||||
|
if samples.dtype != torch.int16:
|
||||||
|
samples = torch.clip(samples, -1.0, 1.0)
|
||||||
|
samples = (samples * 32767.0).to(torch.int16)
|
||||||
|
|
||||||
|
frame_in = av.AudioFrame.from_ndarray(
|
||||||
|
samples.contiguous().reshape(1, -1).cpu().numpy(),
|
||||||
|
format="s16",
|
||||||
|
layout="stereo",
|
||||||
|
)
|
||||||
|
frame_in.sample_rate = audio_sample_rate
|
||||||
|
|
||||||
|
_resample_audio(container, audio_stream, frame_in)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
|
||||||
|
"""
|
||||||
|
Prepare the audio stream for writing.
|
||||||
|
"""
|
||||||
|
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
|
||||||
|
audio_stream.codec_context.sample_rate = audio_sample_rate
|
||||||
|
audio_stream.codec_context.layout = "stereo"
|
||||||
|
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
|
||||||
|
return audio_stream
|
||||||
|
|
||||||
|
def write_video_audio_ltx2(
|
||||||
|
video: list[Image.Image],
|
||||||
|
audio: torch.Tensor | None,
|
||||||
|
output_path: str,
|
||||||
|
fps: int = 24,
|
||||||
|
audio_sample_rate: int | None = 24000,
|
||||||
|
) -> None:
|
||||||
|
|
||||||
|
width, height = video[0].size
|
||||||
|
container = av.open(output_path, mode="w")
|
||||||
|
stream = container.add_stream("libx264", rate=int(fps))
|
||||||
|
stream.width = width
|
||||||
|
stream.height = height
|
||||||
|
stream.pix_fmt = "yuv420p"
|
||||||
|
|
||||||
|
if audio is not None:
|
||||||
|
if audio_sample_rate is None:
|
||||||
|
raise ValueError("audio_sample_rate is required when audio is provided")
|
||||||
|
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
|
||||||
|
|
||||||
|
for frame in tqdm(video, total=len(video)):
|
||||||
|
frame = av.VideoFrame.from_image(frame)
|
||||||
|
for packet in stream.encode(frame):
|
||||||
|
container.mux(packet)
|
||||||
|
|
||||||
|
# Flush encoder
|
||||||
|
for packet in stream.encode():
|
||||||
|
container.mux(packet)
|
||||||
|
|
||||||
|
if audio is not None:
|
||||||
|
_write_audio(container, audio_stream, audio, audio_sample_rate)
|
||||||
|
|
||||||
|
container.close()
|
||||||
|
|
||||||
|
|
||||||
|
def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None:
|
||||||
|
container = av.open(output_file, "w", format="mp4")
|
||||||
|
try:
|
||||||
|
stream = container.add_stream("libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"})
|
||||||
|
# Round to nearest multiple of 2 for compatibility with video codecs
|
||||||
|
height = image_array.shape[0] // 2 * 2
|
||||||
|
width = image_array.shape[1] // 2 * 2
|
||||||
|
image_array = image_array[:height, :width]
|
||||||
|
stream.height = height
|
||||||
|
stream.width = width
|
||||||
|
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(format="yuv420p")
|
||||||
|
container.mux(stream.encode(av_frame))
|
||||||
|
container.mux(stream.encode())
|
||||||
|
finally:
|
||||||
|
container.close()
|
||||||
|
|
||||||
|
|
||||||
|
def decode_single_frame(video_file: str) -> np.array:
|
||||||
|
container = av.open(video_file)
|
||||||
|
try:
|
||||||
|
stream = next(s for s in container.streams if s.type == "video")
|
||||||
|
frame = next(container.decode(stream))
|
||||||
|
finally:
|
||||||
|
container.close()
|
||||||
|
return frame.to_ndarray(format="rgb24")
|
||||||
|
|
||||||
|
|
||||||
|
def ltx2_preprocess(image: np.array, crf: float = 33) -> np.array:
|
||||||
|
if crf == 0:
|
||||||
|
return image
|
||||||
|
|
||||||
|
with BytesIO() as output_file:
|
||||||
|
encode_single_frame(output_file, image, crf)
|
||||||
|
video_bytes = output_file.getvalue()
|
||||||
|
with BytesIO(video_bytes) as video_file:
|
||||||
|
image_array = decode_single_frame(video_file)
|
||||||
|
return image_array
|
||||||
32
diffsynth/utils/state_dict_converters/ltx2_audio_vae.py
Normal file
32
diffsynth/utils/state_dict_converters/ltx2_audio_vae.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
def LTX2AudioEncoderStateDictConverter(state_dict):
|
||||||
|
# Not used
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name.startswith("audio_vae.encoder."):
|
||||||
|
new_name = name.replace("audio_vae.encoder.", "")
|
||||||
|
state_dict_[new_name] = state_dict[name]
|
||||||
|
elif name.startswith("audio_vae.per_channel_statistics."):
|
||||||
|
new_name = name.replace("audio_vae.per_channel_statistics.", "per_channel_statistics.")
|
||||||
|
state_dict_[new_name] = state_dict[name]
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
def LTX2AudioDecoderStateDictConverter(state_dict):
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name.startswith("audio_vae.decoder."):
|
||||||
|
new_name = name.replace("audio_vae.decoder.", "")
|
||||||
|
state_dict_[new_name] = state_dict[name]
|
||||||
|
elif name.startswith("audio_vae.per_channel_statistics."):
|
||||||
|
new_name = name.replace("audio_vae.per_channel_statistics.", "per_channel_statistics.")
|
||||||
|
state_dict_[new_name] = state_dict[name]
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
def LTX2VocoderStateDictConverter(state_dict):
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name.startswith("vocoder."):
|
||||||
|
new_name = name.replace("vocoder.", "")
|
||||||
|
state_dict_[new_name] = state_dict[name]
|
||||||
|
return state_dict_
|
||||||
9
diffsynth/utils/state_dict_converters/ltx2_dit.py
Normal file
9
diffsynth/utils/state_dict_converters/ltx2_dit.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
def LTXModelStateDictConverter(state_dict):
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name.startswith("model.diffusion_model."):
|
||||||
|
new_name = name.replace("model.diffusion_model.", "")
|
||||||
|
if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."):
|
||||||
|
continue
|
||||||
|
state_dict_[new_name] = state_dict[name]
|
||||||
|
return state_dict_
|
||||||
31
diffsynth/utils/state_dict_converters/ltx2_text_encoder.py
Normal file
31
diffsynth/utils/state_dict_converters/ltx2_text_encoder.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
def LTX2TextEncoderStateDictConverter(state_dict):
|
||||||
|
state_dict_ = {}
|
||||||
|
for key in state_dict:
|
||||||
|
if key.startswith("language_model.model."):
|
||||||
|
new_key = key.replace("language_model.model.", "model.language_model.")
|
||||||
|
elif key.startswith("vision_tower."):
|
||||||
|
new_key = key.replace("vision_tower.", "model.vision_tower.")
|
||||||
|
elif key.startswith("multi_modal_projector."):
|
||||||
|
new_key = key.replace("multi_modal_projector.", "model.multi_modal_projector.")
|
||||||
|
elif key.startswith("language_model.lm_head."):
|
||||||
|
new_key = key.replace("language_model.lm_head.", "lm_head.")
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
state_dict_[new_key] = state_dict[key]
|
||||||
|
state_dict_["lm_head.weight"] = state_dict_.get("model.language_model.embed_tokens.weight")
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
def LTX2TextEncoderPostModulesStateDictConverter(state_dict):
|
||||||
|
state_dict_ = {}
|
||||||
|
for key in state_dict:
|
||||||
|
if key.startswith("text_embedding_projection."):
|
||||||
|
new_key = key.replace("text_embedding_projection.", "feature_extractor_linear.")
|
||||||
|
elif key.startswith("model.diffusion_model.video_embeddings_connector."):
|
||||||
|
new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "embeddings_connector.")
|
||||||
|
elif key.startswith("model.diffusion_model.audio_embeddings_connector."):
|
||||||
|
new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "audio_embeddings_connector.")
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
state_dict_[new_key] = state_dict[key]
|
||||||
|
return state_dict_
|
||||||
22
diffsynth/utils/state_dict_converters/ltx2_video_vae.py
Normal file
22
diffsynth/utils/state_dict_converters/ltx2_video_vae.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
def LTX2VideoEncoderStateDictConverter(state_dict):
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name.startswith("vae.encoder."):
|
||||||
|
new_name = name.replace("vae.encoder.", "")
|
||||||
|
state_dict_[new_name] = state_dict[name]
|
||||||
|
elif name.startswith("vae.per_channel_statistics."):
|
||||||
|
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
|
||||||
|
state_dict_[new_name] = state_dict[name]
|
||||||
|
return state_dict_
|
||||||
|
|
||||||
|
|
||||||
|
def LTX2VideoDecoderStateDictConverter(state_dict):
|
||||||
|
state_dict_ = {}
|
||||||
|
for name in state_dict:
|
||||||
|
if name.startswith("vae.decoder."):
|
||||||
|
new_name = name.replace("vae.decoder.", "")
|
||||||
|
state_dict_[new_name] = state_dict[name]
|
||||||
|
elif name.startswith("vae.per_channel_statistics."):
|
||||||
|
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
|
||||||
|
state_dict_[new_name] = state_dict[name]
|
||||||
|
return state_dict_
|
||||||
@@ -1,10 +1,13 @@
|
|||||||
import torch
|
import torch
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
from yunchang.kernels import AttnType
|
||||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||||
get_sequence_parallel_world_size,
|
get_sequence_parallel_world_size,
|
||||||
get_sp_group)
|
get_sp_group)
|
||||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||||
|
|
||||||
|
from ... import IS_NPU_AVAILABLE
|
||||||
from ...core.device import parse_nccl_backend, parse_device_type
|
from ...core.device import parse_nccl_backend, parse_device_type
|
||||||
|
|
||||||
|
|
||||||
@@ -30,13 +33,16 @@ def sinusoidal_embedding_1d(dim, position):
|
|||||||
def pad_freqs(original_tensor, target_len):
|
def pad_freqs(original_tensor, target_len):
|
||||||
seq_len, s1, s2 = original_tensor.shape
|
seq_len, s1, s2 = original_tensor.shape
|
||||||
pad_size = target_len - seq_len
|
pad_size = target_len - seq_len
|
||||||
|
original_tensor_device = original_tensor.device
|
||||||
|
if original_tensor.device == "npu":
|
||||||
|
original_tensor = original_tensor.cpu()
|
||||||
padding_tensor = torch.ones(
|
padding_tensor = torch.ones(
|
||||||
pad_size,
|
pad_size,
|
||||||
s1,
|
s1,
|
||||||
s2,
|
s2,
|
||||||
dtype=original_tensor.dtype,
|
dtype=original_tensor.dtype,
|
||||||
device=original_tensor.device)
|
device=original_tensor.device)
|
||||||
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0).to(device=original_tensor_device)
|
||||||
return padded_tensor
|
return padded_tensor
|
||||||
|
|
||||||
def rope_apply(x, freqs, num_heads):
|
def rope_apply(x, freqs, num_heads):
|
||||||
@@ -50,7 +56,7 @@ def rope_apply(x, freqs, num_heads):
|
|||||||
sp_rank = get_sequence_parallel_rank()
|
sp_rank = get_sequence_parallel_rank()
|
||||||
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
||||||
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
||||||
freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device == "npu" else freqs_rank
|
freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device.type == "npu" else freqs_rank
|
||||||
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
||||||
return x_out.to(x.dtype)
|
return x_out.to(x.dtype)
|
||||||
|
|
||||||
@@ -133,7 +139,12 @@ def usp_attn_forward(self, x, freqs):
|
|||||||
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
|
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
|
||||||
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
|
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
|
||||||
|
|
||||||
x = xFuserLongContextAttention()(
|
attn_type = AttnType.FA
|
||||||
|
ring_impl_type = "basic"
|
||||||
|
if IS_NPU_AVAILABLE:
|
||||||
|
attn_type = AttnType.NPU
|
||||||
|
ring_impl_type = "basic_npu"
|
||||||
|
x = xFuserLongContextAttention(attn_type=attn_type, ring_impl_type=ring_impl_type)(
|
||||||
None,
|
None,
|
||||||
query=q,
|
query=q,
|
||||||
key=k,
|
key=k,
|
||||||
|
|||||||
109
docs/en/Model_Details/LTX-2.md
Normal file
109
docs/en/Model_Details/LTX-2.md
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
# LTX-2
|
||||||
|
|
||||||
|
LTX-2 is a series of audio-video generation models developed by Lightricks.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Before using this project for model inference and training, please install DiffSynth-Studio first.
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
For more information about installation, please refer to [Installation Dependencies](/docs/en/Pipeline_Usage/Setup.md).
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
Run the following code to quickly load the [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) model and perform inference. VRAM management has been enabled, and the framework will automatically control model parameter loading based on remaining VRAM. It can run with a minimum of 8GB VRAM.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float8_e5m2,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float8_e5m2,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e5m2,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
||||||
|
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
height, width, num_frames = 512, 768, 121
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=43,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=True,
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_onestage.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Overview
|
||||||
|
|Model ID|Additional Parameters|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|
||||||
|
|-|-|-|-|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|
|
||||||
|
## Model Inference
|
||||||
|
|
||||||
|
Models are loaded through `LTX2AudioVideoPipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models) for details.
|
||||||
|
|
||||||
|
Input parameters for `LTX2AudioVideoPipeline` inference include:
|
||||||
|
|
||||||
|
* `prompt`: Prompt describing the content appearing in the video.
|
||||||
|
* `negative_prompt`: Negative prompt describing content that should not appear in the video, default value is `""`.
|
||||||
|
* `cfg_scale`: Classifier-free guidance parameter, default value is 3.0.
|
||||||
|
* `input_images`: List of input images for image-to-video generation.
|
||||||
|
* `input_images_indexes`: Frame index list of input images in the video.
|
||||||
|
* `input_images_strength`: Strength of input images, default value is 1.0.
|
||||||
|
* `denoising_strength`: Denoising strength, range is 0~1, default value is 1.0.
|
||||||
|
* `seed`: Random seed. Default is `None`, which means completely random.
|
||||||
|
* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different results will be generated on different GPUs.
|
||||||
|
* `height`: Video height, must be a multiple of 32 (single-stage) or 64 (two-stage).
|
||||||
|
* `width`: Video width, must be a multiple of 32 (single-stage) or 64 (two-stage).
|
||||||
|
* `num_frames`: Number of video frames, default value is 121, must be a multiple of 8 + 1.
|
||||||
|
* `num_inference_steps`: Number of inference steps, default value is 40.
|
||||||
|
* `tiled`: Whether to enable VAE tiling inference, default is `True`. When set to `True`, it can significantly reduce VRAM usage during VAE encoding/decoding stages, with slight errors and minor inference time extension.
|
||||||
|
* `tile_size_in_pixels`: Pixel tiling size during VAE encoding/decoding stages, default is 512.
|
||||||
|
* `tile_overlap_in_pixels`: Pixel tiling overlap size during VAE encoding/decoding stages, default is 128.
|
||||||
|
* `tile_size_in_frames`: Frame tiling size during VAE encoding/decoding stages, default is 128.
|
||||||
|
* `tile_overlap_in_frames`: Frame tiling overlap size during VAE encoding/decoding stages, default is 24.
|
||||||
|
* `use_two_stage_pipeline`: Whether to use two-stage pipeline, default is `False`.
|
||||||
|
* `use_distilled_pipeline`: Whether to use distilled pipeline, default is `False`.
|
||||||
|
* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be set to `lambda x:x` to hide the progress bar.
|
||||||
|
|
||||||
|
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the previous "Supported Inference Scripts" section.
|
||||||
|
|
||||||
|
## Model Training
|
||||||
|
|
||||||
|
The LTX-2 series models currently do not support training functionality. We will add related support as soon as possible.
|
||||||
@@ -85,6 +85,7 @@ graph LR;
|
|||||||
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
|
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
|
||||||
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|
||||||
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
||||||
|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|
||||||
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
||||||
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
||||||
|
|||||||
@@ -52,7 +52,12 @@ image.save("image.jpg")
|
|||||||
|
|
||||||
|Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|
|Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|
||||||
|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|
||||||
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
|
||||||
|
|
||||||
Special Training Scripts:
|
Special Training Scripts:
|
||||||
|
|
||||||
@@ -75,6 +80,9 @@ Input parameters for `ZImagePipeline` inference include:
|
|||||||
* `seed`: Random seed. Default is `None`, meaning completely random.
|
* `seed`: Random seed. Default is `None`, meaning completely random.
|
||||||
* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results.
|
* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results.
|
||||||
* `num_inference_steps`: Number of inference steps, default value is 8.
|
* `num_inference_steps`: Number of inference steps, default value is 8.
|
||||||
|
* `controlnet_inputs`: Inputs for ControlNet models.
|
||||||
|
* `edit_image`: Edit images for image editing models, supporting multiple images.
|
||||||
|
* `positive_only_lora`: LoRA weights used only in positive prompts.
|
||||||
|
|
||||||
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
|
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
|
||||||
|
|
||||||
|
|||||||
@@ -58,6 +58,14 @@ video = pipe(
|
|||||||
save_video(video, "video.mp4", fps=15, quality=5)
|
save_video(video, "video.mp4", fps=15, quality=5)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### USP(Unified Sequence Parallel)
|
||||||
|
If you want to use this feature on NPU, please install additional third-party libraries as follows:
|
||||||
|
```shell
|
||||||
|
pip install git+https://github.com/feifeibear/long-context-attention.git
|
||||||
|
pip install git+https://github.com/xdit-project/xDiT.git
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### Training
|
### Training
|
||||||
NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_training`, for example `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`.
|
NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_training`, for example `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`.
|
||||||
|
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ This section introduces the independent core module `diffsynth.core` in `DiffSyn
|
|||||||
|
|
||||||
This section introduces how to use `DiffSynth-Studio` to train new models, helping researchers explore new model technologies.
|
This section introduces how to use `DiffSynth-Studio` to train new models, helping researchers explore new model technologies.
|
||||||
|
|
||||||
* Training models from scratch 【coming soon】
|
* [Training models from scratch](/docs/en/Research_Tutorial/train_from_scratch.md)
|
||||||
* Inference improvement techniques 【coming soon】
|
* Inference improvement techniques 【coming soon】
|
||||||
* Designing controllable generation models 【coming soon】
|
* Designing controllable generation models 【coming soon】
|
||||||
* Creating new training paradigms 【coming soon】
|
* Creating new training paradigms 【coming soon】
|
||||||
|
|||||||
476
docs/en/Research_Tutorial/train_from_scratch.md
Normal file
476
docs/en/Research_Tutorial/train_from_scratch.md
Normal file
@@ -0,0 +1,476 @@
|
|||||||
|
# Training Models from Scratch
|
||||||
|
|
||||||
|
DiffSynth-Studio's training engine supports training foundation models from scratch. This article introduces how to train a small text-to-image model with only 0.1B parameters from scratch.
|
||||||
|
|
||||||
|
## 1. Building Model Architecture
|
||||||
|
|
||||||
|
### 1.1 Diffusion Model
|
||||||
|
|
||||||
|
From UNet [[1]](https://arxiv.org/abs/1505.04597) [[2]](https://arxiv.org/abs/2112.10752) to DiT [[3]](https://arxiv.org/abs/2212.09748) [[4]](https://arxiv.org/abs/2403.03206), the mainstream model architectures of Diffusion have undergone multiple evolutions. Typically, a Diffusion model's inputs include:
|
||||||
|
|
||||||
|
* Image tensor (`latents`): The encoding of images, generated by the VAE model, containing partial noise
|
||||||
|
* Text tensor (`prompt_embeds`): The encoding of text, generated by the text encoder
|
||||||
|
* Timestep (`timestep`): A scalar used to mark which stage of the Diffusion process we are currently at
|
||||||
|
|
||||||
|
The model's output is a tensor with the same shape as the image tensor, representing the denoising direction predicted by the model. For details about Diffusion model theory, please refer to [Basic Principles of Diffusion Models](/docs/en/Training/Understanding_Diffusion_models.md). In this article, we build a DiT model with only 0.1B parameters: `AAADiT`.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Model Architecture Code</summary>
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch, accelerate
|
||||||
|
from PIL import Image
|
||||||
|
from typing import Union
|
||||||
|
from tqdm import tqdm
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from transformers import AutoProcessor, AutoTokenizer
|
||||||
|
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
|
||||||
|
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
|
||||||
|
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||||
|
from diffsynth.models.general_modules import TimestepEmbeddings
|
||||||
|
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
|
||||||
|
from diffsynth.models.flux2_vae import Flux2VAE
|
||||||
|
|
||||||
|
|
||||||
|
class AAAPositionalEmbedding(torch.nn.Module):
|
||||||
|
def __init__(self, height=16, width=16, dim=1024):
|
||||||
|
super().__init__()
|
||||||
|
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
|
||||||
|
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
|
||||||
|
|
||||||
|
def forward(self, image, text):
|
||||||
|
height, width = image.shape[-2:]
|
||||||
|
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
|
||||||
|
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
|
||||||
|
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
|
||||||
|
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
|
||||||
|
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
|
||||||
|
emb = torch.concat([image_emb, text_emb], dim=1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class AAABlock(torch.nn.Module):
|
||||||
|
def __init__(self, dim=1024, num_heads=32):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||||
|
self.to_q = torch.nn.Linear(dim, dim)
|
||||||
|
self.to_k = torch.nn.Linear(dim, dim)
|
||||||
|
self.to_v = torch.nn.Linear(dim, dim)
|
||||||
|
self.to_out = torch.nn.Linear(dim, dim)
|
||||||
|
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||||
|
self.ff = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim, dim*3),
|
||||||
|
torch.nn.SiLU(),
|
||||||
|
torch.nn.Linear(dim*3, dim),
|
||||||
|
)
|
||||||
|
self.to_gate = torch.nn.Linear(dim, dim * 2)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
def attention(self, emb, pos_emb):
|
||||||
|
emb = self.norm_attn(emb + pos_emb)
|
||||||
|
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
|
||||||
|
emb = attention_forward(
|
||||||
|
q, k, v,
|
||||||
|
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
|
||||||
|
dims={"n": self.num_heads},
|
||||||
|
)
|
||||||
|
emb = self.to_out(emb)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
def feed_forward(self, emb, pos_emb):
|
||||||
|
emb = self.norm_mlp(emb + pos_emb)
|
||||||
|
emb = self.ff(emb)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
def forward(self, emb, pos_emb, t_emb):
|
||||||
|
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
|
||||||
|
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
|
||||||
|
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class AAADiT(torch.nn.Module):
|
||||||
|
def __init__(self, dim=1024):
|
||||||
|
super().__init__()
|
||||||
|
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
|
||||||
|
self.timestep_embedder = TimestepEmbeddings(256, dim)
|
||||||
|
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
|
||||||
|
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
|
||||||
|
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
|
||||||
|
self.proj_out = torch.nn.Linear(dim, 128)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
latents,
|
||||||
|
prompt_embeds,
|
||||||
|
timestep,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
):
|
||||||
|
pos_emb = self.pos_embedder(latents, prompt_embeds)
|
||||||
|
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
|
||||||
|
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
|
||||||
|
text = self.text_embedder(prompt_embeds)
|
||||||
|
emb = torch.concat([image, text], dim=1)
|
||||||
|
for block_id, block in enumerate(self.blocks):
|
||||||
|
emb = gradient_checkpoint_forward(
|
||||||
|
block,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
emb=emb,
|
||||||
|
pos_emb=pos_emb,
|
||||||
|
t_emb=t_emb,
|
||||||
|
)
|
||||||
|
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
|
||||||
|
emb = self.proj_out(emb)
|
||||||
|
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
|
||||||
|
return emb
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 1.2 Encoder-Decoder Models
|
||||||
|
|
||||||
|
Besides the Diffusion model used for denoising, we also need two other models:
|
||||||
|
|
||||||
|
* Text Encoder: Used to encode text into tensors. We adopt the [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) model.
|
||||||
|
* VAE Encoder-Decoder: The encoder part is used to encode images into tensors, and the decoder part is used to decode image tensors into images. We adopt the VAE model from [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B).
|
||||||
|
|
||||||
|
The architectures of these two models are already integrated in DiffSynth-Studio, located at [/diffsynth/models/z_image_text_encoder.py](/diffsynth/models/z_image_text_encoder.py) and [/diffsynth/models/flux2_vae.py](/diffsynth/models/flux2_vae.py), so we don't need to modify any code.
|
||||||
|
|
||||||
|
## 2. Building Pipeline
|
||||||
|
|
||||||
|
We introduced how to build a model Pipeline in the document [Integrating Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md). For the model in this article, we also need to build a Pipeline to connect the text encoder, Diffusion model, and VAE encoder-decoder.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Pipeline Code</summary>
|
||||||
|
|
||||||
|
```python
|
||||||
|
class AAAImagePipeline(BasePipeline):
|
||||||
|
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||||
|
super().__init__(
|
||||||
|
device=device, torch_dtype=torch_dtype,
|
||||||
|
height_division_factor=16, width_division_factor=16,
|
||||||
|
)
|
||||||
|
self.scheduler = FlowMatchScheduler("FLUX.2")
|
||||||
|
self.text_encoder: ZImageTextEncoder = None
|
||||||
|
self.dit: AAADiT = None
|
||||||
|
self.vae: Flux2VAE = None
|
||||||
|
self.tokenizer: AutoProcessor = None
|
||||||
|
self.in_iteration_models = ("dit",)
|
||||||
|
self.units = [
|
||||||
|
AAAUnit_PromptEmbedder(),
|
||||||
|
AAAUnit_NoiseInitializer(),
|
||||||
|
AAAUnit_InputImageEmbedder(),
|
||||||
|
]
|
||||||
|
self.model_fn = model_fn_aaa
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(
|
||||||
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
device: Union[str, torch.device] = "cuda",
|
||||||
|
model_configs: list[ModelConfig] = [],
|
||||||
|
tokenizer_config: ModelConfig = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
# Initialize pipeline
|
||||||
|
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||||
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||||
|
|
||||||
|
# Fetch models
|
||||||
|
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
|
||||||
|
pipe.dit = model_pool.fetch_model("aaa_dit")
|
||||||
|
pipe.vae = model_pool.fetch_model("flux2_vae")
|
||||||
|
if tokenizer_config is not None:
|
||||||
|
tokenizer_config.download_if_necessary()
|
||||||
|
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||||
|
|
||||||
|
# VRAM Management
|
||||||
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
# Prompt
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
cfg_scale: float = 1.0,
|
||||||
|
# Image
|
||||||
|
input_image: Image.Image = None,
|
||||||
|
denoising_strength: float = 1.0,
|
||||||
|
# Shape
|
||||||
|
height: int = 1024,
|
||||||
|
width: int = 1024,
|
||||||
|
# Randomness
|
||||||
|
seed: int = None,
|
||||||
|
rand_device: str = "cpu",
|
||||||
|
# Steps
|
||||||
|
num_inference_steps: int = 30,
|
||||||
|
# Progress bar
|
||||||
|
progress_bar_cmd = tqdm,
|
||||||
|
):
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
inputs_posi = {"prompt": prompt}
|
||||||
|
inputs_nega = {"negative_prompt": negative_prompt}
|
||||||
|
inputs_shared = {
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||||
|
"height": height, "width": width,
|
||||||
|
"seed": seed, "rand_device": rand_device,
|
||||||
|
"num_inference_steps": num_inference_steps,
|
||||||
|
}
|
||||||
|
for unit in self.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
|
|
||||||
|
# Denoise
|
||||||
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
noise_pred = self.cfg_guided_model_fn(
|
||||||
|
self.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
self.load_models_to_device(['vae'])
|
||||||
|
image = self.vae.decode(inputs_shared["latents"])
|
||||||
|
image = self.vae_output_to_image(image)
|
||||||
|
self.load_models_to_device([])
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class AAAUnit_PromptEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
seperate_cfg=True,
|
||||||
|
input_params_posi={"prompt": "prompt"},
|
||||||
|
input_params_nega={"prompt": "negative_prompt"},
|
||||||
|
output_params=("prompt_embeds",),
|
||||||
|
onload_model_names=("text_encoder",)
|
||||||
|
)
|
||||||
|
self.hidden_states_layers = (-1,)
|
||||||
|
|
||||||
|
def process(self, pipe: AAAImagePipeline, prompt):
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
text = pipe.tokenizer.apply_chat_template(
|
||||||
|
[{"role": "user", "content": prompt}],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=False,
|
||||||
|
)
|
||||||
|
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
|
||||||
|
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
|
||||||
|
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
|
||||||
|
return {"prompt_embeds": prompt_embeds}
|
||||||
|
|
||||||
|
|
||||||
|
class AAAUnit_NoiseInitializer(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width", "seed", "rand_device"),
|
||||||
|
output_params=("noise",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
|
||||||
|
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||||
|
return {"noise": noise}
|
||||||
|
|
||||||
|
|
||||||
|
class AAAUnit_InputImageEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_image", "noise"),
|
||||||
|
output_params=("latents", "input_latents"),
|
||||||
|
onload_model_names=("vae",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: AAAImagePipeline, input_image, noise):
|
||||||
|
if input_image is None:
|
||||||
|
return {"latents": noise, "input_latents": None}
|
||||||
|
pipe.load_models_to_device(['vae'])
|
||||||
|
image = pipe.preprocess_image(input_image)
|
||||||
|
input_latents = pipe.vae.encode(image)
|
||||||
|
if pipe.scheduler.training:
|
||||||
|
return {"latents": noise, "input_latents": input_latents}
|
||||||
|
else:
|
||||||
|
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||||
|
return {"latents": latents, "input_latents": input_latents}
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_aaa(
|
||||||
|
dit: AAADiT,
|
||||||
|
latents=None,
|
||||||
|
prompt_embeds=None,
|
||||||
|
timestep=None,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
model_output = dit(
|
||||||
|
latents,
|
||||||
|
prompt_embeds,
|
||||||
|
timestep,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
)
|
||||||
|
return model_output
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 3. Preparing Dataset
|
||||||
|
|
||||||
|
To quickly verify training effectiveness, we use the dataset [Pokemon-First Generation](https://modelscope.cn/datasets/DiffSynth-Studio/pokemon-gen1), which is reproduced from the open-source project [pokemon-dataset-zh](https://github.com/42arch/pokemon-dataset-zh), containing 151 first-generation Pokemon from Bulbasaur to Mew. If you want to use other datasets, please refer to the document [Preparing Datasets](/docs/en/Pipeline_Usage/Model_Training.md#preparing-datasets) and [`diffsynth.core.data`](/docs/en/API_Reference/core/data.md).
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Start Training
|
||||||
|
|
||||||
|
The training process can be quickly implemented using Pipeline. We have placed the complete code at [/docs/en/Research_Tutorial/train_from_scratch.py](/docs/en/Research_Tutorial/train_from_scratch.py), which can be directly started with `python docs/en/Research_Tutorial/train_from_scratch.py` for single GPU training.
|
||||||
|
|
||||||
|
To enable multi-GPU parallel training, please run `accelerate config` to set relevant parameters, then use the command `accelerate launch docs/en/Research_Tutorial/train_from_scratch.py` to start training.
|
||||||
|
|
||||||
|
This training script has no stopping condition, please manually close it when needed. The model converges after training approximately 60,000 steps, requiring 10-20 hours for single GPU training.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Training Code</summary>
|
||||||
|
|
||||||
|
```python
|
||||||
|
class AAATrainingModule(DiffusionTrainingModule):
|
||||||
|
def __init__(self, device):
|
||||||
|
super().__init__()
|
||||||
|
self.pipe = AAAImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device=device,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||||
|
)
|
||||||
|
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
|
||||||
|
self.pipe.freeze_except(["dit"])
|
||||||
|
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
inputs_posi = {"prompt": data["prompt"]}
|
||||||
|
inputs_nega = {"negative_prompt": ""}
|
||||||
|
inputs_shared = {
|
||||||
|
"input_image": data["image"],
|
||||||
|
"height": data["image"].size[1],
|
||||||
|
"width": data["image"].size[0],
|
||||||
|
"cfg_scale": 1,
|
||||||
|
"use_gradient_checkpointing": False,
|
||||||
|
"use_gradient_checkpointing_offload": False,
|
||||||
|
}
|
||||||
|
for unit in self.pipe.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||||
|
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
|
||||||
|
dataset = UnifiedDataset(
|
||||||
|
base_path="data/images",
|
||||||
|
metadata_path="data/metadata_merged.csv",
|
||||||
|
max_data_items=10000000,
|
||||||
|
data_file_keys=("image",),
|
||||||
|
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
|
||||||
|
)
|
||||||
|
model = AAATrainingModule(device=accelerator.device)
|
||||||
|
model_logger = ModelLogger(
|
||||||
|
"models/AAA/v1",
|
||||||
|
remove_prefix_in_ckpt="pipe.dit.",
|
||||||
|
)
|
||||||
|
launch_training_task(
|
||||||
|
accelerator, dataset, model, model_logger,
|
||||||
|
learning_rate=2e-4,
|
||||||
|
num_workers=4,
|
||||||
|
save_steps=50000,
|
||||||
|
num_epochs=999999,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 5. Verifying Training Results
|
||||||
|
|
||||||
|
If you don't want to wait for the model training to complete, you can directly download [our pre-trained model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel).
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel
|
||||||
|
```
|
||||||
|
|
||||||
|
Loading the model
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth import load_model
|
||||||
|
|
||||||
|
pipe = AAAImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||||
|
)
|
||||||
|
pipe.dit = load_model(AAADiT, "models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors", torch_dtype=torch.bfloat16, device="cuda")
|
||||||
|
```
|
||||||
|
|
||||||
|
Model inference, generating the first-generation Pokemon "starter trio". At this point, the images generated by the model basically match the training data.
|
||||||
|
|
||||||
|
```python
|
||||||
|
for seed, prompt in enumerate([
|
||||||
|
"green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws",
|
||||||
|
"orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws",
|
||||||
|
"blue, beige, brown, turtle, water type, shell, big eyes, short limbs, curled tail",
|
||||||
|
]):
|
||||||
|
image = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=" ",
|
||||||
|
num_inference_steps=30,
|
||||||
|
cfg_scale=10,
|
||||||
|
seed=seed,
|
||||||
|
height=256, width=256,
|
||||||
|
)
|
||||||
|
image.save(f"image_{seed}.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
||||
|
||||||
|
|-|-|-|
|
||||||
|
|
||||||
|
Model inference, generating Pokemon with "sharp claws". At this point, different random seeds can produce different image results.
|
||||||
|
|
||||||
|
```python
|
||||||
|
for seed, prompt in enumerate([
|
||||||
|
"sharp claws",
|
||||||
|
"sharp claws",
|
||||||
|
"sharp claws",
|
||||||
|
]):
|
||||||
|
image = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=" ",
|
||||||
|
num_inference_steps=30,
|
||||||
|
cfg_scale=10,
|
||||||
|
seed=seed+4,
|
||||||
|
height=256, width=256,
|
||||||
|
)
|
||||||
|
image.save(f"image_sharp_claws_{seed}.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
||||
|
||||||
|
|-|-|-|
|
||||||
|
|
||||||
|
Now, we have obtained a 0.1B small text-to-image model. This model can already generate 151 Pokemon, but cannot generate other image content. If you increase the amount of data, model parameters, and number of GPUs based on this, you can train a more powerful text-to-image model!
|
||||||
341
docs/en/Research_Tutorial/train_from_scratch.py
Normal file
341
docs/en/Research_Tutorial/train_from_scratch.py
Normal file
@@ -0,0 +1,341 @@
|
|||||||
|
import torch, accelerate
|
||||||
|
from PIL import Image
|
||||||
|
from typing import Union
|
||||||
|
from tqdm import tqdm
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from transformers import AutoProcessor, AutoTokenizer
|
||||||
|
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
|
||||||
|
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
|
||||||
|
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||||
|
from diffsynth.models.general_modules import TimestepEmbeddings
|
||||||
|
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
|
||||||
|
from diffsynth.models.flux2_vae import Flux2VAE
|
||||||
|
|
||||||
|
|
||||||
|
class AAAPositionalEmbedding(torch.nn.Module):
|
||||||
|
def __init__(self, height=16, width=16, dim=1024):
|
||||||
|
super().__init__()
|
||||||
|
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
|
||||||
|
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
|
||||||
|
|
||||||
|
def forward(self, image, text):
|
||||||
|
height, width = image.shape[-2:]
|
||||||
|
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
|
||||||
|
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
|
||||||
|
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
|
||||||
|
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
|
||||||
|
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
|
||||||
|
emb = torch.concat([image_emb, text_emb], dim=1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class AAABlock(torch.nn.Module):
|
||||||
|
def __init__(self, dim=1024, num_heads=32):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||||
|
self.to_q = torch.nn.Linear(dim, dim)
|
||||||
|
self.to_k = torch.nn.Linear(dim, dim)
|
||||||
|
self.to_v = torch.nn.Linear(dim, dim)
|
||||||
|
self.to_out = torch.nn.Linear(dim, dim)
|
||||||
|
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||||
|
self.ff = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim, dim*3),
|
||||||
|
torch.nn.SiLU(),
|
||||||
|
torch.nn.Linear(dim*3, dim),
|
||||||
|
)
|
||||||
|
self.to_gate = torch.nn.Linear(dim, dim * 2)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
def attention(self, emb, pos_emb):
|
||||||
|
emb = self.norm_attn(emb + pos_emb)
|
||||||
|
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
|
||||||
|
emb = attention_forward(
|
||||||
|
q, k, v,
|
||||||
|
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
|
||||||
|
dims={"n": self.num_heads},
|
||||||
|
)
|
||||||
|
emb = self.to_out(emb)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
def feed_forward(self, emb, pos_emb):
|
||||||
|
emb = self.norm_mlp(emb + pos_emb)
|
||||||
|
emb = self.ff(emb)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
def forward(self, emb, pos_emb, t_emb):
|
||||||
|
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
|
||||||
|
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
|
||||||
|
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class AAADiT(torch.nn.Module):
|
||||||
|
def __init__(self, dim=1024):
|
||||||
|
super().__init__()
|
||||||
|
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
|
||||||
|
self.timestep_embedder = TimestepEmbeddings(256, dim)
|
||||||
|
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
|
||||||
|
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
|
||||||
|
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
|
||||||
|
self.proj_out = torch.nn.Linear(dim, 128)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
latents,
|
||||||
|
prompt_embeds,
|
||||||
|
timestep,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
):
|
||||||
|
pos_emb = self.pos_embedder(latents, prompt_embeds)
|
||||||
|
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
|
||||||
|
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
|
||||||
|
text = self.text_embedder(prompt_embeds)
|
||||||
|
emb = torch.concat([image, text], dim=1)
|
||||||
|
for block_id, block in enumerate(self.blocks):
|
||||||
|
emb = gradient_checkpoint_forward(
|
||||||
|
block,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
emb=emb,
|
||||||
|
pos_emb=pos_emb,
|
||||||
|
t_emb=t_emb,
|
||||||
|
)
|
||||||
|
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
|
||||||
|
emb = self.proj_out(emb)
|
||||||
|
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class AAAImagePipeline(BasePipeline):
|
||||||
|
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||||
|
super().__init__(
|
||||||
|
device=device, torch_dtype=torch_dtype,
|
||||||
|
height_division_factor=16, width_division_factor=16,
|
||||||
|
)
|
||||||
|
self.scheduler = FlowMatchScheduler("FLUX.2")
|
||||||
|
self.text_encoder: ZImageTextEncoder = None
|
||||||
|
self.dit: AAADiT = None
|
||||||
|
self.vae: Flux2VAE = None
|
||||||
|
self.tokenizer: AutoProcessor = None
|
||||||
|
self.in_iteration_models = ("dit",)
|
||||||
|
self.units = [
|
||||||
|
AAAUnit_PromptEmbedder(),
|
||||||
|
AAAUnit_NoiseInitializer(),
|
||||||
|
AAAUnit_InputImageEmbedder(),
|
||||||
|
]
|
||||||
|
self.model_fn = model_fn_aaa
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(
|
||||||
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
device: Union[str, torch.device] = "cuda",
|
||||||
|
model_configs: list[ModelConfig] = [],
|
||||||
|
tokenizer_config: ModelConfig = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
# Initialize pipeline
|
||||||
|
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||||
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||||
|
|
||||||
|
# Fetch models
|
||||||
|
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
|
||||||
|
pipe.dit = model_pool.fetch_model("aaa_dit")
|
||||||
|
pipe.vae = model_pool.fetch_model("flux2_vae")
|
||||||
|
if tokenizer_config is not None:
|
||||||
|
tokenizer_config.download_if_necessary()
|
||||||
|
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||||
|
|
||||||
|
# VRAM Management
|
||||||
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
# Prompt
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
cfg_scale: float = 1.0,
|
||||||
|
# Image
|
||||||
|
input_image: Image.Image = None,
|
||||||
|
denoising_strength: float = 1.0,
|
||||||
|
# Shape
|
||||||
|
height: int = 1024,
|
||||||
|
width: int = 1024,
|
||||||
|
# Randomness
|
||||||
|
seed: int = None,
|
||||||
|
rand_device: str = "cpu",
|
||||||
|
# Steps
|
||||||
|
num_inference_steps: int = 30,
|
||||||
|
# Progress bar
|
||||||
|
progress_bar_cmd = tqdm,
|
||||||
|
):
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
inputs_posi = {"prompt": prompt}
|
||||||
|
inputs_nega = {"negative_prompt": negative_prompt}
|
||||||
|
inputs_shared = {
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||||
|
"height": height, "width": width,
|
||||||
|
"seed": seed, "rand_device": rand_device,
|
||||||
|
"num_inference_steps": num_inference_steps,
|
||||||
|
}
|
||||||
|
for unit in self.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
|
|
||||||
|
# Denoise
|
||||||
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
noise_pred = self.cfg_guided_model_fn(
|
||||||
|
self.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
self.load_models_to_device(['vae'])
|
||||||
|
image = self.vae.decode(inputs_shared["latents"])
|
||||||
|
image = self.vae_output_to_image(image)
|
||||||
|
self.load_models_to_device([])
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class AAAUnit_PromptEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
seperate_cfg=True,
|
||||||
|
input_params_posi={"prompt": "prompt"},
|
||||||
|
input_params_nega={"prompt": "negative_prompt"},
|
||||||
|
output_params=("prompt_embeds",),
|
||||||
|
onload_model_names=("text_encoder",)
|
||||||
|
)
|
||||||
|
self.hidden_states_layers = (-1,)
|
||||||
|
|
||||||
|
def process(self, pipe: AAAImagePipeline, prompt):
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
text = pipe.tokenizer.apply_chat_template(
|
||||||
|
[{"role": "user", "content": prompt}],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=False,
|
||||||
|
)
|
||||||
|
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
|
||||||
|
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
|
||||||
|
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
|
||||||
|
return {"prompt_embeds": prompt_embeds}
|
||||||
|
|
||||||
|
|
||||||
|
class AAAUnit_NoiseInitializer(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width", "seed", "rand_device"),
|
||||||
|
output_params=("noise",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
|
||||||
|
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||||
|
return {"noise": noise}
|
||||||
|
|
||||||
|
|
||||||
|
class AAAUnit_InputImageEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_image", "noise"),
|
||||||
|
output_params=("latents", "input_latents"),
|
||||||
|
onload_model_names=("vae",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: AAAImagePipeline, input_image, noise):
|
||||||
|
if input_image is None:
|
||||||
|
return {"latents": noise, "input_latents": None}
|
||||||
|
pipe.load_models_to_device(['vae'])
|
||||||
|
image = pipe.preprocess_image(input_image)
|
||||||
|
input_latents = pipe.vae.encode(image)
|
||||||
|
if pipe.scheduler.training:
|
||||||
|
return {"latents": noise, "input_latents": input_latents}
|
||||||
|
else:
|
||||||
|
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||||
|
return {"latents": latents, "input_latents": input_latents}
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_aaa(
|
||||||
|
dit: AAADiT,
|
||||||
|
latents=None,
|
||||||
|
prompt_embeds=None,
|
||||||
|
timestep=None,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
model_output = dit(
|
||||||
|
latents,
|
||||||
|
prompt_embeds,
|
||||||
|
timestep,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
)
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
|
||||||
|
class AAATrainingModule(DiffusionTrainingModule):
|
||||||
|
def __init__(self, device):
|
||||||
|
super().__init__()
|
||||||
|
self.pipe = AAAImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device=device,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||||
|
)
|
||||||
|
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
|
||||||
|
self.pipe.freeze_except(["dit"])
|
||||||
|
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
inputs_posi = {"prompt": data["prompt"]}
|
||||||
|
inputs_nega = {"negative_prompt": ""}
|
||||||
|
inputs_shared = {
|
||||||
|
"input_image": data["image"],
|
||||||
|
"height": data["image"].size[1],
|
||||||
|
"width": data["image"].size[0],
|
||||||
|
"cfg_scale": 1,
|
||||||
|
"use_gradient_checkpointing": False,
|
||||||
|
"use_gradient_checkpointing_offload": False,
|
||||||
|
}
|
||||||
|
for unit in self.pipe.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||||
|
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
|
||||||
|
dataset = UnifiedDataset(
|
||||||
|
base_path="data/images",
|
||||||
|
metadata_path="data/metadata_merged.csv",
|
||||||
|
max_data_items=10000000,
|
||||||
|
data_file_keys=("image",),
|
||||||
|
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
|
||||||
|
)
|
||||||
|
model = AAATrainingModule(device=accelerator.device)
|
||||||
|
model_logger = ModelLogger(
|
||||||
|
"models/AAA/v1",
|
||||||
|
remove_prefix_in_ckpt="pipe.dit.",
|
||||||
|
)
|
||||||
|
launch_training_task(
|
||||||
|
accelerator, dataset, model, model_logger,
|
||||||
|
learning_rate=2e-4,
|
||||||
|
num_workers=4,
|
||||||
|
save_steps=50000,
|
||||||
|
num_epochs=999999,
|
||||||
|
)
|
||||||
@@ -6,7 +6,7 @@ This document introduces the basic principles of Diffusion models to help you un
|
|||||||
|
|
||||||
Diffusion models generate clear images or video content through iterative denoising. We start by explaining the generation process of a data sample $x_0$. Intuitively, in a complete round of denoising, we start from random Gaussian noise $x_T$ and iteratively obtain $x_{T-1}$, $x_{T-2}$, $x_{T-3}$, $\cdots$, gradually reducing the noise content at each step until we finally obtain the noise-free data sample $x_0$.
|
Diffusion models generate clear images or video content through iterative denoising. We start by explaining the generation process of a data sample $x_0$. Intuitively, in a complete round of denoising, we start from random Gaussian noise $x_T$ and iteratively obtain $x_{T-1}$, $x_{T-2}$, $x_{T-3}$, $\cdots$, gradually reducing the noise content at each step until we finally obtain the noise-free data sample $x_0$.
|
||||||
|
|
||||||
(Figure)
|

|
||||||
|
|
||||||
This process is intuitive, but to understand the details, we need to answer several questions:
|
This process is intuitive, but to understand the details, we need to answer several questions:
|
||||||
|
|
||||||
@@ -28,7 +28,7 @@ As for the intermediate values $\sigma_{T-1}$, $\sigma_{T-2}$, $\cdots$, $\sigma
|
|||||||
|
|
||||||
At an intermediate step, we can directly synthesize noisy data samples $x_t=(1-\sigma_t)x_0+\sigma_t x_T$.
|
At an intermediate step, we can directly synthesize noisy data samples $x_t=(1-\sigma_t)x_0+\sigma_t x_T$.
|
||||||
|
|
||||||
(Figure)
|

|
||||||
|
|
||||||
## How is the iterative denoising computation performed?
|
## How is the iterative denoising computation performed?
|
||||||
|
|
||||||
@@ -40,8 +40,6 @@ Before understanding the iterative denoising computation, we need to clarify wha
|
|||||||
|
|
||||||
Among these, the guidance condition $c$ is a newly introduced parameter that is input by the user. It can be text describing the image content or a sketch outlining the image structure.
|
Among these, the guidance condition $c$ is a newly introduced parameter that is input by the user. It can be text describing the image content or a sketch outlining the image structure.
|
||||||
|
|
||||||
(Figure)
|
|
||||||
|
|
||||||
The model's output $\hat \epsilon(x_t,c,t)$ approximately equals $x_T-x_0$, which is the direction of the entire diffusion process (the reverse process of denoising).
|
The model's output $\hat \epsilon(x_t,c,t)$ approximately equals $x_T-x_0$, which is the direction of the entire diffusion process (the reverse process of denoising).
|
||||||
|
|
||||||
Next, we analyze the computation occurring in one iteration. At time step $t$, after the model computes an approximation of $x_T-x_0$, we calculate the next $x_{t-1}$:
|
Next, we analyze the computation occurring in one iteration. At time step $t$, after the model computes an approximation of $x_T-x_0$, we calculate the next $x_{t-1}$:
|
||||||
@@ -91,8 +89,6 @@ After understanding the iterative denoising process, we next consider how to tra
|
|||||||
|
|
||||||
The training process differs from the generation process. If we retain multi-step iterations during training, the gradient would need to backpropagate through multiple steps, bringing catastrophic time and space complexity. To improve computational efficiency, we randomly select a time step $t$ for training.
|
The training process differs from the generation process. If we retain multi-step iterations during training, the gradient would need to backpropagate through multiple steps, bringing catastrophic time and space complexity. To improve computational efficiency, we randomly select a time step $t$ for training.
|
||||||
|
|
||||||
(Figure)
|
|
||||||
|
|
||||||
The following is pseudocode for the training process:
|
The following is pseudocode for the training process:
|
||||||
|
|
||||||
> Obtain data sample $x_0$ and guidance condition $c$ from the dataset
|
> Obtain data sample $x_0$ and guidance condition $c$ from the dataset
|
||||||
@@ -113,7 +109,7 @@ The following is pseudocode for the training process:
|
|||||||
|
|
||||||
From theory to practice, more details need to be filled in. Modern Diffusion model architectures have matured, with mainstream architectures following the "three-stage" architecture proposed by Latent Diffusion, including data encoder-decoder, guidance condition encoder, and denoising model.
|
From theory to practice, more details need to be filled in. Modern Diffusion model architectures have matured, with mainstream architectures following the "three-stage" architecture proposed by Latent Diffusion, including data encoder-decoder, guidance condition encoder, and denoising model.
|
||||||
|
|
||||||
(Figure)
|

|
||||||
|
|
||||||
### Data Encoder-Decoder
|
### Data Encoder-Decoder
|
||||||
|
|
||||||
|
|||||||
109
docs/zh/Model_Details/LTX-2.md
Normal file
109
docs/zh/Model_Details/LTX-2.md
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
# LTX-2
|
||||||
|
|
||||||
|
LTX-2 是由 Lightricks 开发的音视频生成模型系列。
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||||
|
cd DiffSynth-Studio
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
更多关于安装的信息,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)。
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
运行以下代码可以快速加载 [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float8_e5m2,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float8_e5m2,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e5m2,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||||
|
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
height, width, num_frames = 512, 768, 121
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=43,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=True,
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_onestage.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 模型总览
|
||||||
|
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|
|-|-|-|-|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|
||||||
|
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|
||||||
|
|
||||||
|
## 模型推理
|
||||||
|
|
||||||
|
模型通过 `LTX2AudioVideoPipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。
|
||||||
|
|
||||||
|
`LTX2AudioVideoPipeline` 推理的输入参数包括:
|
||||||
|
|
||||||
|
* `prompt`: 提示词,描述视频中出现的内容。
|
||||||
|
* `negative_prompt`: 负向提示词,描述视频中不应该出现的内容,默认值为 `""`。
|
||||||
|
* `cfg_scale`: Classifier-free guidance 的参数,默认值为 3.0。
|
||||||
|
* `input_images`: 输入图像列表,用于图生视频。
|
||||||
|
* `input_images_indexes`: 输入图像在视频中的帧索引列表。
|
||||||
|
* `input_images_strength`: 输入图像的强度,默认值为 1.0。
|
||||||
|
* `denoising_strength`: 去噪强度,范围是 0~1,默认值为 1.0。
|
||||||
|
* `seed`: 随机种子。默认为 `None`,即完全随机。
|
||||||
|
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
|
||||||
|
* `height`: 视频高度,需保证高度为 32 的倍数(单阶段)或 64 的倍数(两阶段)。
|
||||||
|
* `width`: 视频宽度,需保证宽度为 32 的倍数(单阶段)或 64 的倍数(两阶段)。
|
||||||
|
* `num_frames`: 视频帧数,默认值为 121,需保证为 8 的倍数 + 1。
|
||||||
|
* `num_inference_steps`: 推理次数,默认值为 40。
|
||||||
|
* `tiled`: 是否启用 VAE 分块推理,默认为 `True`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用,会产生少许误差,以及少量推理时间延长。
|
||||||
|
* `tile_size_in_pixels`: VAE 编解码阶段的像素分块大小,默认为 512。
|
||||||
|
* `tile_overlap_in_pixels`: VAE 编解码阶段的像素分块重叠大小,默认为 128。
|
||||||
|
* `tile_size_in_frames`: VAE 编解码阶段的帧分块大小,默认为 128。
|
||||||
|
* `tile_overlap_in_frames`: VAE 编解码阶段的帧分块重叠大小,默认为 24。
|
||||||
|
* `use_two_stage_pipeline`: 是否使用两阶段管道,默认为 `False`。
|
||||||
|
* `use_distilled_pipeline`: 是否使用蒸馏管道,默认为 `False`。
|
||||||
|
* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。
|
||||||
|
|
||||||
|
如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"支持的推理脚本"中的表格。
|
||||||
|
|
||||||
|
## 模型训练
|
||||||
|
|
||||||
|
LTX-2 系列模型目前暂不支持训练功能。我们将尽快添加相关支持。
|
||||||
@@ -85,6 +85,7 @@ graph LR;
|
|||||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||||
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
||||||
|
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|
||||||
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
||||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||||
|
|||||||
@@ -52,7 +52,12 @@ image.save("image.jpg")
|
|||||||
|
|
||||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||||
|-|-|-|-|-|-|-|
|
|-|-|-|-|-|-|-|
|
||||||
|
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|
||||||
|
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|
||||||
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|
||||||
|
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
|
||||||
|
|
||||||
特殊训练脚本:
|
特殊训练脚本:
|
||||||
|
|
||||||
@@ -75,6 +80,9 @@ image.save("image.jpg")
|
|||||||
* `seed`: 随机种子。默认为 `None`,即完全随机。
|
* `seed`: 随机种子。默认为 `None`,即完全随机。
|
||||||
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
|
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
|
||||||
* `num_inference_steps`: 推理次数,默认值为 8。
|
* `num_inference_steps`: 推理次数,默认值为 8。
|
||||||
|
* `controlnet_inputs`: ControlNet 模型的输入。
|
||||||
|
* `edit_image`: 编辑模型的待编辑图像,支持多张图像。
|
||||||
|
* `positive_only_lora`: 仅在正向提示词中使用的 LoRA 权重。
|
||||||
|
|
||||||
如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。
|
如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。
|
||||||
|
|
||||||
|
|||||||
@@ -58,6 +58,13 @@ video = pipe(
|
|||||||
save_video(video, "video.mp4", fps=15, quality=5)
|
save_video(video, "video.mp4", fps=15, quality=5)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### USP(Unified Sequence Parallel)
|
||||||
|
如果想要在NPU上使用该特性,请通过如下方式安装额外的第三方库:
|
||||||
|
```shell
|
||||||
|
pip install git+https://github.com/feifeibear/long-context-attention.git
|
||||||
|
pip install git+https://github.com/xdit-project/xDiT.git
|
||||||
|
```
|
||||||
|
|
||||||
### 训练
|
### 训练
|
||||||
当前已为每类模型添加NPU的启动脚本样例,脚本存放在`examples/xxx/special/npu_training`目录下,例如 `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`。
|
当前已为每类模型添加NPU的启动脚本样例,脚本存放在`examples/xxx/special/npu_training`目录下,例如 `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`。
|
||||||
|
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ graph LR;
|
|||||||
|
|
||||||
本节介绍如何利用 `DiffSynth-Studio` 训练新的模型,帮助科研工作者探索新的模型技术。
|
本节介绍如何利用 `DiffSynth-Studio` 训练新的模型,帮助科研工作者探索新的模型技术。
|
||||||
|
|
||||||
* 从零开始训练模型【coming soon】
|
* [从零开始训练模型](/docs/zh/Research_Tutorial/train_from_scratch.md)
|
||||||
* 推理改进优化技术【coming soon】
|
* 推理改进优化技术【coming soon】
|
||||||
* 设计可控生成模型【coming soon】
|
* 设计可控生成模型【coming soon】
|
||||||
* 创建新的训练范式【coming soon】
|
* 创建新的训练范式【coming soon】
|
||||||
|
|||||||
477
docs/zh/Research_Tutorial/train_from_scratch.md
Normal file
477
docs/zh/Research_Tutorial/train_from_scratch.md
Normal file
@@ -0,0 +1,477 @@
|
|||||||
|
# 从零开始训练模型
|
||||||
|
|
||||||
|
DiffSynth-Studio 的训练引擎支持从零开始训练基础模型,本文介绍如何从零开始训练一个参数量仅为 0.1B 的小型文生图模型。
|
||||||
|
|
||||||
|
## 1. 构建模型结构
|
||||||
|
|
||||||
|
### 1.1 Diffusion 模型
|
||||||
|
|
||||||
|
从 UNet [[1]](https://arxiv.org/abs/1505.04597) [[2]](https://arxiv.org/abs/2112.10752) 到 DiT [[3]](https://arxiv.org/abs/2212.09748) [[4]](https://arxiv.org/abs/2403.03206),Diffusion 的主流模型结构经历了多次演变。通常,一个 Diffusion 模型的输入包括:
|
||||||
|
|
||||||
|
* 图像张量(`latents`):图像的编码,由 VAE 模型产生,含有部分噪声
|
||||||
|
* 文本张量(`prompt_embeds`):文本的编码,由文本编码器产生
|
||||||
|
* 时间步(`timestep`):标量,用于标记当前处于 Diffusion 过程的哪个阶段
|
||||||
|
|
||||||
|
模型的输出是与图像张量形状相同的张量,表示模型预测的去噪方向,关于 Diffusion 模型理论的细节,请参考 [Diffusion 模型基本原理](/docs/zh/Training/Understanding_Diffusion_models.md)。在本文中,我们构建一个仅含 0.1B 参数的 DiT 模型:`AAADiT`。
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>模型结构代码</summary>
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch, accelerate
|
||||||
|
from PIL import Image
|
||||||
|
from typing import Union
|
||||||
|
from tqdm import tqdm
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from transformers import AutoProcessor, AutoTokenizer
|
||||||
|
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
|
||||||
|
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
|
||||||
|
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||||
|
from diffsynth.models.general_modules import TimestepEmbeddings
|
||||||
|
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
|
||||||
|
from diffsynth.models.flux2_vae import Flux2VAE
|
||||||
|
|
||||||
|
|
||||||
|
class AAAPositionalEmbedding(torch.nn.Module):
|
||||||
|
def __init__(self, height=16, width=16, dim=1024):
|
||||||
|
super().__init__()
|
||||||
|
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
|
||||||
|
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
|
||||||
|
|
||||||
|
def forward(self, image, text):
|
||||||
|
height, width = image.shape[-2:]
|
||||||
|
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
|
||||||
|
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
|
||||||
|
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
|
||||||
|
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
|
||||||
|
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
|
||||||
|
emb = torch.concat([image_emb, text_emb], dim=1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class AAABlock(torch.nn.Module):
|
||||||
|
def __init__(self, dim=1024, num_heads=32):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||||
|
self.to_q = torch.nn.Linear(dim, dim)
|
||||||
|
self.to_k = torch.nn.Linear(dim, dim)
|
||||||
|
self.to_v = torch.nn.Linear(dim, dim)
|
||||||
|
self.to_out = torch.nn.Linear(dim, dim)
|
||||||
|
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||||
|
self.ff = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim, dim*3),
|
||||||
|
torch.nn.SiLU(),
|
||||||
|
torch.nn.Linear(dim*3, dim),
|
||||||
|
)
|
||||||
|
self.to_gate = torch.nn.Linear(dim, dim * 2)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
def attention(self, emb, pos_emb):
|
||||||
|
emb = self.norm_attn(emb + pos_emb)
|
||||||
|
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
|
||||||
|
emb = attention_forward(
|
||||||
|
q, k, v,
|
||||||
|
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
|
||||||
|
dims={"n": self.num_heads},
|
||||||
|
)
|
||||||
|
emb = self.to_out(emb)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
def feed_forward(self, emb, pos_emb):
|
||||||
|
emb = self.norm_mlp(emb + pos_emb)
|
||||||
|
emb = self.ff(emb)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
def forward(self, emb, pos_emb, t_emb):
|
||||||
|
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
|
||||||
|
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
|
||||||
|
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class AAADiT(torch.nn.Module):
|
||||||
|
def __init__(self, dim=1024):
|
||||||
|
super().__init__()
|
||||||
|
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
|
||||||
|
self.timestep_embedder = TimestepEmbeddings(256, dim)
|
||||||
|
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
|
||||||
|
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
|
||||||
|
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
|
||||||
|
self.proj_out = torch.nn.Linear(dim, 128)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
latents,
|
||||||
|
prompt_embeds,
|
||||||
|
timestep,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
):
|
||||||
|
pos_emb = self.pos_embedder(latents, prompt_embeds)
|
||||||
|
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
|
||||||
|
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
|
||||||
|
text = self.text_embedder(prompt_embeds)
|
||||||
|
emb = torch.concat([image, text], dim=1)
|
||||||
|
for block_id, block in enumerate(self.blocks):
|
||||||
|
emb = gradient_checkpoint_forward(
|
||||||
|
block,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
emb=emb,
|
||||||
|
pos_emb=pos_emb,
|
||||||
|
t_emb=t_emb,
|
||||||
|
)
|
||||||
|
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
|
||||||
|
emb = self.proj_out(emb)
|
||||||
|
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
|
||||||
|
return emb
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### 1.2 编解码器模型
|
||||||
|
|
||||||
|
除了用于去噪的 Diffusion 模型以外,我们还需要另外两个模型:
|
||||||
|
|
||||||
|
* 文本编码器:用于将文本编码为张量。我们采用 [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) 模型。
|
||||||
|
* VAE 编解码器:编码器部分用于将图像编码为张量,解码器部分用于将图像张量解码为图像。我们采用 [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 中的 VAE 模型。
|
||||||
|
|
||||||
|
这两个模型的结构都已集成在 DiffSynth-Studio 中,分别位于 [/diffsynth/models/z_image_text_encoder.py](/diffsynth/models/z_image_text_encoder.py) 和 [/diffsynth/models/flux2_vae.py](/diffsynth/models/flux2_vae.py),因此我们不需要修改任何代码。
|
||||||
|
|
||||||
|
## 2. 构建 Pipeline
|
||||||
|
|
||||||
|
我们在文档 [接入 Pipeline](/docs/zh/Developer_Guide/Building_a_Pipeline.md) 中介绍了如何构建一个模型 Pipeline,对于本文中的模型,我们也需要构建一个 Pipeline,连接文本编码器、Diffusion 模型、VAE 编解码器。
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>Pipeline 代码</summary>
|
||||||
|
|
||||||
|
```python
|
||||||
|
class AAAImagePipeline(BasePipeline):
|
||||||
|
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||||
|
super().__init__(
|
||||||
|
device=device, torch_dtype=torch_dtype,
|
||||||
|
height_division_factor=16, width_division_factor=16,
|
||||||
|
)
|
||||||
|
self.scheduler = FlowMatchScheduler("FLUX.2")
|
||||||
|
self.text_encoder: ZImageTextEncoder = None
|
||||||
|
self.dit: AAADiT = None
|
||||||
|
self.vae: Flux2VAE = None
|
||||||
|
self.tokenizer: AutoProcessor = None
|
||||||
|
self.in_iteration_models = ("dit",)
|
||||||
|
self.units = [
|
||||||
|
AAAUnit_PromptEmbedder(),
|
||||||
|
AAAUnit_NoiseInitializer(),
|
||||||
|
AAAUnit_InputImageEmbedder(),
|
||||||
|
]
|
||||||
|
self.model_fn = model_fn_aaa
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(
|
||||||
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
device: Union[str, torch.device] = "cuda",
|
||||||
|
model_configs: list[ModelConfig] = [],
|
||||||
|
tokenizer_config: ModelConfig = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
# Initialize pipeline
|
||||||
|
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||||
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||||
|
|
||||||
|
# Fetch models
|
||||||
|
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
|
||||||
|
pipe.dit = model_pool.fetch_model("aaa_dit")
|
||||||
|
pipe.vae = model_pool.fetch_model("flux2_vae")
|
||||||
|
if tokenizer_config is not None:
|
||||||
|
tokenizer_config.download_if_necessary()
|
||||||
|
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||||
|
|
||||||
|
# VRAM Management
|
||||||
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
# Prompt
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
cfg_scale: float = 1.0,
|
||||||
|
# Image
|
||||||
|
input_image: Image.Image = None,
|
||||||
|
denoising_strength: float = 1.0,
|
||||||
|
# Shape
|
||||||
|
height: int = 1024,
|
||||||
|
width: int = 1024,
|
||||||
|
# Randomness
|
||||||
|
seed: int = None,
|
||||||
|
rand_device: str = "cpu",
|
||||||
|
# Steps
|
||||||
|
num_inference_steps: int = 30,
|
||||||
|
# Progress bar
|
||||||
|
progress_bar_cmd = tqdm,
|
||||||
|
):
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
inputs_posi = {"prompt": prompt}
|
||||||
|
inputs_nega = {"negative_prompt": negative_prompt}
|
||||||
|
inputs_shared = {
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||||
|
"height": height, "width": width,
|
||||||
|
"seed": seed, "rand_device": rand_device,
|
||||||
|
"num_inference_steps": num_inference_steps,
|
||||||
|
}
|
||||||
|
for unit in self.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
|
|
||||||
|
# Denoise
|
||||||
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
noise_pred = self.cfg_guided_model_fn(
|
||||||
|
self.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
self.load_models_to_device(['vae'])
|
||||||
|
image = self.vae.decode(inputs_shared["latents"])
|
||||||
|
image = self.vae_output_to_image(image)
|
||||||
|
self.load_models_to_device([])
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class AAAUnit_PromptEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
seperate_cfg=True,
|
||||||
|
input_params_posi={"prompt": "prompt"},
|
||||||
|
input_params_nega={"prompt": "negative_prompt"},
|
||||||
|
output_params=("prompt_embeds",),
|
||||||
|
onload_model_names=("text_encoder",)
|
||||||
|
)
|
||||||
|
self.hidden_states_layers = (-1,)
|
||||||
|
|
||||||
|
def process(self, pipe: AAAImagePipeline, prompt):
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
text = pipe.tokenizer.apply_chat_template(
|
||||||
|
[{"role": "user", "content": prompt}],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=False,
|
||||||
|
)
|
||||||
|
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
|
||||||
|
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
|
||||||
|
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
|
||||||
|
return {"prompt_embeds": prompt_embeds}
|
||||||
|
|
||||||
|
|
||||||
|
class AAAUnit_NoiseInitializer(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width", "seed", "rand_device"),
|
||||||
|
output_params=("noise",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
|
||||||
|
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||||
|
return {"noise": noise}
|
||||||
|
|
||||||
|
|
||||||
|
class AAAUnit_InputImageEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_image", "noise"),
|
||||||
|
output_params=("latents", "input_latents"),
|
||||||
|
onload_model_names=("vae",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: AAAImagePipeline, input_image, noise):
|
||||||
|
if input_image is None:
|
||||||
|
return {"latents": noise, "input_latents": None}
|
||||||
|
pipe.load_models_to_device(['vae'])
|
||||||
|
image = pipe.preprocess_image(input_image)
|
||||||
|
input_latents = pipe.vae.encode(image)
|
||||||
|
if pipe.scheduler.training:
|
||||||
|
return {"latents": noise, "input_latents": input_latents}
|
||||||
|
else:
|
||||||
|
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||||
|
return {"latents": latents, "input_latents": input_latents}
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_aaa(
|
||||||
|
dit: AAADiT,
|
||||||
|
latents=None,
|
||||||
|
prompt_embeds=None,
|
||||||
|
timestep=None,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
model_output = dit(
|
||||||
|
latents,
|
||||||
|
prompt_embeds,
|
||||||
|
timestep,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
)
|
||||||
|
return model_output
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 3. 准备数据集
|
||||||
|
|
||||||
|
为了快速验证训练效果,我们使用数据集 [宝可梦-第一世代](https://modelscope.cn/datasets/DiffSynth-Studio/pokemon-gen1),这个数据集转载自开源项目 [pokemon-dataset-zh](https://github.com/42arch/pokemon-dataset-zh),包含从妙蛙种子到梦幻的 151 个第一世代宝可梦。如果你想使用其他数据集,请参考文档 [准备数据集](/docs/zh/Pipeline_Usage/Model_Training.md#准备数据集) 和 [`diffsynth.core.data`](/docs/zh/API_Reference/core/data.md)。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 开始训练
|
||||||
|
|
||||||
|
训练过程可使用 Pipeline 快速实现,我们已将完整的代码放在 [/docs/zh/Research_Tutorial/train_from_scratch.py](/docs/zh/Research_Tutorial/train_from_scratch.py),可直接通过 `python docs/zh/Research_Tutorial/train_from_scratch.py` 开始单 GPU 训练。
|
||||||
|
|
||||||
|
如需开启多 GPU 并行训练,请运行 `accelerate config` 设置相关参数,然后使用命令 `accelerate launch docs/zh/Research_Tutorial/train_from_scratch.py` 开始训练。
|
||||||
|
|
||||||
|
这个训练脚本没有设置停止条件,请在需要时手动关闭。模型在训练大约 6 万步后收敛,单 GPU 训练需要 10~20 小时。
|
||||||
|
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary>训练代码</summary>
|
||||||
|
|
||||||
|
```python
|
||||||
|
class AAATrainingModule(DiffusionTrainingModule):
|
||||||
|
def __init__(self, device):
|
||||||
|
super().__init__()
|
||||||
|
self.pipe = AAAImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device=device,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||||
|
)
|
||||||
|
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
|
||||||
|
self.pipe.freeze_except(["dit"])
|
||||||
|
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
inputs_posi = {"prompt": data["prompt"]}
|
||||||
|
inputs_nega = {"negative_prompt": ""}
|
||||||
|
inputs_shared = {
|
||||||
|
"input_image": data["image"],
|
||||||
|
"height": data["image"].size[1],
|
||||||
|
"width": data["image"].size[0],
|
||||||
|
"cfg_scale": 1,
|
||||||
|
"use_gradient_checkpointing": False,
|
||||||
|
"use_gradient_checkpointing_offload": False,
|
||||||
|
}
|
||||||
|
for unit in self.pipe.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||||
|
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
|
||||||
|
dataset = UnifiedDataset(
|
||||||
|
base_path="data/images",
|
||||||
|
metadata_path="data/metadata_merged.csv",
|
||||||
|
max_data_items=10000000,
|
||||||
|
data_file_keys=("image",),
|
||||||
|
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
|
||||||
|
)
|
||||||
|
model = AAATrainingModule(device=accelerator.device)
|
||||||
|
model_logger = ModelLogger(
|
||||||
|
"models/AAA/v1",
|
||||||
|
remove_prefix_in_ckpt="pipe.dit.",
|
||||||
|
)
|
||||||
|
launch_training_task(
|
||||||
|
accelerator, dataset, model, model_logger,
|
||||||
|
learning_rate=2e-4,
|
||||||
|
num_workers=4,
|
||||||
|
save_steps=50000,
|
||||||
|
num_epochs=999999,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
## 5. 验证训练效果
|
||||||
|
|
||||||
|
如果你不想等待模型训练完成,可以直接下载[我们预先训练好的模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel)。
|
||||||
|
|
||||||
|
```shell
|
||||||
|
modelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel
|
||||||
|
```
|
||||||
|
|
||||||
|
加载模型
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffsynth import load_model
|
||||||
|
|
||||||
|
pipe = AAAImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||||
|
)
|
||||||
|
pipe.dit = load_model(AAADiT, "models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors", torch_dtype=torch.bfloat16, device="cuda")
|
||||||
|
```
|
||||||
|
|
||||||
|
模型推理,生成第一世代宝可梦“御三家”,此时模型生成的图像内容与训练数据基本一致。
|
||||||
|
|
||||||
|
```python
|
||||||
|
for seed, prompt in enumerate([
|
||||||
|
"green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws",
|
||||||
|
"orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws",
|
||||||
|
"蓝色,米色,棕色,乌龟,水系,龟壳,大眼睛,短四肢,卷曲尾巴",
|
||||||
|
]):
|
||||||
|
image = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=" ",
|
||||||
|
num_inference_steps=30,
|
||||||
|
cfg_scale=10,
|
||||||
|
seed=seed,
|
||||||
|
height=256, width=256,
|
||||||
|
)
|
||||||
|
image.save(f"image_{seed}.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
||||
|
||||||
|
|-|-|-|
|
||||||
|
|
||||||
|
模型推理,生成具有“锐利爪子”的宝可梦,此时不同的随机种子能够产生不同的图像结果。
|
||||||
|
|
||||||
|
```python
|
||||||
|
for seed, prompt in enumerate([
|
||||||
|
"sharp claws",
|
||||||
|
"sharp claws",
|
||||||
|
"sharp claws",
|
||||||
|
]):
|
||||||
|
image = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=" ",
|
||||||
|
num_inference_steps=30,
|
||||||
|
cfg_scale=10,
|
||||||
|
seed=seed+4,
|
||||||
|
height=256, width=256,
|
||||||
|
)
|
||||||
|
image.save(f"image_sharp_claws_{seed}.jpg")
|
||||||
|
```
|
||||||
|
|
||||||
|
||||
|
||||||
|
|-|-|-|
|
||||||
|
|
||||||
|
现在,我们获得了一个 0.1B 的小型文生图模型,这个模型已经能够生成 151 个宝可梦,但无法生成其他图像内容。如果在此基础上增加数据量、模型参数量、GPU 数量,你就可以训练出一个更强大的文生图模型!
|
||||||
341
docs/zh/Research_Tutorial/train_from_scratch.py
Normal file
341
docs/zh/Research_Tutorial/train_from_scratch.py
Normal file
@@ -0,0 +1,341 @@
|
|||||||
|
import torch, accelerate
|
||||||
|
from PIL import Image
|
||||||
|
from typing import Union
|
||||||
|
from tqdm import tqdm
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from transformers import AutoProcessor, AutoTokenizer
|
||||||
|
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
|
||||||
|
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
|
||||||
|
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||||
|
from diffsynth.models.general_modules import TimestepEmbeddings
|
||||||
|
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
|
||||||
|
from diffsynth.models.flux2_vae import Flux2VAE
|
||||||
|
|
||||||
|
|
||||||
|
class AAAPositionalEmbedding(torch.nn.Module):
|
||||||
|
def __init__(self, height=16, width=16, dim=1024):
|
||||||
|
super().__init__()
|
||||||
|
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
|
||||||
|
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
|
||||||
|
|
||||||
|
def forward(self, image, text):
|
||||||
|
height, width = image.shape[-2:]
|
||||||
|
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
|
||||||
|
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
|
||||||
|
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
|
||||||
|
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
|
||||||
|
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
|
||||||
|
emb = torch.concat([image_emb, text_emb], dim=1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class AAABlock(torch.nn.Module):
|
||||||
|
def __init__(self, dim=1024, num_heads=32):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||||
|
self.to_q = torch.nn.Linear(dim, dim)
|
||||||
|
self.to_k = torch.nn.Linear(dim, dim)
|
||||||
|
self.to_v = torch.nn.Linear(dim, dim)
|
||||||
|
self.to_out = torch.nn.Linear(dim, dim)
|
||||||
|
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||||
|
self.ff = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(dim, dim*3),
|
||||||
|
torch.nn.SiLU(),
|
||||||
|
torch.nn.Linear(dim*3, dim),
|
||||||
|
)
|
||||||
|
self.to_gate = torch.nn.Linear(dim, dim * 2)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
def attention(self, emb, pos_emb):
|
||||||
|
emb = self.norm_attn(emb + pos_emb)
|
||||||
|
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
|
||||||
|
emb = attention_forward(
|
||||||
|
q, k, v,
|
||||||
|
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
|
||||||
|
dims={"n": self.num_heads},
|
||||||
|
)
|
||||||
|
emb = self.to_out(emb)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
def feed_forward(self, emb, pos_emb):
|
||||||
|
emb = self.norm_mlp(emb + pos_emb)
|
||||||
|
emb = self.ff(emb)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
def forward(self, emb, pos_emb, t_emb):
|
||||||
|
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
|
||||||
|
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
|
||||||
|
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class AAADiT(torch.nn.Module):
|
||||||
|
def __init__(self, dim=1024):
|
||||||
|
super().__init__()
|
||||||
|
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
|
||||||
|
self.timestep_embedder = TimestepEmbeddings(256, dim)
|
||||||
|
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
|
||||||
|
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
|
||||||
|
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
|
||||||
|
self.proj_out = torch.nn.Linear(dim, 128)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
latents,
|
||||||
|
prompt_embeds,
|
||||||
|
timestep,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
):
|
||||||
|
pos_emb = self.pos_embedder(latents, prompt_embeds)
|
||||||
|
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
|
||||||
|
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
|
||||||
|
text = self.text_embedder(prompt_embeds)
|
||||||
|
emb = torch.concat([image, text], dim=1)
|
||||||
|
for block_id, block in enumerate(self.blocks):
|
||||||
|
emb = gradient_checkpoint_forward(
|
||||||
|
block,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
emb=emb,
|
||||||
|
pos_emb=pos_emb,
|
||||||
|
t_emb=t_emb,
|
||||||
|
)
|
||||||
|
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
|
||||||
|
emb = self.proj_out(emb)
|
||||||
|
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class AAAImagePipeline(BasePipeline):
|
||||||
|
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||||
|
super().__init__(
|
||||||
|
device=device, torch_dtype=torch_dtype,
|
||||||
|
height_division_factor=16, width_division_factor=16,
|
||||||
|
)
|
||||||
|
self.scheduler = FlowMatchScheduler("FLUX.2")
|
||||||
|
self.text_encoder: ZImageTextEncoder = None
|
||||||
|
self.dit: AAADiT = None
|
||||||
|
self.vae: Flux2VAE = None
|
||||||
|
self.tokenizer: AutoProcessor = None
|
||||||
|
self.in_iteration_models = ("dit",)
|
||||||
|
self.units = [
|
||||||
|
AAAUnit_PromptEmbedder(),
|
||||||
|
AAAUnit_NoiseInitializer(),
|
||||||
|
AAAUnit_InputImageEmbedder(),
|
||||||
|
]
|
||||||
|
self.model_fn = model_fn_aaa
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_pretrained(
|
||||||
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
device: Union[str, torch.device] = "cuda",
|
||||||
|
model_configs: list[ModelConfig] = [],
|
||||||
|
tokenizer_config: ModelConfig = None,
|
||||||
|
vram_limit: float = None,
|
||||||
|
):
|
||||||
|
# Initialize pipeline
|
||||||
|
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||||
|
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||||
|
|
||||||
|
# Fetch models
|
||||||
|
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
|
||||||
|
pipe.dit = model_pool.fetch_model("aaa_dit")
|
||||||
|
pipe.vae = model_pool.fetch_model("flux2_vae")
|
||||||
|
if tokenizer_config is not None:
|
||||||
|
tokenizer_config.download_if_necessary()
|
||||||
|
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||||
|
|
||||||
|
# VRAM Management
|
||||||
|
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
# Prompt
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str = "",
|
||||||
|
cfg_scale: float = 1.0,
|
||||||
|
# Image
|
||||||
|
input_image: Image.Image = None,
|
||||||
|
denoising_strength: float = 1.0,
|
||||||
|
# Shape
|
||||||
|
height: int = 1024,
|
||||||
|
width: int = 1024,
|
||||||
|
# Randomness
|
||||||
|
seed: int = None,
|
||||||
|
rand_device: str = "cpu",
|
||||||
|
# Steps
|
||||||
|
num_inference_steps: int = 30,
|
||||||
|
# Progress bar
|
||||||
|
progress_bar_cmd = tqdm,
|
||||||
|
):
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
|
||||||
|
|
||||||
|
# Parameters
|
||||||
|
inputs_posi = {"prompt": prompt}
|
||||||
|
inputs_nega = {"negative_prompt": negative_prompt}
|
||||||
|
inputs_shared = {
|
||||||
|
"cfg_scale": cfg_scale,
|
||||||
|
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||||
|
"height": height, "width": width,
|
||||||
|
"seed": seed, "rand_device": rand_device,
|
||||||
|
"num_inference_steps": num_inference_steps,
|
||||||
|
}
|
||||||
|
for unit in self.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||||
|
|
||||||
|
# Denoise
|
||||||
|
self.load_models_to_device(self.in_iteration_models)
|
||||||
|
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||||
|
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||||
|
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||||
|
noise_pred = self.cfg_guided_model_fn(
|
||||||
|
self.model_fn, cfg_scale,
|
||||||
|
inputs_shared, inputs_posi, inputs_nega,
|
||||||
|
**models, timestep=timestep, progress_id=progress_id
|
||||||
|
)
|
||||||
|
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
self.load_models_to_device(['vae'])
|
||||||
|
image = self.vae.decode(inputs_shared["latents"])
|
||||||
|
image = self.vae_output_to_image(image)
|
||||||
|
self.load_models_to_device([])
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class AAAUnit_PromptEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
seperate_cfg=True,
|
||||||
|
input_params_posi={"prompt": "prompt"},
|
||||||
|
input_params_nega={"prompt": "negative_prompt"},
|
||||||
|
output_params=("prompt_embeds",),
|
||||||
|
onload_model_names=("text_encoder",)
|
||||||
|
)
|
||||||
|
self.hidden_states_layers = (-1,)
|
||||||
|
|
||||||
|
def process(self, pipe: AAAImagePipeline, prompt):
|
||||||
|
pipe.load_models_to_device(self.onload_model_names)
|
||||||
|
text = pipe.tokenizer.apply_chat_template(
|
||||||
|
[{"role": "user", "content": prompt}],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=False,
|
||||||
|
)
|
||||||
|
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
|
||||||
|
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
|
||||||
|
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
|
||||||
|
return {"prompt_embeds": prompt_embeds}
|
||||||
|
|
||||||
|
|
||||||
|
class AAAUnit_NoiseInitializer(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("height", "width", "seed", "rand_device"),
|
||||||
|
output_params=("noise",),
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
|
||||||
|
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||||
|
return {"noise": noise}
|
||||||
|
|
||||||
|
|
||||||
|
class AAAUnit_InputImageEmbedder(PipelineUnit):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
input_params=("input_image", "noise"),
|
||||||
|
output_params=("latents", "input_latents"),
|
||||||
|
onload_model_names=("vae",)
|
||||||
|
)
|
||||||
|
|
||||||
|
def process(self, pipe: AAAImagePipeline, input_image, noise):
|
||||||
|
if input_image is None:
|
||||||
|
return {"latents": noise, "input_latents": None}
|
||||||
|
pipe.load_models_to_device(['vae'])
|
||||||
|
image = pipe.preprocess_image(input_image)
|
||||||
|
input_latents = pipe.vae.encode(image)
|
||||||
|
if pipe.scheduler.training:
|
||||||
|
return {"latents": noise, "input_latents": input_latents}
|
||||||
|
else:
|
||||||
|
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||||
|
return {"latents": latents, "input_latents": input_latents}
|
||||||
|
|
||||||
|
|
||||||
|
def model_fn_aaa(
|
||||||
|
dit: AAADiT,
|
||||||
|
latents=None,
|
||||||
|
prompt_embeds=None,
|
||||||
|
timestep=None,
|
||||||
|
use_gradient_checkpointing=False,
|
||||||
|
use_gradient_checkpointing_offload=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
model_output = dit(
|
||||||
|
latents,
|
||||||
|
prompt_embeds,
|
||||||
|
timestep,
|
||||||
|
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||||
|
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||||
|
)
|
||||||
|
return model_output
|
||||||
|
|
||||||
|
|
||||||
|
class AAATrainingModule(DiffusionTrainingModule):
|
||||||
|
def __init__(self, device):
|
||||||
|
super().__init__()
|
||||||
|
self.pipe = AAAImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device=device,
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
|
||||||
|
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||||
|
)
|
||||||
|
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
|
||||||
|
self.pipe.freeze_except(["dit"])
|
||||||
|
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
inputs_posi = {"prompt": data["prompt"]}
|
||||||
|
inputs_nega = {"negative_prompt": ""}
|
||||||
|
inputs_shared = {
|
||||||
|
"input_image": data["image"],
|
||||||
|
"height": data["image"].size[1],
|
||||||
|
"width": data["image"].size[0],
|
||||||
|
"cfg_scale": 1,
|
||||||
|
"use_gradient_checkpointing": False,
|
||||||
|
"use_gradient_checkpointing_offload": False,
|
||||||
|
}
|
||||||
|
for unit in self.pipe.units:
|
||||||
|
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||||
|
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
|
||||||
|
dataset = UnifiedDataset(
|
||||||
|
base_path="data/images",
|
||||||
|
metadata_path="data/metadata_merged.csv",
|
||||||
|
max_data_items=10000000,
|
||||||
|
data_file_keys=("image",),
|
||||||
|
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
|
||||||
|
)
|
||||||
|
model = AAATrainingModule(device=accelerator.device)
|
||||||
|
model_logger = ModelLogger(
|
||||||
|
"models/AAA/v1",
|
||||||
|
remove_prefix_in_ckpt="pipe.dit.",
|
||||||
|
)
|
||||||
|
launch_training_task(
|
||||||
|
accelerator, dataset, model, model_logger,
|
||||||
|
learning_rate=2e-4,
|
||||||
|
num_workers=4,
|
||||||
|
save_steps=50000,
|
||||||
|
num_epochs=999999,
|
||||||
|
)
|
||||||
@@ -6,7 +6,7 @@
|
|||||||
|
|
||||||
Diffusion 模型通过多步迭代式地去噪(denoise)生成清晰的图像或视频内容,我们从一个数据样本 $x_0$ 的生成过程开始讲起。直观地,在完整的一轮 denoise 过程中,我们从随机高斯噪声 $x_T$ 开始,通过迭代依次得到 $x_{T-1}$、$x_{T-2}$、$x_{T-3}$、$\cdots$,在每一步中逐渐减少噪声含量,最终得到不含噪声的数据样本 $x_0$。
|
Diffusion 模型通过多步迭代式地去噪(denoise)生成清晰的图像或视频内容,我们从一个数据样本 $x_0$ 的生成过程开始讲起。直观地,在完整的一轮 denoise 过程中,我们从随机高斯噪声 $x_T$ 开始,通过迭代依次得到 $x_{T-1}$、$x_{T-2}$、$x_{T-3}$、$\cdots$,在每一步中逐渐减少噪声含量,最终得到不含噪声的数据样本 $x_0$。
|
||||||
|
|
||||||
(图)
|

|
||||||
|
|
||||||
这个过程是很直观的,但如果要理解其中的细节,我们就需要回答这几个问题:
|
这个过程是很直观的,但如果要理解其中的细节,我们就需要回答这几个问题:
|
||||||
|
|
||||||
@@ -28,7 +28,7 @@ Diffusion 模型通过多步迭代式地去噪(denoise)生成清晰的图像
|
|||||||
|
|
||||||
那么在中间的某一步,我们可以直接合成含噪声的数据样本 $x_t=(1-\sigma_t)x_0+\sigma_t x_T$。
|
那么在中间的某一步,我们可以直接合成含噪声的数据样本 $x_t=(1-\sigma_t)x_0+\sigma_t x_T$。
|
||||||
|
|
||||||
(图)
|

|
||||||
|
|
||||||
## 迭代去噪的计算是如何进行的?
|
## 迭代去噪的计算是如何进行的?
|
||||||
|
|
||||||
@@ -40,8 +40,6 @@ Diffusion 模型通过多步迭代式地去噪(denoise)生成清晰的图像
|
|||||||
|
|
||||||
其中,引导条件 $c$ 是新引入的参数,它是由用户输入的,可以是用于描述图像内容的文本,也可以是用于勾勒图像结构的线稿图。
|
其中,引导条件 $c$ 是新引入的参数,它是由用户输入的,可以是用于描述图像内容的文本,也可以是用于勾勒图像结构的线稿图。
|
||||||
|
|
||||||
(图)
|
|
||||||
|
|
||||||
而模型的输出 $\hat \epsilon(x_t,c,t)$,则近似地等于 $x_T-x_0$,也就是整个扩散过程(去噪过程的反向过程)的方向。
|
而模型的输出 $\hat \epsilon(x_t,c,t)$,则近似地等于 $x_T-x_0$,也就是整个扩散过程(去噪过程的反向过程)的方向。
|
||||||
|
|
||||||
接下来我们分析一步迭代中发生的计算,在时间步 $t$,模型通过计算得到近似的 $x_T-x_0$ 后,我们计算下一步的 $x_{t-1}$:
|
接下来我们分析一步迭代中发生的计算,在时间步 $t$,模型通过计算得到近似的 $x_T-x_0$ 后,我们计算下一步的 $x_{t-1}$:
|
||||||
@@ -89,8 +87,6 @@ $$
|
|||||||
|
|
||||||
训练过程不同于生成过程,如果我们在训练过程中保留多步迭代,那么梯度需经过多步回传,带来的时间和空间复杂度是灾难性的。为了提高计算效率,我们在训练中随机选择某一时间步 $t$ 进行训练。
|
训练过程不同于生成过程,如果我们在训练过程中保留多步迭代,那么梯度需经过多步回传,带来的时间和空间复杂度是灾难性的。为了提高计算效率,我们在训练中随机选择某一时间步 $t$ 进行训练。
|
||||||
|
|
||||||
(图)
|
|
||||||
|
|
||||||
以下是训练过程的伪代码
|
以下是训练过程的伪代码
|
||||||
|
|
||||||
> 从数据集获取数据样本 $x_0$ 和引导条件 $c$
|
> 从数据集获取数据样本 $x_0$ 和引导条件 $c$
|
||||||
@@ -111,7 +107,7 @@ $$
|
|||||||
|
|
||||||
从理论到实践,还需要填充更多细节。现代 Diffusion 模型架构已经发展成熟,主流的架构沿用了 Latent Diffusion 所提出的“三段式”架构,包括数据编解码器、引导条件编码器、去噪模型三部分。
|
从理论到实践,还需要填充更多细节。现代 Diffusion 模型架构已经发展成熟,主流的架构沿用了 Latent Diffusion 所提出的“三段式”架构,包括数据编解码器、引导条件编码器、去噪模型三部分。
|
||||||
|
|
||||||
(图)
|

|
||||||
|
|
||||||
### 数据编解码器
|
### 数据编解码器
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ pipe = Flux2ImagePipeline.from_pretrained(
|
|||||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
)
|
)
|
||||||
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
||||||
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
|
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ pipe = Flux2ImagePipeline.from_pretrained(
|
|||||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
)
|
)
|
||||||
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
||||||
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
|
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ pipe = Flux2ImagePipeline.from_pretrained(
|
|||||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
)
|
)
|
||||||
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
||||||
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ pipe = Flux2ImagePipeline.from_pretrained(
|
|||||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
|
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
)
|
)
|
||||||
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
||||||
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||||
|
|||||||
@@ -0,0 +1,69 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
from PIL import Image
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cuda",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||||
|
negative_prompt = (
|
||||||
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||||
|
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||||
|
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||||
|
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||||
|
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||||
|
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||||
|
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||||
|
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||||
|
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||||
|
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||||
|
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
)
|
||||||
|
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||||
|
local_dir="./",
|
||||||
|
allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"]
|
||||||
|
)
|
||||||
|
image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height))
|
||||||
|
# first frame
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=43,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=True,
|
||||||
|
use_distilled_pipeline=True,
|
||||||
|
input_images=[image],
|
||||||
|
input_images_indexes=[0],
|
||||||
|
input_images_strength=1.0,
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_distilled_i2av_first.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
55
examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py
Normal file
55
examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
from PIL import Image
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cuda",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
)
|
||||||
|
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||||
|
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||||
|
local_dir="./",
|
||||||
|
allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"]
|
||||||
|
)
|
||||||
|
image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height))
|
||||||
|
# first frame
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=43,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=False,
|
||||||
|
input_images=[image],
|
||||||
|
input_images_indexes=[0],
|
||||||
|
input_images_strength=1.0,
|
||||||
|
num_inference_steps=40,
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_onestage_i2av_first.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
72
examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py
Normal file
72
examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
from PIL import Image
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cuda",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||||
|
negative_prompt = (
|
||||||
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||||
|
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||||
|
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||||
|
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||||
|
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||||
|
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||||
|
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||||
|
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||||
|
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||||
|
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||||
|
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
)
|
||||||
|
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||||
|
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||||
|
local_dir="./",
|
||||||
|
allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"]
|
||||||
|
)
|
||||||
|
image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height))
|
||||||
|
# first frame
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=42,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=True,
|
||||||
|
use_two_stage_pipeline=True,
|
||||||
|
num_inference_steps=40,
|
||||||
|
input_images=[image],
|
||||||
|
input_images_indexes=[0],
|
||||||
|
input_images_strength=1.0,
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_twostage_i2av_first.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cuda",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||||
|
negative_prompt = (
|
||||||
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||||
|
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||||
|
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||||
|
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||||
|
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||||
|
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||||
|
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||||
|
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||||
|
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||||
|
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||||
|
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
)
|
||||||
|
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=43,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=True,
|
||||||
|
use_distilled_pipeline=True,
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_distilled.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
42
examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py
Normal file
42
examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cuda",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
)
|
||||||
|
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||||
|
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
height, width, num_frames = 512, 768, 121
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=43,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=True,
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_onestage.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
58
examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py
Normal file
58
examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cuda",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||||
|
negative_prompt = (
|
||||||
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||||
|
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||||
|
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||||
|
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||||
|
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||||
|
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||||
|
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||||
|
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||||
|
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||||
|
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||||
|
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
)
|
||||||
|
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=43,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=True,
|
||||||
|
use_two_stage_pipeline=True,
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_twostage.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
@@ -0,0 +1,70 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
from PIL import Image
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float8_e5m2,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float8_e5m2,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e5m2,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||||
|
negative_prompt = (
|
||||||
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||||
|
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||||
|
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||||
|
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||||
|
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||||
|
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||||
|
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||||
|
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||||
|
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||||
|
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||||
|
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
)
|
||||||
|
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||||
|
local_dir="./",
|
||||||
|
allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"]
|
||||||
|
)
|
||||||
|
image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height))
|
||||||
|
# first frame
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=43,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=True,
|
||||||
|
use_distilled_pipeline=True,
|
||||||
|
input_images=[image],
|
||||||
|
input_images_indexes=[0],
|
||||||
|
input_images_strength=1.0,
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_distilled_i2av_first.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
@@ -0,0 +1,56 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
from PIL import Image
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float8_e5m2,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float8_e5m2,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e5m2,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||||
|
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||||
|
local_dir="./",
|
||||||
|
allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"]
|
||||||
|
)
|
||||||
|
image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height))
|
||||||
|
# first frame
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=43,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=False,
|
||||||
|
input_images=[image],
|
||||||
|
input_images_indexes=[0],
|
||||||
|
input_images_strength=1.0,
|
||||||
|
num_inference_steps=40,
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_onestage_i2av_first.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
@@ -0,0 +1,72 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
from PIL import Image
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float8_e5m2,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float8_e5m2,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e5m2,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||||
|
negative_prompt = (
|
||||||
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||||
|
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||||
|
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||||
|
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||||
|
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||||
|
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||||
|
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||||
|
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||||
|
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||||
|
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||||
|
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
)
|
||||||
|
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
|
||||||
|
local_dir="./",
|
||||||
|
allow_file_pattern=["data/examples/ltx-2/first_frame.jpg"]
|
||||||
|
)
|
||||||
|
image = Image.open("data/examples/ltx-2/first_frame.jpg").convert("RGB").resize((width, height))
|
||||||
|
# first frame
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=42,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=True,
|
||||||
|
use_two_stage_pipeline=True,
|
||||||
|
num_inference_steps=40,
|
||||||
|
input_images=[image],
|
||||||
|
input_images_indexes=[0],
|
||||||
|
input_images_strength=1.0,
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_twostage_i2av_first.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float8_e5m2,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float8_e5m2,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e5m2,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||||
|
negative_prompt = (
|
||||||
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||||
|
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||||
|
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||||
|
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||||
|
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||||
|
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||||
|
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||||
|
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||||
|
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||||
|
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||||
|
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
)
|
||||||
|
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=43,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=True,
|
||||||
|
use_distilled_pipeline=True,
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_distilled.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float8_e5m2,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float8_e5m2,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e5m2,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||||
|
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
height, width, num_frames = 512, 768, 121
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=43,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=True,
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_onestage.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
import torch
|
||||||
|
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||||
|
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.float8_e5m2,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.float8_e5m2,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e5m2,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||||
|
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||||
|
negative_prompt = (
|
||||||
|
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||||
|
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||||
|
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||||
|
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||||
|
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||||
|
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||||
|
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||||
|
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||||
|
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||||
|
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||||
|
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||||
|
)
|
||||||
|
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||||
|
video, audio = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=43,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_frames=num_frames,
|
||||||
|
tiled=True,
|
||||||
|
use_two_stage_pipeline=True,
|
||||||
|
)
|
||||||
|
write_video_audio_ltx2(
|
||||||
|
video=video,
|
||||||
|
audio=audio,
|
||||||
|
output_path='ltx2_twostage.mp4',
|
||||||
|
fps=24,
|
||||||
|
audio_sample_rate=24000,
|
||||||
|
)
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||||
|
|
||||||
|
|
||||||
|
pipe = QwenImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors"),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny", origin_file_pattern="model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||||
|
local_dir="./data/example_image_dataset",
|
||||||
|
allow_file_pattern="canny/*.jpg"
|
||||||
|
)
|
||||||
|
prompt = "一只小狗,毛发光洁柔顺,眼神灵动,背景是樱花纷飞的春日庭院,唯美温馨。"
|
||||||
|
|
||||||
|
controlnet_canny_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1328, 1328))
|
||||||
|
|
||||||
|
controlnet_inpaint_image = Image.open("./data/example_image_dataset/canny/image_2.jpg").convert("RGB").resize((1328, 1328))
|
||||||
|
# generate a centered square mask
|
||||||
|
inpaint_mask = Image.new("L", controlnet_inpaint_image.size, 0)
|
||||||
|
mask_size = 512
|
||||||
|
left = (controlnet_inpaint_image.width - mask_size) // 2
|
||||||
|
top = (controlnet_inpaint_image.height - mask_size) // 2
|
||||||
|
right = left + mask_size
|
||||||
|
bottom = top + mask_size
|
||||||
|
inpaint_mask.paste(255, (left, top, right, bottom))
|
||||||
|
inpaint_mask = inpaint_mask.resize((1328, 1328)).convert("RGB")
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt, seed=0,
|
||||||
|
input_image=controlnet_inpaint_image, inpaint_mask=inpaint_mask,
|
||||||
|
blockwise_controlnet_inputs=[
|
||||||
|
ControlNetInput(image=controlnet_inpaint_image, inpaint_mask=inpaint_mask, controlnet_id=0),
|
||||||
|
ControlNetInput(image=controlnet_canny_image, controlnet_id=1),
|
||||||
|
],
|
||||||
|
num_inference_steps=40,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, FlowMatchScheduler
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pipe = QwenImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
lora = ModelConfig(
|
||||||
|
model_id="lightx2v/Qwen-Image-Edit-2511-Lightning",
|
||||||
|
origin_file_pattern="Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors"
|
||||||
|
)
|
||||||
|
pipe.load_lora(pipe.dit, lora, alpha=8/64)
|
||||||
|
pipe.scheduler = FlowMatchScheduler("Qwen-Image-Lightning")
|
||||||
|
|
||||||
|
|
||||||
|
dataset_snapshot_download(
|
||||||
|
"DiffSynth-Studio/example_image_dataset",
|
||||||
|
allow_file_pattern="qwen_image_edit/*",
|
||||||
|
local_dir="data/example_image_dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "生成这两个人的合影"
|
||||||
|
edit_image = [
|
||||||
|
Image.open("data/example_image_dataset/qwen_image_edit/image1.jpg"),
|
||||||
|
Image.open("data/example_image_dataset/qwen_image_edit/image2.jpg"),
|
||||||
|
]
|
||||||
|
image = pipe(
|
||||||
|
prompt,
|
||||||
|
edit_image=edit_image,
|
||||||
|
seed=1,
|
||||||
|
num_inference_steps=4,
|
||||||
|
height=1152,
|
||||||
|
width=896,
|
||||||
|
edit_image_auto_resize=True,
|
||||||
|
zero_cond_t=True, # This is a special parameter introduced by Qwen-Image-Edit-2511
|
||||||
|
cfg_scale=1.0,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
|
|
||||||
|
# Qwen-Image-Edit-2511 is a multi-image editing model.
|
||||||
|
# Please use a list to input `edit_image`, even if the input contains only one image.
|
||||||
|
# edit_image = [Image.open("image.jpg")]
|
||||||
|
# Please do not input the image directly.
|
||||||
|
# edit_image = Image.open("image.jpg")
|
||||||
@@ -0,0 +1,59 @@
|
|||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, ControlNetInput
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": torch.float8_e4m3fn,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e4m3fn,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = QwenImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint", origin_file_pattern="model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny", origin_file_pattern="model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_snapshot_download(
|
||||||
|
dataset_id="DiffSynth-Studio/example_image_dataset",
|
||||||
|
local_dir="./data/example_image_dataset",
|
||||||
|
allow_file_pattern="canny/*.jpg"
|
||||||
|
)
|
||||||
|
prompt = "一只小狗,毛发光洁柔顺,眼神灵动,背景是樱花纷飞的春日庭院,唯美温馨。"
|
||||||
|
|
||||||
|
controlnet_canny_image = Image.open("data/example_image_dataset/canny/image_1.jpg").resize((1328, 1328))
|
||||||
|
|
||||||
|
controlnet_inpaint_image = Image.open("./data/example_image_dataset/canny/image_2.jpg").convert("RGB").resize((1328, 1328))
|
||||||
|
# generate a centered square mask
|
||||||
|
inpaint_mask = Image.new("L", controlnet_inpaint_image.size, 0)
|
||||||
|
mask_size = 512
|
||||||
|
left = (controlnet_inpaint_image.width - mask_size) // 2
|
||||||
|
top = (controlnet_inpaint_image.height - mask_size) // 2
|
||||||
|
right = left + mask_size
|
||||||
|
bottom = top + mask_size
|
||||||
|
inpaint_mask.paste(255, (left, top, right, bottom))
|
||||||
|
inpaint_mask = inpaint_mask.resize((1328, 1328)).convert("RGB")
|
||||||
|
|
||||||
|
image = pipe(
|
||||||
|
prompt, seed=0,
|
||||||
|
input_image=controlnet_inpaint_image, inpaint_mask=inpaint_mask,
|
||||||
|
blockwise_controlnet_inputs=[
|
||||||
|
ControlNetInput(image=controlnet_inpaint_image, inpaint_mask=inpaint_mask, controlnet_id=0),
|
||||||
|
ControlNetInput(image=controlnet_canny_image, controlnet_id=1),
|
||||||
|
],
|
||||||
|
num_inference_steps=40,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig, FlowMatchScheduler
|
||||||
|
from modelscope import dataset_snapshot_download
|
||||||
|
from PIL import Image
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": "disk",
|
||||||
|
"offload_device": "disk",
|
||||||
|
"onload_dtype": torch.float8_e4m3fn,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.float8_e4m3fn,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = QwenImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image-Edit-2511", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
processor_config=ModelConfig(model_id="Qwen/Qwen-Image-Edit", origin_file_pattern="processor/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
lora = ModelConfig(
|
||||||
|
model_id="lightx2v/Qwen-Image-Edit-2511-Lightning",
|
||||||
|
origin_file_pattern="Qwen-Image-Edit-2511-Lightning-4steps-V1.0-bf16.safetensors"
|
||||||
|
)
|
||||||
|
pipe.load_lora(pipe.dit, lora, alpha=8/64)
|
||||||
|
pipe.scheduler = FlowMatchScheduler("Qwen-Image-Lightning")
|
||||||
|
|
||||||
|
|
||||||
|
dataset_snapshot_download(
|
||||||
|
"DiffSynth-Studio/example_image_dataset",
|
||||||
|
allow_file_pattern="qwen_image_edit/*",
|
||||||
|
local_dir="data/example_image_dataset",
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt = "生成这两个人的合影"
|
||||||
|
edit_image = [
|
||||||
|
Image.open("data/example_image_dataset/qwen_image_edit/image1.jpg"),
|
||||||
|
Image.open("data/example_image_dataset/qwen_image_edit/image2.jpg"),
|
||||||
|
]
|
||||||
|
image = pipe(
|
||||||
|
prompt,
|
||||||
|
edit_image=edit_image,
|
||||||
|
seed=1,
|
||||||
|
num_inference_steps=4,
|
||||||
|
height=1152,
|
||||||
|
width=896,
|
||||||
|
edit_image_auto_resize=True,
|
||||||
|
zero_cond_t=True, # This is a special parameter introduced by Qwen-Image-Edit-2511
|
||||||
|
cfg_scale=1.0,
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
|
|
||||||
|
# Qwen-Image-Edit-2511 is a multi-image editing model.
|
||||||
|
# Please use a list to input `edit_image`, even if the input contains only one image.
|
||||||
|
# edit_image = [Image.open("image.jpg")]
|
||||||
|
# Please do not input the image directly.
|
||||||
|
# edit_image = Image.open("image.jpg")
|
||||||
61
examples/z_image/model_inference/Z-Image-i2L.py
Normal file
61
examples/z_image/model_inference/Z-Image-i2L.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
from diffsynth.pipelines.z_image import (
|
||||||
|
ZImagePipeline, ModelConfig,
|
||||||
|
ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode
|
||||||
|
)
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# Use `vram_config` to enable LoRA hot-loading
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cuda",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cuda",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load models
|
||||||
|
pipe = ZImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors"),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors"),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/Z-Image-i2L", origin_file_pattern="model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load images
|
||||||
|
snapshot_download(
|
||||||
|
model_id="DiffSynth-Studio/Z-Image-i2L",
|
||||||
|
allow_file_pattern="assets/style/*",
|
||||||
|
local_dir="data/Z-Image-i2L_style_input"
|
||||||
|
)
|
||||||
|
images = [Image.open(f"data/Z-Image-i2L_style_input/assets/style/1/{i}.jpg") for i in range(4)]
|
||||||
|
|
||||||
|
# Image to LoRA
|
||||||
|
with torch.no_grad():
|
||||||
|
embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)
|
||||||
|
lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"]
|
||||||
|
save_file(lora, "lora.safetensors")
|
||||||
|
|
||||||
|
# Generate images
|
||||||
|
prompt = "a cat"
|
||||||
|
negative_prompt = "泛黄,发绿,模糊,低分辨率,低质量图像,扭曲的肢体,诡异的外观,丑陋,AI感,噪点,网格感,JPEG压缩条纹,异常的肢体,水印,乱码,意义不明的字符"
|
||||||
|
image = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=0, cfg_scale=4, num_inference_steps=50,
|
||||||
|
positive_only_lora=lora,
|
||||||
|
sigma_shift=8
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
17
examples/z_image/model_inference/Z-Image.py
Normal file
17
examples/z_image/model_inference/Z-Image.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
pipe = ZImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors"),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
|
||||||
|
image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||||
|
image.save("image_Z-Image.jpg")
|
||||||
@@ -33,6 +33,7 @@ pipe = ZImagePipeline.from_pretrained(
|
|||||||
ModelConfig(model_id="DiffSynth-Studio/Z-Image-Omni-Base-i2L", origin_file_pattern="model.safetensors", **vram_config),
|
ModelConfig(model_id="DiffSynth-Studio/Z-Image-Omni-Base-i2L", origin_file_pattern="model.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load images
|
# Load images
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ pipe = ZImagePipeline.from_pretrained(
|
|||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
)
|
)
|
||||||
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
|
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
|
||||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4)
|
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ pipe = ZImagePipeline.from_pretrained(
|
|||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_snapshot_download(
|
dataset_snapshot_download(
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ pipe = ZImagePipeline.from_pretrained(
|
|||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Control
|
# Control
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ pipe = ZImagePipeline.from_pretrained(
|
|||||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
],
|
],
|
||||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Control
|
# Control
|
||||||
|
|||||||
62
examples/z_image/model_inference_low_vram/Z-Image-i2L.py
Normal file
62
examples/z_image/model_inference_low_vram/Z-Image-i2L.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
from diffsynth.pipelines.z_image import (
|
||||||
|
ZImagePipeline, ModelConfig,
|
||||||
|
ZImageUnit_Image2LoRAEncode, ZImageUnit_Image2LoRADecode
|
||||||
|
)
|
||||||
|
from modelscope import snapshot_download
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# Use `vram_config` to enable LoRA hot-loading
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Load models
|
||||||
|
pipe = ZImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="SigLIP2-G384/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/General-Image-Encoders", origin_file_pattern="DINOv3-7B/model.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="DiffSynth-Studio/Z-Image-i2L", origin_file_pattern="model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load images
|
||||||
|
snapshot_download(
|
||||||
|
model_id="DiffSynth-Studio/Z-Image-i2L",
|
||||||
|
allow_file_pattern="assets/style/*",
|
||||||
|
local_dir="data/Z-Image-i2L_style_input"
|
||||||
|
)
|
||||||
|
images = [Image.open(f"data/Z-Image-i2L_style_input/assets/style/1/{i}.jpg") for i in range(4)]
|
||||||
|
|
||||||
|
# Image to LoRA
|
||||||
|
with torch.no_grad():
|
||||||
|
embs = ZImageUnit_Image2LoRAEncode().process(pipe, image2lora_images=images)
|
||||||
|
lora = ZImageUnit_Image2LoRADecode().process(pipe, **embs)["lora"]
|
||||||
|
save_file(lora, "lora.safetensors")
|
||||||
|
|
||||||
|
# Generate images
|
||||||
|
prompt = "a cat"
|
||||||
|
negative_prompt = "泛黄,发绿,模糊,低分辨率,低质量图像,扭曲的肢体,诡异的外观,丑陋,AI感,噪点,网格感,JPEG压缩条纹,异常的肢体,水印,乱码,意义不明的字符"
|
||||||
|
image = pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
seed=0, cfg_scale=4, num_inference_steps=50,
|
||||||
|
positive_only_lora=lora,
|
||||||
|
sigma_shift=8
|
||||||
|
)
|
||||||
|
image.save("image.jpg")
|
||||||
27
examples/z_image/model_inference_low_vram/Z-Image.py
Normal file
27
examples/z_image/model_inference_low_vram/Z-Image.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
vram_config = {
|
||||||
|
"offload_dtype": torch.bfloat16,
|
||||||
|
"offload_device": "cpu",
|
||||||
|
"onload_dtype": torch.bfloat16,
|
||||||
|
"onload_device": "cpu",
|
||||||
|
"preparing_dtype": torch.bfloat16,
|
||||||
|
"preparing_device": "cuda",
|
||||||
|
"computation_dtype": torch.bfloat16,
|
||||||
|
"computation_device": "cuda",
|
||||||
|
}
|
||||||
|
pipe = ZImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
|
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||||
|
)
|
||||||
|
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
|
||||||
|
image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||||
|
image.save("image_Z-Image.jpg")
|
||||||
14
examples/z_image/model_training/full/Z-Image.sh
Normal file
14
examples/z_image/model_training/full/Z-Image.sh
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# This example is tested on 8*A100
|
||||||
|
accelerate launch --config_file examples/z_image/model_training/full/accelerate_config.yaml examples/z_image/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_image_dataset \
|
||||||
|
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 400 \
|
||||||
|
--model_id_with_origin_paths "Tongyi-MAI/Z-Image:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--learning_rate 1e-5 \
|
||||||
|
--num_epochs 2 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Z-Image_full" \
|
||||||
|
--trainable_models "dit" \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--dataset_num_workers 8
|
||||||
15
examples/z_image/model_training/lora/Z-Image.sh
Normal file
15
examples/z_image/model_training/lora/Z-Image.sh
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
accelerate launch examples/z_image/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_image_dataset \
|
||||||
|
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 50 \
|
||||||
|
--model_id_with_origin_paths "Tongyi-MAI/Z-Image:transformer/*.safetensors,Tongyi-MAI/Z-Image-Turbo:text_encoder/*.safetensors,Tongyi-MAI/Z-Image-Turbo:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Z-Image_lora" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "to_q,to_k,to_v,to_out.0,w1,w2,w3" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--dataset_num_workers 8
|
||||||
20
examples/z_image/model_training/validate_full/Z-Image.py
Normal file
20
examples/z_image/model_training/validate_full/Z-Image.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||||
|
from diffsynth.core import load_state_dict
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
pipe = ZImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors"),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
state_dict = load_state_dict("./models/train/Z-Image_full/epoch-1.safetensors", torch_dtype=torch.bfloat16)
|
||||||
|
pipe.dit.load_state_dict(state_dict)
|
||||||
|
prompt = "a dog"
|
||||||
|
image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||||
|
image.save("image.jpg")
|
||||||
18
examples/z_image/model_training/validate_lora/Z-Image.py
Normal file
18
examples/z_image/model_training/validate_lora/Z-Image.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
pipe = ZImagePipeline.from_pretrained(
|
||||||
|
torch_dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
model_configs=[
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image", origin_file_pattern="transformer/*.safetensors"),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors"),
|
||||||
|
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||||
|
],
|
||||||
|
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||||
|
)
|
||||||
|
pipe.load_lora(pipe.dit, "./models/train/Z-Image_lora/epoch-4.safetensors")
|
||||||
|
prompt = "a dog"
|
||||||
|
image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||||
|
image.save("image.jpg")
|
||||||
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "diffsynth"
|
name = "diffsynth"
|
||||||
version = "2.0.3"
|
version = "2.0.4"
|
||||||
description = "Enjoy the magic of Diffusion models!"
|
description = "Enjoy the magic of Diffusion models!"
|
||||||
authors = [{name = "ModelScope Team"}]
|
authors = [{name = "ModelScope Team"}]
|
||||||
license = {text = "Apache-2.0"}
|
license = {text = "Apache-2.0"}
|
||||||
|
|||||||
Reference in New Issue
Block a user