mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
Compare commits
1 Commits
examples-u
...
z-imgae-om
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
63559a3ad6 |
2
.github/workflows/publish.yaml
vendored
2
.github/workflows/publish.yaml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
- name: Install wheel
|
||||
run: pip install wheel==0.44.0 && pip install -r requirements.txt
|
||||
- name: Build DiffSynth
|
||||
run: python -m build
|
||||
run: python setup.py sdist bdist_wheel
|
||||
- name: Publish package to PyPI
|
||||
run: |
|
||||
pip install twine
|
||||
|
||||
119
README.md
119
README.md
@@ -33,16 +33,6 @@ We believe that a well-developed open-source code framework can lower the thresh
|
||||
|
||||
> Currently, the development personnel of this project are limited, with most of the work handled by [Artiprocher](https://github.com/Artiprocher). Therefore, the progress of new feature development will be relatively slow, and the speed of responding to and resolving issues is limited. We apologize for this and ask developers to understand.
|
||||
|
||||
- **February 2, 2026** The first document of the Research Tutorial series is now available, guiding you through training a small 0.1B text-to-image model from scratch. For details, see the [documentation](/docs/en/Research_Tutorial/train_from_scratch.md) and [model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel). We hope DiffSynth-Studio can evolve into a more powerful training framework for Diffusion models.
|
||||
|
||||
- **January 27, 2026**: [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) is released, and our [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) model is released concurrently. You can use it in [ModelScope Studios](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L). For details, see the [documentation](/docs/zh/Model_Details/Z-Image.md).
|
||||
|
||||
- **January 19, 2026**: Added support for [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) and [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) models, including training and inference capabilities. [Documentation](/docs/en/Model_Details/FLUX2.md) and [example code](/examples/flux2/) are now available.
|
||||
|
||||
- **January 12, 2026**: We trained and open-sourced a text-guided image layer separation model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)). Given an input image and a textual description, the model isolates the image layer corresponding to the described content. For more details, please refer to our blog post ([Chinese version](https://modelscope.cn/learn/4938), [English version](https://huggingface.co/blog/kelseye/qwen-image-layered-control)).
|
||||
|
||||
- **December 24, 2025**: Based on Qwen-Image-Edit-2511, we trained an In-Context Editing LoRA model ([Model Link](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)). This model takes three images as input (Image A, Image B, and Image C), and automatically analyzes the transformation from Image A to Image B, then applies the same transformation to Image C to generate Image D. For more details, please refer to our blog post ([Chinese version](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g), [English version](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora)).
|
||||
|
||||
- **December 9, 2025** We release a wild model based on DiffSynth-Studio 2.0: [Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L) (Image-to-LoRA). This model takes an image as input and outputs a LoRA. Although this version still has significant room for improvement in terms of generalization, detail preservation, and other aspects, we are open-sourcing these models to inspire more innovative research. For more details, please refer to our [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l).
|
||||
|
||||
- **December 4, 2025** DiffSynth-Studio 2.0 released! Many new features online
|
||||
@@ -273,14 +263,9 @@ image.save("image.jpg")
|
||||
|
||||
Example code for Z-Image is available at: [/examples/z_image/](/examples/z_image/)
|
||||
|
||||
|Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|
||||
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|
||||
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|
||||
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
||||
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|
||||
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|
||||
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
|
||||
|
||||
</details>
|
||||
|
||||
@@ -330,13 +315,9 @@ image.save("image.jpg")
|
||||
|
||||
Example code for FLUX.2 is available at: [/examples/flux2/](/examples/flux2/)
|
||||
|
||||
| Model ID | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||
|-|-|-|-|-|-|-|
|
||||
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
|
||||
| Model ID | Inference | Low-VRAM Inference | LoRA Training | LoRA Training Validation |
|
||||
|-|-|-|-|-|
|
||||
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
||||
|
||||
</details>
|
||||
|
||||
@@ -419,9 +400,7 @@ Example code for Qwen-Image is available at: [/examples/qwen_image/](/examples/q
|
||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
||||
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|
||||
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
||||
@@ -532,95 +511,6 @@ Example code for FLUX.1 is available at: [/examples/flux/](/examples/flux/)
|
||||
|
||||
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||
|
||||
#### LTX-2: [/docs/en/Model_Details/LTX-2.md](/docs/en/Model_Details/LTX-2.md)
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Quick Start</summary>
|
||||
|
||||
Running the following code will quickly load the [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) model for inference. VRAM management is enabled, and the framework automatically adjusts model parameter loading based on available GPU memory. The model can run with as little as 8GB of VRAM.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e5m2,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e5m2,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e5m2,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
use_two_stage_pipeline=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_twostage.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Examples</summary>
|
||||
|
||||
Example code for LTX-2 is available at: [/examples/ltx2/](/examples/ltx2/)
|
||||
|
||||
| Model ID | Extra Args | Inference | Low-VRAM Inference | Full Training | Full Training Validation | LoRA Training | LoRA Training Validation |
|
||||
|-|-|-|-|-|-|-|-|
|
||||
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|
||||
|
||||
</details>
|
||||
|
||||
#### Wan: [/docs/en/Model_Details/Wan.md](/docs/en/Model_Details/Wan.md)
|
||||
|
||||
<details>
|
||||
@@ -879,3 +769,4 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/b54c05c5-d747-47
|
||||
https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-4481-b79f-0c3a7361a1ea
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
116
README_zh.md
116
README_zh.md
@@ -33,16 +33,6 @@ DiffSynth 目前包括两个开源项目:
|
||||
|
||||
> 目前本项目的开发人员有限,大部分工作由 [Artiprocher](https://github.com/Artiprocher) 负责,因此新功能的开发进展会比较缓慢,issue 的回复和解决速度有限,我们对此感到非常抱歉,请各位开发者理解。
|
||||
|
||||
- **2026年2月2日** Research Tutorial 的第一篇文档上线,带你从零开始训练一个 0.1B 的小型文生图模型,详见[文档](/docs/zh/Research_Tutorial/train_from_scratch.md)、[模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel),我们希望 DiffSynth-Studio 能够成为一个更强大的 Diffusion 模型训练框架。
|
||||
|
||||
- **2026年1月27日** [Z-Image](https://modelscope.cn/models/Tongyi-MAI/Z-Image) 发布,我们的 [Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L) 模型同步发布,在[魔搭创空间](https://modelscope.cn/studios/DiffSynth-Studio/Z-Image-i2L)可直接体验,详见[文档](/docs/zh/Model_Details/Z-Image.md)。
|
||||
|
||||
- **2026年1月19日** 新增对 [FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 和 [FLUX.2-klein-9B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B) 模型的支持,包括完整的训练和推理功能。[文档](/docs/zh/Model_Details/FLUX2.md)和[示例代码](/examples/flux2/)现已可用。
|
||||
|
||||
- **2026年1月12日** 我们训练并开源了一个文本引导的图层拆分模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)),这一模型输入一张图与一段文本描述,模型会将图像中与文本描述相关的图层拆分出来。更多细节请阅读我们的 blog([中文版](https://modelscope.cn/learn/4938)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-layered-control))。
|
||||
|
||||
- **2025年12月24日** 我们基于 Qwen-Image-Edit-2511 训练了一个 In-Context Editing LoRA 模型([模型链接](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-2511-ICEdit-LoRA)),这个模型可以输入三张图:图A、图B、图C,模型会自行分析图A到图B的变化,并将这样的变化应用到图C,生成图D。更多细节请阅读我们的 blog([中文版](https://mp.weixin.qq.com/s/41aEiN3lXKGCJs1-we4Q2g)、[英文版](https://huggingface.co/blog/kelseye/qwen-image-edit-2511-icedit-lora))。
|
||||
|
||||
- **2025年12月9日** 我们基于 DiffSynth-Studio 2.0 训练了一个疯狂的模型:[Qwen-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-i2L)(Image to LoRA)。这一模型以图像为输入,以 LoRA 为输出。尽管这个版本的模型在泛化能力、细节保持能力等方面还有很大改进空间,我们将这些模型开源,以启发更多创新性的研究工作。更多细节,请参考我们的 [blog](https://huggingface.co/blog/kelseye/qwen-image-i2l)。
|
||||
|
||||
- **2025年12月4日** DiffSynth-Studio 2.0 发布!众多新功能上线
|
||||
@@ -275,12 +265,7 @@ Z-Image 的示例代码位于:[/examples/z_image/](/examples/z_image/)
|
||||
|
||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|
||||
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|
||||
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
||||
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|
||||
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|
||||
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
|
||||
|
||||
</details>
|
||||
|
||||
@@ -330,13 +315,9 @@ image.save("image.jpg")
|
||||
|
||||
FLUX.2 的示例代码位于:[/examples/flux2/](/examples/flux2/)
|
||||
|
||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
|
||||
|模型 ID|推理|低显存推理|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|
|
||||
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
||||
|
||||
</details>
|
||||
|
||||
@@ -419,9 +400,7 @@ Qwen-Image 的示例代码位于:[/examples/qwen_image/](/examples/qwen_image/
|
||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
||||
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|
||||
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
||||
@@ -532,95 +511,6 @@ FLUX.1 的示例代码位于:[/examples/flux/](/examples/flux/)
|
||||
|
||||
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||
|
||||
#### LTX-2: [/docs/zh/Model_Details/LTX-2.md](/docs/zh/Model_Details/LTX-2.md)
|
||||
|
||||
<details>
|
||||
|
||||
<summary>快速开始</summary>
|
||||
|
||||
运行以下代码可以快速加载 [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e5m2,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e5m2,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e5m2,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-spatial-upscaler-x2-1.0.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
stage2_lora_config=ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-distilled-lora-384.safetensors"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
|
||||
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
||||
negative_prompt = (
|
||||
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
|
||||
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
|
||||
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
|
||||
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
|
||||
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
|
||||
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
|
||||
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
|
||||
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
|
||||
"off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
|
||||
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
|
||||
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
)
|
||||
height, width, num_frames = 512 * 2, 768 * 2, 121
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
use_two_stage_pipeline=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_twostage.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>示例代码</summary>
|
||||
|
||||
LTX-2 的示例代码位于:[/examples/ltx2/](/examples/ltx2/)
|
||||
|
||||
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|-|
|
||||
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|
||||
|
||||
</details>
|
||||
|
||||
#### Wan: [/docs/zh/Model_Details/Wan.md](/docs/zh/Model_Details/Wan.md)
|
||||
|
||||
<details>
|
||||
|
||||
@@ -317,13 +317,6 @@ flux_series = [
|
||||
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Supported due to historical reasons.
|
||||
"model_hash": "605c56eab23e9e2af863ad8f0813a25d",
|
||||
"model_name": "flux_dit",
|
||||
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverterFromDiffusers",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors")
|
||||
"model_hash": "94eefa3dac9cec93cb1ebaf1747d7b78",
|
||||
@@ -481,13 +474,6 @@ flux_series = [
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||
"extra_kwargs": {"disable_guidance_embedder": True},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="MAILAND/majicflus_v1", origin_file_pattern="majicflus_v134.safetensors")
|
||||
"model_hash": "3394f306c4cbf04334b712bf5aaed95f",
|
||||
"model_name": "flux_dit",
|
||||
"model_class": "diffsynth.models.flux_dit.FluxDiT",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.flux_dit.FluxDiTStateDictConverter",
|
||||
},
|
||||
]
|
||||
|
||||
flux2_series = [
|
||||
@@ -510,28 +496,6 @@ flux2_series = [
|
||||
"model_name": "flux2_vae",
|
||||
"model_class": "diffsynth.models.flux2_vae.Flux2VAE",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors")
|
||||
"model_hash": "3bde7b817fec8143028b6825a63180df",
|
||||
"model_name": "flux2_dit",
|
||||
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
||||
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 7680, "num_attention_heads": 24, "num_layers": 5, "num_single_layers": 20}
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors")
|
||||
"model_hash": "9195f3ea256fcd0ae6d929c203470754",
|
||||
"model_name": "z_image_text_encoder",
|
||||
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||
"extra_kwargs": {"model_size": "8B"},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors")
|
||||
"model_hash": "39c6fc48f07bebecedbbaa971ff466c8",
|
||||
"model_name": "flux2_dit",
|
||||
"model_class": "diffsynth.models.flux2_dit.Flux2DiT",
|
||||
"extra_kwargs": {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24}
|
||||
},
|
||||
]
|
||||
|
||||
z_image_series = [
|
||||
@@ -576,91 +540,6 @@ z_image_series = [
|
||||
"model_name": "siglip_vision_model_428m",
|
||||
"model_class": "diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1", origin_file_pattern="Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.safetensors")
|
||||
"model_hash": "1677708d40029ab380a95f6c731a57d7",
|
||||
"model_name": "z_image_controlnet",
|
||||
"model_class": "diffsynth.models.z_image_controlnet.ZImageControlNet",
|
||||
},
|
||||
{
|
||||
# Example: ???
|
||||
"model_hash": "9510cb8cd1dd34ee0e4f111c24905510",
|
||||
"model_name": "z_image_image2lora_style",
|
||||
"model_class": "diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel",
|
||||
"extra_kwargs": {"compress_dim": 128},
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors")
|
||||
"model_hash": "1392adecee344136041e70553f875f31",
|
||||
"model_name": "z_image_text_encoder",
|
||||
"model_class": "diffsynth.models.z_image_text_encoder.ZImageTextEncoder",
|
||||
"extra_kwargs": {"model_size": "0.6B"},
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.z_image_text_encoder.ZImageTextEncoderStateDictConverter",
|
||||
},
|
||||
]
|
||||
|
||||
ltx2_series = [
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_dit",
|
||||
"model_class": "diffsynth.models.ltx2_dit.LTXModel",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_dit.LTXModelStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_video_vae_encoder",
|
||||
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_video_vae_decoder",
|
||||
"model_class": "diffsynth.models.ltx2_video_vae.LTX2VideoDecoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_video_vae.LTX2VideoDecoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_audio_vae_decoder",
|
||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioDecoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_audio_vocoder",
|
||||
"model_class": "diffsynth.models.ltx2_audio_vae.LTX2Vocoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2VocoderStateDictConverter",
|
||||
},
|
||||
# { # not used currently
|
||||
# # Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
# "model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
# "model_name": "ltx2_audio_vae_encoder",
|
||||
# "model_class": "diffsynth.models.ltx2_audio_vae.LTX2AudioEncoder",
|
||||
# "state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_audio_vae.LTX2AudioEncoderStateDictConverter",
|
||||
# },
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "aca7b0bbf8415e9c98360750268915fc",
|
||||
"model_name": "ltx2_text_encoder_post_modules",
|
||||
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderPostModulesStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors")
|
||||
"model_hash": "33917f31c4a79196171154cca39f165e",
|
||||
"model_name": "ltx2_text_encoder",
|
||||
"model_class": "diffsynth.models.ltx2_text_encoder.LTX2TextEncoder",
|
||||
"state_dict_converter": "diffsynth.utils.state_dict_converters.ltx2_text_encoder.LTX2TextEncoderStateDictConverter",
|
||||
},
|
||||
{
|
||||
# Example: ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors")
|
||||
"model_hash": "c79c458c6e99e0e14d47e676761732d2",
|
||||
"model_name": "ltx2_latent_upsampler",
|
||||
"model_class": "diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler",
|
||||
},
|
||||
]
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series + ltx2_series
|
||||
MODEL_CONFIGS = qwen_image_series + wan_series + flux_series + flux2_series + z_image_series
|
||||
|
||||
@@ -195,52 +195,4 @@ VRAM_MANAGEMENT_MODULE_MAPS = {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.z_image_controlnet.ZImageControlNet": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"diffsynth.models.z_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.z_image_image2lora.ZImageImage2LoRAModel": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
},
|
||||
"diffsynth.models.siglip2_image_encoder.Siglip2ImageEncoder428M": {
|
||||
"transformers.models.siglip2.modeling_siglip2.Siglip2VisionEmbeddings": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.siglip2.modeling_siglip2.Siglip2MultiheadAttentionPoolingHead": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Embedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.LayerNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
},
|
||||
"diffsynth.models.ltx2_dit.LTXModel": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_upsampler.LTX2LatentUpsampler": {
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.GroupNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_video_vae.LTX2VideoEncoder": {
|
||||
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_video_vae.LTX2VideoDecoder": {
|
||||
"torch.nn.Conv3d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_audio_vae.LTX2AudioDecoder": {
|
||||
"torch.nn.Conv2d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_audio_vae.LTX2Vocoder": {
|
||||
"torch.nn.Conv1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.ConvTranspose1d": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoderPostModules": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"torch.nn.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"diffsynth.models.ltx2_text_encoder.Embeddings1DConnector": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
"diffsynth.models.ltx2_text_encoder.LTX2TextEncoder": {
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
"transformers.models.gemma3.modeling_gemma3.Gemma3MultiModalProjector": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.gemma3.modeling_gemma3.Gemma3RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"transformers.models.gemma3.modeling_gemma3.Gemma3TextScaledWordEmbedding": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ def rearrange_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, q_pattern="
|
||||
if k_pattern != required_in_pattern:
|
||||
k = rearrange(k, f"{k_pattern} -> {required_in_pattern}", **dims)
|
||||
if v_pattern != required_in_pattern:
|
||||
v = rearrange(v, f"{v_pattern} -> {required_in_pattern}", **dims)
|
||||
v = rearrange(v, f"{q_pattern} -> {required_in_pattern}", **dims)
|
||||
return q, k, v
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ class UnifiedDataset(torch.utils.data.Dataset):
|
||||
data_file_keys=tuple(),
|
||||
main_data_operator=lambda x: x,
|
||||
special_operator_map=None,
|
||||
max_data_items=None,
|
||||
):
|
||||
self.base_path = base_path
|
||||
self.metadata_path = metadata_path
|
||||
@@ -19,7 +18,6 @@ class UnifiedDataset(torch.utils.data.Dataset):
|
||||
self.main_data_operator = main_data_operator
|
||||
self.cached_data_operator = LoadTorchPickle()
|
||||
self.special_operator_map = {} if special_operator_map is None else special_operator_map
|
||||
self.max_data_items = max_data_items
|
||||
self.data = []
|
||||
self.cached_data = []
|
||||
self.load_from_cache = metadata_path is None
|
||||
@@ -99,9 +97,7 @@ class UnifiedDataset(torch.utils.data.Dataset):
|
||||
return data
|
||||
|
||||
def __len__(self):
|
||||
if self.max_data_items is not None:
|
||||
return self.max_data_items
|
||||
elif self.load_from_cache:
|
||||
if self.load_from_cache:
|
||||
return len(self.cached_data) * self.repeat
|
||||
else:
|
||||
return len(self.data) * self.repeat
|
||||
|
||||
@@ -1,2 +1 @@
|
||||
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type, get_device_name
|
||||
from .npu_compatible_device import IS_NPU_AVAILABLE, IS_CUDA_AVAILABLE
|
||||
from .npu_compatible_device import parse_device_type, parse_nccl_backend, get_available_device_type
|
||||
@@ -1,5 +1,5 @@
|
||||
import torch, glob, os
|
||||
from typing import Optional, Union, Dict
|
||||
from typing import Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from modelscope import snapshot_download
|
||||
from huggingface_hub import snapshot_download as hf_snapshot_download
|
||||
@@ -23,14 +23,13 @@ class ModelConfig:
|
||||
computation_device: Optional[Union[str, torch.device]] = None
|
||||
computation_dtype: Optional[torch.dtype] = None
|
||||
clear_parameters: bool = False
|
||||
state_dict: Dict[str, torch.Tensor] = None
|
||||
|
||||
def check_input(self):
|
||||
if self.path is None and self.model_id is None:
|
||||
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
|
||||
|
||||
def parse_original_file_pattern(self):
|
||||
if self.origin_file_pattern in [None, "", "./"]:
|
||||
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
||||
return "*"
|
||||
elif self.origin_file_pattern.endswith("/"):
|
||||
return self.origin_file_pattern + "*"
|
||||
@@ -98,8 +97,7 @@ class ModelConfig:
|
||||
self.reset_local_model_path()
|
||||
if self.require_downloading():
|
||||
self.download()
|
||||
if self.path is None:
|
||||
if self.origin_file_pattern in [None, "", "./"]:
|
||||
if self.origin_file_pattern is None or self.origin_file_pattern == "":
|
||||
self.path = os.path.join(self.local_model_path, self.model_id)
|
||||
else:
|
||||
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
|
||||
|
||||
@@ -2,25 +2,16 @@ from safetensors import safe_open
|
||||
import torch, hashlib
|
||||
|
||||
|
||||
def load_state_dict(file_path, torch_dtype=None, device="cpu", pin_memory=False, verbose=0):
|
||||
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
|
||||
if isinstance(file_path, list):
|
||||
state_dict = {}
|
||||
for file_path_ in file_path:
|
||||
state_dict.update(load_state_dict(file_path_, torch_dtype, device, pin_memory=pin_memory, verbose=verbose))
|
||||
state_dict.update(load_state_dict(file_path_, torch_dtype, device))
|
||||
return state_dict
|
||||
if file_path.endswith(".safetensors"):
|
||||
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||
else:
|
||||
if verbose >= 1:
|
||||
print(f"Loading file [started]: {file_path}")
|
||||
if file_path.endswith(".safetensors"):
|
||||
state_dict = load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
|
||||
else:
|
||||
state_dict = load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||
# If load state dict in CPU memory, `pin_memory=True` will make `model.to("cuda")` faster.
|
||||
if pin_memory:
|
||||
for i in state_dict:
|
||||
state_dict[i] = state_dict[i].pin_memory()
|
||||
if verbose >= 1:
|
||||
print(f"Loading file [done]: {file_path}")
|
||||
return state_dict
|
||||
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
|
||||
|
||||
|
||||
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
|
||||
|
||||
@@ -5,7 +5,7 @@ from .file import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None, state_dict=None):
|
||||
def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, device="cpu", state_dict_converter=None, use_disk_map=False, module_map=None, vram_config=None, vram_limit=None):
|
||||
config = {} if config is None else config
|
||||
# Why do we use `skip_model_initialization`?
|
||||
# It skips the random initialization of model parameters,
|
||||
@@ -20,7 +20,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
||||
dtypes = [vram_config["offload_dtype"], vram_config["onload_dtype"], vram_config["preparing_dtype"], vram_config["computation_dtype"]]
|
||||
dtype = [d for d in dtypes if d != "disk"][0]
|
||||
if vram_config["offload_device"] != "disk":
|
||||
if state_dict is None: state_dict = DiskMap(path, device, torch_dtype=dtype)
|
||||
state_dict = DiskMap(path, device, torch_dtype=dtype)
|
||||
if state_dict_converter is not None:
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
else:
|
||||
@@ -35,9 +35,7 @@ def load_model(model_class, path, config=None, torch_dtype=torch.bfloat16, devic
|
||||
# Sometimes a model file contains multiple models,
|
||||
# and DiskMap can load only the parameters of a single model,
|
||||
# avoiding the need to load all parameters in the file.
|
||||
if state_dict is not None:
|
||||
pass
|
||||
elif use_disk_map:
|
||||
if use_disk_map:
|
||||
state_dict = DiskMap(path, device, torch_dtype=torch_dtype)
|
||||
else:
|
||||
state_dict = load_state_dict(path, torch_dtype, device)
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch, copy
|
||||
from typing import Union
|
||||
from .initialization import skip_model_initialization
|
||||
from .disk_map import DiskMap
|
||||
from ..device import parse_device_type, get_device_name, IS_NPU_AVAILABLE
|
||||
from ..device import parse_device_type
|
||||
|
||||
|
||||
class AutoTorchModule(torch.nn.Module):
|
||||
@@ -63,7 +63,7 @@ class AutoTorchModule(torch.nn.Module):
|
||||
return r
|
||||
|
||||
def check_free_vram(self):
|
||||
device = self.computation_device if not IS_NPU_AVAILABLE else get_device_name()
|
||||
device = self.computation_device if self.computation_device != "npu" else "npu:0"
|
||||
gpu_mem_state = getattr(torch, self.computation_device_type).mem_get_info(device)
|
||||
used_memory = (gpu_mem_state[1] - gpu_mem_state[0]) / (1024**3)
|
||||
return used_memory < self.vram_limit
|
||||
|
||||
@@ -4,11 +4,9 @@ import numpy as np
|
||||
from einops import repeat, reduce
|
||||
from typing import Union
|
||||
from ..core import AutoTorchModule, AutoWrappedLinear, load_state_dict, ModelConfig, parse_device_type
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..utils.lora import GeneralLoRALoader
|
||||
from ..models.model_loader import ModelPool
|
||||
from ..utils.controlnet import ControlNetInput
|
||||
from ..core.device import get_device_name, IS_NPU_AVAILABLE
|
||||
|
||||
|
||||
class PipelineUnit:
|
||||
@@ -62,7 +60,7 @@ class BasePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device=get_device_type(), torch_dtype=torch.float16,
|
||||
device="cuda", torch_dtype=torch.float16,
|
||||
height_division_factor=64, width_division_factor=64,
|
||||
time_division_factor=None, time_division_remainder=None,
|
||||
):
|
||||
@@ -179,7 +177,7 @@ class BasePipeline(torch.nn.Module):
|
||||
|
||||
|
||||
def get_vram(self):
|
||||
device = self.device if not IS_NPU_AVAILABLE else get_device_name()
|
||||
device = self.device if self.device != "npu" else "npu:0"
|
||||
return getattr(torch, self.device_type).mem_get_info(device)[1] / (1024 ** 3)
|
||||
|
||||
def get_module(self, model, name):
|
||||
@@ -237,7 +235,6 @@ class BasePipeline(torch.nn.Module):
|
||||
alpha=1,
|
||||
hotload=None,
|
||||
state_dict=None,
|
||||
verbose=1,
|
||||
):
|
||||
if state_dict is None:
|
||||
if isinstance(lora_config, str):
|
||||
@@ -264,13 +261,12 @@ class BasePipeline(torch.nn.Module):
|
||||
updated_num += 1
|
||||
module.lora_A_weights.append(lora[lora_a_name] * alpha)
|
||||
module.lora_B_weights.append(lora[lora_b_name])
|
||||
if verbose >= 1:
|
||||
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
|
||||
print(f"{updated_num} tensors are patched by LoRA. You can use `pipe.clear_lora()` to clear all LoRA layers.")
|
||||
else:
|
||||
lora_loader.fuse_lora_to_base_model(module, lora, alpha=alpha)
|
||||
|
||||
|
||||
def clear_lora(self, verbose=1):
|
||||
def clear_lora(self):
|
||||
cleared_num = 0
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, AutoWrappedLinear):
|
||||
@@ -280,8 +276,7 @@ class BasePipeline(torch.nn.Module):
|
||||
module.lora_A_weights.clear()
|
||||
if hasattr(module, "lora_B_weights"):
|
||||
module.lora_B_weights.clear()
|
||||
if verbose >= 1:
|
||||
print(f"{cleared_num} LoRA layers are cleared.")
|
||||
print(f"{cleared_num} LoRA layers are cleared.")
|
||||
|
||||
|
||||
def download_and_load_models(self, model_configs: list[ModelConfig] = [], vram_limit: float = None):
|
||||
@@ -296,7 +291,6 @@ class BasePipeline(torch.nn.Module):
|
||||
vram_config=vram_config,
|
||||
vram_limit=vram_limit,
|
||||
clear_parameters=model_config.clear_parameters,
|
||||
state_dict=model_config.state_dict,
|
||||
)
|
||||
return model_pool
|
||||
|
||||
@@ -310,22 +304,10 @@ class BasePipeline(torch.nn.Module):
|
||||
|
||||
|
||||
def cfg_guided_model_fn(self, model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega, **inputs_others):
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
self.load_lora(self.dit, state_dict=inputs_shared["positive_only_lora"], verbose=0)
|
||||
noise_pred_posi = model_fn(**inputs_posi, **inputs_shared, **inputs_others)
|
||||
if cfg_scale != 1.0:
|
||||
if inputs_shared.get("positive_only_lora", None) is not None:
|
||||
self.clear_lora(verbose=0)
|
||||
noise_pred_nega = model_fn(**inputs_nega, **inputs_shared, **inputs_others)
|
||||
if isinstance(noise_pred_posi, tuple):
|
||||
# Separately handling different output types of latents, eg. video and audio latents.
|
||||
noise_pred = tuple(
|
||||
n_nega + cfg_scale * (n_posi - n_nega)
|
||||
for n_posi, n_nega in zip(noise_pred_posi, noise_pred_nega)
|
||||
)
|
||||
else:
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
return noise_pred
|
||||
|
||||
@@ -4,15 +4,13 @@ from typing_extensions import Literal
|
||||
|
||||
class FlowMatchScheduler():
|
||||
|
||||
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"):
|
||||
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image"] = "FLUX.1"):
|
||||
self.set_timesteps_fn = {
|
||||
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
|
||||
"Wan": FlowMatchScheduler.set_timesteps_wan,
|
||||
"Qwen-Image": FlowMatchScheduler.set_timesteps_qwen_image,
|
||||
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
|
||||
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
||||
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
|
||||
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
|
||||
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
||||
self.num_train_timesteps = 1000
|
||||
|
||||
@@ -72,28 +70,6 @@ class FlowMatchScheduler():
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_qwen_image_lightning(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
|
||||
sigma_min = 0.0
|
||||
sigma_max = 1.0
|
||||
num_train_timesteps = 1000
|
||||
base_shift = math.log(3)
|
||||
max_shift = math.log(3)
|
||||
# Sigmas
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||
# Mu
|
||||
if exponential_shift_mu is not None:
|
||||
mu = exponential_shift_mu
|
||||
elif dynamic_shift_len is not None:
|
||||
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len, base_shift=base_shift, max_shift=max_shift)
|
||||
else:
|
||||
mu = 0.8
|
||||
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
||||
# Timesteps
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def compute_empirical_mu(image_seq_len, num_steps):
|
||||
a1, b1 = 8.73809524e-05, 1.89833333
|
||||
@@ -113,18 +89,13 @@ class FlowMatchScheduler():
|
||||
return float(mu)
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None):
|
||||
def set_timesteps_flux2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=1024//16*1024//16):
|
||||
sigma_min = 1 / num_inference_steps
|
||||
sigma_max = 1.0
|
||||
num_train_timesteps = 1000
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps)
|
||||
if dynamic_shift_len is None:
|
||||
# If you ask me why I set mu=0.8,
|
||||
# I can only say that it yields better training results.
|
||||
mu = 0.8
|
||||
else:
|
||||
mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps)
|
||||
mu = FlowMatchScheduler.compute_empirical_mu(dynamic_shift_len, num_inference_steps)
|
||||
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
@@ -145,35 +116,7 @@ class FlowMatchScheduler():
|
||||
timestep_id = torch.argmin((timesteps - timestep).abs())
|
||||
timesteps[timestep_id] = timestep
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_ltx2(num_inference_steps=100, denoising_strength=1.0, dynamic_shift_len=None, terminal=0.1, special_case=None):
|
||||
num_train_timesteps = 1000
|
||||
if special_case == "stage2":
|
||||
sigmas = torch.Tensor([0.909375, 0.725, 0.421875])
|
||||
elif special_case == "ditilled_stage1":
|
||||
sigmas = torch.Tensor([1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875])
|
||||
else:
|
||||
dynamic_shift_len = dynamic_shift_len or 4096
|
||||
sigma_shift = FlowMatchScheduler._calculate_shift_qwen_image(
|
||||
image_seq_len=dynamic_shift_len,
|
||||
base_seq_len=1024,
|
||||
max_seq_len=4096,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
)
|
||||
sigma_min = 0.0
|
||||
sigma_max = 1.0
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||
sigmas = math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1))
|
||||
# Shift terminal
|
||||
one_minus_z = 1.0 - sigmas
|
||||
scale_factor = one_minus_z[-1] / (1 - terminal)
|
||||
sigmas = 1.0 - (one_minus_z / scale_factor)
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
|
||||
def set_training_weight(self):
|
||||
steps = 1000
|
||||
x = self.timesteps
|
||||
|
||||
@@ -10,7 +10,7 @@ class ModelLogger:
|
||||
self.num_steps = 0
|
||||
|
||||
|
||||
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None, **kwargs):
|
||||
def on_step_end(self, accelerator: Accelerator, model: torch.nn.Module, save_steps=None):
|
||||
self.num_steps += 1
|
||||
if save_steps is not None and self.num_steps % save_steps == 0:
|
||||
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
|
||||
|
||||
@@ -13,16 +13,9 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
||||
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||
|
||||
if "first_frame_latents" in inputs:
|
||||
inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"]
|
||||
|
||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
|
||||
|
||||
if "first_frame_latents" in inputs:
|
||||
noise_pred = noise_pred[:, :, 1:]
|
||||
training_target = training_target[:, :, 1:]
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * pipe.scheduler.training_weight(timestep)
|
||||
return loss
|
||||
|
||||
@@ -40,7 +40,7 @@ def launch_training_task(
|
||||
loss = model(data)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
model_logger.on_step_end(accelerator, model, save_steps, loss=loss)
|
||||
model_logger.on_step_end(accelerator, model, save_steps)
|
||||
scheduler.step()
|
||||
if save_steps is None:
|
||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import torch, json, os
|
||||
import torch, json
|
||||
from ..core import ModelConfig, load_state_dict
|
||||
from ..utils.controlnet import ControlNetInput
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
@@ -127,67 +127,16 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
if model_id_with_origin_paths is not None:
|
||||
model_id_with_origin_paths = model_id_with_origin_paths.split(",")
|
||||
for model_id_with_origin_path in model_id_with_origin_paths:
|
||||
model_id, origin_file_pattern = model_id_with_origin_path.split(":")
|
||||
vram_config = self.parse_vram_config(
|
||||
fp8=model_id_with_origin_path in fp8_models,
|
||||
offload=model_id_with_origin_path in offload_models,
|
||||
device=device
|
||||
)
|
||||
config = self.parse_path_or_model_id(model_id_with_origin_path)
|
||||
model_configs.append(ModelConfig(model_id=config.model_id, origin_file_pattern=config.origin_file_pattern, **vram_config))
|
||||
model_configs.append(ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern, **vram_config))
|
||||
return model_configs
|
||||
|
||||
|
||||
def parse_path_or_model_id(self, model_id_with_origin_path, default_value=None):
|
||||
if model_id_with_origin_path is None:
|
||||
return default_value
|
||||
elif os.path.exists(model_id_with_origin_path):
|
||||
return ModelConfig(path=model_id_with_origin_path)
|
||||
else:
|
||||
if ":" not in model_id_with_origin_path:
|
||||
raise ValueError(f"Failed to parse model config: {model_id_with_origin_path}. This is neither a valid path nor in the format of `model_id/origin_file_pattern`.")
|
||||
split_id = model_id_with_origin_path.rfind(":")
|
||||
model_id = model_id_with_origin_path[:split_id]
|
||||
origin_file_pattern = model_id_with_origin_path[split_id + 1:]
|
||||
return ModelConfig(model_id=model_id, origin_file_pattern=origin_file_pattern)
|
||||
|
||||
|
||||
def auto_detect_lora_target_modules(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
search_for_linear=False,
|
||||
linear_detector=lambda x: min(x.weight.shape) >= 512,
|
||||
block_list_detector=lambda x: isinstance(x, torch.nn.ModuleList) and len(x) > 1,
|
||||
name_prefix="",
|
||||
):
|
||||
lora_target_modules = []
|
||||
if search_for_linear:
|
||||
for name, module in model.named_modules():
|
||||
module_name = name_prefix + ["", "."][name_prefix != ""] + name
|
||||
if isinstance(module, torch.nn.Linear) and linear_detector(module):
|
||||
lora_target_modules.append(module_name)
|
||||
else:
|
||||
for name, module in model.named_children():
|
||||
module_name = name_prefix + ["", "."][name_prefix != ""] + name
|
||||
lora_target_modules += self.auto_detect_lora_target_modules(
|
||||
module,
|
||||
search_for_linear=block_list_detector(module),
|
||||
linear_detector=linear_detector,
|
||||
block_list_detector=block_list_detector,
|
||||
name_prefix=module_name,
|
||||
)
|
||||
return lora_target_modules
|
||||
|
||||
|
||||
def parse_lora_target_modules(self, model, lora_target_modules):
|
||||
if lora_target_modules == "":
|
||||
print("No LoRA target modules specified. The framework will automatically search for them.")
|
||||
lora_target_modules = self.auto_detect_lora_target_modules(model)
|
||||
print(f"LoRA will be patched at {lora_target_modules}.")
|
||||
else:
|
||||
lora_target_modules = lora_target_modules.split(",")
|
||||
return lora_target_modules
|
||||
|
||||
|
||||
def switch_pipe_to_training_mode(
|
||||
self,
|
||||
pipe,
|
||||
@@ -217,7 +166,7 @@ class DiffusionTrainingModule(torch.nn.Module):
|
||||
return
|
||||
model = self.add_lora_to_model(
|
||||
getattr(pipe, lora_base_model),
|
||||
target_modules=self.parse_lora_target_modules(getattr(pipe, lora_base_model), lora_target_modules),
|
||||
target_modules=lora_target_modules.split(","),
|
||||
lora_rank=lora_rank,
|
||||
upcast_dtype=pipe.torch_dtype,
|
||||
)
|
||||
|
||||
@@ -2,8 +2,6 @@ from transformers import DINOv3ViTModel, DINOv3ViTImageProcessorFast
|
||||
from transformers.models.dinov3_vit.modeling_dinov3_vit import DINOv3ViTConfig
|
||||
import torch
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
|
||||
|
||||
class DINOv3ImageEncoder(DINOv3ViTModel):
|
||||
def __init__(self):
|
||||
@@ -72,7 +70,7 @@ class DINOv3ImageEncoder(DINOv3ViTModel):
|
||||
}
|
||||
)
|
||||
|
||||
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
|
||||
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
|
||||
inputs = self.processor(images=image, return_tensors="pt")
|
||||
pixel_values = inputs["pixel_values"].to(dtype=torch_dtype, device=device)
|
||||
bool_masked_pos = None
|
||||
|
||||
@@ -823,13 +823,7 @@ class Flux2PosEmbed(nn.Module):
|
||||
|
||||
|
||||
class Flux2TimestepGuidanceEmbeddings(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 256,
|
||||
embedding_dim: int = 6144,
|
||||
bias: bool = False,
|
||||
guidance_embeds: bool = True,
|
||||
):
|
||||
def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
@@ -837,24 +831,20 @@ class Flux2TimestepGuidanceEmbeddings(nn.Module):
|
||||
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
||||
)
|
||||
|
||||
if guidance_embeds:
|
||||
self.guidance_embedder = TimestepEmbedding(
|
||||
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
||||
)
|
||||
else:
|
||||
self.guidance_embedder = None
|
||||
self.guidance_embedder = TimestepEmbedding(
|
||||
in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
|
||||
)
|
||||
|
||||
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
|
||||
|
||||
if guidance is not None and self.guidance_embedder is not None:
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
|
||||
time_guidance_emb = timesteps_emb + guidance_emb
|
||||
return time_guidance_emb
|
||||
else:
|
||||
return timesteps_emb
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
|
||||
|
||||
time_guidance_emb = timesteps_emb + guidance_emb
|
||||
|
||||
return time_guidance_emb
|
||||
|
||||
|
||||
class Flux2Modulation(nn.Module):
|
||||
@@ -892,7 +882,6 @@ class Flux2DiT(torch.nn.Module):
|
||||
axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
|
||||
rope_theta: int = 2000,
|
||||
eps: float = 1e-6,
|
||||
guidance_embeds: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = out_channels or in_channels
|
||||
@@ -903,10 +892,7 @@ class Flux2DiT(torch.nn.Module):
|
||||
|
||||
# 2. Combined timestep + guidance embedding
|
||||
self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
|
||||
in_channels=timestep_guidance_channels,
|
||||
embedding_dim=self.inner_dim,
|
||||
bias=False,
|
||||
guidance_embeds=guidance_embeds,
|
||||
in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
|
||||
)
|
||||
|
||||
# 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
|
||||
@@ -967,9 +953,34 @@ class Flux2DiT(torch.nn.Module):
|
||||
txt_ids: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
) -> Union[torch.Tensor]:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# 0. Handle input arguments
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
@@ -981,9 +992,7 @@ class Flux2DiT(torch.nn.Module):
|
||||
|
||||
# 1. Calculate timestep embedding and modulation parameters
|
||||
timestep = timestep.to(hidden_states.dtype) * 1000
|
||||
|
||||
if guidance is not None:
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
guidance = guidance.to(hidden_states.dtype) * 1000
|
||||
|
||||
temb = self.time_guidance_embed(timestep, guidance)
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from .wan_video_dit import flash_attention
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
|
||||
@@ -374,7 +373,7 @@ class FinalLayer_FP32(nn.Module):
|
||||
B, N, C = x.shape
|
||||
T, _, _ = latent_shape
|
||||
|
||||
with amp.autocast(get_device_type(), dtype=torch.float32):
|
||||
with amp.autocast('cuda', dtype=torch.float32):
|
||||
shift, scale = self.adaLN_modulation(t).unsqueeze(2).chunk(2, dim=-1) # [B, T, 1, C]
|
||||
x = modulate_fp32(self.norm_final, x.view(B, T, -1, C), shift, scale).view(B, N, C)
|
||||
x = self.linear(x)
|
||||
@@ -584,7 +583,7 @@ class LongCatSingleStreamBlock(nn.Module):
|
||||
T, _, _ = latent_shape # S != T*H*W in case of CP split on H*W.
|
||||
|
||||
# compute modulation params in fp32
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
||||
shift_msa, scale_msa, gate_msa, \
|
||||
shift_mlp, scale_mlp, gate_mlp = \
|
||||
self.adaLN_modulation(t).unsqueeze(2).chunk(6, dim=-1) # [B, T, 1, C]
|
||||
@@ -603,7 +602,7 @@ class LongCatSingleStreamBlock(nn.Module):
|
||||
else:
|
||||
x_s = attn_outputs
|
||||
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
||||
x = x + (gate_msa * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||
x = x.to(x_dtype)
|
||||
|
||||
@@ -616,7 +615,7 @@ class LongCatSingleStreamBlock(nn.Module):
|
||||
# ffn with modulation
|
||||
x_m = modulate_fp32(self.mod_norm_ffn, x.view(B, -1, N//T, C), shift_mlp, scale_mlp).view(B, -1, C)
|
||||
x_s = self.ffn(x_m)
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
||||
x = x + (gate_mlp * x_s.view(B, -1, N//T, C)).view(B, -1, C) # [B, N, C]
|
||||
x = x.to(x_dtype)
|
||||
|
||||
@@ -798,7 +797,7 @@ class LongCatVideoTransformer3DModel(torch.nn.Module):
|
||||
|
||||
hidden_states = self.x_embedder(hidden_states) # [B, N, C]
|
||||
|
||||
with amp.autocast(device_type=get_device_type(), dtype=torch.float32):
|
||||
with amp.autocast(device_type='cuda', dtype=torch.float32):
|
||||
t = self.t_embedder(timestep.float().flatten(), dtype=torch.float32).reshape(B, N_t, -1) # [B, T, C_t]
|
||||
|
||||
encoder_hidden_states = self.y_embedder(encoder_hidden_states) # [B, 1, N_token, C]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,371 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import NamedTuple, Protocol, Tuple
|
||||
import torch
|
||||
from torch import nn
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class VideoPixelShape(NamedTuple):
|
||||
"""
|
||||
Shape of the tensor representing the video pixel array. Assumes BGR channel format.
|
||||
"""
|
||||
|
||||
batch: int
|
||||
frames: int
|
||||
height: int
|
||||
width: int
|
||||
fps: float
|
||||
|
||||
|
||||
class SpatioTemporalScaleFactors(NamedTuple):
|
||||
"""
|
||||
Describes the spatiotemporal downscaling between decoded video space and
|
||||
the corresponding VAE latent grid.
|
||||
"""
|
||||
|
||||
time: int
|
||||
width: int
|
||||
height: int
|
||||
|
||||
@classmethod
|
||||
def default(cls) -> "SpatioTemporalScaleFactors":
|
||||
return cls(time=8, width=32, height=32)
|
||||
|
||||
|
||||
VIDEO_SCALE_FACTORS = SpatioTemporalScaleFactors.default()
|
||||
|
||||
|
||||
class VideoLatentShape(NamedTuple):
|
||||
"""
|
||||
Shape of the tensor representing video in VAE latent space.
|
||||
The latent representation is a 5D tensor with dimensions ordered as
|
||||
(batch, channels, frames, height, width). Spatial and temporal dimensions
|
||||
are downscaled relative to pixel space according to the VAE's scale factors.
|
||||
"""
|
||||
|
||||
batch: int
|
||||
channels: int
|
||||
frames: int
|
||||
height: int
|
||||
width: int
|
||||
|
||||
def to_torch_shape(self) -> torch.Size:
|
||||
return torch.Size([self.batch, self.channels, self.frames, self.height, self.width])
|
||||
|
||||
@staticmethod
|
||||
def from_torch_shape(shape: torch.Size) -> "VideoLatentShape":
|
||||
return VideoLatentShape(
|
||||
batch=shape[0],
|
||||
channels=shape[1],
|
||||
frames=shape[2],
|
||||
height=shape[3],
|
||||
width=shape[4],
|
||||
)
|
||||
|
||||
def mask_shape(self) -> "VideoLatentShape":
|
||||
return self._replace(channels=1)
|
||||
|
||||
@staticmethod
|
||||
def from_pixel_shape(
|
||||
shape: VideoPixelShape,
|
||||
latent_channels: int = 128,
|
||||
scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS,
|
||||
) -> "VideoLatentShape":
|
||||
frames = (shape.frames - 1) // scale_factors[0] + 1
|
||||
height = shape.height // scale_factors[1]
|
||||
width = shape.width // scale_factors[2]
|
||||
|
||||
return VideoLatentShape(
|
||||
batch=shape.batch,
|
||||
channels=latent_channels,
|
||||
frames=frames,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
|
||||
def upscale(self, scale_factors: SpatioTemporalScaleFactors = VIDEO_SCALE_FACTORS) -> "VideoLatentShape":
|
||||
return self._replace(
|
||||
channels=3,
|
||||
frames=(self.frames - 1) * scale_factors.time + 1,
|
||||
height=self.height * scale_factors.height,
|
||||
width=self.width * scale_factors.width,
|
||||
)
|
||||
|
||||
|
||||
class AudioLatentShape(NamedTuple):
|
||||
"""
|
||||
Shape of audio in VAE latent space: (batch, channels, frames, mel_bins).
|
||||
mel_bins is the number of frequency bins from the mel-spectrogram encoding.
|
||||
"""
|
||||
|
||||
batch: int
|
||||
channels: int
|
||||
frames: int
|
||||
mel_bins: int
|
||||
|
||||
def to_torch_shape(self) -> torch.Size:
|
||||
return torch.Size([self.batch, self.channels, self.frames, self.mel_bins])
|
||||
|
||||
def mask_shape(self) -> "AudioLatentShape":
|
||||
return self._replace(channels=1, mel_bins=1)
|
||||
|
||||
@staticmethod
|
||||
def from_torch_shape(shape: torch.Size) -> "AudioLatentShape":
|
||||
return AudioLatentShape(
|
||||
batch=shape[0],
|
||||
channels=shape[1],
|
||||
frames=shape[2],
|
||||
mel_bins=shape[3],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_duration(
|
||||
batch: int,
|
||||
duration: float,
|
||||
channels: int = 8,
|
||||
mel_bins: int = 16,
|
||||
sample_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
audio_latent_downsample_factor: int = 4,
|
||||
) -> "AudioLatentShape":
|
||||
latents_per_second = float(sample_rate) / float(hop_length) / float(audio_latent_downsample_factor)
|
||||
|
||||
return AudioLatentShape(
|
||||
batch=batch,
|
||||
channels=channels,
|
||||
frames=round(duration * latents_per_second),
|
||||
mel_bins=mel_bins,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_video_pixel_shape(
|
||||
shape: VideoPixelShape,
|
||||
channels: int = 8,
|
||||
mel_bins: int = 16,
|
||||
sample_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
audio_latent_downsample_factor: int = 4,
|
||||
) -> "AudioLatentShape":
|
||||
return AudioLatentShape.from_duration(
|
||||
batch=shape.batch,
|
||||
duration=float(shape.frames) / float(shape.fps),
|
||||
channels=channels,
|
||||
mel_bins=mel_bins,
|
||||
sample_rate=sample_rate,
|
||||
hop_length=hop_length,
|
||||
audio_latent_downsample_factor=audio_latent_downsample_factor,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LatentState:
|
||||
"""
|
||||
State of latents during the diffusion denoising process.
|
||||
Attributes:
|
||||
latent: The current noisy latent tensor being denoised.
|
||||
denoise_mask: Mask encoding the denoising strength for each token (1 = full denoising, 0 = no denoising).
|
||||
positions: Positional indices for each latent element, used for positional embeddings.
|
||||
clean_latent: Initial state of the latent before denoising, may include conditioning latents.
|
||||
"""
|
||||
|
||||
latent: torch.Tensor
|
||||
denoise_mask: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
clean_latent: torch.Tensor
|
||||
|
||||
def clone(self) -> "LatentState":
|
||||
return LatentState(
|
||||
latent=self.latent.clone(),
|
||||
denoise_mask=self.denoise_mask.clone(),
|
||||
positions=self.positions.clone(),
|
||||
clean_latent=self.clean_latent.clone(),
|
||||
)
|
||||
|
||||
|
||||
class NormType(Enum):
|
||||
"""Normalization layer types: GROUP (GroupNorm) or PIXEL (per-location RMS norm)."""
|
||||
|
||||
GROUP = "group"
|
||||
PIXEL = "pixel"
|
||||
|
||||
|
||||
class PixelNorm(nn.Module):
|
||||
"""
|
||||
Per-pixel (per-location) RMS normalization layer.
|
||||
For each element along the chosen dimension, this layer normalizes the tensor
|
||||
by the root-mean-square of its values across that dimension:
|
||||
y = x / sqrt(mean(x^2, dim=dim, keepdim=True) + eps)
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
|
||||
"""
|
||||
Args:
|
||||
dim: Dimension along which to compute the RMS (typically channels).
|
||||
eps: Small constant added for numerical stability.
|
||||
"""
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply RMS normalization along the configured dimension.
|
||||
"""
|
||||
# Compute mean of squared values along `dim`, keep dimensions for broadcasting.
|
||||
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
|
||||
# Normalize by the root-mean-square (RMS).
|
||||
rms = torch.sqrt(mean_sq + self.eps)
|
||||
return x / rms
|
||||
|
||||
|
||||
def build_normalization_layer(
|
||||
in_channels: int, *, num_groups: int = 32, normtype: NormType = NormType.GROUP
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Create a normalization layer based on the normalization type.
|
||||
Args:
|
||||
in_channels: Number of input channels
|
||||
num_groups: Number of groups for group normalization
|
||||
normtype: Type of normalization: "group" or "pixel"
|
||||
Returns:
|
||||
A normalization layer
|
||||
"""
|
||||
if normtype == NormType.GROUP:
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
if normtype == NormType.PIXEL:
|
||||
return PixelNorm(dim=1, eps=1e-6)
|
||||
raise ValueError(f"Invalid normalization type: {normtype}")
|
||||
|
||||
|
||||
def rms_norm(x: torch.Tensor, weight: torch.Tensor | None = None, eps: float = 1e-6) -> torch.Tensor:
|
||||
"""Root-mean-square (RMS) normalize `x` over its last dimension.
|
||||
Thin wrapper around `torch.nn.functional.rms_norm` that infers the normalized
|
||||
shape and forwards `weight` and `eps`.
|
||||
"""
|
||||
return torch.nn.functional.rms_norm(x, (x.shape[-1],), weight=weight, eps=eps)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Modality:
|
||||
"""
|
||||
Input data for a single modality (video or audio) in the transformer.
|
||||
Bundles the latent tokens, timestep embeddings, positional information,
|
||||
and text conditioning context for processing by the diffusion transformer.
|
||||
"""
|
||||
|
||||
latent: (
|
||||
torch.Tensor
|
||||
) # Shape: (B, T, D) where B is the batch size, T is the number of tokens, and D is input dimension
|
||||
timesteps: torch.Tensor # Shape: (B, T) where T is the number of timesteps
|
||||
positions: (
|
||||
torch.Tensor
|
||||
) # Shape: (B, 3, T) for video, where 3 is the number of dimensions and T is the number of tokens
|
||||
context: torch.Tensor
|
||||
enabled: bool = True
|
||||
context_mask: torch.Tensor | None = None
|
||||
|
||||
|
||||
def to_denoised(
|
||||
sample: torch.Tensor,
|
||||
velocity: torch.Tensor,
|
||||
sigma: float | torch.Tensor,
|
||||
calc_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert the sample and its denoising velocity to denoised sample.
|
||||
Returns:
|
||||
Denoised sample
|
||||
"""
|
||||
if isinstance(sigma, torch.Tensor):
|
||||
sigma = sigma.to(calc_dtype)
|
||||
return (sample.to(calc_dtype) - velocity.to(calc_dtype) * sigma).to(sample.dtype)
|
||||
|
||||
|
||||
|
||||
class Patchifier(Protocol):
|
||||
"""
|
||||
Protocol for patchifiers that convert latent tensors into patches and assemble them back.
|
||||
"""
|
||||
|
||||
def patchify(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
"""
|
||||
Convert latent tensors into flattened patch tokens.
|
||||
Args:
|
||||
latents: Latent tensor to patchify.
|
||||
Returns:
|
||||
Flattened patch tokens tensor.
|
||||
"""
|
||||
|
||||
def unpatchify(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
output_shape: AudioLatentShape | VideoLatentShape,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Converts latent tensors between spatio-temporal formats and flattened sequence representations.
|
||||
Args:
|
||||
latents: Patch tokens that must be rearranged back into the latent grid constructed by `patchify`.
|
||||
output_shape: Shape of the output tensor. Note that output_shape is either AudioLatentShape or
|
||||
VideoLatentShape.
|
||||
Returns:
|
||||
Dense latent tensor restored from the flattened representation.
|
||||
"""
|
||||
|
||||
@property
|
||||
def patch_size(self) -> Tuple[int, int, int]:
|
||||
...
|
||||
"""
|
||||
Returns the patch size as a tuple of (temporal, height, width) dimensions
|
||||
"""
|
||||
|
||||
def get_patch_grid_bounds(
|
||||
self,
|
||||
output_shape: AudioLatentShape | VideoLatentShape,
|
||||
device: torch.device | None = None,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
"""
|
||||
Compute metadata describing where each latent patch resides within the
|
||||
grid specified by `output_shape`.
|
||||
Args:
|
||||
output_shape: Target grid layout for the patches.
|
||||
device: Target device for the returned tensor.
|
||||
Returns:
|
||||
Tensor containing patch coordinate metadata such as spatial or temporal intervals.
|
||||
"""
|
||||
|
||||
|
||||
def get_pixel_coords(
|
||||
latent_coords: torch.Tensor,
|
||||
scale_factors: SpatioTemporalScaleFactors,
|
||||
causal_fix: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Map latent-space `[start, end)` coordinates to their pixel-space equivalents by scaling
|
||||
each axis (frame/time, height, width) with the corresponding VAE downsampling factors.
|
||||
Optionally compensate for causal encoding that keeps the first frame at unit temporal scale.
|
||||
Args:
|
||||
latent_coords: Tensor of latent bounds shaped `(batch, 3, num_patches, 2)`.
|
||||
scale_factors: SpatioTemporalScaleFactors tuple `(temporal, height, width)` with integer scale factors applied
|
||||
per axis.
|
||||
causal_fix: When True, rewrites the temporal axis of the first frame so causal VAEs
|
||||
that treat frame zero differently still yield non-negative timestamps.
|
||||
"""
|
||||
# Broadcast the VAE scale factors so they align with the `(batch, axis, patch, bound)` layout.
|
||||
broadcast_shape = [1] * latent_coords.ndim
|
||||
broadcast_shape[1] = -1 # axis dimension corresponds to (frame/time, height, width)
|
||||
scale_tensor = torch.tensor(scale_factors, device=latent_coords.device).view(*broadcast_shape)
|
||||
|
||||
# Apply per-axis scaling to convert latent bounds into pixel-space coordinates.
|
||||
pixel_coords = latent_coords * scale_tensor
|
||||
|
||||
if causal_fix:
|
||||
# VAE temporal stride for the very first frame is 1 instead of `scale_factors[0]`.
|
||||
# Shift and clamp to keep the first-frame timestamps causal and non-negative.
|
||||
pixel_coords[:, 0, ...] = (pixel_coords[:, 0, ...] + 1 - scale_factors[0]).clamp(min=0)
|
||||
|
||||
return pixel_coords
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,366 +0,0 @@
|
||||
import torch
|
||||
from transformers import Gemma3ForConditionalGeneration, Gemma3Config, AutoTokenizer
|
||||
from .ltx2_dit import (LTXRopeType, generate_freq_grid_np, generate_freq_grid_pytorch, precompute_freqs_cis, Attention,
|
||||
FeedForward)
|
||||
from .ltx2_common import rms_norm
|
||||
|
||||
|
||||
class LTX2TextEncoder(Gemma3ForConditionalGeneration):
|
||||
def __init__(self):
|
||||
config = Gemma3Config(
|
||||
**{
|
||||
"architectures": ["Gemma3ForConditionalGeneration"],
|
||||
"boi_token_index": 255999,
|
||||
"dtype": "bfloat16",
|
||||
"eoi_token_index": 256000,
|
||||
"eos_token_id": [1, 106],
|
||||
"image_token_index": 262144,
|
||||
"initializer_range": 0.02,
|
||||
"mm_tokens_per_image": 256,
|
||||
"model_type": "gemma3",
|
||||
"text_config": {
|
||||
"_sliding_window_pattern": 6,
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"attn_logit_softcapping": None,
|
||||
"cache_implementation": "hybrid",
|
||||
"dtype": "bfloat16",
|
||||
"final_logit_softcapping": None,
|
||||
"head_dim": 256,
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"hidden_size": 3840,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 15360,
|
||||
"layer_types": [
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "full_attention", "sliding_attention", "sliding_attention",
|
||||
"sliding_attention", "sliding_attention", "sliding_attention", "full_attention"
|
||||
],
|
||||
"max_position_embeddings": 131072,
|
||||
"model_type": "gemma3_text",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 48,
|
||||
"num_key_value_heads": 8,
|
||||
"query_pre_attn_scalar": 256,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_local_base_freq": 10000,
|
||||
"rope_scaling": {
|
||||
"factor": 8.0,
|
||||
"rope_type": "linear"
|
||||
},
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": 1024,
|
||||
"sliding_window_pattern": 6,
|
||||
"use_bidirectional_attention": False,
|
||||
"use_cache": True,
|
||||
"vocab_size": 262208
|
||||
},
|
||||
"transformers_version": "4.57.3",
|
||||
"vision_config": {
|
||||
"attention_dropout": 0.0,
|
||||
"dtype": "bfloat16",
|
||||
"hidden_act": "gelu_pytorch_tanh",
|
||||
"hidden_size": 1152,
|
||||
"image_size": 896,
|
||||
"intermediate_size": 4304,
|
||||
"layer_norm_eps": 1e-06,
|
||||
"model_type": "siglip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 27,
|
||||
"patch_size": 14,
|
||||
"vision_use_head": False
|
||||
}
|
||||
})
|
||||
super().__init__(config)
|
||||
|
||||
|
||||
class LTXVGemmaTokenizer:
|
||||
"""
|
||||
Tokenizer wrapper for Gemma models compatible with LTXV processes.
|
||||
This class wraps HuggingFace's `AutoTokenizer` for use with Gemma text encoders,
|
||||
ensuring correct settings and output formatting for downstream consumption.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer_path: str, max_length: int = 1024):
|
||||
"""
|
||||
Initialize the tokenizer.
|
||||
Args:
|
||||
tokenizer_path (str): Path to the pretrained tokenizer files or model directory.
|
||||
max_length (int, optional): Max sequence length for encoding. Defaults to 256.
|
||||
"""
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer_path, local_files_only=True, model_max_length=max_length
|
||||
)
|
||||
# Gemma expects left padding for chat-style prompts; for plain text it doesn't matter much.
|
||||
self.tokenizer.padding_side = "left"
|
||||
if self.tokenizer.pad_token is None:
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
self.max_length = max_length
|
||||
|
||||
def tokenize_with_weights(self, text: str, return_word_ids: bool = False) -> dict[str, list[tuple[int, int]]]:
|
||||
"""
|
||||
Tokenize the given text and return token IDs and attention weights.
|
||||
Args:
|
||||
text (str): The input string to tokenize.
|
||||
return_word_ids (bool, optional): If True, includes the token's position (index) in the output tuples.
|
||||
If False (default), omits the indices.
|
||||
Returns:
|
||||
dict[str, list[tuple[int, int]]] OR dict[str, list[tuple[int, int, int]]]:
|
||||
A dictionary with a "gemma" key mapping to:
|
||||
- a list of (token_id, attention_mask) tuples if return_word_ids is False;
|
||||
- a list of (token_id, attention_mask, index) tuples if return_word_ids is True.
|
||||
Example:
|
||||
>>> tokenizer = LTXVGemmaTokenizer("path/to/tokenizer", max_length=8)
|
||||
>>> tokenizer.tokenize_with_weights("hello world")
|
||||
{'gemma': [(1234, 1), (5678, 1), (2, 0), ...]}
|
||||
"""
|
||||
text = text.strip()
|
||||
encoded = self.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
max_length=self.max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
input_ids = encoded.input_ids
|
||||
attention_mask = encoded.attention_mask
|
||||
tuples = [
|
||||
(token_id, attn, i) for i, (token_id, attn) in enumerate(zip(input_ids[0], attention_mask[0], strict=True))
|
||||
]
|
||||
out = {"gemma": tuples}
|
||||
|
||||
if not return_word_ids:
|
||||
# Return only (token_id, attention_mask) pairs, omitting token position
|
||||
out = {k: [(t, w) for t, w, _ in v] for k, v in out.items()}
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class GemmaFeaturesExtractorProjLinear(torch.nn.Module):
|
||||
"""
|
||||
Feature extractor module for Gemma models.
|
||||
This module applies a single linear projection to the input tensor.
|
||||
It expects a flattened feature tensor of shape (batch_size, 3840*49).
|
||||
The linear layer maps this to a (batch_size, 3840) embedding.
|
||||
Attributes:
|
||||
aggregate_embed (torch.nn.Linear): Linear projection layer.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize the GemmaFeaturesExtractorProjLinear module.
|
||||
The input dimension is expected to be 3840 * 49, and the output is 3840.
|
||||
"""
|
||||
super().__init__()
|
||||
self.aggregate_embed = torch.nn.Linear(3840 * 49, 3840, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the feature extractor.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor of shape (batch_size, 3840 * 49).
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of shape (batch_size, 3840).
|
||||
"""
|
||||
return self.aggregate_embed(x)
|
||||
|
||||
|
||||
class _BasicTransformerBlock1D(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
heads: int,
|
||||
dim_head: int,
|
||||
rope_type: LTXRopeType = LTXRopeType.INTERLEAVED,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
rope_type=rope_type,
|
||||
)
|
||||
|
||||
self.ff = FeedForward(
|
||||
dim,
|
||||
dim_out=dim,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
pe: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# Notice that normalization is always applied before the real computation in the following blocks.
|
||||
|
||||
# 1. Normalization Before Self-Attention
|
||||
norm_hidden_states = rms_norm(hidden_states)
|
||||
|
||||
norm_hidden_states = norm_hidden_states.squeeze(1)
|
||||
|
||||
# 2. Self-Attention
|
||||
attn_output = self.attn1(norm_hidden_states, mask=attention_mask, pe=pe)
|
||||
|
||||
hidden_states = attn_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
# 3. Normalization before Feed-Forward
|
||||
norm_hidden_states = rms_norm(hidden_states)
|
||||
|
||||
# 4. Feed-forward
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
if hidden_states.ndim == 4:
|
||||
hidden_states = hidden_states.squeeze(1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Embeddings1DConnector(torch.nn.Module):
|
||||
"""
|
||||
Embeddings1DConnector applies a 1D transformer-based processing to sequential embeddings (e.g., for video, audio, or
|
||||
other modalities). It supports rotary positional encoding (rope), optional causal temporal positioning, and can
|
||||
substitute padded positions with learnable registers. The module is highly configurable for head size, number of
|
||||
layers, and register usage.
|
||||
Args:
|
||||
attention_head_dim (int): Dimension of each attention head (default=128).
|
||||
num_attention_heads (int): Number of attention heads (default=30).
|
||||
num_layers (int): Number of transformer layers (default=2).
|
||||
positional_embedding_theta (float): Scaling factor for position embedding (default=10000.0).
|
||||
positional_embedding_max_pos (list[int] | None): Max positions for positional embeddings (default=[1]).
|
||||
causal_temporal_positioning (bool): If True, uses causal attention (default=False).
|
||||
num_learnable_registers (int | None): Number of learnable registers to replace padded tokens. If None, disables
|
||||
register replacement. (default=128)
|
||||
rope_type (LTXRopeType): The RoPE variant to use (default=DEFAULT_ROPE_TYPE).
|
||||
double_precision_rope (bool): Use double precision rope calculation (default=False).
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 30,
|
||||
num_layers: int = 2,
|
||||
positional_embedding_theta: float = 10000.0,
|
||||
positional_embedding_max_pos: list[int] | None = [4096],
|
||||
causal_temporal_positioning: bool = False,
|
||||
num_learnable_registers: int | None = 128,
|
||||
rope_type: LTXRopeType = LTXRopeType.SPLIT,
|
||||
double_precision_rope: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
self.positional_embedding_theta = positional_embedding_theta
|
||||
self.positional_embedding_max_pos = (
|
||||
positional_embedding_max_pos if positional_embedding_max_pos is not None else [1]
|
||||
)
|
||||
self.rope_type = rope_type
|
||||
self.double_precision_rope = double_precision_rope
|
||||
self.transformer_1d_blocks = torch.nn.ModuleList(
|
||||
[
|
||||
_BasicTransformerBlock1D(
|
||||
dim=self.inner_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
rope_type=rope_type,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
if self.num_learnable_registers:
|
||||
self.learnable_registers = torch.nn.Parameter(
|
||||
torch.rand(self.num_learnable_registers, self.inner_dim, dtype=torch.bfloat16) * 2.0 - 1.0
|
||||
)
|
||||
|
||||
def _replace_padded_with_learnable_registers(
|
||||
self, hidden_states: torch.Tensor, attention_mask: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert hidden_states.shape[1] % self.num_learnable_registers == 0, (
|
||||
f"Hidden states sequence length {hidden_states.shape[1]} must be divisible by num_learnable_registers "
|
||||
f"{self.num_learnable_registers}."
|
||||
)
|
||||
|
||||
num_registers_duplications = hidden_states.shape[1] // self.num_learnable_registers
|
||||
learnable_registers = torch.tile(self.learnable_registers, (num_registers_duplications, 1))
|
||||
attention_mask_binary = (attention_mask.squeeze(1).squeeze(1).unsqueeze(-1) >= -9000.0).int()
|
||||
|
||||
non_zero_hidden_states = hidden_states[:, attention_mask_binary.squeeze().bool(), :]
|
||||
non_zero_nums = non_zero_hidden_states.shape[1]
|
||||
pad_length = hidden_states.shape[1] - non_zero_nums
|
||||
adjusted_hidden_states = torch.nn.functional.pad(non_zero_hidden_states, pad=(0, 0, 0, pad_length), value=0)
|
||||
flipped_mask = torch.flip(attention_mask_binary, dims=[1])
|
||||
hidden_states = flipped_mask * adjusted_hidden_states + (1 - flipped_mask) * learnable_registers
|
||||
|
||||
attention_mask = torch.full_like(
|
||||
attention_mask,
|
||||
0.0,
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
|
||||
return hidden_states, attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass of Embeddings1DConnector.
|
||||
Args:
|
||||
hidden_states (torch.Tensor): Input tensor of embeddings (shape [batch, seq_len, feature_dim]).
|
||||
attention_mask (torch.Tensor|None): Optional mask for valid tokens (shape compatible with hidden_states).
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: Processed features and the corresponding (possibly modified) mask.
|
||||
"""
|
||||
if self.num_learnable_registers:
|
||||
hidden_states, attention_mask = self._replace_padded_with_learnable_registers(hidden_states, attention_mask)
|
||||
|
||||
indices_grid = torch.arange(hidden_states.shape[1], dtype=torch.float32, device=hidden_states.device)
|
||||
indices_grid = indices_grid[None, None, :]
|
||||
freq_grid_generator = generate_freq_grid_np if self.double_precision_rope else generate_freq_grid_pytorch
|
||||
freqs_cis = precompute_freqs_cis(
|
||||
indices_grid=indices_grid,
|
||||
dim=self.inner_dim,
|
||||
out_dtype=hidden_states.dtype,
|
||||
theta=self.positional_embedding_theta,
|
||||
max_pos=self.positional_embedding_max_pos,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
rope_type=self.rope_type,
|
||||
freq_grid_generator=freq_grid_generator,
|
||||
)
|
||||
|
||||
for block in self.transformer_1d_blocks:
|
||||
hidden_states = block(hidden_states, attention_mask=attention_mask, pe=freqs_cis)
|
||||
|
||||
hidden_states = rms_norm(hidden_states)
|
||||
|
||||
return hidden_states, attention_mask
|
||||
|
||||
|
||||
class LTX2TextEncoderPostModules(torch.nn.Module):
|
||||
def __init__(self,):
|
||||
super().__init__()
|
||||
self.feature_extractor_linear = GemmaFeaturesExtractorProjLinear()
|
||||
self.embeddings_connector = Embeddings1DConnector()
|
||||
self.audio_embeddings_connector = Embeddings1DConnector()
|
||||
@@ -1,313 +0,0 @@
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
import torch
|
||||
from einops import rearrange
|
||||
import torch.nn.functional as F
|
||||
from .ltx2_video_vae import LTX2VideoEncoder
|
||||
|
||||
class PixelShuffleND(torch.nn.Module):
|
||||
"""
|
||||
N-dimensional pixel shuffle operation for upsampling tensors.
|
||||
Args:
|
||||
dims (int): Number of dimensions to apply pixel shuffle to.
|
||||
- 1: Temporal (e.g., frames)
|
||||
- 2: Spatial (e.g., height and width)
|
||||
- 3: Spatiotemporal (e.g., depth, height, width)
|
||||
upscale_factors (tuple[int, int, int], optional): Upscaling factors for each dimension.
|
||||
For dims=1, only the first value is used.
|
||||
For dims=2, the first two values are used.
|
||||
For dims=3, all three values are used.
|
||||
The input tensor is rearranged so that the channel dimension is split into
|
||||
smaller channels and upscaling factors, and the upscaling factors are moved
|
||||
into the corresponding spatial/temporal dimensions.
|
||||
Note:
|
||||
This operation is equivalent to the patchifier operation in for the models. Consider
|
||||
using this class instead.
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, upscale_factors: tuple[int, int, int] = (2, 2, 2)):
|
||||
super().__init__()
|
||||
assert dims in [1, 2, 3], "dims must be 1, 2, or 3"
|
||||
self.dims = dims
|
||||
self.upscale_factors = upscale_factors
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.dims == 3:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
||||
p1=self.upscale_factors[0],
|
||||
p2=self.upscale_factors[1],
|
||||
p3=self.upscale_factors[2],
|
||||
)
|
||||
elif self.dims == 2:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1 p2) h w -> b c (h p1) (w p2)",
|
||||
p1=self.upscale_factors[0],
|
||||
p2=self.upscale_factors[1],
|
||||
)
|
||||
elif self.dims == 1:
|
||||
return rearrange(
|
||||
x,
|
||||
"b (c p1) f h w -> b c (f p1) h w",
|
||||
p1=self.upscale_factors[0],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dims: {self.dims}")
|
||||
|
||||
|
||||
class ResBlock(torch.nn.Module):
|
||||
"""
|
||||
Residual block with two convolutional layers, group normalization, and SiLU activation.
|
||||
Args:
|
||||
channels (int): Number of input and output channels.
|
||||
mid_channels (Optional[int]): Number of channels in the intermediate convolution layer. Defaults to `channels`
|
||||
if not specified.
|
||||
dims (int): Dimensionality of the convolution (2 for Conv2d, 3 for Conv3d). Defaults to 3.
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
|
||||
super().__init__()
|
||||
if mid_channels is None:
|
||||
mid_channels = channels
|
||||
|
||||
conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||
|
||||
self.conv1 = conv(channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.norm1 = torch.nn.GroupNorm(32, mid_channels)
|
||||
self.conv2 = conv(mid_channels, channels, kernel_size=3, padding=1)
|
||||
self.norm2 = torch.nn.GroupNorm(32, channels)
|
||||
self.activation = torch.nn.SiLU()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
residual = x
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.activation(x)
|
||||
x = self.conv2(x)
|
||||
x = self.norm2(x)
|
||||
x = self.activation(x + residual)
|
||||
return x
|
||||
|
||||
|
||||
class BlurDownsample(torch.nn.Module):
|
||||
"""
|
||||
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel.
|
||||
Applies only on H,W. Works for dims=2 or dims=3 (per-frame).
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:
|
||||
super().__init__()
|
||||
assert dims in (2, 3)
|
||||
assert isinstance(stride, int)
|
||||
assert stride >= 1
|
||||
assert kernel_size >= 3
|
||||
assert kernel_size % 2 == 1
|
||||
self.dims = dims
|
||||
self.stride = stride
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
# 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from
|
||||
# the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and
|
||||
# provides a smooth approximation of a Gaussian filter (often called a "binomial filter").
|
||||
# The 2D kernel is constructed as the outer product and normalized.
|
||||
k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)])
|
||||
k2d = k[:, None] @ k[None, :]
|
||||
k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size)
|
||||
self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.stride == 1:
|
||||
return x
|
||||
|
||||
if self.dims == 2:
|
||||
return self._apply_2d(x)
|
||||
else:
|
||||
# dims == 3: apply per-frame on H,W
|
||||
b, _, f, _, _ = x.shape
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = self._apply_2d(x)
|
||||
h2, w2 = x.shape[-2:]
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f, h=h2, w=w2)
|
||||
return x
|
||||
|
||||
def _apply_2d(self, x2d: torch.Tensor) -> torch.Tensor:
|
||||
c = x2d.shape[1]
|
||||
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
|
||||
x2d = F.conv2d(x2d, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
|
||||
return x2d
|
||||
|
||||
|
||||
def _rational_for_scale(scale: float) -> Tuple[int, int]:
|
||||
mapping = {0.75: (3, 4), 1.5: (3, 2), 2.0: (2, 1), 4.0: (4, 1)}
|
||||
if float(scale) not in mapping:
|
||||
raise ValueError(f"Unsupported scale {scale}. Choose from {list(mapping.keys())}")
|
||||
return mapping[float(scale)]
|
||||
|
||||
|
||||
class SpatialRationalResampler(torch.nn.Module):
|
||||
"""
|
||||
Fully-learned rational spatial scaling: up by 'num' via PixelShuffle, then anti-aliased
|
||||
downsample by 'den' using fixed blur + stride. Operates on H,W only.
|
||||
For dims==3, work per-frame for spatial scaling (temporal axis untouched).
|
||||
Args:
|
||||
mid_channels (`int`): Number of intermediate channels for the convolution layer
|
||||
scale (`float`): Spatial scaling factor. Supported values are:
|
||||
- 0.75: Downsample by 3/4 (reduce spatial size)
|
||||
- 1.5: Upsample by 3/2 (increase spatial size)
|
||||
- 2.0: Upsample by 2x (double spatial size)
|
||||
- 4.0: Upsample by 4x (quadruple spatial size)
|
||||
Any other value will raise a ValueError.
|
||||
"""
|
||||
|
||||
def __init__(self, mid_channels: int, scale: float):
|
||||
super().__init__()
|
||||
self.scale = float(scale)
|
||||
self.num, self.den = _rational_for_scale(self.scale)
|
||||
self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)
|
||||
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
|
||||
self.blur_down = BlurDownsample(dims=2, stride=self.den)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, _, f, _, _ = x.shape
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = self.conv(x)
|
||||
x = self.pixel_shuffle(x)
|
||||
x = self.blur_down(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
return x
|
||||
|
||||
|
||||
class LTX2LatentUpsampler(torch.nn.Module):
|
||||
"""
|
||||
Model to upsample VAE latents spatially and/or temporally.
|
||||
Args:
|
||||
in_channels (`int`): Number of channels in the input latent
|
||||
mid_channels (`int`): Number of channels in the middle layers
|
||||
num_blocks_per_stage (`int`): Number of ResBlocks to use in each stage (pre/post upsampling)
|
||||
dims (`int`): Number of dimensions for convolutions (2 or 3)
|
||||
spatial_upsample (`bool`): Whether to spatially upsample the latent
|
||||
temporal_upsample (`bool`): Whether to temporally upsample the latent
|
||||
spatial_scale (`float`): Scale factor for spatial upsampling
|
||||
rational_resampler (`bool`): Whether to use a rational resampler for spatial upsampling
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
mid_channels: int = 1024,
|
||||
num_blocks_per_stage: int = 4,
|
||||
dims: int = 3,
|
||||
spatial_upsample: bool = True,
|
||||
temporal_upsample: bool = False,
|
||||
spatial_scale: float = 2.0,
|
||||
rational_resampler: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.num_blocks_per_stage = num_blocks_per_stage
|
||||
self.dims = dims
|
||||
self.spatial_upsample = spatial_upsample
|
||||
self.temporal_upsample = temporal_upsample
|
||||
self.spatial_scale = float(spatial_scale)
|
||||
self.rational_resampler = rational_resampler
|
||||
|
||||
conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||
|
||||
self.initial_conv = conv(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
|
||||
self.initial_activation = torch.nn.SiLU()
|
||||
|
||||
self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
|
||||
|
||||
if spatial_upsample and temporal_upsample:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(3),
|
||||
)
|
||||
elif spatial_upsample:
|
||||
if rational_resampler:
|
||||
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=self.spatial_scale)
|
||||
else:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(2),
|
||||
)
|
||||
elif temporal_upsample:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(1),
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either spatial_upsample or temporal_upsample must be True")
|
||||
|
||||
self.post_upsample_res_blocks = torch.nn.ModuleList(
|
||||
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||
)
|
||||
|
||||
self.final_conv = conv(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, latent: torch.Tensor) -> torch.Tensor:
|
||||
b, _, f, _, _ = latent.shape
|
||||
|
||||
if self.dims == 2:
|
||||
x = rearrange(latent, "b c f h w -> (b f) c h w")
|
||||
x = self.initial_conv(x)
|
||||
x = self.initial_norm(x)
|
||||
x = self.initial_activation(x)
|
||||
|
||||
for block in self.res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.upsampler(x)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
else:
|
||||
x = self.initial_conv(latent)
|
||||
x = self.initial_norm(x)
|
||||
x = self.initial_activation(x)
|
||||
|
||||
for block in self.res_blocks:
|
||||
x = block(x)
|
||||
|
||||
if self.temporal_upsample:
|
||||
x = self.upsampler(x)
|
||||
# remove the first frame after upsampling.
|
||||
# This is done because the first frame encodes one pixel frame.
|
||||
x = x[:, :, 1:, :, :]
|
||||
elif isinstance(self.upsampler, SpatialRationalResampler):
|
||||
x = self.upsampler(x)
|
||||
else:
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = self.upsampler(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", b=b, f=f)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
x = block(x)
|
||||
|
||||
x = self.final_conv(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def upsample_video(latent: torch.Tensor, video_encoder: LTX2VideoEncoder, upsampler: "LTX2LatentUpsampler") -> torch.Tensor:
|
||||
"""
|
||||
Apply upsampling to the latent representation using the provided upsampler,
|
||||
with normalization and un-normalization based on the video encoder's per-channel statistics.
|
||||
Args:
|
||||
latent: Input latent tensor of shape [B, C, F, H, W].
|
||||
video_encoder: VideoEncoder with per_channel_statistics for normalization.
|
||||
upsampler: LTX2LatentUpsampler module to perform upsampling.
|
||||
Returns:
|
||||
torch.Tensor: Upsampled and re-normalized latent tensor.
|
||||
"""
|
||||
latent = video_encoder.per_channel_statistics.un_normalize(latent)
|
||||
latent = upsampler(latent)
|
||||
latent = video_encoder.per_channel_statistics.normalize(latent)
|
||||
return latent
|
||||
File diff suppressed because it is too large
Load Diff
@@ -29,7 +29,7 @@ class ModelPool:
|
||||
module_map = None
|
||||
return module_map
|
||||
|
||||
def load_model_file(self, config, path, vram_config, vram_limit=None, state_dict=None):
|
||||
def load_model_file(self, config, path, vram_config, vram_limit=None):
|
||||
model_class = self.import_model_class(config["model_class"])
|
||||
model_config = config.get("extra_kwargs", {})
|
||||
if "state_dict_converter" in config:
|
||||
@@ -43,7 +43,6 @@ class ModelPool:
|
||||
state_dict_converter,
|
||||
use_disk_map=True,
|
||||
vram_config=vram_config, module_map=module_map, vram_limit=vram_limit,
|
||||
state_dict=state_dict,
|
||||
)
|
||||
return model
|
||||
|
||||
@@ -60,7 +59,7 @@ class ModelPool:
|
||||
}
|
||||
return vram_config
|
||||
|
||||
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False, state_dict=None):
|
||||
def auto_load_model(self, path, vram_config=None, vram_limit=None, clear_parameters=False):
|
||||
print(f"Loading models from: {json.dumps(path, indent=4)}")
|
||||
if vram_config is None:
|
||||
vram_config = self.default_vram_config()
|
||||
@@ -68,7 +67,7 @@ class ModelPool:
|
||||
loaded = False
|
||||
for config in MODEL_CONFIGS:
|
||||
if config["model_hash"] == model_hash:
|
||||
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit, state_dict=state_dict)
|
||||
model = self.load_model_file(config, path, vram_config, vram_limit=vram_limit)
|
||||
if clear_parameters: self.clear_parameters(model)
|
||||
self.model.append(model)
|
||||
model_name = config["model_name"]
|
||||
|
||||
@@ -583,7 +583,7 @@ class Qwen2_5_VLForConditionalGeneration(Qwen2_5_VLPreTrainedModel, GenerationMi
|
||||
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
|
||||
is_compileable = is_compileable and not self.generation_config.disable_compile
|
||||
if is_compileable and (
|
||||
self.device.type in ["cuda", "npu"] or generation_config.compile_config._compile_all_devices
|
||||
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices
|
||||
):
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "0"
|
||||
model_forward = self.get_compiled_call(generation_config.compile_config)
|
||||
|
||||
@@ -2,8 +2,6 @@ from transformers.models.siglip.modeling_siglip import SiglipVisionTransformer,
|
||||
from transformers import SiglipImageProcessor, Siglip2VisionModel, Siglip2VisionConfig, Siglip2ImageProcessorFast
|
||||
import torch
|
||||
|
||||
from diffsynth.core.device.npu_compatible_device import get_device_type
|
||||
|
||||
|
||||
class Siglip2ImageEncoder(SiglipVisionTransformer):
|
||||
def __init__(self):
|
||||
@@ -49,7 +47,7 @@ class Siglip2ImageEncoder(SiglipVisionTransformer):
|
||||
}
|
||||
)
|
||||
|
||||
def forward(self, image, torch_dtype=torch.bfloat16, device=get_device_type()):
|
||||
def forward(self, image, torch_dtype=torch.bfloat16, device="cuda"):
|
||||
pixel_values = self.processor(images=[image], return_tensors="pt")["pixel_values"]
|
||||
pixel_values = pixel_values.to(device=device, dtype=torch_dtype)
|
||||
output_attentions = False
|
||||
@@ -92,10 +90,12 @@ class Siglip2ImageEncoder428M(Siglip2VisionModel):
|
||||
super().__init__(config)
|
||||
self.processor = Siglip2ImageProcessorFast(
|
||||
**{
|
||||
"crop_size": None,
|
||||
"data_format": "channels_first",
|
||||
"default_to_square": True,
|
||||
"device": None,
|
||||
"disable_grouping": None,
|
||||
"do_center_crop": None,
|
||||
"do_convert_rgb": None,
|
||||
"do_normalize": True,
|
||||
"do_pad": None,
|
||||
@@ -120,6 +120,7 @@ class Siglip2ImageEncoder428M(Siglip2VisionModel):
|
||||
"resample": 2,
|
||||
"rescale_factor": 0.00392156862745098,
|
||||
"return_tensors": None,
|
||||
"size": None
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import torch
|
||||
from typing import Optional, Union
|
||||
from .qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from ..core.device.npu_compatible_device import get_device_type, get_torch_device
|
||||
|
||||
|
||||
class Step1xEditEmbedder(torch.nn.Module):
|
||||
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device=get_device_type()):
|
||||
def __init__(self, model: QwenImageTextEncoder, processor, max_length=640, dtype=torch.bfloat16, device="cuda"):
|
||||
super().__init__()
|
||||
self.max_length = max_length
|
||||
self.dtype = dtype
|
||||
@@ -78,13 +77,13 @@ User Prompt:'''
|
||||
self.max_length,
|
||||
self.model.config.hidden_size,
|
||||
dtype=torch.bfloat16,
|
||||
device=get_torch_device().current_device(),
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
masks = torch.zeros(
|
||||
len(text_list),
|
||||
self.max_length,
|
||||
dtype=torch.long,
|
||||
device=get_torch_device().current_device(),
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
|
||||
def split_string(s):
|
||||
@@ -159,7 +158,7 @@ User Prompt:'''
|
||||
else:
|
||||
token_list.append(token_each)
|
||||
|
||||
new_txt_ids = torch.cat(token_list, dim=1).to(get_device_type())
|
||||
new_txt_ids = torch.cat(token_list, dim=1).to("cuda")
|
||||
|
||||
new_txt_ids = new_txt_ids.to(old_inputs_ids.device)
|
||||
|
||||
@@ -168,15 +167,15 @@ User Prompt:'''
|
||||
inputs.input_ids = (
|
||||
torch.cat([old_inputs_ids[0, :idx1], new_txt_ids[0, idx2:]], dim=0)
|
||||
.unsqueeze(0)
|
||||
.to(get_device_type())
|
||||
.to("cuda")
|
||||
)
|
||||
inputs.attention_mask = (inputs.input_ids > 0).long().to(get_device_type())
|
||||
inputs.attention_mask = (inputs.input_ids > 0).long().to("cuda")
|
||||
outputs = self.model_forward(
|
||||
self.model,
|
||||
input_ids=inputs.input_ids,
|
||||
attention_mask=inputs.attention_mask,
|
||||
pixel_values=inputs.pixel_values.to(get_device_type()),
|
||||
image_grid_thw=inputs.image_grid_thw.to(get_device_type()),
|
||||
pixel_values=inputs.pixel_values.to("cuda"),
|
||||
image_grid_thw=inputs.image_grid_thw.to("cuda"),
|
||||
output_hidden_states=True,
|
||||
)
|
||||
|
||||
@@ -189,7 +188,7 @@ User Prompt:'''
|
||||
masks[idx, : min(self.max_length, emb.shape[1] - 217)] = torch.ones(
|
||||
(min(self.max_length, emb.shape[1] - 217)),
|
||||
dtype=torch.long,
|
||||
device=get_torch_device().current_device(),
|
||||
device=torch.cuda.current_device(),
|
||||
)
|
||||
|
||||
return embs, masks
|
||||
|
||||
@@ -5,7 +5,6 @@ import math
|
||||
from typing import Tuple, Optional
|
||||
from einops import rearrange
|
||||
from .wan_video_camera_controller import SimpleAdapter
|
||||
|
||||
try:
|
||||
import flash_attn_interface
|
||||
FLASH_ATTN_3_AVAILABLE = True
|
||||
@@ -93,7 +92,6 @@ def rope_apply(x, freqs, num_heads):
|
||||
x = rearrange(x, "b s (n d) -> b s n d", n=num_heads)
|
||||
x_out = torch.view_as_complex(x.to(torch.float64).reshape(
|
||||
x.shape[0], x.shape[1], x.shape[2], -1, 2))
|
||||
freqs = freqs.to(torch.complex64) if freqs.device.type == "npu" else freqs
|
||||
x_out = torch.view_as_real(x_out * freqs).flatten(2)
|
||||
return x_out.to(x.dtype)
|
||||
|
||||
|
||||
@@ -1,154 +0,0 @@
|
||||
from .z_image_dit import ZImageTransformerBlock
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ZImageControlTransformerBlock(ZImageTransformerBlock):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int = 1000,
|
||||
dim: int = 3840,
|
||||
n_heads: int = 30,
|
||||
n_kv_heads: int = 30,
|
||||
norm_eps: float = 1e-5,
|
||||
qk_norm: bool = True,
|
||||
modulation = True,
|
||||
block_id = 0
|
||||
):
|
||||
super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
|
||||
self.block_id = block_id
|
||||
if block_id == 0:
|
||||
self.before_proj = nn.Linear(self.dim, self.dim)
|
||||
self.after_proj = nn.Linear(self.dim, self.dim)
|
||||
|
||||
def forward(self, c, x, **kwargs):
|
||||
if self.block_id == 0:
|
||||
c = self.before_proj(c) + x
|
||||
all_c = []
|
||||
else:
|
||||
all_c = list(torch.unbind(c))
|
||||
c = all_c.pop(-1)
|
||||
|
||||
c = super().forward(c, **kwargs)
|
||||
c_skip = self.after_proj(c)
|
||||
all_c += [c_skip, c]
|
||||
c = torch.stack(all_c)
|
||||
return c
|
||||
|
||||
|
||||
class ZImageControlNet(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
control_layers_places=(0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28),
|
||||
control_in_dim=33,
|
||||
dim=3840,
|
||||
n_refiner_layers=2,
|
||||
):
|
||||
super().__init__()
|
||||
self.control_layers = nn.ModuleList([ZImageControlTransformerBlock(layer_id=i, block_id=i) for i in control_layers_places])
|
||||
self.control_all_x_embedder = nn.ModuleDict({"2-1": nn.Linear(1 * 2 * 2 * control_in_dim, dim, bias=True)})
|
||||
self.control_noise_refiner = nn.ModuleList([ZImageControlTransformerBlock(block_id=layer_id) for layer_id in range(n_refiner_layers)])
|
||||
self.control_layers_mapping = {0: 0, 2: 1, 4: 2, 6: 3, 8: 4, 10: 5, 12: 6, 14: 7, 16: 8, 18: 9, 20: 10, 22: 11, 24: 12, 26: 13, 28: 14}
|
||||
|
||||
def forward_layers(
|
||||
self,
|
||||
x,
|
||||
cap_feats,
|
||||
control_context,
|
||||
control_context_item_seqlens,
|
||||
kwargs,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
bsz = len(control_context)
|
||||
# unified
|
||||
cap_item_seqlens = [len(_) for _ in cap_feats]
|
||||
control_context_unified = []
|
||||
for i in range(bsz):
|
||||
control_context_len = control_context_item_seqlens[i]
|
||||
cap_len = cap_item_seqlens[i]
|
||||
control_context_unified.append(torch.cat([control_context[i][:control_context_len], cap_feats[i][:cap_len]]))
|
||||
c = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0)
|
||||
|
||||
# arguments
|
||||
new_kwargs = dict(x=x)
|
||||
new_kwargs.update(kwargs)
|
||||
|
||||
for layer in self.control_layers:
|
||||
c = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
c=c, **new_kwargs
|
||||
)
|
||||
|
||||
hints = torch.unbind(c)[:-1]
|
||||
return hints
|
||||
|
||||
def forward_refiner(
|
||||
self,
|
||||
dit,
|
||||
x,
|
||||
cap_feats,
|
||||
control_context,
|
||||
kwargs,
|
||||
t=None,
|
||||
patch_size=2,
|
||||
f_patch_size=1,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
# embeddings
|
||||
bsz = len(control_context)
|
||||
device = control_context[0].device
|
||||
(
|
||||
control_context,
|
||||
control_context_size,
|
||||
control_context_pos_ids,
|
||||
control_context_inner_pad_mask,
|
||||
) = dit.patchify_controlnet(control_context, patch_size, f_patch_size, cap_feats[0].size(0))
|
||||
|
||||
# control_context embed & refine
|
||||
control_context_item_seqlens = [len(_) for _ in control_context]
|
||||
assert all(_ % 2 == 0 for _ in control_context_item_seqlens)
|
||||
control_context_max_item_seqlen = max(control_context_item_seqlens)
|
||||
|
||||
control_context = torch.cat(control_context, dim=0)
|
||||
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context)
|
||||
|
||||
# Match t_embedder output dtype to control_context for layerwise casting compatibility
|
||||
adaln_input = t.type_as(control_context)
|
||||
control_context[torch.cat(control_context_inner_pad_mask)] = dit.x_pad_token.to(dtype=control_context.dtype, device=control_context.device)
|
||||
control_context = list(control_context.split(control_context_item_seqlens, dim=0))
|
||||
control_context_freqs_cis = list(dit.rope_embedder(torch.cat(control_context_pos_ids, dim=0)).split(control_context_item_seqlens, dim=0))
|
||||
|
||||
control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0)
|
||||
control_context_freqs_cis = pad_sequence(control_context_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
control_context_attn_mask = torch.zeros((bsz, control_context_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(control_context_item_seqlens):
|
||||
control_context_attn_mask[i, :seq_len] = 1
|
||||
c = control_context
|
||||
|
||||
# arguments
|
||||
new_kwargs = dict(
|
||||
x=x,
|
||||
attn_mask=control_context_attn_mask,
|
||||
freqs_cis=control_context_freqs_cis,
|
||||
adaln_input=adaln_input,
|
||||
)
|
||||
new_kwargs.update(kwargs)
|
||||
|
||||
for layer in self.control_noise_refiner:
|
||||
c = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
c=c, **new_kwargs
|
||||
)
|
||||
|
||||
hints = torch.unbind(c)[:-1]
|
||||
control_context = torch.unbind(c)[-1]
|
||||
|
||||
return hints, control_context, control_context_item_seqlens
|
||||
@@ -6,9 +6,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from .general_modules import RMSNorm
|
||||
from torch.nn import RMSNorm
|
||||
from ..core.attention import attention_forward
|
||||
from ..core.device.npu_compatible_device import IS_NPU_AVAILABLE, get_device_type
|
||||
from ..core.gradient import gradient_checkpoint_forward
|
||||
|
||||
|
||||
@@ -40,7 +39,7 @@ class TimestepEmbedder(nn.Module):
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
with torch.amp.autocast(get_device_type(), enabled=False):
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
@@ -105,7 +104,7 @@ class Attention(torch.nn.Module):
|
||||
|
||||
# Apply RoPE
|
||||
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
with torch.amp.autocast(get_device_type(), enabled=False):
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
@@ -316,10 +315,7 @@ class RopeEmbedder:
|
||||
result = []
|
||||
for i in range(len(self.axes_dims)):
|
||||
index = ids[:, i]
|
||||
if IS_NPU_AVAILABLE:
|
||||
result.append(torch.index_select(self.freqs_cis[i], 0, index))
|
||||
else:
|
||||
result.append(self.freqs_cis[i][index])
|
||||
result.append(self.freqs_cis[i][index])
|
||||
return torch.cat(result, dim=-1)
|
||||
|
||||
|
||||
@@ -613,72 +609,6 @@ class ZImageDiT(nn.Module):
|
||||
# all_img_pad_mask,
|
||||
# all_cap_pad_mask,
|
||||
# )
|
||||
|
||||
def patchify_controlnet(
|
||||
self,
|
||||
all_image: List[torch.Tensor],
|
||||
patch_size: int = 2,
|
||||
f_patch_size: int = 1,
|
||||
cap_padding_len: int = None,
|
||||
):
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
device = all_image[0].device
|
||||
|
||||
all_image_out = []
|
||||
all_image_size = []
|
||||
all_image_pos_ids = []
|
||||
all_image_pad_mask = []
|
||||
|
||||
for i, image in enumerate(all_image):
|
||||
### Process Image
|
||||
C, F, H, W = image.size()
|
||||
all_image_size.append((F, H, W))
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
|
||||
image_ori_len = len(image)
|
||||
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
||||
|
||||
image_ori_pos_ids = self.create_coordinate_grid(
|
||||
size=(F_tokens, H_tokens, W_tokens),
|
||||
start=(cap_padding_len + 1, 0, 0),
|
||||
device=device,
|
||||
).flatten(0, 2)
|
||||
image_padding_pos_ids = (
|
||||
self.create_coordinate_grid(
|
||||
size=(1, 1, 1),
|
||||
start=(0, 0, 0),
|
||||
device=device,
|
||||
)
|
||||
.flatten(0, 2)
|
||||
.repeat(image_padding_len, 1)
|
||||
)
|
||||
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
|
||||
all_image_pos_ids.append(image_padded_pos_ids)
|
||||
# pad mask
|
||||
all_image_pad_mask.append(
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
|
||||
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
)
|
||||
# padded feature
|
||||
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
|
||||
all_image_out.append(image_padded_feat)
|
||||
|
||||
return (
|
||||
all_image_out,
|
||||
all_image_size,
|
||||
all_image_pos_ids,
|
||||
all_image_pad_mask,
|
||||
)
|
||||
|
||||
def _prepare_sequence(
|
||||
self,
|
||||
@@ -696,7 +626,7 @@ class ZImageDiT(nn.Module):
|
||||
|
||||
# Pad token
|
||||
feats_cat = torch.cat(feats, dim=0)
|
||||
feats_cat[torch.cat(inner_pad_mask)] = pad_token.to(dtype=feats_cat.dtype, device=feats_cat.device)
|
||||
feats_cat[torch.cat(inner_pad_mask)] = pad_token
|
||||
feats = list(feats_cat.split(item_seqlens, dim=0))
|
||||
|
||||
# RoPE
|
||||
|
||||
@@ -1,189 +0,0 @@
|
||||
import torch
|
||||
from .qwen_image_image2lora import ImageEmbeddingToLoraMatrix, SequencialMLP
|
||||
|
||||
|
||||
class LoRATrainerBlock(torch.nn.Module):
|
||||
def __init__(self, lora_patterns, in_dim=1536+4096, compress_dim=128, rank=4, block_id=0, use_residual=True, residual_length=64+7, residual_dim=3584, residual_mid_dim=1024, prefix="transformer_blocks"):
|
||||
super().__init__()
|
||||
self.prefix = prefix
|
||||
self.lora_patterns = lora_patterns
|
||||
self.block_id = block_id
|
||||
self.layers = []
|
||||
for name, lora_a_dim, lora_b_dim in self.lora_patterns:
|
||||
self.layers.append(ImageEmbeddingToLoraMatrix(in_dim, compress_dim, lora_a_dim, lora_b_dim, rank))
|
||||
self.layers = torch.nn.ModuleList(self.layers)
|
||||
if use_residual:
|
||||
self.proj_residual = SequencialMLP(residual_length, residual_dim, residual_mid_dim, compress_dim)
|
||||
else:
|
||||
self.proj_residual = None
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
lora = {}
|
||||
if self.proj_residual is not None: residual = self.proj_residual(residual)
|
||||
for lora_pattern, layer in zip(self.lora_patterns, self.layers):
|
||||
name = lora_pattern[0]
|
||||
lora_a, lora_b = layer(x, residual=residual)
|
||||
lora[f"{self.prefix}.{self.block_id}.{name}.lora_A.default.weight"] = lora_a
|
||||
lora[f"{self.prefix}.{self.block_id}.{name}.lora_B.default.weight"] = lora_b
|
||||
return lora
|
||||
|
||||
|
||||
class ZImageImage2LoRAComponent(torch.nn.Module):
|
||||
def __init__(self, lora_patterns, prefix, num_blocks=60, use_residual=True, compress_dim=128, rank=4, residual_length=64+7, residual_mid_dim=1024):
|
||||
super().__init__()
|
||||
self.lora_patterns = lora_patterns
|
||||
self.num_blocks = num_blocks
|
||||
self.blocks = []
|
||||
for lora_patterns in self.lora_patterns:
|
||||
for block_id in range(self.num_blocks):
|
||||
self.blocks.append(LoRATrainerBlock(lora_patterns, block_id=block_id, use_residual=use_residual, compress_dim=compress_dim, rank=rank, residual_length=residual_length, residual_mid_dim=residual_mid_dim, prefix=prefix))
|
||||
self.blocks = torch.nn.ModuleList(self.blocks)
|
||||
self.residual_scale = 0.05
|
||||
self.use_residual = use_residual
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
if residual is not None:
|
||||
if self.use_residual:
|
||||
residual = residual * self.residual_scale
|
||||
else:
|
||||
residual = None
|
||||
lora = {}
|
||||
for block in self.blocks:
|
||||
lora.update(block(x, residual))
|
||||
return lora
|
||||
|
||||
|
||||
class ZImageImage2LoRAModel(torch.nn.Module):
|
||||
def __init__(self, use_residual=False, compress_dim=64, rank=4, residual_length=64+7, residual_mid_dim=1024):
|
||||
super().__init__()
|
||||
lora_patterns = [
|
||||
[
|
||||
("attention.to_q", 3840, 3840),
|
||||
("attention.to_k", 3840, 3840),
|
||||
("attention.to_v", 3840, 3840),
|
||||
("attention.to_out.0", 3840, 3840),
|
||||
],
|
||||
[
|
||||
("feed_forward.w1", 3840, 10240),
|
||||
("feed_forward.w2", 10240, 3840),
|
||||
("feed_forward.w3", 3840, 10240),
|
||||
],
|
||||
]
|
||||
config = {
|
||||
"lora_patterns": lora_patterns,
|
||||
"use_residual": use_residual,
|
||||
"compress_dim": compress_dim,
|
||||
"rank": rank,
|
||||
"residual_length": residual_length,
|
||||
"residual_mid_dim": residual_mid_dim,
|
||||
}
|
||||
self.layers_lora = ZImageImage2LoRAComponent(
|
||||
prefix="layers",
|
||||
num_blocks=30,
|
||||
**config,
|
||||
)
|
||||
self.context_refiner_lora = ZImageImage2LoRAComponent(
|
||||
prefix="context_refiner",
|
||||
num_blocks=2,
|
||||
**config,
|
||||
)
|
||||
self.noise_refiner_lora = ZImageImage2LoRAComponent(
|
||||
prefix="noise_refiner",
|
||||
num_blocks=2,
|
||||
**config,
|
||||
)
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
lora = {}
|
||||
lora.update(self.layers_lora(x, residual=residual))
|
||||
lora.update(self.context_refiner_lora(x, residual=residual))
|
||||
lora.update(self.noise_refiner_lora(x, residual=residual))
|
||||
return lora
|
||||
|
||||
def initialize_weights(self):
|
||||
state_dict = self.state_dict()
|
||||
for name in state_dict:
|
||||
if ".proj_a." in name:
|
||||
state_dict[name] = state_dict[name] * 0.3
|
||||
elif ".proj_b.proj_out." in name:
|
||||
state_dict[name] = state_dict[name] * 0
|
||||
elif ".proj_residual.proj_out." in name:
|
||||
state_dict[name] = state_dict[name] * 0.3
|
||||
self.load_state_dict(state_dict)
|
||||
|
||||
|
||||
class ImageEmb2LoRAWeightCompressed(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, emb_dim, rank):
|
||||
super().__init__()
|
||||
self.lora_a = torch.nn.Parameter(torch.randn((rank, in_dim)))
|
||||
self.lora_b = torch.nn.Parameter(torch.randn((out_dim, rank)))
|
||||
self.proj = torch.nn.Linear(emb_dim, rank * rank, bias=True)
|
||||
self.rank = rank
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x).view(self.rank, self.rank)
|
||||
lora_a = x @ self.lora_a
|
||||
lora_b = self.lora_b
|
||||
return lora_a, lora_b
|
||||
|
||||
|
||||
class ZImageImage2LoRAModelCompressed(torch.nn.Module):
|
||||
def __init__(self, emb_dim=1536+4096, rank=32):
|
||||
super().__init__()
|
||||
target_layers = [
|
||||
("attention.to_q", 3840, 3840),
|
||||
("attention.to_k", 3840, 3840),
|
||||
("attention.to_v", 3840, 3840),
|
||||
("attention.to_out.0", 3840, 3840),
|
||||
("feed_forward.w1", 3840, 10240),
|
||||
("feed_forward.w2", 10240, 3840),
|
||||
("feed_forward.w3", 3840, 10240),
|
||||
]
|
||||
self.lora_patterns = [
|
||||
{
|
||||
"prefix": "layers",
|
||||
"num_layers": 30,
|
||||
"target_layers": target_layers,
|
||||
},
|
||||
{
|
||||
"prefix": "context_refiner",
|
||||
"num_layers": 2,
|
||||
"target_layers": target_layers,
|
||||
},
|
||||
{
|
||||
"prefix": "noise_refiner",
|
||||
"num_layers": 2,
|
||||
"target_layers": target_layers,
|
||||
},
|
||||
]
|
||||
module_dict = {}
|
||||
for lora_pattern in self.lora_patterns:
|
||||
prefix, num_layers, target_layers = lora_pattern["prefix"], lora_pattern["num_layers"], lora_pattern["target_layers"]
|
||||
for layer_id in range(num_layers):
|
||||
for layer_name, in_dim, out_dim in target_layers:
|
||||
name = f"{prefix}.{layer_id}.{layer_name}".replace(".", "___")
|
||||
model = ImageEmb2LoRAWeightCompressed(in_dim, out_dim, emb_dim, rank)
|
||||
module_dict[name] = model
|
||||
self.module_dict = torch.nn.ModuleDict(module_dict)
|
||||
|
||||
def forward(self, x, residual=None):
|
||||
lora = {}
|
||||
for name, module in self.module_dict.items():
|
||||
name = name.replace("___", ".")
|
||||
name_a, name_b = f"{name}.lora_A.default.weight", f"{name}.lora_B.default.weight"
|
||||
lora_a, lora_b = module(x)
|
||||
lora[name_a] = lora_a
|
||||
lora[name_b] = lora_b
|
||||
return lora
|
||||
|
||||
def initialize_weights(self):
|
||||
state_dict = self.state_dict()
|
||||
for name in state_dict:
|
||||
if "lora_b" in name:
|
||||
state_dict[name] = state_dict[name] * 0
|
||||
elif "lora_a" in name:
|
||||
state_dict[name] = state_dict[name] * 0.2
|
||||
elif "proj.weight" in name:
|
||||
print(name)
|
||||
state_dict[name] = state_dict[name] * 0.2
|
||||
self.load_state_dict(state_dict)
|
||||
@@ -3,101 +3,38 @@ import torch
|
||||
|
||||
|
||||
class ZImageTextEncoder(torch.nn.Module):
|
||||
def __init__(self, model_size="4B"):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
config_dict = {
|
||||
"0.6B": Qwen3Config(**{
|
||||
"architectures": [
|
||||
"Qwen3ForCausalLM"
|
||||
],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 1024,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"max_position_embeddings": 40960,
|
||||
"max_window_layers": 28,
|
||||
"model_type": "qwen3",
|
||||
"num_attention_heads": 16,
|
||||
"num_hidden_layers": 28,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": None,
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": None,
|
||||
"tie_word_embeddings": True,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.51.0",
|
||||
"use_cache": True,
|
||||
"use_sliding_window": False,
|
||||
"vocab_size": 151936
|
||||
}),
|
||||
"4B": Qwen3Config(**{
|
||||
"architectures": [
|
||||
"Qwen3ForCausalLM"
|
||||
],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 2560,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 9728,
|
||||
"max_position_embeddings": 40960,
|
||||
"max_window_layers": 36,
|
||||
"model_type": "qwen3",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 36,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": None,
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": None,
|
||||
"tie_word_embeddings": True,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.51.0",
|
||||
"use_cache": True,
|
||||
"use_sliding_window": False,
|
||||
"vocab_size": 151936
|
||||
}),
|
||||
"8B": Qwen3Config(**{
|
||||
"architectures": [
|
||||
"Qwen3ForCausalLM"
|
||||
],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"dtype": "bfloat16",
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 4096,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 12288,
|
||||
"max_position_embeddings": 40960,
|
||||
"max_window_layers": 36,
|
||||
"model_type": "qwen3",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 36,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": None,
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": None,
|
||||
"tie_word_embeddings": False,
|
||||
"transformers_version": "4.56.1",
|
||||
"use_cache": True,
|
||||
"use_sliding_window": False,
|
||||
"vocab_size": 151936
|
||||
})
|
||||
}
|
||||
config = config_dict[model_size]
|
||||
config = Qwen3Config(**{
|
||||
"architectures": [
|
||||
"Qwen3ForCausalLM"
|
||||
],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 151643,
|
||||
"eos_token_id": 151645,
|
||||
"head_dim": 128,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 2560,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 9728,
|
||||
"max_position_embeddings": 40960,
|
||||
"max_window_layers": 36,
|
||||
"model_type": "qwen3",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 36,
|
||||
"num_key_value_heads": 8,
|
||||
"rms_norm_eps": 1e-06,
|
||||
"rope_scaling": None,
|
||||
"rope_theta": 1000000,
|
||||
"sliding_window": None,
|
||||
"tie_word_embeddings": True,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.51.0",
|
||||
"use_cache": True,
|
||||
"use_sliding_window": False,
|
||||
"vocab_size": 151936
|
||||
})
|
||||
self.model = Qwen3Model(config)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import torch, math, torchvision
|
||||
import torch, math
|
||||
from PIL import Image
|
||||
from typing import Union
|
||||
from tqdm import tqdm
|
||||
@@ -6,28 +6,25 @@ from einops import rearrange
|
||||
import numpy as np
|
||||
from typing import Union, List, Optional, Tuple
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
from transformers import AutoProcessor
|
||||
from ..models.flux2_text_encoder import Flux2TextEncoder
|
||||
from ..models.flux2_dit import Flux2DiT
|
||||
from ..models.flux2_vae import Flux2VAE
|
||||
from ..models.z_image_text_encoder import ZImageTextEncoder
|
||||
|
||||
|
||||
class Flux2ImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler("FLUX.2")
|
||||
self.text_encoder: Flux2TextEncoder = None
|
||||
self.text_encoder_qwen3: ZImageTextEncoder = None
|
||||
self.dit: Flux2DiT = None
|
||||
self.vae: Flux2VAE = None
|
||||
self.tokenizer: AutoProcessor = None
|
||||
@@ -35,10 +32,8 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
self.units = [
|
||||
Flux2Unit_ShapeChecker(),
|
||||
Flux2Unit_PromptEmbedder(),
|
||||
Flux2Unit_Qwen3PromptEmbedder(),
|
||||
Flux2Unit_NoiseInitializer(),
|
||||
Flux2Unit_InputImageEmbedder(),
|
||||
Flux2Unit_EditImageEmbedder(),
|
||||
Flux2Unit_ImageIDs(),
|
||||
]
|
||||
self.model_fn = model_fn_flux2
|
||||
@@ -47,7 +42,7 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
|
||||
vram_limit: float = None,
|
||||
@@ -58,12 +53,11 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("flux2_text_encoder")
|
||||
pipe.text_encoder_qwen3 = model_pool.fetch_model("z_image_text_encoder")
|
||||
pipe.dit = model_pool.fetch_model("flux2_dit")
|
||||
pipe.vae = model_pool.fetch_model("flux2_vae")
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
pipe.tokenizer = AutoProcessor.from_pretrained(tokenizer_config.path)
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
@@ -81,9 +75,6 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
# Image
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Edit
|
||||
edit_image: Union[Image.Image, List[Image.Image]] = None,
|
||||
edit_image_auto_resize: bool = True,
|
||||
# Shape
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
@@ -107,7 +98,6 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale, "embedded_guidance": embedded_guidance,
|
||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
@@ -285,10 +275,6 @@ class Flux2Unit_PromptEmbedder(PipelineUnit):
|
||||
return prompt_embeds, text_ids
|
||||
|
||||
def process(self, pipe: Flux2ImagePipeline, prompt):
|
||||
# Skip if Qwen3 text encoder is available (handled by Qwen3PromptEmbedder)
|
||||
if pipe.text_encoder_qwen3 is not None:
|
||||
return {}
|
||||
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
prompt_embeds, text_ids = self.encode_prompt(
|
||||
pipe.text_encoder, pipe.tokenizer, prompt,
|
||||
@@ -297,136 +283,6 @@ class Flux2Unit_PromptEmbedder(PipelineUnit):
|
||||
return {"prompt_embeds": prompt_embeds, "text_ids": text_ids}
|
||||
|
||||
|
||||
class Flux2Unit_Qwen3PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("prompt_emb", "prompt_emb_mask"),
|
||||
onload_model_names=("text_encoder_qwen3",)
|
||||
)
|
||||
self.hidden_states_layers = (9, 18, 27) # Qwen3 layers
|
||||
|
||||
def get_qwen3_prompt_embeds(
|
||||
self,
|
||||
text_encoder: ZImageTextEncoder,
|
||||
tokenizer: AutoTokenizer,
|
||||
prompt: Union[str, List[str]],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
dtype = text_encoder.dtype if dtype is None else dtype
|
||||
device = text_encoder.device if device is None else device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
all_input_ids = []
|
||||
all_attention_masks = []
|
||||
|
||||
for single_prompt in prompt:
|
||||
messages = [{"role": "user", "content": single_prompt}]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_sequence_length,
|
||||
)
|
||||
|
||||
all_input_ids.append(inputs["input_ids"])
|
||||
all_attention_masks.append(inputs["attention_mask"])
|
||||
|
||||
input_ids = torch.cat(all_input_ids, dim=0).to(device)
|
||||
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
||||
|
||||
# Forward pass through the model
|
||||
with torch.inference_mode():
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# Only use outputs from intermediate layers and stack them
|
||||
out = torch.stack([output.hidden_states[k] for k in self.hidden_states_layers], dim=1)
|
||||
out = out.to(dtype=dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
return prompt_embeds
|
||||
|
||||
def prepare_text_ids(
|
||||
self,
|
||||
x: torch.Tensor, # (B, L, D) or (L, D)
|
||||
t_coord: Optional[torch.Tensor] = None,
|
||||
):
|
||||
B, L, _ = x.shape
|
||||
out_ids = []
|
||||
|
||||
for i in range(B):
|
||||
t = torch.arange(1) if t_coord is None else t_coord[i]
|
||||
h = torch.arange(1)
|
||||
w = torch.arange(1)
|
||||
l = torch.arange(L)
|
||||
|
||||
coords = torch.cartesian_prod(t, h, w, l)
|
||||
out_ids.append(coords)
|
||||
|
||||
return torch.stack(out_ids)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
text_encoder: ZImageTextEncoder,
|
||||
tokenizer: AutoTokenizer,
|
||||
prompt: Union[str, List[str]],
|
||||
dtype = None,
|
||||
device: Optional[torch.device] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self.get_qwen3_prompt_embeds(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
prompt=prompt,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
batch_size, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
text_ids = self.prepare_text_ids(prompt_embeds)
|
||||
text_ids = text_ids.to(device)
|
||||
return prompt_embeds, text_ids
|
||||
|
||||
def process(self, pipe: Flux2ImagePipeline, prompt):
|
||||
# Check if Qwen3 text encoder is available
|
||||
if pipe.text_encoder_qwen3 is None:
|
||||
return {}
|
||||
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
prompt_embeds, text_ids = self.encode_prompt(
|
||||
pipe.text_encoder_qwen3, pipe.tokenizer, prompt,
|
||||
dtype=pipe.torch_dtype, device=pipe.device,
|
||||
)
|
||||
return {"prompt_embeds": prompt_embeds, "text_ids": text_ids}
|
||||
|
||||
|
||||
class Flux2Unit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -462,75 +318,6 @@ class Flux2Unit_InputImageEmbedder(PipelineUnit):
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
|
||||
|
||||
class Flux2Unit_EditImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("edit_image", "edit_image_auto_resize"),
|
||||
output_params=("edit_latents", "edit_image_ids"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def calculate_dimensions(self, target_area, ratio):
|
||||
import math
|
||||
width = math.sqrt(target_area * ratio)
|
||||
height = width / ratio
|
||||
width = round(width / 32) * 32
|
||||
height = round(height / 32) * 32
|
||||
return width, height
|
||||
|
||||
def crop_and_resize(self, image, target_height, target_width):
|
||||
width, height = image.size
|
||||
scale = max(target_width / width, target_height / height)
|
||||
image = torchvision.transforms.functional.resize(
|
||||
image,
|
||||
(round(height*scale), round(width*scale)),
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||
)
|
||||
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
|
||||
return image
|
||||
|
||||
def edit_image_auto_resize(self, edit_image):
|
||||
calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1])
|
||||
return self.crop_and_resize(edit_image, calculated_height, calculated_width)
|
||||
|
||||
def process_image_ids(self, image_latents, scale=10):
|
||||
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
|
||||
t_coords = [t.view(-1) for t in t_coords]
|
||||
|
||||
image_latent_ids = []
|
||||
for x, t in zip(image_latents, t_coords):
|
||||
x = x.squeeze(0)
|
||||
_, height, width = x.shape
|
||||
|
||||
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
|
||||
image_latent_ids.append(x_ids)
|
||||
|
||||
image_latent_ids = torch.cat(image_latent_ids, dim=0)
|
||||
image_latent_ids = image_latent_ids.unsqueeze(0)
|
||||
|
||||
return image_latent_ids
|
||||
|
||||
def process(self, pipe: Flux2ImagePipeline, edit_image, edit_image_auto_resize):
|
||||
if edit_image is None:
|
||||
return {}
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
if isinstance(edit_image, Image.Image):
|
||||
edit_image = [edit_image]
|
||||
resized_edit_image, edit_latents = [], []
|
||||
for image in edit_image:
|
||||
# Preprocess
|
||||
if edit_image_auto_resize is None or edit_image_auto_resize:
|
||||
image = self.edit_image_auto_resize(image)
|
||||
resized_edit_image.append(image)
|
||||
# Encode
|
||||
image = pipe.preprocess_image(image)
|
||||
latents = pipe.vae.encode(image)
|
||||
edit_latents.append(latents)
|
||||
edit_image_ids = self.process_image_ids(edit_latents).to(pipe.device)
|
||||
edit_latents = torch.concat([rearrange(latents, "B C H W -> B (H W) C") for latents in edit_latents], dim=1)
|
||||
return {"edit_latents": edit_latents, "edit_image_ids": edit_image_ids}
|
||||
|
||||
|
||||
class Flux2Unit_ImageIDs(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
@@ -565,17 +352,10 @@ def model_fn_flux2(
|
||||
prompt_embeds=None,
|
||||
text_ids=None,
|
||||
image_ids=None,
|
||||
edit_latents=None,
|
||||
edit_image_ids=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
image_seq_len = latents.shape[1]
|
||||
if edit_latents is not None:
|
||||
image_seq_len = latents.shape[1]
|
||||
latents = torch.concat([latents, edit_latents], dim=1)
|
||||
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
|
||||
embedded_guidance = torch.tensor([embedded_guidance], device=latents.device)
|
||||
model_output = dit(
|
||||
hidden_states=latents,
|
||||
@@ -587,5 +367,4 @@ def model_fn_flux2(
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
model_output = model_output[:, :image_seq_len]
|
||||
return model_output
|
||||
|
||||
@@ -6,7 +6,6 @@ from einops import rearrange, repeat
|
||||
import numpy as np
|
||||
from transformers import CLIPTokenizer, T5TokenizerFast
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward, load_state_dict
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||
@@ -56,7 +55,7 @@ class MultiControlNet(torch.nn.Module):
|
||||
|
||||
class FluxImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
@@ -118,7 +117,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_1_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer/"),
|
||||
tokenizer_2_config: ModelConfig = ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="tokenizer_2/"),
|
||||
@@ -378,7 +377,7 @@ class FluxImageUnit_PromptEmbedder(PipelineUnit):
|
||||
text_encoder_2,
|
||||
prompt,
|
||||
positive=True,
|
||||
device=get_device_type(),
|
||||
device="cuda",
|
||||
t5_sequence_length=512,
|
||||
):
|
||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
|
||||
@@ -559,7 +558,7 @@ class FluxImageUnit_EntityControl(PipelineUnit):
|
||||
text_encoder_2,
|
||||
prompt,
|
||||
positive=True,
|
||||
device=get_device_type(),
|
||||
device="cuda",
|
||||
t5_sequence_length=512,
|
||||
):
|
||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, text_encoder_1, tokenizer_1, 77, device)
|
||||
@@ -794,7 +793,7 @@ class FluxImageUnit_ValueControl(PipelineUnit):
|
||||
|
||||
|
||||
class InfinitYou(torch.nn.Module):
|
||||
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__()
|
||||
from facexlib.recognition import init_recognition_model
|
||||
from insightface.app import FaceAnalysis
|
||||
|
||||
@@ -1,550 +0,0 @@
|
||||
import torch, types
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from einops import repeat
|
||||
from typing import Optional, Union
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from typing import Optional
|
||||
from transformers import AutoImageProcessor, Gemma3Processor
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
|
||||
from ..models.ltx2_text_encoder import LTX2TextEncoder, LTX2TextEncoderPostModules, LTXVGemmaTokenizer
|
||||
from ..models.ltx2_dit import LTXModel
|
||||
from ..models.ltx2_video_vae import LTX2VideoEncoder, LTX2VideoDecoder, VideoLatentPatchifier
|
||||
from ..models.ltx2_audio_vae import LTX2AudioEncoder, LTX2AudioDecoder, LTX2Vocoder, AudioPatchifier
|
||||
from ..models.ltx2_upsampler import LTX2LatentUpsampler
|
||||
from ..models.ltx2_common import VideoLatentShape, AudioLatentShape, VideoPixelShape, get_pixel_coords, VIDEO_SCALE_FACTORS
|
||||
from ..utils.data.media_io_ltx2 import ltx2_preprocess
|
||||
|
||||
|
||||
class LTX2AudioVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device,
|
||||
torch_dtype=torch_dtype,
|
||||
height_division_factor=32,
|
||||
width_division_factor=32,
|
||||
time_division_factor=8,
|
||||
time_division_remainder=1,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler("LTX-2")
|
||||
self.text_encoder: LTX2TextEncoder = None
|
||||
self.tokenizer: LTXVGemmaTokenizer = None
|
||||
self.processor: Gemma3Processor = None
|
||||
self.text_encoder_post_modules: LTX2TextEncoderPostModules = None
|
||||
self.dit: LTXModel = None
|
||||
self.video_vae_encoder: LTX2VideoEncoder = None
|
||||
self.video_vae_decoder: LTX2VideoDecoder = None
|
||||
self.audio_vae_encoder: LTX2AudioEncoder = None
|
||||
self.audio_vae_decoder: LTX2AudioDecoder = None
|
||||
self.audio_vocoder: LTX2Vocoder = None
|
||||
self.upsampler: LTX2LatentUpsampler = None
|
||||
|
||||
self.video_patchifier: VideoLatentPatchifier = VideoLatentPatchifier(patch_size=1)
|
||||
self.audio_patchifier: AudioPatchifier = AudioPatchifier(patch_size=1)
|
||||
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
LTX2AudioVideoUnit_PipelineChecker(),
|
||||
LTX2AudioVideoUnit_ShapeChecker(),
|
||||
LTX2AudioVideoUnit_PromptEmbedder(),
|
||||
LTX2AudioVideoUnit_NoiseInitializer(),
|
||||
LTX2AudioVideoUnit_InputVideoEmbedder(),
|
||||
LTX2AudioVideoUnit_InputImagesEmbedder(),
|
||||
]
|
||||
self.model_fn = model_fn_ltx2
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
stage2_lora_config: Optional[ModelConfig] = None,
|
||||
vram_limit: float = None,
|
||||
):
|
||||
# Initialize pipeline
|
||||
pipe = LTX2AudioVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("ltx2_text_encoder")
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = LTXVGemmaTokenizer(tokenizer_path=tokenizer_config.path)
|
||||
image_processor = AutoImageProcessor.from_pretrained(tokenizer_config.path, local_files_only=True)
|
||||
pipe.processor = Gemma3Processor(image_processor=image_processor, tokenizer=pipe.tokenizer.tokenizer)
|
||||
|
||||
pipe.text_encoder_post_modules = model_pool.fetch_model("ltx2_text_encoder_post_modules")
|
||||
pipe.dit = model_pool.fetch_model("ltx2_dit")
|
||||
pipe.video_vae_encoder = model_pool.fetch_model("ltx2_video_vae_encoder")
|
||||
pipe.video_vae_decoder = model_pool.fetch_model("ltx2_video_vae_decoder")
|
||||
pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder")
|
||||
pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder")
|
||||
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
|
||||
|
||||
# Stage 2
|
||||
if stage2_lora_config is not None:
|
||||
stage2_lora_config.download_if_necessary()
|
||||
pipe.stage2_lora_path = stage2_lora_config.path
|
||||
# Optional, currently not used
|
||||
# pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm):
|
||||
if inputs_shared["use_two_stage_pipeline"]:
|
||||
latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
|
||||
self.load_models_to_device('upsampler',)
|
||||
latent = self.upsampler(latent)
|
||||
latent = self.video_vae_encoder.per_channel_statistics.normalize(latent)
|
||||
self.scheduler.set_timesteps(special_case="stage2")
|
||||
inputs_shared.update({k.replace("stage2_", ""): v for k, v in inputs_shared.items() if k.startswith("stage2_")})
|
||||
denoise_mask_video = 1.0
|
||||
if inputs_shared.get("input_images", None) is not None:
|
||||
latent, denoise_mask_video, initial_latents = self.apply_input_images_to_latents(
|
||||
latent, inputs_shared.pop("input_latents"), inputs_shared["input_images_indexes"],
|
||||
inputs_shared["input_images_strength"], latent.clone())
|
||||
inputs_shared.update({"input_latents_video": initial_latents, "denoise_mask_video": denoise_mask_video})
|
||||
inputs_shared["video_latents"] = self.scheduler.sigmas[0] * denoise_mask_video * inputs_shared[
|
||||
"video_noise"] + (1 - self.scheduler.sigmas[0] * denoise_mask_video) * latent
|
||||
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (
|
||||
1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"]
|
||||
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
if not inputs_shared["use_distilled_pipeline"]:
|
||||
self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
|
||||
self.model_fn, 1.0, inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id,
|
||||
noise_pred=noise_pred_video, inpaint_mask=inputs_shared.get("denoise_mask_video", None),
|
||||
input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
|
||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
|
||||
noise_pred=noise_pred_audio, **inputs_shared)
|
||||
return inputs_shared
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = "",
|
||||
# Image-to-video
|
||||
denoising_strength: float = 1.0,
|
||||
input_images: Optional[list[Image.Image]] = None,
|
||||
input_images_indexes: Optional[list[int]] = None,
|
||||
input_images_strength: Optional[float] = 1.0,
|
||||
# Randomness
|
||||
seed: Optional[int] = None,
|
||||
rand_device: Optional[str] = "cpu",
|
||||
# Shape
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 768,
|
||||
num_frames=121,
|
||||
# Classifier-free guidance
|
||||
cfg_scale: Optional[float] = 3.0,
|
||||
cfg_merge: Optional[bool] = False,
|
||||
# Scheduler
|
||||
num_inference_steps: Optional[int] = 40,
|
||||
# VAE tiling
|
||||
tiled: Optional[bool] = True,
|
||||
tile_size_in_pixels: Optional[int] = 512,
|
||||
tile_overlap_in_pixels: Optional[int] = 128,
|
||||
tile_size_in_frames: Optional[int] = 128,
|
||||
tile_overlap_in_frames: Optional[int] = 24,
|
||||
# Special Pipelines
|
||||
use_two_stage_pipeline: Optional[bool] = False,
|
||||
use_distilled_pipeline: Optional[bool] = False,
|
||||
# progress_bar
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength,
|
||||
special_case="ditilled_stage1" if use_distilled_pipeline else None)
|
||||
# Inputs
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
}
|
||||
inputs_nega = {
|
||||
"negative_prompt": negative_prompt,
|
||||
}
|
||||
inputs_shared = {
|
||||
"input_images": input_images, "input_images_indexes": input_images_indexes, "input_images_strength": input_images_strength,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"height": height, "width": width, "num_frames": num_frames,
|
||||
"cfg_scale": cfg_scale, "cfg_merge": cfg_merge,
|
||||
"tiled": tiled, "tile_size_in_pixels": tile_size_in_pixels, "tile_overlap_in_pixels": tile_overlap_in_pixels,
|
||||
"tile_size_in_frames": tile_size_in_frames, "tile_overlap_in_frames": tile_overlap_in_frames,
|
||||
"use_two_stage_pipeline": use_two_stage_pipeline, "use_distilled_pipeline": use_distilled_pipeline,
|
||||
"video_patchifier": self.video_patchifier, "audio_patchifier": self.audio_patchifier,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# Denoise Stage 1
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id, noise_pred=noise_pred_video,
|
||||
inpaint_mask=inputs_shared.get("denoise_mask_video", None), input_latents=inputs_shared.get("input_latents_video", None), **inputs_shared)
|
||||
inputs_shared["audio_latents"] = self.step(self.scheduler, inputs_shared["audio_latents"], progress_id=progress_id,
|
||||
noise_pred=noise_pred_audio, **inputs_shared)
|
||||
|
||||
# Denoise Stage 2
|
||||
inputs_shared = self.stage2_denoise(inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['video_vae_decoder'])
|
||||
video = self.video_vae_decoder.decode(inputs_shared["video_latents"], tiled, tile_size_in_pixels,
|
||||
tile_overlap_in_pixels, tile_size_in_frames, tile_overlap_in_frames)
|
||||
video = self.vae_output_to_video(video)
|
||||
self.load_models_to_device(['audio_vae_decoder', 'audio_vocoder'])
|
||||
decoded_audio = self.audio_vae_decoder(inputs_shared["audio_latents"])
|
||||
decoded_audio = self.audio_vocoder(decoded_audio).squeeze(0).float()
|
||||
return video, decoded_audio
|
||||
|
||||
def apply_input_images_to_latents(self, latents, input_latents, input_indexes, input_strength, initial_latents=None, num_frames=121):
|
||||
b, _, f, h, w = latents.shape
|
||||
denoise_mask = torch.ones((b, 1, f, h, w), dtype=latents.dtype, device=latents.device)
|
||||
initial_latents = torch.zeros_like(latents) if initial_latents is None else initial_latents
|
||||
for idx, input_latent in zip(input_indexes, input_latents):
|
||||
idx = min(max(1 + (idx-1) // 8, 0), f - 1)
|
||||
input_latent = input_latent.to(dtype=latents.dtype, device=latents.device)
|
||||
initial_latents[:, :, idx:idx + input_latent.shape[2], :, :] = input_latent
|
||||
denoise_mask[:, :, idx:idx + input_latent.shape[2], :, :] = 1.0 - input_strength
|
||||
latents = latents * denoise_mask + initial_latents * (1.0 - denoise_mask)
|
||||
return latents, denoise_mask, initial_latents
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
input_params=("use_distilled_pipeline", "use_two_stage_pipeline"),
|
||||
output_params=("use_two_stage_pipeline", "cfg_scale")
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if inputs_shared.get("use_distilled_pipeline", False):
|
||||
inputs_shared["use_two_stage_pipeline"] = True
|
||||
inputs_shared["cfg_scale"] = 1.0
|
||||
print(f"Distilled pipeline requested, setting use_two_stage_pipeline to True, disable CFG by setting cfg_scale to 1.0.")
|
||||
if inputs_shared.get("use_two_stage_pipeline", False):
|
||||
# distill pipeline also uses two-stage, but it does not needs lora
|
||||
if not inputs_shared.get("use_distilled_pipeline", False):
|
||||
if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None):
|
||||
raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.")
|
||||
if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None):
|
||||
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_ShapeChecker(PipelineUnit):
|
||||
"""
|
||||
For two-stage pipelines, the resolution must be divisible by 64.
|
||||
For one-stage pipelines, the resolution must be divisible by 32.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "num_frames"),
|
||||
output_params=("height", "width", "num_frames"),
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, use_two_stage_pipeline=False):
|
||||
if use_two_stage_pipeline:
|
||||
self.width_division_factor = 64
|
||||
self.height_division_factor = 64
|
||||
height, width, num_frames = pipe.check_resize_height_width(height, width, num_frames)
|
||||
if use_two_stage_pipeline:
|
||||
self.width_division_factor = 32
|
||||
self.height_division_factor = 32
|
||||
return {"height": height, "width": width, "num_frames": num_frames}
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_PromptEmbedder(PipelineUnit):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("video_context", "audio_context"),
|
||||
onload_model_names=("text_encoder", "text_encoder_post_modules"),
|
||||
)
|
||||
|
||||
def _convert_to_additive_mask(self, attention_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
return (attention_mask - 1).to(dtype).reshape(
|
||||
(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(dtype).max
|
||||
|
||||
def _run_connectors(self, pipe, encoded_input: torch.Tensor,
|
||||
attention_mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
connector_attention_mask = self._convert_to_additive_mask(attention_mask, encoded_input.dtype)
|
||||
|
||||
encoded, encoded_connector_attention_mask = pipe.text_encoder_post_modules.embeddings_connector(
|
||||
encoded_input,
|
||||
connector_attention_mask,
|
||||
)
|
||||
|
||||
# restore the mask values to int64
|
||||
attention_mask = (encoded_connector_attention_mask < 0.000001).to(torch.int64)
|
||||
attention_mask = attention_mask.reshape([encoded.shape[0], encoded.shape[1], 1])
|
||||
encoded = encoded * attention_mask
|
||||
|
||||
encoded_for_audio, _ = pipe.text_encoder_post_modules.audio_embeddings_connector(
|
||||
encoded_input, connector_attention_mask)
|
||||
|
||||
return encoded, encoded_for_audio, attention_mask.squeeze(-1)
|
||||
|
||||
def _norm_and_concat_padded_batch(
|
||||
self,
|
||||
encoded_text: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
padding_side: str = "right",
|
||||
) -> torch.Tensor:
|
||||
"""Normalize and flatten multi-layer hidden states, respecting padding.
|
||||
Performs per-batch, per-layer normalization using masked mean and range,
|
||||
then concatenates across the layer dimension.
|
||||
Args:
|
||||
encoded_text: Hidden states of shape [batch, seq_len, hidden_dim, num_layers].
|
||||
sequence_lengths: Number of valid (non-padded) tokens per batch item.
|
||||
padding_side: Whether padding is on "left" or "right".
|
||||
Returns:
|
||||
Normalized tensor of shape [batch, seq_len, hidden_dim * num_layers],
|
||||
with padded positions zeroed out.
|
||||
"""
|
||||
b, t, d, l = encoded_text.shape # noqa: E741
|
||||
device = encoded_text.device
|
||||
# Build mask: [B, T, 1, 1]
|
||||
token_indices = torch.arange(t, device=device)[None, :] # [1, T]
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [B, T]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = t - sequence_lengths[:, None] # [B, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = rearrange(mask, "b t -> b t 1 1")
|
||||
eps = 1e-6
|
||||
# Compute masked mean: [B, 1, 1, L]
|
||||
masked = encoded_text.masked_fill(~mask, 0.0)
|
||||
denom = (sequence_lengths * d).view(b, 1, 1, 1)
|
||||
mean = masked.sum(dim=(1, 2), keepdim=True) / (denom + eps)
|
||||
# Compute masked min/max: [B, 1, 1, L]
|
||||
x_min = encoded_text.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = encoded_text.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
range_ = x_max - x_min
|
||||
# Normalize only the valid tokens
|
||||
normed = 8 * (encoded_text - mean) / (range_ + eps)
|
||||
# concat to be [Batch, T, D * L] - this preserves the original structure
|
||||
normed = normed.reshape(b, t, -1) # [B, T, D * L]
|
||||
# Apply mask to preserve original padding (set padded positions to 0)
|
||||
mask_flattened = rearrange(mask, "b t 1 1 -> b t 1").expand(-1, -1, d * l)
|
||||
normed = normed.masked_fill(~mask_flattened, 0.0)
|
||||
|
||||
return normed
|
||||
|
||||
def _run_feature_extractor(self,
|
||||
pipe,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
padding_side: str = "right") -> torch.Tensor:
|
||||
encoded_text_features = torch.stack(hidden_states, dim=-1)
|
||||
encoded_text_features_dtype = encoded_text_features.dtype
|
||||
sequence_lengths = attention_mask.sum(dim=-1)
|
||||
normed_concated_encoded_text_features = self._norm_and_concat_padded_batch(encoded_text_features,
|
||||
sequence_lengths,
|
||||
padding_side=padding_side)
|
||||
|
||||
return pipe.text_encoder_post_modules.feature_extractor_linear(
|
||||
normed_concated_encoded_text_features.to(encoded_text_features_dtype))
|
||||
|
||||
def _preprocess_text(
|
||||
self,
|
||||
pipe,
|
||||
text: str,
|
||||
padding_side: str = "left",
|
||||
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Encode a given string into feature tensors suitable for downstream tasks.
|
||||
Args:
|
||||
text (str): Input string to encode.
|
||||
Returns:
|
||||
tuple[torch.Tensor, dict[str, torch.Tensor]]: Encoded features and a dictionary with attention mask.
|
||||
"""
|
||||
token_pairs = pipe.tokenizer.tokenize_with_weights(text)["gemma"]
|
||||
input_ids = torch.tensor([[t[0] for t in token_pairs]], device=pipe.device)
|
||||
attention_mask = torch.tensor([[w[1] for w in token_pairs]], device=pipe.device)
|
||||
outputs = pipe.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
|
||||
projected = self._run_feature_extractor(pipe,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
padding_side=padding_side)
|
||||
return projected, attention_mask
|
||||
|
||||
def encode_prompt(self, pipe, text, padding_side="left"):
|
||||
encoded_inputs, attention_mask = self._preprocess_text(pipe, text, padding_side)
|
||||
video_encoding, audio_encoding, attention_mask = self._run_connectors(pipe, encoded_inputs, attention_mask)
|
||||
return video_encoding, audio_encoding, attention_mask
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, prompt: str):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
video_context, audio_context, _ = self.encode_prompt(pipe, prompt)
|
||||
return {"video_context": video_context, "audio_context": audio_context}
|
||||
|
||||
|
||||
class LTX2AudioVideoUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "num_frames", "seed", "rand_device", "use_two_stage_pipeline"),
|
||||
output_params=("video_noise", "audio_noise",),
|
||||
)
|
||||
|
||||
def process_stage(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0):
|
||||
video_pixel_shape = VideoPixelShape(batch=1, frames=num_frames, width=width, height=height, fps=frame_rate)
|
||||
video_latent_shape = VideoLatentShape.from_pixel_shape(shape=video_pixel_shape, latent_channels=pipe.video_vae_encoder.latent_channels)
|
||||
video_noise = pipe.generate_noise(video_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
||||
|
||||
latent_coords = pipe.video_patchifier.get_patch_grid_bounds(output_shape=video_latent_shape, device=pipe.device)
|
||||
video_positions = get_pixel_coords(latent_coords, VIDEO_SCALE_FACTORS, True).float()
|
||||
video_positions[:, 0, ...] = video_positions[:, 0, ...] / frame_rate
|
||||
video_positions = video_positions.to(pipe.torch_dtype)
|
||||
|
||||
audio_latent_shape = AudioLatentShape.from_video_pixel_shape(video_pixel_shape)
|
||||
audio_noise = pipe.generate_noise(audio_latent_shape.to_torch_shape(), seed=seed, rand_device=rand_device)
|
||||
audio_positions = pipe.audio_patchifier.get_patch_grid_bounds(audio_latent_shape, device=pipe.device)
|
||||
return {
|
||||
"video_noise": video_noise,
|
||||
"audio_noise": audio_noise,
|
||||
"video_positions": video_positions,
|
||||
"audio_positions": audio_positions,
|
||||
"video_latent_shape": video_latent_shape,
|
||||
"audio_latent_shape": audio_latent_shape
|
||||
}
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, height, width, num_frames, seed, rand_device, frame_rate=24.0, use_two_stage_pipeline=False):
|
||||
if use_two_stage_pipeline:
|
||||
stage1_dict = self.process_stage(pipe, height // 2, width // 2, num_frames, seed, rand_device, frame_rate)
|
||||
stage2_dict = self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
|
||||
initial_dict = stage1_dict
|
||||
initial_dict.update({"stage2_" + k: v for k, v in stage2_dict.items()})
|
||||
return initial_dict
|
||||
else:
|
||||
return self.process_stage(pipe, height, width, num_frames, seed, rand_device, frame_rate)
|
||||
|
||||
class LTX2AudioVideoUnit_InputVideoEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_video", "video_noise", "audio_noise", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("video_latents", "audio_latents"),
|
||||
onload_model_names=("video_vae_encoder")
|
||||
)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, input_video, video_noise, audio_noise, tiled, tile_size, tile_stride):
|
||||
if input_video is None:
|
||||
return {"video_latents": video_noise, "audio_latents": audio_noise}
|
||||
else:
|
||||
# TODO: implement video-to-video
|
||||
raise NotImplementedError("Video-to-video not implemented yet.")
|
||||
|
||||
class LTX2AudioVideoUnit_InputImagesEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_images", "input_images_indexes", "input_images_strength", "video_latents", "height", "width", "num_frames", "tiled", "tile_size_in_pixels", "tile_overlap_in_pixels", "use_two_stage_pipeline"),
|
||||
output_params=("video_latents"),
|
||||
onload_model_names=("video_vae_encoder")
|
||||
)
|
||||
|
||||
def get_image_latent(self, pipe, input_image, height, width, tiled, tile_size_in_pixels, tile_overlap_in_pixels):
|
||||
image = ltx2_preprocess(np.array(input_image.resize((width, height))))
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32)).to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
image = image / 127.5 - 1.0
|
||||
image = repeat(image, f"H W C -> B C F H W", B=1, F=1)
|
||||
latent = pipe.video_vae_encoder.encode(image, tiled, tile_size_in_pixels, tile_overlap_in_pixels).to(pipe.device)
|
||||
return latent
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, input_images, input_images_indexes, input_images_strength, video_latents, height, width, num_frames, tiled, tile_size_in_pixels, tile_overlap_in_pixels, use_two_stage_pipeline=False):
|
||||
if input_images is None or len(input_images) == 0:
|
||||
return {"video_latents": video_latents}
|
||||
else:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
output_dicts = {}
|
||||
stage1_height = height // 2 if use_two_stage_pipeline else height
|
||||
stage1_width = width // 2 if use_two_stage_pipeline else width
|
||||
stage1_latents = [
|
||||
self.get_image_latent(pipe, img, stage1_height, stage1_width, tiled, tile_size_in_pixels,
|
||||
tile_overlap_in_pixels) for img in input_images
|
||||
]
|
||||
video_latents, denoise_mask_video, initial_latents = pipe.apply_input_images_to_latents(video_latents, stage1_latents, input_images_indexes, input_images_strength, num_frames=num_frames)
|
||||
output_dicts.update({"video_latents": video_latents, "denoise_mask_video": denoise_mask_video, "input_latents_video": initial_latents})
|
||||
if use_two_stage_pipeline:
|
||||
stage2_latents = [
|
||||
self.get_image_latent(pipe, img, height, width, tiled, tile_size_in_pixels,
|
||||
tile_overlap_in_pixels) for img in input_images
|
||||
]
|
||||
output_dicts.update({"stage2_input_latents": stage2_latents})
|
||||
return output_dicts
|
||||
|
||||
|
||||
def model_fn_ltx2(
|
||||
dit: LTXModel,
|
||||
video_latents=None,
|
||||
video_context=None,
|
||||
video_positions=None,
|
||||
video_patchifier=None,
|
||||
audio_latents=None,
|
||||
audio_context=None,
|
||||
audio_positions=None,
|
||||
audio_patchifier=None,
|
||||
timestep=None,
|
||||
denoise_mask_video=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
timestep = timestep.float() / 1000.
|
||||
|
||||
# patchify
|
||||
b, c_v, f, h, w = video_latents.shape
|
||||
video_latents = video_patchifier.patchify(video_latents)
|
||||
video_timesteps = timestep.repeat(1, video_latents.shape[1], 1)
|
||||
if denoise_mask_video is not None:
|
||||
video_timesteps = video_patchifier.patchify(denoise_mask_video) * video_timesteps
|
||||
_, c_a, _, mel_bins = audio_latents.shape
|
||||
audio_latents = audio_patchifier.patchify(audio_latents)
|
||||
audio_timesteps = timestep.repeat(1, audio_latents.shape[1], 1)
|
||||
#TODO: support gradient checkpointing in training
|
||||
vx, ax = dit(
|
||||
video_latents=video_latents,
|
||||
video_positions=video_positions,
|
||||
video_context=video_context,
|
||||
video_timesteps=video_timesteps,
|
||||
audio_latents=audio_latents,
|
||||
audio_positions=audio_positions,
|
||||
audio_context=audio_context,
|
||||
audio_timesteps=audio_timesteps,
|
||||
)
|
||||
# unpatchify
|
||||
vx = video_patchifier.unpatchify_video(vx, f, h, w)
|
||||
ax = audio_patchifier.unpatchify_audio(ax, c_a, mel_bins)
|
||||
return vx, ax
|
||||
@@ -6,7 +6,6 @@ from einops import rearrange
|
||||
import numpy as np
|
||||
from math import prod
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||
@@ -23,7 +22,7 @@ from ..models.qwen_image_image2lora import QwenImageImage2LoRAModel
|
||||
|
||||
class QwenImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
@@ -61,7 +60,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
processor_config: ModelConfig = None,
|
||||
|
||||
@@ -11,7 +11,6 @@ from typing import Optional
|
||||
from typing_extensions import Literal
|
||||
from transformers import Wav2Vec2Processor
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
@@ -31,7 +30,7 @@ from ..models.longcat_video_dit import LongCatVideoTransformer3DModel
|
||||
|
||||
class WanVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
|
||||
@@ -99,7 +98,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||
audio_processor_config: ModelConfig = None,
|
||||
@@ -123,15 +122,11 @@ class WanVideoPipeline(BasePipeline):
|
||||
model_config.model_id = redirect_dict[model_config.origin_file_pattern][0]
|
||||
model_config.origin_file_pattern = redirect_dict[model_config.origin_file_pattern][1]
|
||||
|
||||
# Initialize pipeline
|
||||
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
if use_usp:
|
||||
from ..utils.xfuser import initialize_usp
|
||||
initialize_usp(device)
|
||||
import torch.distributed as dist
|
||||
from ..core.device.npu_compatible_device import get_device_name
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
device = get_device_name()
|
||||
# Initialize pipeline
|
||||
pipe = WanVideoPipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
@@ -965,7 +960,7 @@ class WanVideoUnit_AnimateInpaint(PipelineUnit):
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device=get_device_type()):
|
||||
def get_i2v_mask(self, lat_t, lat_h, lat_w, mask_len=1, mask_pixel_values=None, device="cuda"):
|
||||
if mask_pixel_values is None:
|
||||
msk = torch.zeros(1, (lat_t-1) * 4 + 1, lat_h, lat_w, device=device)
|
||||
else:
|
||||
|
||||
@@ -4,29 +4,23 @@ from typing import Union
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from typing import Union, List, Optional, Tuple, Iterable, Dict
|
||||
from typing import Union, List, Optional, Tuple, Iterable
|
||||
|
||||
from ..core.device.npu_compatible_device import get_device_type
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig, gradient_checkpoint_forward
|
||||
from ..core.data.operators import ImageCropAndResize
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit, ControlNetInput
|
||||
from ..utils.lora import merge_lora
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from ..models.z_image_text_encoder import ZImageTextEncoder
|
||||
from ..models.z_image_dit import ZImageDiT
|
||||
from ..models.flux_vae import FluxVAEEncoder, FluxVAEDecoder
|
||||
from ..models.siglip2_image_encoder import Siglip2ImageEncoder428M
|
||||
from ..models.z_image_controlnet import ZImageControlNet
|
||||
from ..models.siglip2_image_encoder import Siglip2ImageEncoder
|
||||
from ..models.dinov3_image_encoder import DINOv3ImageEncoder
|
||||
from ..models.z_image_image2lora import ZImageImage2LoRAModel
|
||||
|
||||
|
||||
class ZImagePipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
@@ -37,12 +31,8 @@ class ZImagePipeline(BasePipeline):
|
||||
self.vae_encoder: FluxVAEEncoder = None
|
||||
self.vae_decoder: FluxVAEDecoder = None
|
||||
self.image_encoder: Siglip2ImageEncoder428M = None
|
||||
self.controlnet: ZImageControlNet = None
|
||||
self.siglip2_image_encoder: Siglip2ImageEncoder = None
|
||||
self.dinov3_image_encoder: DINOv3ImageEncoder = None
|
||||
self.image2lora_style: ZImageImage2LoRAModel = None
|
||||
self.tokenizer: AutoTokenizer = None
|
||||
self.in_iteration_models = ("dit", "controlnet")
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
ZImageUnit_ShapeChecker(),
|
||||
ZImageUnit_PromptEmbedder(),
|
||||
@@ -51,7 +41,6 @@ class ZImagePipeline(BasePipeline):
|
||||
ZImageUnit_EditImageAutoResize(),
|
||||
ZImageUnit_EditImageEmbedderVAE(),
|
||||
ZImageUnit_EditImageEmbedderSiglip(),
|
||||
ZImageUnit_PAIControlNet(),
|
||||
]
|
||||
self.model_fn = model_fn_z_image
|
||||
|
||||
@@ -59,7 +48,7 @@ class ZImagePipeline(BasePipeline):
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = get_device_type(),
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||
vram_limit: float = None,
|
||||
@@ -74,10 +63,6 @@ class ZImagePipeline(BasePipeline):
|
||||
pipe.vae_encoder = model_pool.fetch_model("flux_vae_encoder")
|
||||
pipe.vae_decoder = model_pool.fetch_model("flux_vae_decoder")
|
||||
pipe.image_encoder = model_pool.fetch_model("siglip_vision_model_428m")
|
||||
pipe.controlnet = model_pool.fetch_model("z_image_controlnet")
|
||||
pipe.siglip2_image_encoder = model_pool.fetch_model("siglip2_image_encoder")
|
||||
pipe.dinov3_image_encoder = model_pool.fetch_model("dinov3_image_encoder")
|
||||
pipe.image2lora_style = model_pool.fetch_model("z_image_image2lora_style")
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
@@ -109,11 +94,6 @@ class ZImagePipeline(BasePipeline):
|
||||
# Steps
|
||||
num_inference_steps: int = 8,
|
||||
sigma_shift: float = None,
|
||||
# ControlNet
|
||||
controlnet_inputs: List[ControlNetInput] = None,
|
||||
# Image to LoRA
|
||||
image2lora_images: List[Image.Image] = None,
|
||||
positive_only_lora: Dict[str, torch.Tensor] = None,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
@@ -134,8 +114,6 @@ class ZImagePipeline(BasePipeline):
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
"edit_image": edit_image, "edit_image_auto_resize": edit_image_auto_resize,
|
||||
"controlnet_inputs": controlnet_inputs,
|
||||
"image2lora_images": image2lora_images, "positive_only_lora": positive_only_lora,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
@@ -353,9 +331,7 @@ class ZImageUnit_EditImageAutoResize(PipelineUnit):
|
||||
if edit_image_auto_resize is None or not edit_image_auto_resize:
|
||||
return {}
|
||||
operator = ImageCropAndResize(max_pixels=1024*1024, height_division_factor=16, width_division_factor=16)
|
||||
if not isinstance(edit_image, list):
|
||||
edit_image = [edit_image]
|
||||
edit_image = [operator(i) for i in edit_image]
|
||||
edit_image = operator(edit_image)
|
||||
return {"edit_image": edit_image}
|
||||
|
||||
|
||||
@@ -400,49 +376,8 @@ class ZImageUnit_EditImageEmbedderVAE(PipelineUnit):
|
||||
return {"image_latents": image_latents}
|
||||
|
||||
|
||||
class ZImageUnit_PAIControlNet(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("controlnet_inputs", "height", "width"),
|
||||
output_params=("control_context", "control_scale"),
|
||||
onload_model_names=("vae_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: ZImagePipeline, controlnet_inputs: List[ControlNetInput], height, width):
|
||||
if controlnet_inputs is None:
|
||||
return {}
|
||||
if len(controlnet_inputs) != 1:
|
||||
print("Z-Image ControlNet doesn't support multi-ControlNet. Only one image will be used.")
|
||||
controlnet_input = controlnet_inputs[0]
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
|
||||
control_image = controlnet_input.image
|
||||
if control_image is not None:
|
||||
control_image = pipe.preprocess_image(control_image)
|
||||
control_latents = pipe.vae_encoder(control_image)
|
||||
else:
|
||||
control_latents = torch.ones((1, 16, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device) * -1
|
||||
|
||||
inpaint_mask = controlnet_input.inpaint_mask
|
||||
if inpaint_mask is not None:
|
||||
inpaint_mask = pipe.preprocess_image(inpaint_mask, min_value=0, max_value=1)
|
||||
inpaint_image = controlnet_input.inpaint_image
|
||||
inpaint_image = pipe.preprocess_image(inpaint_image)
|
||||
inpaint_image = inpaint_image * (inpaint_mask < 0.5)
|
||||
inpaint_mask = torch.nn.functional.interpolate(1 - inpaint_mask, (height // 8, width // 8), mode='nearest')[:, :1]
|
||||
else:
|
||||
inpaint_mask = torch.zeros((1, 1, height // 8, width // 8), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
inpaint_image = torch.zeros((1, 3, height, width), dtype=pipe.torch_dtype, device=pipe.device)
|
||||
inpaint_latent = pipe.vae_encoder(inpaint_image)
|
||||
|
||||
control_context = torch.concat([control_latents, inpaint_mask, inpaint_latent], dim=1)
|
||||
control_context = rearrange(control_context, "B C H W -> B C 1 H W")
|
||||
return {"control_context": control_context, "control_scale": controlnet_input.scale}
|
||||
|
||||
|
||||
def model_fn_z_image(
|
||||
dit: ZImageDiT,
|
||||
controlnet: ZImageControlNet = None,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_embeds=None,
|
||||
@@ -458,14 +393,13 @@ def model_fn_z_image(
|
||||
if dit.siglip_embedder is None:
|
||||
return model_fn_z_image_turbo(
|
||||
dit,
|
||||
controlnet=controlnet,
|
||||
latents=latents,
|
||||
timestep=timestep,
|
||||
prompt_embeds=prompt_embeds,
|
||||
image_embeds=image_embeds,
|
||||
image_latents=image_latents,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
latents,
|
||||
timestep,
|
||||
prompt_embeds,
|
||||
image_embeds,
|
||||
image_latents,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
**kwargs,
|
||||
)
|
||||
latents = [rearrange(latents, "B C H W -> C B H W")]
|
||||
@@ -495,81 +429,13 @@ def model_fn_z_image(
|
||||
return model_output
|
||||
|
||||
|
||||
class ZImageUnit_Image2LoRAEncode(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("image2lora_images",),
|
||||
output_params=("image2lora_x",),
|
||||
onload_model_names=("siglip2_image_encoder", "dinov3_image_encoder",),
|
||||
)
|
||||
from ..core.data.operators import ImageCropAndResize
|
||||
self.processor_highres = ImageCropAndResize(height=1024, width=1024)
|
||||
|
||||
def encode_images_using_siglip2(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
||||
pipe.load_models_to_device(["siglip2_image_encoder"])
|
||||
embs = []
|
||||
for image in images:
|
||||
image = self.processor_highres(image)
|
||||
embs.append(pipe.siglip2_image_encoder(image).to(pipe.torch_dtype))
|
||||
embs = torch.stack(embs)
|
||||
return embs
|
||||
|
||||
def encode_images_using_dinov3(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
||||
pipe.load_models_to_device(["dinov3_image_encoder"])
|
||||
embs = []
|
||||
for image in images:
|
||||
image = self.processor_highres(image)
|
||||
embs.append(pipe.dinov3_image_encoder(image).to(pipe.torch_dtype))
|
||||
embs = torch.stack(embs)
|
||||
return embs
|
||||
|
||||
def encode_images(self, pipe: ZImagePipeline, images: list[Image.Image]):
|
||||
if images is None:
|
||||
return {}
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
embs_siglip2 = self.encode_images_using_siglip2(pipe, images)
|
||||
embs_dinov3 = self.encode_images_using_dinov3(pipe, images)
|
||||
x = torch.concat([embs_siglip2, embs_dinov3], dim=-1)
|
||||
return x
|
||||
|
||||
def process(self, pipe: ZImagePipeline, image2lora_images):
|
||||
if image2lora_images is None:
|
||||
return {}
|
||||
x = self.encode_images(pipe, image2lora_images)
|
||||
return {"image2lora_x": x}
|
||||
|
||||
|
||||
class ZImageUnit_Image2LoRADecode(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("image2lora_x",),
|
||||
output_params=("lora",),
|
||||
onload_model_names=("image2lora_style",),
|
||||
)
|
||||
|
||||
def process(self, pipe: ZImagePipeline, image2lora_x):
|
||||
if image2lora_x is None:
|
||||
return {}
|
||||
loras = []
|
||||
if pipe.image2lora_style is not None:
|
||||
pipe.load_models_to_device(["image2lora_style"])
|
||||
for x in image2lora_x:
|
||||
loras.append(pipe.image2lora_style(x=x, residual=None))
|
||||
lora = merge_lora(loras, alpha=1 / len(image2lora_x))
|
||||
return {"lora": lora}
|
||||
|
||||
|
||||
def model_fn_z_image_turbo(
|
||||
dit: ZImageDiT,
|
||||
controlnet: ZImageControlNet = None,
|
||||
latents=None,
|
||||
timestep=None,
|
||||
prompt_embeds=None,
|
||||
image_embeds=None,
|
||||
image_latents=None,
|
||||
control_context=None,
|
||||
control_scale=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
@@ -594,19 +460,11 @@ def model_fn_z_image_turbo(
|
||||
|
||||
# Noise refine
|
||||
x = dit.all_x_embedder["2-1"](x)
|
||||
x[torch.cat(patch_metadata.get("x_pad_mask"))] = dit.x_pad_token.to(dtype=x.dtype, device=x.device)
|
||||
x_freqs_cis = dit.rope_embedder(torch.cat(patch_metadata.get("x_pos_ids"), dim=0))
|
||||
x = rearrange(x, "L C -> 1 L C")
|
||||
x_freqs_cis = rearrange(x_freqs_cis, "L C -> 1 L C")
|
||||
|
||||
if control_context is not None:
|
||||
kwargs = dict(attn_mask=None, freqs_cis=x_freqs_cis, adaln_input=t_noisy)
|
||||
refiner_hints, control_context, control_context_item_seqlens = controlnet.forward_refiner(
|
||||
dit, x, [cap_feats], control_context, kwargs, t=t_noisy, patch_size=2, f_patch_size=1,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
|
||||
for layer_id, layer in enumerate(dit.noise_refiner):
|
||||
for layer in dit.noise_refiner:
|
||||
x = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
@@ -616,8 +474,6 @@ def model_fn_z_image_turbo(
|
||||
freqs_cis=x_freqs_cis,
|
||||
adaln_input=t_noisy,
|
||||
)
|
||||
if control_context is not None:
|
||||
x = x + refiner_hints[layer_id] * control_scale
|
||||
|
||||
# Prompt refine
|
||||
cap_feats = dit.cap_embedder(cap_feats)
|
||||
@@ -639,15 +495,7 @@ def model_fn_z_image_turbo(
|
||||
# Unified
|
||||
unified = torch.cat([x, cap_feats], dim=1)
|
||||
unified_freqs_cis = torch.cat([x_freqs_cis, cap_freqs_cis], dim=1)
|
||||
|
||||
if control_context is not None:
|
||||
kwargs = dict(attn_mask=None, freqs_cis=unified_freqs_cis, adaln_input=t_noisy)
|
||||
hints = controlnet.forward_layers(
|
||||
unified, cap_feats, control_context, control_context_item_seqlens, kwargs,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing, use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
|
||||
for layer_id, layer in enumerate(dit.layers):
|
||||
for layer in dit.layers:
|
||||
unified = gradient_checkpoint_forward(
|
||||
layer,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
@@ -657,9 +505,6 @@ def model_fn_z_image_turbo(
|
||||
freqs_cis=unified_freqs_cis,
|
||||
adaln_input=t_noisy,
|
||||
)
|
||||
if control_context is not None:
|
||||
if layer_id in controlnet.control_layers_mapping:
|
||||
unified = unified + hints[controlnet.control_layers_mapping[layer_id]] * control_scale
|
||||
|
||||
# Output
|
||||
unified = dit.all_final_layer["2-1"](unified, t_noisy)
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
|
||||
from diffsynth.core.device.npu_compatible_device import get_device_type
|
||||
|
||||
Processor_id: TypeAlias = Literal[
|
||||
"canny", "depth", "softedge", "lineart", "lineart_anime", "openpose", "normal", "tile", "none", "inpaint"
|
||||
]
|
||||
|
||||
class Annotator:
|
||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device=get_device_type(), skip_processor=False):
|
||||
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda', skip_processor=False):
|
||||
if not skip_processor:
|
||||
if processor_id == "canny":
|
||||
from controlnet_aux.processor import CannyDetector
|
||||
|
||||
@@ -9,6 +9,5 @@ class ControlNetInput:
|
||||
start: float = 1.0
|
||||
end: float = 0.0
|
||||
image: Image.Image = None
|
||||
inpaint_image: Image.Image = None
|
||||
inpaint_mask: Image.Image = None
|
||||
processor_id: str = None
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
|
||||
from fractions import Fraction
|
||||
import torch
|
||||
import av
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from io import BytesIO
|
||||
from collections.abc import Generator, Iterator
|
||||
|
||||
|
||||
def _resample_audio(
|
||||
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
|
||||
) -> None:
|
||||
cc = audio_stream.codec_context
|
||||
|
||||
# Use the encoder's format/layout/rate as the *target*
|
||||
target_format = cc.format or "fltp" # AAC → usually fltp
|
||||
target_layout = cc.layout or "stereo"
|
||||
target_rate = cc.sample_rate or frame_in.sample_rate
|
||||
|
||||
audio_resampler = av.audio.resampler.AudioResampler(
|
||||
format=target_format,
|
||||
layout=target_layout,
|
||||
rate=target_rate,
|
||||
)
|
||||
|
||||
audio_next_pts = 0
|
||||
for rframe in audio_resampler.resample(frame_in):
|
||||
if rframe.pts is None:
|
||||
rframe.pts = audio_next_pts
|
||||
audio_next_pts += rframe.samples
|
||||
rframe.sample_rate = frame_in.sample_rate
|
||||
container.mux(audio_stream.encode(rframe))
|
||||
|
||||
# flush audio encoder
|
||||
for packet in audio_stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
|
||||
def _write_audio(
|
||||
container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
|
||||
) -> None:
|
||||
if samples.ndim == 1:
|
||||
samples = samples[:, None]
|
||||
|
||||
if samples.shape[1] != 2 and samples.shape[0] == 2:
|
||||
samples = samples.T
|
||||
|
||||
if samples.shape[1] != 2:
|
||||
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
|
||||
|
||||
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
|
||||
if samples.dtype != torch.int16:
|
||||
samples = torch.clip(samples, -1.0, 1.0)
|
||||
samples = (samples * 32767.0).to(torch.int16)
|
||||
|
||||
frame_in = av.AudioFrame.from_ndarray(
|
||||
samples.contiguous().reshape(1, -1).cpu().numpy(),
|
||||
format="s16",
|
||||
layout="stereo",
|
||||
)
|
||||
frame_in.sample_rate = audio_sample_rate
|
||||
|
||||
_resample_audio(container, audio_stream, frame_in)
|
||||
|
||||
|
||||
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
|
||||
"""
|
||||
Prepare the audio stream for writing.
|
||||
"""
|
||||
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
|
||||
audio_stream.codec_context.sample_rate = audio_sample_rate
|
||||
audio_stream.codec_context.layout = "stereo"
|
||||
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
|
||||
return audio_stream
|
||||
|
||||
def write_video_audio_ltx2(
|
||||
video: list[Image.Image],
|
||||
audio: torch.Tensor | None,
|
||||
output_path: str,
|
||||
fps: int = 24,
|
||||
audio_sample_rate: int | None = 24000,
|
||||
) -> None:
|
||||
|
||||
width, height = video[0].size
|
||||
container = av.open(output_path, mode="w")
|
||||
stream = container.add_stream("libx264", rate=int(fps))
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
if audio is not None:
|
||||
if audio_sample_rate is None:
|
||||
raise ValueError("audio_sample_rate is required when audio is provided")
|
||||
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
|
||||
|
||||
for frame in tqdm(video, total=len(video)):
|
||||
frame = av.VideoFrame.from_image(frame)
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
|
||||
# Flush encoder
|
||||
for packet in stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
if audio is not None:
|
||||
_write_audio(container, audio_stream, audio, audio_sample_rate)
|
||||
|
||||
container.close()
|
||||
|
||||
|
||||
def encode_single_frame(output_file: str, image_array: np.ndarray, crf: float) -> None:
|
||||
container = av.open(output_file, "w", format="mp4")
|
||||
try:
|
||||
stream = container.add_stream("libx264", rate=1, options={"crf": str(crf), "preset": "veryfast"})
|
||||
# Round to nearest multiple of 2 for compatibility with video codecs
|
||||
height = image_array.shape[0] // 2 * 2
|
||||
width = image_array.shape[1] // 2 * 2
|
||||
image_array = image_array[:height, :width]
|
||||
stream.height = height
|
||||
stream.width = width
|
||||
av_frame = av.VideoFrame.from_ndarray(image_array, format="rgb24").reformat(format="yuv420p")
|
||||
container.mux(stream.encode(av_frame))
|
||||
container.mux(stream.encode())
|
||||
finally:
|
||||
container.close()
|
||||
|
||||
|
||||
def decode_single_frame(video_file: str) -> np.array:
|
||||
container = av.open(video_file)
|
||||
try:
|
||||
stream = next(s for s in container.streams if s.type == "video")
|
||||
frame = next(container.decode(stream))
|
||||
finally:
|
||||
container.close()
|
||||
return frame.to_ndarray(format="rgb24")
|
||||
|
||||
|
||||
def ltx2_preprocess(image: np.array, crf: float = 33) -> np.array:
|
||||
if crf == 0:
|
||||
return image
|
||||
|
||||
with BytesIO() as output_file:
|
||||
encode_single_frame(output_file, image, crf)
|
||||
video_bytes = output_file.getvalue()
|
||||
with BytesIO(video_bytes) as video_file:
|
||||
image_array = decode_single_frame(video_file)
|
||||
return image_array
|
||||
@@ -149,8 +149,6 @@ class FluxLoRALoader(GeneralLoRALoader):
|
||||
dtype=state_dict_[name].dtype)
|
||||
else:
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
|
||||
|
||||
mlp = mlp.to(device=state_dict_[name].device)
|
||||
if 'lora_A' in name:
|
||||
param = torch.concat([
|
||||
state_dict_.pop(name),
|
||||
|
||||
@@ -89,109 +89,4 @@ def FluxDiTStateDictConverter(state_dict):
|
||||
state_dict_[rename] = state_dict[original_name]
|
||||
else:
|
||||
pass
|
||||
return state_dict_
|
||||
|
||||
|
||||
def FluxDiTStateDictConverterFromDiffusers(state_dict):
|
||||
global_rename_dict = {
|
||||
"context_embedder": "context_embedder",
|
||||
"x_embedder": "x_embedder",
|
||||
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
|
||||
"time_text_embed.timestep_embedder.linear_2": "time_embedder.timestep_embedder.2",
|
||||
"time_text_embed.guidance_embedder.linear_1": "guidance_embedder.timestep_embedder.0",
|
||||
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
|
||||
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
|
||||
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
|
||||
"norm_out.linear": "final_norm_out.linear",
|
||||
"proj_out": "final_proj_out",
|
||||
}
|
||||
rename_dict = {
|
||||
"proj_out": "proj_out",
|
||||
"norm1.linear": "norm1_a.linear",
|
||||
"norm1_context.linear": "norm1_b.linear",
|
||||
"attn.to_q": "attn.a_to_q",
|
||||
"attn.to_k": "attn.a_to_k",
|
||||
"attn.to_v": "attn.a_to_v",
|
||||
"attn.to_out.0": "attn.a_to_out",
|
||||
"attn.add_q_proj": "attn.b_to_q",
|
||||
"attn.add_k_proj": "attn.b_to_k",
|
||||
"attn.add_v_proj": "attn.b_to_v",
|
||||
"attn.to_add_out": "attn.b_to_out",
|
||||
"ff.net.0.proj": "ff_a.0",
|
||||
"ff.net.2": "ff_a.2",
|
||||
"ff_context.net.0.proj": "ff_b.0",
|
||||
"ff_context.net.2": "ff_b.2",
|
||||
"attn.norm_q": "attn.norm_q_a",
|
||||
"attn.norm_k": "attn.norm_k_a",
|
||||
"attn.norm_added_q": "attn.norm_q_b",
|
||||
"attn.norm_added_k": "attn.norm_k_b",
|
||||
}
|
||||
rename_dict_single = {
|
||||
"attn.to_q": "a_to_q",
|
||||
"attn.to_k": "a_to_k",
|
||||
"attn.to_v": "a_to_v",
|
||||
"attn.norm_q": "norm_q_a",
|
||||
"attn.norm_k": "norm_k_a",
|
||||
"norm.linear": "norm.linear",
|
||||
"proj_mlp": "proj_in_besides_attn",
|
||||
"proj_out": "proj_out",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
param = state_dict[name]
|
||||
if name.endswith(".weight") or name.endswith(".bias"):
|
||||
suffix = ".weight" if name.endswith(".weight") else ".bias"
|
||||
prefix = name[:-len(suffix)]
|
||||
if prefix in global_rename_dict:
|
||||
if global_rename_dict[prefix] == "final_norm_out.linear":
|
||||
param = torch.concat([param[3072:], param[:3072]], dim=0)
|
||||
state_dict_[global_rename_dict[prefix] + suffix] = param
|
||||
elif prefix.startswith("transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict:
|
||||
name_ = ".".join(names[:2] + [rename_dict[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
elif prefix.startswith("single_transformer_blocks."):
|
||||
names = prefix.split(".")
|
||||
names[0] = "single_blocks"
|
||||
middle = ".".join(names[2:])
|
||||
if middle in rename_dict_single:
|
||||
name_ = ".".join(names[:2] + [rename_dict_single[middle]] + [suffix[1:]])
|
||||
state_dict_[name_] = param
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
for name in list(state_dict_.keys()):
|
||||
if "single_blocks." in name and ".a_to_q." in name:
|
||||
mlp = state_dict_.get(name.replace(".a_to_q.", ".proj_in_besides_attn."), None)
|
||||
if mlp is None:
|
||||
mlp = torch.zeros(4 * state_dict_[name].shape[0],
|
||||
*state_dict_[name].shape[1:],
|
||||
dtype=state_dict_[name].dtype)
|
||||
else:
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".proj_in_besides_attn."))
|
||||
param = torch.concat([
|
||||
state_dict_.pop(name),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_k.")),
|
||||
state_dict_.pop(name.replace(".a_to_q.", ".a_to_v.")),
|
||||
mlp,
|
||||
], dim=0)
|
||||
name_ = name.replace(".a_to_q.", ".to_qkv_mlp.")
|
||||
state_dict_[name_] = param
|
||||
for name in list(state_dict_.keys()):
|
||||
for component in ["a", "b"]:
|
||||
if f".{component}_to_q." in name:
|
||||
name_ = name.replace(f".{component}_to_q.", f".{component}_to_qkv.")
|
||||
param = torch.concat([
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_q.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_k.")],
|
||||
state_dict_[name.replace(f".{component}_to_q.", f".{component}_to_v.")],
|
||||
], dim=0)
|
||||
state_dict_[name_] = param
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_q."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_k."))
|
||||
state_dict_.pop(name.replace(f".{component}_to_q.", f".{component}_to_v."))
|
||||
return state_dict_
|
||||
@@ -1,32 +0,0 @@
|
||||
def LTX2AudioEncoderStateDictConverter(state_dict):
|
||||
# Not used
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("audio_vae.encoder."):
|
||||
new_name = name.replace("audio_vae.encoder.", "")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
elif name.startswith("audio_vae.per_channel_statistics."):
|
||||
new_name = name.replace("audio_vae.per_channel_statistics.", "per_channel_statistics.")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
|
||||
def LTX2AudioDecoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("audio_vae.decoder."):
|
||||
new_name = name.replace("audio_vae.decoder.", "")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
elif name.startswith("audio_vae.per_channel_statistics."):
|
||||
new_name = name.replace("audio_vae.per_channel_statistics.", "per_channel_statistics.")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
|
||||
def LTX2VocoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("vocoder."):
|
||||
new_name = name.replace("vocoder.", "")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
@@ -1,9 +0,0 @@
|
||||
def LTXModelStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("model.diffusion_model."):
|
||||
new_name = name.replace("model.diffusion_model.", "")
|
||||
if new_name.startswith("audio_embeddings_connector.") or new_name.startswith("video_embeddings_connector."):
|
||||
continue
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
@@ -1,31 +0,0 @@
|
||||
def LTX2TextEncoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if key.startswith("language_model.model."):
|
||||
new_key = key.replace("language_model.model.", "model.language_model.")
|
||||
elif key.startswith("vision_tower."):
|
||||
new_key = key.replace("vision_tower.", "model.vision_tower.")
|
||||
elif key.startswith("multi_modal_projector."):
|
||||
new_key = key.replace("multi_modal_projector.", "model.multi_modal_projector.")
|
||||
elif key.startswith("language_model.lm_head."):
|
||||
new_key = key.replace("language_model.lm_head.", "lm_head.")
|
||||
else:
|
||||
continue
|
||||
state_dict_[new_key] = state_dict[key]
|
||||
state_dict_["lm_head.weight"] = state_dict_.get("model.language_model.embed_tokens.weight")
|
||||
return state_dict_
|
||||
|
||||
|
||||
def LTX2TextEncoderPostModulesStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if key.startswith("text_embedding_projection."):
|
||||
new_key = key.replace("text_embedding_projection.", "feature_extractor_linear.")
|
||||
elif key.startswith("model.diffusion_model.video_embeddings_connector."):
|
||||
new_key = key.replace("model.diffusion_model.video_embeddings_connector.", "embeddings_connector.")
|
||||
elif key.startswith("model.diffusion_model.audio_embeddings_connector."):
|
||||
new_key = key.replace("model.diffusion_model.audio_embeddings_connector.", "audio_embeddings_connector.")
|
||||
else:
|
||||
continue
|
||||
state_dict_[new_key] = state_dict[key]
|
||||
return state_dict_
|
||||
@@ -1,22 +0,0 @@
|
||||
def LTX2VideoEncoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("vae.encoder."):
|
||||
new_name = name.replace("vae.encoder.", "")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
elif name.startswith("vae.per_channel_statistics."):
|
||||
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
|
||||
|
||||
def LTX2VideoDecoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name.startswith("vae.decoder."):
|
||||
new_name = name.replace("vae.decoder.", "")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
elif name.startswith("vae.per_channel_statistics."):
|
||||
new_name = name.replace("vae.per_channel_statistics.", "per_channel_statistics.")
|
||||
state_dict_[new_name] = state_dict[name]
|
||||
return state_dict_
|
||||
@@ -1,6 +0,0 @@
|
||||
def ZImageTextEncoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for name in state_dict:
|
||||
if name != "lm_head.weight":
|
||||
state_dict_[name] = state_dict[name]
|
||||
return state_dict_
|
||||
@@ -1,13 +1,10 @@
|
||||
import torch
|
||||
from typing import Optional
|
||||
from einops import rearrange
|
||||
from yunchang.kernels import AttnType
|
||||
from xfuser.core.distributed import (get_sequence_parallel_rank,
|
||||
get_sequence_parallel_world_size,
|
||||
get_sp_group)
|
||||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
||||
|
||||
from ... import IS_NPU_AVAILABLE
|
||||
from ...core.device import parse_nccl_backend, parse_device_type
|
||||
|
||||
|
||||
@@ -33,16 +30,13 @@ def sinusoidal_embedding_1d(dim, position):
|
||||
def pad_freqs(original_tensor, target_len):
|
||||
seq_len, s1, s2 = original_tensor.shape
|
||||
pad_size = target_len - seq_len
|
||||
original_tensor_device = original_tensor.device
|
||||
if original_tensor.device == "npu":
|
||||
original_tensor = original_tensor.cpu()
|
||||
padding_tensor = torch.ones(
|
||||
pad_size,
|
||||
s1,
|
||||
s2,
|
||||
dtype=original_tensor.dtype,
|
||||
device=original_tensor.device)
|
||||
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0).to(device=original_tensor_device)
|
||||
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
||||
return padded_tensor
|
||||
|
||||
def rope_apply(x, freqs, num_heads):
|
||||
@@ -56,7 +50,7 @@ def rope_apply(x, freqs, num_heads):
|
||||
sp_rank = get_sequence_parallel_rank()
|
||||
freqs = pad_freqs(freqs, s_per_rank * sp_size)
|
||||
freqs_rank = freqs[(sp_rank * s_per_rank):((sp_rank + 1) * s_per_rank), :, :]
|
||||
freqs_rank = freqs_rank.to(torch.complex64) if freqs_rank.device.type == "npu" else freqs_rank
|
||||
|
||||
x_out = torch.view_as_real(x_out * freqs_rank).flatten(2)
|
||||
return x_out.to(x.dtype)
|
||||
|
||||
@@ -139,12 +133,7 @@ def usp_attn_forward(self, x, freqs):
|
||||
k = rearrange(k, "b s (n d) -> b s n d", n=self.num_heads)
|
||||
v = rearrange(v, "b s (n d) -> b s n d", n=self.num_heads)
|
||||
|
||||
attn_type = AttnType.FA
|
||||
ring_impl_type = "basic"
|
||||
if IS_NPU_AVAILABLE:
|
||||
attn_type = AttnType.NPU
|
||||
ring_impl_type = "basic_npu"
|
||||
x = xFuserLongContextAttention(attn_type=attn_type, ring_impl_type=ring_impl_type)(
|
||||
x = xFuserLongContextAttention()(
|
||||
None,
|
||||
query=q,
|
||||
key=k,
|
||||
|
||||
@@ -2,15 +2,6 @@
|
||||
|
||||
FLUX.2 is an image generation model trained and open-sourced by Black Forest Labs.
|
||||
|
||||
## Model Lineage
|
||||
|
||||
```mermaid
|
||||
graph LR;
|
||||
FLUX.2-Series-->black-forest-labs/FLUX.2-dev;
|
||||
FLUX.2-Series-->black-forest-labs/FLUX.2-klein-4B;
|
||||
FLUX.2-Series-->black-forest-labs/FLUX.2-klein-9B;
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
Before using this project for model inference and training, please install DiffSynth-Studio first.
|
||||
@@ -59,20 +50,16 @@ image.save("image.jpg")
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||
| - | - | - | - | - | - | - |
|
||||
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
|
||||
| Model ID | Inference | Low VRAM Inference | LoRA Training | Validation After LoRA Training |
|
||||
| - | - | - | - | - |
|
||||
| [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) | [code](/examples/flux2/model_inference/FLUX.2-dev.py) | [code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py) | [code](/examples/flux2/model_training/lora/FLUX.2-dev.sh) | [code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py) |
|
||||
|
||||
Special Training Scripts:
|
||||
|
||||
* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md)
|
||||
* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md)
|
||||
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md)
|
||||
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md)
|
||||
* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/flux/model_training/special/differential_training/)
|
||||
* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/flux/model_training/special/fp8_training/)
|
||||
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/flux/model_training/special/split_training/)
|
||||
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh)
|
||||
|
||||
## Model Inference
|
||||
|
||||
@@ -148,4 +135,4 @@ We have built a sample image dataset for your testing. You can download this dat
|
||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
```
|
||||
|
||||
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).
|
||||
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).
|
||||
@@ -1,109 +0,0 @@
|
||||
# LTX-2
|
||||
|
||||
LTX-2 is a series of audio-video generation models developed by Lightricks.
|
||||
|
||||
## Installation
|
||||
|
||||
Before using this project for model inference and training, please install DiffSynth-Studio first.
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
For more information about installation, please refer to [Installation Dependencies](/docs/en/Pipeline_Usage/Setup.md).
|
||||
|
||||
## Quick Start
|
||||
|
||||
Run the following code to quickly load the [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) model and perform inference. VRAM management has been enabled, and the framework will automatically control model parameter loading based on remaining VRAM. It can run with a minimum of 8GB VRAM.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e5m2,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e5m2,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e5m2,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
prompt = "A girl is very happy, she is speaking: \"I enjoy working with Diffsynth-Studio, it's a perfect framework.\""
|
||||
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
height, width, num_frames = 512, 768, 121
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_onestage.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
```
|
||||
|
||||
## Model Overview
|
||||
|Model ID|Additional Parameters|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|
||||
|-|-|-|-|-|-|-|-|
|
||||
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|
||||
|
||||
## Model Inference
|
||||
|
||||
Models are loaded through `LTX2AudioVideoPipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models) for details.
|
||||
|
||||
Input parameters for `LTX2AudioVideoPipeline` inference include:
|
||||
|
||||
* `prompt`: Prompt describing the content appearing in the video.
|
||||
* `negative_prompt`: Negative prompt describing content that should not appear in the video, default value is `""`.
|
||||
* `cfg_scale`: Classifier-free guidance parameter, default value is 3.0.
|
||||
* `input_images`: List of input images for image-to-video generation.
|
||||
* `input_images_indexes`: Frame index list of input images in the video.
|
||||
* `input_images_strength`: Strength of input images, default value is 1.0.
|
||||
* `denoising_strength`: Denoising strength, range is 0~1, default value is 1.0.
|
||||
* `seed`: Random seed. Default is `None`, which means completely random.
|
||||
* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different results will be generated on different GPUs.
|
||||
* `height`: Video height, must be a multiple of 32 (single-stage) or 64 (two-stage).
|
||||
* `width`: Video width, must be a multiple of 32 (single-stage) or 64 (two-stage).
|
||||
* `num_frames`: Number of video frames, default value is 121, must be a multiple of 8 + 1.
|
||||
* `num_inference_steps`: Number of inference steps, default value is 40.
|
||||
* `tiled`: Whether to enable VAE tiling inference, default is `True`. When set to `True`, it can significantly reduce VRAM usage during VAE encoding/decoding stages, with slight errors and minor inference time extension.
|
||||
* `tile_size_in_pixels`: Pixel tiling size during VAE encoding/decoding stages, default is 512.
|
||||
* `tile_overlap_in_pixels`: Pixel tiling overlap size during VAE encoding/decoding stages, default is 128.
|
||||
* `tile_size_in_frames`: Frame tiling size during VAE encoding/decoding stages, default is 128.
|
||||
* `tile_overlap_in_frames`: Frame tiling overlap size during VAE encoding/decoding stages, default is 24.
|
||||
* `use_two_stage_pipeline`: Whether to use two-stage pipeline, default is `False`.
|
||||
* `use_distilled_pipeline`: Whether to use distilled pipeline, default is `False`.
|
||||
* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be set to `lambda x:x` to hide the progress bar.
|
||||
|
||||
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the previous "Supported Inference Scripts" section.
|
||||
|
||||
## Model Training
|
||||
|
||||
The LTX-2 series models currently do not support training functionality. We will add related support as soon as possible.
|
||||
@@ -85,9 +85,7 @@ graph LR;
|
||||
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
|
||||
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|
||||
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
||||
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|
||||
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
||||
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |
|
||||
|
||||
@@ -50,14 +50,9 @@ image.save("image.jpg")
|
||||
|
||||
## Model Overview
|
||||
|
||||
|Model ID|Inference|Low VRAM Inference|Full Training|Validation After Full Training|LoRA Training|Validation After LoRA Training|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|
||||
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|
||||
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
||||
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|
||||
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|
||||
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
|
||||
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||
| - | - | - | - | - | - | - |
|
||||
| [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) | [code](/examples/z_image/model_inference/Z-Image-Turbo.py) | [code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py) | [code](/examples/z_image/model_training/full/Z-Image-Turbo.sh) | [code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py) | [code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh) | [code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py) |
|
||||
|
||||
Special Training Scripts:
|
||||
|
||||
@@ -80,9 +75,6 @@ Input parameters for `ZImagePipeline` inference include:
|
||||
* `seed`: Random seed. Default is `None`, meaning completely random.
|
||||
* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results.
|
||||
* `num_inference_steps`: Number of inference steps, default value is 8.
|
||||
* `controlnet_inputs`: Inputs for ControlNet models.
|
||||
* `edit_image`: Edit images for image editing models, supporting multiple images.
|
||||
* `positive_only_lora`: LoRA weights used only in positive prompts.
|
||||
|
||||
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ All sample code provided by this project supports NVIDIA GPUs by default, requir
|
||||
AMD provides PyTorch packages based on ROCm, so most models can run without code changes. A small number of models may not be compatible due to their reliance on CUDA-specific instructions.
|
||||
|
||||
## Ascend NPU
|
||||
### Inference
|
||||
|
||||
When using Ascend NPU, you need to replace `"cuda"` with `"npu"` in your code.
|
||||
|
||||
For example, here is the inference code for **Wan2.1-T2V-1.3B**, modified for Ascend NPU:
|
||||
@@ -22,7 +22,6 @@ For example, here is the inference code for **Wan2.1-T2V-1.3B**, modified for As
|
||||
import torch
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from diffsynth.core.device.npu_compatible_device import get_device_name
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
@@ -47,7 +46,7 @@ pipe = WanVideoPipeline.from_pretrained(
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
|
||||
+ vram_limit=torch.npu.mem_get_info(get_device_name())[1] / (1024 ** 3) - 2,
|
||||
+ vram_limit=torch.npu.mem_get_info("npu:0")[1] / (1024 ** 3) - 2,
|
||||
)
|
||||
|
||||
video = pipe(
|
||||
@@ -57,36 +56,3 @@ video = pipe(
|
||||
)
|
||||
save_video(video, "video.mp4", fps=15, quality=5)
|
||||
```
|
||||
|
||||
#### USP(Unified Sequence Parallel)
|
||||
If you want to use this feature on NPU, please install additional third-party libraries as follows:
|
||||
```shell
|
||||
pip install git+https://github.com/feifeibear/long-context-attention.git
|
||||
pip install git+https://github.com/xdit-project/xDiT.git
|
||||
```
|
||||
|
||||
|
||||
### Training
|
||||
NPU startup script samples have been added for each type of model,the scripts are stored in the `examples/xxx/special/npu_training`, for example `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`.
|
||||
|
||||
In the NPU training scripts, NPU specific environment variables that can optimize performance have been added, and relevant parameters have been enabled for specific models.
|
||||
|
||||
#### Environment variables
|
||||
```shell
|
||||
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
|
||||
```
|
||||
`expandable_segments:<value>`: Enable the memory pool expansion segment function, which is the virtual memory feature.
|
||||
|
||||
```shell
|
||||
export CPU_AFFINITY_CONF=1
|
||||
```
|
||||
Set 0 or not set: indicates not enabling the binding function
|
||||
|
||||
1: Indicates enabling coarse-grained kernel binding
|
||||
|
||||
2: Indicates enabling fine-grained kernel binding
|
||||
|
||||
#### Parameters for specific models
|
||||
| Model | Parameter | Note |
|
||||
|----------------|---------------------------|-------------------|
|
||||
| Wan 14B series | --initialize_model_on_cpu | The 14B model needs to be initialized on the CPU |
|
||||
@@ -30,16 +30,11 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6
|
||||
|
||||
* **Ascend NPU**
|
||||
|
||||
1. Install [CANN](https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/softwareinst/instg/instg_quick.html?Mode=PmIns&InstallType=local&OS=openEuler&Software=cannToolKit) through official documentation.
|
||||
Ascend NPU support is provided via the `torch-npu` package. Taking version `2.1.0.post17` (as of the article update date: December 15, 2025) as an example, run the following command:
|
||||
|
||||
2. Install from source
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
# aarch64/ARM
|
||||
pip install -e .[npu_aarch64] --extra-index-url "https://download.pytorch.org/whl/cpu"
|
||||
# x86
|
||||
pip install -e .[npu]
|
||||
```shell
|
||||
pip install torch-npu==2.1.0.post17
|
||||
```
|
||||
|
||||
When using Ascend NPU, please replace `"cuda"` with `"npu"` in your Python code. For details, see [NPU Support](/docs/en/Pipeline_Usage/GPU_support.md#ascend-npu).
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ This section introduces the independent core module `diffsynth.core` in `DiffSyn
|
||||
|
||||
This section introduces how to use `DiffSynth-Studio` to train new models, helping researchers explore new model technologies.
|
||||
|
||||
* [Training models from scratch](/docs/en/Research_Tutorial/train_from_scratch.md)
|
||||
* Training models from scratch 【coming soon】
|
||||
* Inference improvement techniques 【coming soon】
|
||||
* Designing controllable generation models 【coming soon】
|
||||
* Creating new training paradigms 【coming soon】
|
||||
|
||||
@@ -1,476 +0,0 @@
|
||||
# Training Models from Scratch
|
||||
|
||||
DiffSynth-Studio's training engine supports training foundation models from scratch. This article introduces how to train a small text-to-image model with only 0.1B parameters from scratch.
|
||||
|
||||
## 1. Building Model Architecture
|
||||
|
||||
### 1.1 Diffusion Model
|
||||
|
||||
From UNet [[1]](https://arxiv.org/abs/1505.04597) [[2]](https://arxiv.org/abs/2112.10752) to DiT [[3]](https://arxiv.org/abs/2212.09748) [[4]](https://arxiv.org/abs/2403.03206), the mainstream model architectures of Diffusion have undergone multiple evolutions. Typically, a Diffusion model's inputs include:
|
||||
|
||||
* Image tensor (`latents`): The encoding of images, generated by the VAE model, containing partial noise
|
||||
* Text tensor (`prompt_embeds`): The encoding of text, generated by the text encoder
|
||||
* Timestep (`timestep`): A scalar used to mark which stage of the Diffusion process we are currently at
|
||||
|
||||
The model's output is a tensor with the same shape as the image tensor, representing the denoising direction predicted by the model. For details about Diffusion model theory, please refer to [Basic Principles of Diffusion Models](/docs/en/Training/Understanding_Diffusion_models.md). In this article, we build a DiT model with only 0.1B parameters: `AAADiT`.
|
||||
|
||||
<details>
|
||||
<summary>Model Architecture Code</summary>
|
||||
|
||||
```python
|
||||
import torch, accelerate
|
||||
from PIL import Image
|
||||
from typing import Union
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
|
||||
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
|
||||
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
from diffsynth.models.general_modules import TimestepEmbeddings
|
||||
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
|
||||
from diffsynth.models.flux2_vae import Flux2VAE
|
||||
|
||||
|
||||
class AAAPositionalEmbedding(torch.nn.Module):
|
||||
def __init__(self, height=16, width=16, dim=1024):
|
||||
super().__init__()
|
||||
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
|
||||
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
|
||||
|
||||
def forward(self, image, text):
|
||||
height, width = image.shape[-2:]
|
||||
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
|
||||
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
|
||||
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
|
||||
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
|
||||
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
|
||||
emb = torch.concat([image_emb, text_emb], dim=1)
|
||||
return emb
|
||||
|
||||
|
||||
class AAABlock(torch.nn.Module):
|
||||
def __init__(self, dim=1024, num_heads=32):
|
||||
super().__init__()
|
||||
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||
self.to_q = torch.nn.Linear(dim, dim)
|
||||
self.to_k = torch.nn.Linear(dim, dim)
|
||||
self.to_v = torch.nn.Linear(dim, dim)
|
||||
self.to_out = torch.nn.Linear(dim, dim)
|
||||
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||
self.ff = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*3),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(dim*3, dim),
|
||||
)
|
||||
self.to_gate = torch.nn.Linear(dim, dim * 2)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def attention(self, emb, pos_emb):
|
||||
emb = self.norm_attn(emb + pos_emb)
|
||||
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
|
||||
emb = attention_forward(
|
||||
q, k, v,
|
||||
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
|
||||
dims={"n": self.num_heads},
|
||||
)
|
||||
emb = self.to_out(emb)
|
||||
return emb
|
||||
|
||||
def feed_forward(self, emb, pos_emb):
|
||||
emb = self.norm_mlp(emb + pos_emb)
|
||||
emb = self.ff(emb)
|
||||
return emb
|
||||
|
||||
def forward(self, emb, pos_emb, t_emb):
|
||||
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
|
||||
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
|
||||
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
|
||||
return emb
|
||||
|
||||
|
||||
class AAADiT(torch.nn.Module):
|
||||
def __init__(self, dim=1024):
|
||||
super().__init__()
|
||||
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
|
||||
self.timestep_embedder = TimestepEmbeddings(256, dim)
|
||||
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
|
||||
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
|
||||
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
|
||||
self.proj_out = torch.nn.Linear(dim, 128)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents,
|
||||
prompt_embeds,
|
||||
timestep,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
pos_emb = self.pos_embedder(latents, prompt_embeds)
|
||||
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
|
||||
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
|
||||
text = self.text_embedder(prompt_embeds)
|
||||
emb = torch.concat([image, text], dim=1)
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
emb = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
emb=emb,
|
||||
pos_emb=pos_emb,
|
||||
t_emb=t_emb,
|
||||
)
|
||||
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
|
||||
emb = self.proj_out(emb)
|
||||
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
|
||||
return emb
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### 1.2 Encoder-Decoder Models
|
||||
|
||||
Besides the Diffusion model used for denoising, we also need two other models:
|
||||
|
||||
* Text Encoder: Used to encode text into tensors. We adopt the [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) model.
|
||||
* VAE Encoder-Decoder: The encoder part is used to encode images into tensors, and the decoder part is used to decode image tensors into images. We adopt the VAE model from [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B).
|
||||
|
||||
The architectures of these two models are already integrated in DiffSynth-Studio, located at [/diffsynth/models/z_image_text_encoder.py](/diffsynth/models/z_image_text_encoder.py) and [/diffsynth/models/flux2_vae.py](/diffsynth/models/flux2_vae.py), so we don't need to modify any code.
|
||||
|
||||
## 2. Building Pipeline
|
||||
|
||||
We introduced how to build a model Pipeline in the document [Integrating Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md). For the model in this article, we also need to build a Pipeline to connect the text encoder, Diffusion model, and VAE encoder-decoder.
|
||||
|
||||
<details>
|
||||
<summary>Pipeline Code</summary>
|
||||
|
||||
```python
|
||||
class AAAImagePipeline(BasePipeline):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler("FLUX.2")
|
||||
self.text_encoder: ZImageTextEncoder = None
|
||||
self.dit: AAADiT = None
|
||||
self.vae: Flux2VAE = None
|
||||
self.tokenizer: AutoProcessor = None
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
AAAUnit_PromptEmbedder(),
|
||||
AAAUnit_NoiseInitializer(),
|
||||
AAAUnit_InputImageEmbedder(),
|
||||
]
|
||||
self.model_fn = model_fn_aaa
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = None,
|
||||
vram_limit: float = None,
|
||||
):
|
||||
# Initialize pipeline
|
||||
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
|
||||
pipe.dit = model_pool.fetch_model("aaa_dit")
|
||||
pipe.vae = model_pool.fetch_model("flux2_vae")
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 1.0,
|
||||
# Image
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Shape
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
# Randomness
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 30,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {"prompt": prompt}
|
||||
inputs_nega = {"negative_prompt": negative_prompt}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
image = self.vae.decode(inputs_shared["latents"])
|
||||
image = self.vae_output_to_image(image)
|
||||
self.load_models_to_device([])
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class AAAUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("prompt_embeds",),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
self.hidden_states_layers = (-1,)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, prompt):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
text = pipe.tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
|
||||
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
|
||||
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
|
||||
return {"prompt_embeds": prompt_embeds}
|
||||
|
||||
|
||||
class AAAUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "seed", "rand_device"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
|
||||
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
class AAAUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "noise"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, input_image, noise):
|
||||
if input_image is None:
|
||||
return {"latents": noise, "input_latents": None}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
image = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae.encode(image)
|
||||
if pipe.scheduler.training:
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
|
||||
|
||||
def model_fn_aaa(
|
||||
dit: AAADiT,
|
||||
latents=None,
|
||||
prompt_embeds=None,
|
||||
timestep=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
model_output = dit(
|
||||
latents,
|
||||
prompt_embeds,
|
||||
timestep,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
return model_output
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 3. Preparing Dataset
|
||||
|
||||
To quickly verify training effectiveness, we use the dataset [Pokemon-First Generation](https://modelscope.cn/datasets/DiffSynth-Studio/pokemon-gen1), which is reproduced from the open-source project [pokemon-dataset-zh](https://github.com/42arch/pokemon-dataset-zh), containing 151 first-generation Pokemon from Bulbasaur to Mew. If you want to use other datasets, please refer to the document [Preparing Datasets](/docs/en/Pipeline_Usage/Model_Training.md#preparing-datasets) and [`diffsynth.core.data`](/docs/en/API_Reference/core/data.md).
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data
|
||||
```
|
||||
|
||||
### 4. Start Training
|
||||
|
||||
The training process can be quickly implemented using Pipeline. We have placed the complete code at [/docs/en/Research_Tutorial/train_from_scratch.py](/docs/en/Research_Tutorial/train_from_scratch.py), which can be directly started with `python docs/en/Research_Tutorial/train_from_scratch.py` for single GPU training.
|
||||
|
||||
To enable multi-GPU parallel training, please run `accelerate config` to set relevant parameters, then use the command `accelerate launch docs/en/Research_Tutorial/train_from_scratch.py` to start training.
|
||||
|
||||
This training script has no stopping condition, please manually close it when needed. The model converges after training approximately 60,000 steps, requiring 10-20 hours for single GPU training.
|
||||
|
||||
<details>
|
||||
<summary>Training Code</summary>
|
||||
|
||||
```python
|
||||
class AAATrainingModule(DiffusionTrainingModule):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.pipe = AAAImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device=device,
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||
)
|
||||
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
|
||||
self.pipe.freeze_except(["dit"])
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
def forward(self, data):
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
inputs_nega = {"negative_prompt": ""}
|
||||
inputs_shared = {
|
||||
"input_image": data["image"],
|
||||
"height": data["image"].size[1],
|
||||
"width": data["image"].size[0],
|
||||
"cfg_scale": 1,
|
||||
"use_gradient_checkpointing": False,
|
||||
"use_gradient_checkpointing_offload": False,
|
||||
}
|
||||
for unit in self.pipe.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
|
||||
return loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
|
||||
dataset = UnifiedDataset(
|
||||
base_path="data/images",
|
||||
metadata_path="data/metadata_merged.csv",
|
||||
max_data_items=10000000,
|
||||
data_file_keys=("image",),
|
||||
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
|
||||
)
|
||||
model = AAATrainingModule(device=accelerator.device)
|
||||
model_logger = ModelLogger(
|
||||
"models/AAA/v1",
|
||||
remove_prefix_in_ckpt="pipe.dit.",
|
||||
)
|
||||
launch_training_task(
|
||||
accelerator, dataset, model, model_logger,
|
||||
learning_rate=2e-4,
|
||||
num_workers=4,
|
||||
save_steps=50000,
|
||||
num_epochs=999999,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 5. Verifying Training Results
|
||||
|
||||
If you don't want to wait for the model training to complete, you can directly download [our pre-trained model](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel).
|
||||
|
||||
```shell
|
||||
modelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel
|
||||
```
|
||||
|
||||
Loading the model
|
||||
|
||||
```python
|
||||
from diffsynth import load_model
|
||||
|
||||
pipe = AAAImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||
)
|
||||
pipe.dit = load_model(AAADiT, "models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors", torch_dtype=torch.bfloat16, device="cuda")
|
||||
```
|
||||
|
||||
Model inference, generating the first-generation Pokemon "starter trio". At this point, the images generated by the model basically match the training data.
|
||||
|
||||
```python
|
||||
for seed, prompt in enumerate([
|
||||
"green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws",
|
||||
"orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws",
|
||||
"blue, beige, brown, turtle, water type, shell, big eyes, short limbs, curled tail",
|
||||
]):
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=" ",
|
||||
num_inference_steps=30,
|
||||
cfg_scale=10,
|
||||
seed=seed,
|
||||
height=256, width=256,
|
||||
)
|
||||
image.save(f"image_{seed}.jpg")
|
||||
```
|
||||
|
||||
||||
|
||||
|-|-|-|
|
||||
|
||||
Model inference, generating Pokemon with "sharp claws". At this point, different random seeds can produce different image results.
|
||||
|
||||
```python
|
||||
for seed, prompt in enumerate([
|
||||
"sharp claws",
|
||||
"sharp claws",
|
||||
"sharp claws",
|
||||
]):
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=" ",
|
||||
num_inference_steps=30,
|
||||
cfg_scale=10,
|
||||
seed=seed+4,
|
||||
height=256, width=256,
|
||||
)
|
||||
image.save(f"image_sharp_claws_{seed}.jpg")
|
||||
```
|
||||
|
||||
||||
|
||||
|-|-|-|
|
||||
|
||||
Now, we have obtained a 0.1B small text-to-image model. This model can already generate 151 Pokemon, but cannot generate other image content. If you increase the amount of data, model parameters, and number of GPUs based on this, you can train a more powerful text-to-image model!
|
||||
@@ -1,341 +0,0 @@
|
||||
import torch, accelerate
|
||||
from PIL import Image
|
||||
from typing import Union
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
|
||||
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
|
||||
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
from diffsynth.models.general_modules import TimestepEmbeddings
|
||||
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
|
||||
from diffsynth.models.flux2_vae import Flux2VAE
|
||||
|
||||
|
||||
class AAAPositionalEmbedding(torch.nn.Module):
|
||||
def __init__(self, height=16, width=16, dim=1024):
|
||||
super().__init__()
|
||||
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
|
||||
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
|
||||
|
||||
def forward(self, image, text):
|
||||
height, width = image.shape[-2:]
|
||||
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
|
||||
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
|
||||
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
|
||||
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
|
||||
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
|
||||
emb = torch.concat([image_emb, text_emb], dim=1)
|
||||
return emb
|
||||
|
||||
|
||||
class AAABlock(torch.nn.Module):
|
||||
def __init__(self, dim=1024, num_heads=32):
|
||||
super().__init__()
|
||||
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||
self.to_q = torch.nn.Linear(dim, dim)
|
||||
self.to_k = torch.nn.Linear(dim, dim)
|
||||
self.to_v = torch.nn.Linear(dim, dim)
|
||||
self.to_out = torch.nn.Linear(dim, dim)
|
||||
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||
self.ff = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*3),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(dim*3, dim),
|
||||
)
|
||||
self.to_gate = torch.nn.Linear(dim, dim * 2)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def attention(self, emb, pos_emb):
|
||||
emb = self.norm_attn(emb + pos_emb)
|
||||
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
|
||||
emb = attention_forward(
|
||||
q, k, v,
|
||||
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
|
||||
dims={"n": self.num_heads},
|
||||
)
|
||||
emb = self.to_out(emb)
|
||||
return emb
|
||||
|
||||
def feed_forward(self, emb, pos_emb):
|
||||
emb = self.norm_mlp(emb + pos_emb)
|
||||
emb = self.ff(emb)
|
||||
return emb
|
||||
|
||||
def forward(self, emb, pos_emb, t_emb):
|
||||
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
|
||||
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
|
||||
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
|
||||
return emb
|
||||
|
||||
|
||||
class AAADiT(torch.nn.Module):
|
||||
def __init__(self, dim=1024):
|
||||
super().__init__()
|
||||
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
|
||||
self.timestep_embedder = TimestepEmbeddings(256, dim)
|
||||
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
|
||||
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
|
||||
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
|
||||
self.proj_out = torch.nn.Linear(dim, 128)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents,
|
||||
prompt_embeds,
|
||||
timestep,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
pos_emb = self.pos_embedder(latents, prompt_embeds)
|
||||
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
|
||||
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
|
||||
text = self.text_embedder(prompt_embeds)
|
||||
emb = torch.concat([image, text], dim=1)
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
emb = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
emb=emb,
|
||||
pos_emb=pos_emb,
|
||||
t_emb=t_emb,
|
||||
)
|
||||
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
|
||||
emb = self.proj_out(emb)
|
||||
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
|
||||
return emb
|
||||
|
||||
|
||||
class AAAImagePipeline(BasePipeline):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler("FLUX.2")
|
||||
self.text_encoder: ZImageTextEncoder = None
|
||||
self.dit: AAADiT = None
|
||||
self.vae: Flux2VAE = None
|
||||
self.tokenizer: AutoProcessor = None
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
AAAUnit_PromptEmbedder(),
|
||||
AAAUnit_NoiseInitializer(),
|
||||
AAAUnit_InputImageEmbedder(),
|
||||
]
|
||||
self.model_fn = model_fn_aaa
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = None,
|
||||
vram_limit: float = None,
|
||||
):
|
||||
# Initialize pipeline
|
||||
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
|
||||
pipe.dit = model_pool.fetch_model("aaa_dit")
|
||||
pipe.vae = model_pool.fetch_model("flux2_vae")
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 1.0,
|
||||
# Image
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Shape
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
# Randomness
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 30,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {"prompt": prompt}
|
||||
inputs_nega = {"negative_prompt": negative_prompt}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
image = self.vae.decode(inputs_shared["latents"])
|
||||
image = self.vae_output_to_image(image)
|
||||
self.load_models_to_device([])
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class AAAUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("prompt_embeds",),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
self.hidden_states_layers = (-1,)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, prompt):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
text = pipe.tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
|
||||
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
|
||||
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
|
||||
return {"prompt_embeds": prompt_embeds}
|
||||
|
||||
|
||||
class AAAUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "seed", "rand_device"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
|
||||
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
class AAAUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "noise"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, input_image, noise):
|
||||
if input_image is None:
|
||||
return {"latents": noise, "input_latents": None}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
image = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae.encode(image)
|
||||
if pipe.scheduler.training:
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
|
||||
|
||||
def model_fn_aaa(
|
||||
dit: AAADiT,
|
||||
latents=None,
|
||||
prompt_embeds=None,
|
||||
timestep=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
model_output = dit(
|
||||
latents,
|
||||
prompt_embeds,
|
||||
timestep,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
return model_output
|
||||
|
||||
|
||||
class AAATrainingModule(DiffusionTrainingModule):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.pipe = AAAImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device=device,
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||
)
|
||||
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
|
||||
self.pipe.freeze_except(["dit"])
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
def forward(self, data):
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
inputs_nega = {"negative_prompt": ""}
|
||||
inputs_shared = {
|
||||
"input_image": data["image"],
|
||||
"height": data["image"].size[1],
|
||||
"width": data["image"].size[0],
|
||||
"cfg_scale": 1,
|
||||
"use_gradient_checkpointing": False,
|
||||
"use_gradient_checkpointing_offload": False,
|
||||
}
|
||||
for unit in self.pipe.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
|
||||
return loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
|
||||
dataset = UnifiedDataset(
|
||||
base_path="data/images",
|
||||
metadata_path="data/metadata_merged.csv",
|
||||
max_data_items=10000000,
|
||||
data_file_keys=("image",),
|
||||
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
|
||||
)
|
||||
model = AAATrainingModule(device=accelerator.device)
|
||||
model_logger = ModelLogger(
|
||||
"models/AAA/v1",
|
||||
remove_prefix_in_ckpt="pipe.dit.",
|
||||
)
|
||||
launch_training_task(
|
||||
accelerator, dataset, model, model_logger,
|
||||
learning_rate=2e-4,
|
||||
num_workers=4,
|
||||
save_steps=50000,
|
||||
num_epochs=999999,
|
||||
)
|
||||
@@ -6,7 +6,7 @@ This document introduces the basic principles of Diffusion models to help you un
|
||||
|
||||
Diffusion models generate clear images or video content through iterative denoising. We start by explaining the generation process of a data sample $x_0$. Intuitively, in a complete round of denoising, we start from random Gaussian noise $x_T$ and iteratively obtain $x_{T-1}$, $x_{T-2}$, $x_{T-3}$, $\cdots$, gradually reducing the noise content at each step until we finally obtain the noise-free data sample $x_0$.
|
||||
|
||||

|
||||
(Figure)
|
||||
|
||||
This process is intuitive, but to understand the details, we need to answer several questions:
|
||||
|
||||
@@ -28,7 +28,7 @@ As for the intermediate values $\sigma_{T-1}$, $\sigma_{T-2}$, $\cdots$, $\sigma
|
||||
|
||||
At an intermediate step, we can directly synthesize noisy data samples $x_t=(1-\sigma_t)x_0+\sigma_t x_T$.
|
||||
|
||||

|
||||
(Figure)
|
||||
|
||||
## How is the iterative denoising computation performed?
|
||||
|
||||
@@ -40,6 +40,8 @@ Before understanding the iterative denoising computation, we need to clarify wha
|
||||
|
||||
Among these, the guidance condition $c$ is a newly introduced parameter that is input by the user. It can be text describing the image content or a sketch outlining the image structure.
|
||||
|
||||
(Figure)
|
||||
|
||||
The model's output $\hat \epsilon(x_t,c,t)$ approximately equals $x_T-x_0$, which is the direction of the entire diffusion process (the reverse process of denoising).
|
||||
|
||||
Next, we analyze the computation occurring in one iteration. At time step $t$, after the model computes an approximation of $x_T-x_0$, we calculate the next $x_{t-1}$:
|
||||
@@ -89,6 +91,8 @@ After understanding the iterative denoising process, we next consider how to tra
|
||||
|
||||
The training process differs from the generation process. If we retain multi-step iterations during training, the gradient would need to backpropagate through multiple steps, bringing catastrophic time and space complexity. To improve computational efficiency, we randomly select a time step $t$ for training.
|
||||
|
||||
(Figure)
|
||||
|
||||
The following is pseudocode for the training process:
|
||||
|
||||
> Obtain data sample $x_0$ and guidance condition $c$ from the dataset
|
||||
@@ -109,7 +113,7 @@ The following is pseudocode for the training process:
|
||||
|
||||
From theory to practice, more details need to be filled in. Modern Diffusion model architectures have matured, with mainstream architectures following the "three-stage" architecture proposed by Latent Diffusion, including data encoder-decoder, guidance condition encoder, and denoising model.
|
||||
|
||||

|
||||
(Figure)
|
||||
|
||||
### Data Encoder-Decoder
|
||||
|
||||
|
||||
@@ -2,15 +2,6 @@
|
||||
|
||||
FLUX.2 是由 Black Forest Labs 训练并开源的图像生成模型。
|
||||
|
||||
## 模型血缘
|
||||
|
||||
```mermaid
|
||||
graph LR;
|
||||
FLUX.2-Series-->black-forest-labs/FLUX.2-dev;
|
||||
FLUX.2-Series-->black-forest-labs/FLUX.2-klein-4B;
|
||||
FLUX.2-Series-->black-forest-labs/FLUX.2-klein-9B;
|
||||
```
|
||||
|
||||
## 安装
|
||||
|
||||
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||
@@ -59,20 +50,16 @@ image.save("image.jpg")
|
||||
|
||||
## 模型总览
|
||||
|
||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|-|-|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-4B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-9B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-base-4B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-4B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-4B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-4B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-4B.py)|
|
||||
|[black-forest-labs/FLUX.2-klein-base-9B](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-klein-base-9B)|[code](/examples/flux2/model_inference/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/full/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_full/FLUX.2-klein-base-9B.py)|[code](/examples/flux2/model_training/lora/FLUX.2-klein-base-9B.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-klein-base-9B.py)|
|
||||
|模型 ID|推理|低显存推理|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|
|
||||
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
||||
|
||||
特殊训练脚本:
|
||||
|
||||
* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)
|
||||
* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)
|
||||
* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)
|
||||
* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)
|
||||
* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)、[code](/examples/flux/model_training/special/differential_training/)
|
||||
* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)、[code](/examples/flux/model_training/special/fp8_training/)
|
||||
* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/flux/model_training/special/split_training/)
|
||||
* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh)
|
||||
|
||||
## 模型推理
|
||||
|
||||
@@ -148,4 +135,4 @@ FLUX.2 系列模型统一通过 [`examples/flux2/model_training/train.py`](/exam
|
||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
```
|
||||
|
||||
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。
|
||||
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。
|
||||
@@ -1,109 +0,0 @@
|
||||
# LTX-2
|
||||
|
||||
LTX-2 是由 Lightricks 开发的音视频生成模型系列。
|
||||
|
||||
## 安装
|
||||
|
||||
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
更多关于安装的信息,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)。
|
||||
|
||||
## 快速开始
|
||||
|
||||
运行以下代码可以快速加载 [Lightricks/LTX-2](https://www.modelscope.cn/models/Lightricks/LTX-2) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8GB 显存即可运行。
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.ltx2_audio_video import LTX2AudioVideoPipeline, ModelConfig
|
||||
from diffsynth.utils.data.media_io_ltx2 import write_video_audio_ltx2
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e5m2,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e5m2,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e5m2,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = LTX2AudioVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized", origin_file_pattern="model-*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Lightricks/LTX-2", origin_file_pattern="ltx-2-19b-dev.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="google/gemma-3-12b-it-qat-q4_0-unquantized"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
prompt = "A girl is very happy, she is speaking: “I enjoy working with Diffsynth-Studio, it's a perfect framework.”"
|
||||
negative_prompt = "blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, off-sync audio, incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
|
||||
height, width, num_frames = 512, 768, 121
|
||||
video, audio = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=43,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
tiled=True,
|
||||
)
|
||||
write_video_audio_ltx2(
|
||||
video=video,
|
||||
audio=audio,
|
||||
output_path='ltx2_onestage.mp4',
|
||||
fps=24,
|
||||
audio_sample_rate=24000,
|
||||
)
|
||||
```
|
||||
|
||||
## 模型总览
|
||||
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|-|
|
||||
|[Lightricks/LTX-2: OneStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-OneStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: TwoStagePipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-TwoStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: DistilledPipeline-T2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)||[code](/examples/ltx2/model_inference/LTX-2-T2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-T2AV-DistilledPipeline.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: OneStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-OneStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-OneStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: TwoStagePipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-TwoStage.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-TwoStage.py)|-|-|-|-|
|
||||
|[Lightricks/LTX-2: DistilledPipeline-I2AV](https://www.modelscope.cn/models/Lightricks/LTX-2)|`input_images`|[code](/examples/ltx2/model_inference/LTX-2-I2AV-DistilledPipeline.py)|[code](/examples/ltx2/model_inference_low_vram/LTX-2-I2AV-DistilledPipeline.py)|-|-|-|-|
|
||||
|
||||
## 模型推理
|
||||
|
||||
模型通过 `LTX2AudioVideoPipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。
|
||||
|
||||
`LTX2AudioVideoPipeline` 推理的输入参数包括:
|
||||
|
||||
* `prompt`: 提示词,描述视频中出现的内容。
|
||||
* `negative_prompt`: 负向提示词,描述视频中不应该出现的内容,默认值为 `""`。
|
||||
* `cfg_scale`: Classifier-free guidance 的参数,默认值为 3.0。
|
||||
* `input_images`: 输入图像列表,用于图生视频。
|
||||
* `input_images_indexes`: 输入图像在视频中的帧索引列表。
|
||||
* `input_images_strength`: 输入图像的强度,默认值为 1.0。
|
||||
* `denoising_strength`: 去噪强度,范围是 0~1,默认值为 1.0。
|
||||
* `seed`: 随机种子。默认为 `None`,即完全随机。
|
||||
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
|
||||
* `height`: 视频高度,需保证高度为 32 的倍数(单阶段)或 64 的倍数(两阶段)。
|
||||
* `width`: 视频宽度,需保证宽度为 32 的倍数(单阶段)或 64 的倍数(两阶段)。
|
||||
* `num_frames`: 视频帧数,默认值为 121,需保证为 8 的倍数 + 1。
|
||||
* `num_inference_steps`: 推理次数,默认值为 40。
|
||||
* `tiled`: 是否启用 VAE 分块推理,默认为 `True`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用,会产生少许误差,以及少量推理时间延长。
|
||||
* `tile_size_in_pixels`: VAE 编解码阶段的像素分块大小,默认为 512。
|
||||
* `tile_overlap_in_pixels`: VAE 编解码阶段的像素分块重叠大小,默认为 128。
|
||||
* `tile_size_in_frames`: VAE 编解码阶段的帧分块大小,默认为 128。
|
||||
* `tile_overlap_in_frames`: VAE 编解码阶段的帧分块重叠大小,默认为 24。
|
||||
* `use_two_stage_pipeline`: 是否使用两阶段管道,默认为 `False`。
|
||||
* `use_distilled_pipeline`: 是否使用蒸馏管道,默认为 `False`。
|
||||
* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。
|
||||
|
||||
如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"支持的推理脚本"中的表格。
|
||||
|
||||
## 模型训练
|
||||
|
||||
LTX-2 系列模型目前暂不支持训练功能。我们将尽快添加相关支持。
|
||||
@@ -85,9 +85,7 @@ graph LR;
|
||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||
|[Qwen/Qwen-Image-Edit-2511](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2511)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2511.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2511.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2511.py)|
|
||||
|[lightx2v/Qwen-Image-Edit-2511-Lightning](https://modelscope.cn/models/lightx2v/Qwen-Image-Edit-2511-Lightning)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2511-Lightning.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2511-Lightning.py)|-|-|-|-|
|
||||
|[Qwen/Qwen-Image-Layered](https://www.modelscope.cn/models/Qwen/Qwen-Image-Layered)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Layered-Control](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Layered-Control)|[code](/examples/qwen_image/model_inference/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Layered-Control.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Layered-Control.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Layered-Control.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
||||
|
||||
@@ -52,12 +52,7 @@ image.save("image.jpg")
|
||||
|
||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Tongyi-MAI/Z-Image](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image)|[code](/examples/z_image/model_inference/Z-Image.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image.py)|[code](/examples/z_image/model_training/full/Z-Image.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image.py)|[code](/examples/z_image/model_training/lora/Z-Image.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image.py)|
|
||||
|[DiffSynth-Studio/Z-Image-i2L](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-i2L)|[code](/examples/z_image/model_inference/Z-Image-i2L.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-i2L.py)|-|-|-|-|
|
||||
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
||||
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1.py)|
|
||||
|[PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Union-2.1-8steps.py)|
|
||||
|[PAI/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps](https://www.modelscope.cn/models/PAI/Z-Image-Turbo-Fun-Controlnet-Union-2.1)|[code](/examples/z_image/model_inference/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo-Fun-Controlnet-Tile-2.1-8steps.py)|
|
||||
|
||||
特殊训练脚本:
|
||||
|
||||
@@ -80,9 +75,6 @@ image.save("image.jpg")
|
||||
* `seed`: 随机种子。默认为 `None`,即完全随机。
|
||||
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
|
||||
* `num_inference_steps`: 推理次数,默认值为 8。
|
||||
* `controlnet_inputs`: ControlNet 模型的输入。
|
||||
* `edit_image`: 编辑模型的待编辑图像,支持多张图像。
|
||||
* `positive_only_lora`: 仅在正向提示词中使用的 LoRA 权重。
|
||||
|
||||
如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
AMD 提供了基于 ROCm 的 torch 包,所以大多数模型无需修改代码即可运行,少数模型由于依赖特定的 cuda 指令无法运行。
|
||||
|
||||
## Ascend NPU
|
||||
### 推理
|
||||
|
||||
使用 Ascend NPU 时,需把代码中的 `"cuda"` 改为 `"npu"`。
|
||||
|
||||
例如,Wan2.1-T2V-1.3B 的推理代码:
|
||||
@@ -22,7 +22,6 @@ AMD 提供了基于 ROCm 的 torch 包,所以大多数模型无需修改代码
|
||||
import torch
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
from diffsynth.core.device.npu_compatible_device import get_device_name
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
@@ -34,7 +33,7 @@ vram_config = {
|
||||
+ "preparing_device": "npu",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
- "computation_device": "cuda",
|
||||
+ "computation_device": "npu",
|
||||
+ "preparing_device": "npu",
|
||||
}
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
@@ -47,7 +46,7 @@ pipe = WanVideoPipeline.from_pretrained(
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||
- vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
|
||||
+ vram_limit=torch.npu.mem_get_info(get_device_name())[1] / (1024 ** 3) - 2,
|
||||
+ vram_limit=torch.npu.mem_get_info("npu:0")[1] / (1024 ** 3) - 2,
|
||||
)
|
||||
|
||||
video = pipe(
|
||||
@@ -57,35 +56,3 @@ video = pipe(
|
||||
)
|
||||
save_video(video, "video.mp4", fps=15, quality=5)
|
||||
```
|
||||
|
||||
#### USP(Unified Sequence Parallel)
|
||||
如果想要在NPU上使用该特性,请通过如下方式安装额外的第三方库:
|
||||
```shell
|
||||
pip install git+https://github.com/feifeibear/long-context-attention.git
|
||||
pip install git+https://github.com/xdit-project/xDiT.git
|
||||
```
|
||||
|
||||
### 训练
|
||||
当前已为每类模型添加NPU的启动脚本样例,脚本存放在`examples/xxx/special/npu_training`目录下,例如 `examples/wanvideo/model_training/special/npu_training/Wan2.2-T2V-A14B-NPU.sh`。
|
||||
|
||||
在NPU训练脚本中,添加了可以优化性能的NPU特有环境变量,并针对特定模型开启了相关参数。
|
||||
|
||||
#### 环境变量
|
||||
```shell
|
||||
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
|
||||
```
|
||||
`expandable_segments:<value>`: 使能内存池扩展段功能,即虚拟内存特征。
|
||||
|
||||
```shell
|
||||
export CPU_AFFINITY_CONF=1
|
||||
```
|
||||
设置0或未设置: 表示不启用绑核功能
|
||||
|
||||
1: 表示开启粗粒度绑核
|
||||
|
||||
2: 表示开启细粒度绑核
|
||||
|
||||
#### 特定模型需要开启的参数
|
||||
| 模型 | 参数 | 备注 |
|
||||
|-----------|------|-------------------|
|
||||
| Wan 14B系列 | --initialize_model_on_cpu | 14B模型需要在cpu上进行初始化 |
|
||||
@@ -30,16 +30,11 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/rocm6
|
||||
|
||||
* Ascend NPU
|
||||
|
||||
1. 通过官方文档安装[CANN](https://www.hiascend.com/document/detail/zh/canncommercial/83RC1/softwareinst/instg/instg_quick.html?Mode=PmIns&InstallType=local&OS=openEuler&Software=cannToolKit)
|
||||
Ascend NPU 通过 `torch-npu` 包提供支持,以 `2.1.0.post17` 版本(本文更新于 2025 年 12 月 15 日)为例,请运行以下命令
|
||||
|
||||
2. 从源码安装
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
# aarch64/ARM
|
||||
pip install -e .[npu_aarch64] --extra-index-url "https://download.pytorch.org/whl/cpu"
|
||||
# x86
|
||||
pip install -e .[npu]
|
||||
```shell
|
||||
pip install torch-npu==2.1.0.post17
|
||||
```
|
||||
|
||||
使用 Ascend NPU 时,请将 Python 代码中的 `"cuda"` 改为 `"npu"`,详见[NPU 支持](/docs/zh/Pipeline_Usage/GPU_support.md#ascend-npu)。
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ graph LR;
|
||||
|
||||
本节介绍如何利用 `DiffSynth-Studio` 训练新的模型,帮助科研工作者探索新的模型技术。
|
||||
|
||||
* [从零开始训练模型](/docs/zh/Research_Tutorial/train_from_scratch.md)
|
||||
* 从零开始训练模型【coming soon】
|
||||
* 推理改进优化技术【coming soon】
|
||||
* 设计可控生成模型【coming soon】
|
||||
* 创建新的训练范式【coming soon】
|
||||
|
||||
@@ -1,477 +0,0 @@
|
||||
# 从零开始训练模型
|
||||
|
||||
DiffSynth-Studio 的训练引擎支持从零开始训练基础模型,本文介绍如何从零开始训练一个参数量仅为 0.1B 的小型文生图模型。
|
||||
|
||||
## 1. 构建模型结构
|
||||
|
||||
### 1.1 Diffusion 模型
|
||||
|
||||
从 UNet [[1]](https://arxiv.org/abs/1505.04597) [[2]](https://arxiv.org/abs/2112.10752) 到 DiT [[3]](https://arxiv.org/abs/2212.09748) [[4]](https://arxiv.org/abs/2403.03206),Diffusion 的主流模型结构经历了多次演变。通常,一个 Diffusion 模型的输入包括:
|
||||
|
||||
* 图像张量(`latents`):图像的编码,由 VAE 模型产生,含有部分噪声
|
||||
* 文本张量(`prompt_embeds`):文本的编码,由文本编码器产生
|
||||
* 时间步(`timestep`):标量,用于标记当前处于 Diffusion 过程的哪个阶段
|
||||
|
||||
模型的输出是与图像张量形状相同的张量,表示模型预测的去噪方向,关于 Diffusion 模型理论的细节,请参考 [Diffusion 模型基本原理](/docs/zh/Training/Understanding_Diffusion_models.md)。在本文中,我们构建一个仅含 0.1B 参数的 DiT 模型:`AAADiT`。
|
||||
|
||||
<details>
|
||||
<summary>模型结构代码</summary>
|
||||
|
||||
```python
|
||||
import torch, accelerate
|
||||
from PIL import Image
|
||||
from typing import Union
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
|
||||
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
|
||||
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
from diffsynth.models.general_modules import TimestepEmbeddings
|
||||
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
|
||||
from diffsynth.models.flux2_vae import Flux2VAE
|
||||
|
||||
|
||||
class AAAPositionalEmbedding(torch.nn.Module):
|
||||
def __init__(self, height=16, width=16, dim=1024):
|
||||
super().__init__()
|
||||
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
|
||||
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
|
||||
|
||||
def forward(self, image, text):
|
||||
height, width = image.shape[-2:]
|
||||
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
|
||||
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
|
||||
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
|
||||
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
|
||||
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
|
||||
emb = torch.concat([image_emb, text_emb], dim=1)
|
||||
return emb
|
||||
|
||||
|
||||
class AAABlock(torch.nn.Module):
|
||||
def __init__(self, dim=1024, num_heads=32):
|
||||
super().__init__()
|
||||
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||
self.to_q = torch.nn.Linear(dim, dim)
|
||||
self.to_k = torch.nn.Linear(dim, dim)
|
||||
self.to_v = torch.nn.Linear(dim, dim)
|
||||
self.to_out = torch.nn.Linear(dim, dim)
|
||||
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||
self.ff = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*3),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(dim*3, dim),
|
||||
)
|
||||
self.to_gate = torch.nn.Linear(dim, dim * 2)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def attention(self, emb, pos_emb):
|
||||
emb = self.norm_attn(emb + pos_emb)
|
||||
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
|
||||
emb = attention_forward(
|
||||
q, k, v,
|
||||
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
|
||||
dims={"n": self.num_heads},
|
||||
)
|
||||
emb = self.to_out(emb)
|
||||
return emb
|
||||
|
||||
def feed_forward(self, emb, pos_emb):
|
||||
emb = self.norm_mlp(emb + pos_emb)
|
||||
emb = self.ff(emb)
|
||||
return emb
|
||||
|
||||
def forward(self, emb, pos_emb, t_emb):
|
||||
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
|
||||
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
|
||||
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
|
||||
return emb
|
||||
|
||||
|
||||
class AAADiT(torch.nn.Module):
|
||||
def __init__(self, dim=1024):
|
||||
super().__init__()
|
||||
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
|
||||
self.timestep_embedder = TimestepEmbeddings(256, dim)
|
||||
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
|
||||
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
|
||||
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
|
||||
self.proj_out = torch.nn.Linear(dim, 128)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents,
|
||||
prompt_embeds,
|
||||
timestep,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
pos_emb = self.pos_embedder(latents, prompt_embeds)
|
||||
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
|
||||
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
|
||||
text = self.text_embedder(prompt_embeds)
|
||||
emb = torch.concat([image, text], dim=1)
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
emb = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
emb=emb,
|
||||
pos_emb=pos_emb,
|
||||
t_emb=t_emb,
|
||||
)
|
||||
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
|
||||
emb = self.proj_out(emb)
|
||||
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
|
||||
return emb
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### 1.2 编解码器模型
|
||||
|
||||
除了用于去噪的 Diffusion 模型以外,我们还需要另外两个模型:
|
||||
|
||||
* 文本编码器:用于将文本编码为张量。我们采用 [Qwen/Qwen3-0.6B](https://modelscope.cn/models/Qwen/Qwen3-0.6B) 模型。
|
||||
* VAE 编解码器:编码器部分用于将图像编码为张量,解码器部分用于将图像张量解码为图像。我们采用 [black-forest-labs/FLUX.2-klein-4B](https://modelscope.cn/models/black-forest-labs/FLUX.2-klein-4B) 中的 VAE 模型。
|
||||
|
||||
这两个模型的结构都已集成在 DiffSynth-Studio 中,分别位于 [/diffsynth/models/z_image_text_encoder.py](/diffsynth/models/z_image_text_encoder.py) 和 [/diffsynth/models/flux2_vae.py](/diffsynth/models/flux2_vae.py),因此我们不需要修改任何代码。
|
||||
|
||||
## 2. 构建 Pipeline
|
||||
|
||||
我们在文档 [接入 Pipeline](/docs/zh/Developer_Guide/Building_a_Pipeline.md) 中介绍了如何构建一个模型 Pipeline,对于本文中的模型,我们也需要构建一个 Pipeline,连接文本编码器、Diffusion 模型、VAE 编解码器。
|
||||
|
||||
<details>
|
||||
<summary>Pipeline 代码</summary>
|
||||
|
||||
```python
|
||||
class AAAImagePipeline(BasePipeline):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler("FLUX.2")
|
||||
self.text_encoder: ZImageTextEncoder = None
|
||||
self.dit: AAADiT = None
|
||||
self.vae: Flux2VAE = None
|
||||
self.tokenizer: AutoProcessor = None
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
AAAUnit_PromptEmbedder(),
|
||||
AAAUnit_NoiseInitializer(),
|
||||
AAAUnit_InputImageEmbedder(),
|
||||
]
|
||||
self.model_fn = model_fn_aaa
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = None,
|
||||
vram_limit: float = None,
|
||||
):
|
||||
# Initialize pipeline
|
||||
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
|
||||
pipe.dit = model_pool.fetch_model("aaa_dit")
|
||||
pipe.vae = model_pool.fetch_model("flux2_vae")
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 1.0,
|
||||
# Image
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Shape
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
# Randomness
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 30,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {"prompt": prompt}
|
||||
inputs_nega = {"negative_prompt": negative_prompt}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
image = self.vae.decode(inputs_shared["latents"])
|
||||
image = self.vae_output_to_image(image)
|
||||
self.load_models_to_device([])
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class AAAUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("prompt_embeds",),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
self.hidden_states_layers = (-1,)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, prompt):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
text = pipe.tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
|
||||
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
|
||||
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
|
||||
return {"prompt_embeds": prompt_embeds}
|
||||
|
||||
|
||||
class AAAUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "seed", "rand_device"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
|
||||
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
class AAAUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "noise"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, input_image, noise):
|
||||
if input_image is None:
|
||||
return {"latents": noise, "input_latents": None}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
image = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae.encode(image)
|
||||
if pipe.scheduler.training:
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
|
||||
|
||||
def model_fn_aaa(
|
||||
dit: AAADiT,
|
||||
latents=None,
|
||||
prompt_embeds=None,
|
||||
timestep=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
model_output = dit(
|
||||
latents,
|
||||
prompt_embeds,
|
||||
timestep,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
return model_output
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 3. 准备数据集
|
||||
|
||||
为了快速验证训练效果,我们使用数据集 [宝可梦-第一世代](https://modelscope.cn/datasets/DiffSynth-Studio/pokemon-gen1),这个数据集转载自开源项目 [pokemon-dataset-zh](https://github.com/42arch/pokemon-dataset-zh),包含从妙蛙种子到梦幻的 151 个第一世代宝可梦。如果你想使用其他数据集,请参考文档 [准备数据集](/docs/zh/Pipeline_Usage/Model_Training.md#准备数据集) 和 [`diffsynth.core.data`](/docs/zh/API_Reference/core/data.md)。
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/pokemon-gen1 --local_dir ./data
|
||||
```
|
||||
|
||||
### 4. 开始训练
|
||||
|
||||
训练过程可使用 Pipeline 快速实现,我们已将完整的代码放在 [/docs/zh/Research_Tutorial/train_from_scratch.py](/docs/zh/Research_Tutorial/train_from_scratch.py),可直接通过 `python docs/zh/Research_Tutorial/train_from_scratch.py` 开始单 GPU 训练。
|
||||
|
||||
如需开启多 GPU 并行训练,请运行 `accelerate config` 设置相关参数,然后使用命令 `accelerate launch docs/zh/Research_Tutorial/train_from_scratch.py` 开始训练。
|
||||
|
||||
这个训练脚本没有设置停止条件,请在需要时手动关闭。模型在训练大约 6 万步后收敛,单 GPU 训练需要 10~20 小时。
|
||||
|
||||
|
||||
<details>
|
||||
<summary>训练代码</summary>
|
||||
|
||||
```python
|
||||
class AAATrainingModule(DiffusionTrainingModule):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.pipe = AAAImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device=device,
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||
)
|
||||
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
|
||||
self.pipe.freeze_except(["dit"])
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
def forward(self, data):
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
inputs_nega = {"negative_prompt": ""}
|
||||
inputs_shared = {
|
||||
"input_image": data["image"],
|
||||
"height": data["image"].size[1],
|
||||
"width": data["image"].size[0],
|
||||
"cfg_scale": 1,
|
||||
"use_gradient_checkpointing": False,
|
||||
"use_gradient_checkpointing_offload": False,
|
||||
}
|
||||
for unit in self.pipe.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
|
||||
return loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
|
||||
dataset = UnifiedDataset(
|
||||
base_path="data/images",
|
||||
metadata_path="data/metadata_merged.csv",
|
||||
max_data_items=10000000,
|
||||
data_file_keys=("image",),
|
||||
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
|
||||
)
|
||||
model = AAATrainingModule(device=accelerator.device)
|
||||
model_logger = ModelLogger(
|
||||
"models/AAA/v1",
|
||||
remove_prefix_in_ckpt="pipe.dit.",
|
||||
)
|
||||
launch_training_task(
|
||||
accelerator, dataset, model, model_logger,
|
||||
learning_rate=2e-4,
|
||||
num_workers=4,
|
||||
save_steps=50000,
|
||||
num_epochs=999999,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## 5. 验证训练效果
|
||||
|
||||
如果你不想等待模型训练完成,可以直接下载[我们预先训练好的模型](https://modelscope.cn/models/DiffSynth-Studio/AAAMyModel)。
|
||||
|
||||
```shell
|
||||
modelscope download --model DiffSynth-Studio/AAAMyModel step-600000.safetensors --local_dir models/DiffSynth-Studio/AAAMyModel
|
||||
```
|
||||
|
||||
加载模型
|
||||
|
||||
```python
|
||||
from diffsynth import load_model
|
||||
|
||||
pipe = AAAImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||
)
|
||||
pipe.dit = load_model(AAADiT, "models/DiffSynth-Studio/AAAMyModel/step-600000.safetensors", torch_dtype=torch.bfloat16, device="cuda")
|
||||
```
|
||||
|
||||
模型推理,生成第一世代宝可梦“御三家”,此时模型生成的图像内容与训练数据基本一致。
|
||||
|
||||
```python
|
||||
for seed, prompt in enumerate([
|
||||
"green, lizard, plant, Grass, Poison, seed on back, red eyes, smiling expression, short stout limbs, sharp claws",
|
||||
"orange, cream, lizard, Fire, flame on tail tip, large eyes, smiling expression, cream-colored belly patch, sharp claws",
|
||||
"蓝色,米色,棕色,乌龟,水系,龟壳,大眼睛,短四肢,卷曲尾巴",
|
||||
]):
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=" ",
|
||||
num_inference_steps=30,
|
||||
cfg_scale=10,
|
||||
seed=seed,
|
||||
height=256, width=256,
|
||||
)
|
||||
image.save(f"image_{seed}.jpg")
|
||||
```
|
||||
|
||||
||||
|
||||
|-|-|-|
|
||||
|
||||
模型推理,生成具有“锐利爪子”的宝可梦,此时不同的随机种子能够产生不同的图像结果。
|
||||
|
||||
```python
|
||||
for seed, prompt in enumerate([
|
||||
"sharp claws",
|
||||
"sharp claws",
|
||||
"sharp claws",
|
||||
]):
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=" ",
|
||||
num_inference_steps=30,
|
||||
cfg_scale=10,
|
||||
seed=seed+4,
|
||||
height=256, width=256,
|
||||
)
|
||||
image.save(f"image_sharp_claws_{seed}.jpg")
|
||||
```
|
||||
|
||||
||||
|
||||
|-|-|-|
|
||||
|
||||
现在,我们获得了一个 0.1B 的小型文生图模型,这个模型已经能够生成 151 个宝可梦,但无法生成其他图像内容。如果在此基础上增加数据量、模型参数量、GPU 数量,你就可以训练出一个更强大的文生图模型!
|
||||
@@ -1,341 +0,0 @@
|
||||
import torch, accelerate
|
||||
from PIL import Image
|
||||
from typing import Union
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from transformers import AutoProcessor, AutoTokenizer
|
||||
from diffsynth.core import ModelConfig, gradient_checkpoint_forward, attention_forward, UnifiedDataset, load_model
|
||||
from diffsynth.diffusion import FlowMatchScheduler, DiffusionTrainingModule, FlowMatchSFTLoss, ModelLogger, launch_training_task
|
||||
from diffsynth.diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
from diffsynth.models.general_modules import TimestepEmbeddings
|
||||
from diffsynth.models.z_image_text_encoder import ZImageTextEncoder
|
||||
from diffsynth.models.flux2_vae import Flux2VAE
|
||||
|
||||
|
||||
class AAAPositionalEmbedding(torch.nn.Module):
|
||||
def __init__(self, height=16, width=16, dim=1024):
|
||||
super().__init__()
|
||||
self.image_emb = torch.nn.Parameter(torch.randn((1, dim, height, width)))
|
||||
self.text_emb = torch.nn.Parameter(torch.randn((dim,)))
|
||||
|
||||
def forward(self, image, text):
|
||||
height, width = image.shape[-2:]
|
||||
image_emb = self.image_emb.to(device=image.device, dtype=image.dtype)
|
||||
image_emb = torch.nn.functional.interpolate(image_emb, size=(height, width), mode="bilinear")
|
||||
image_emb = rearrange(image_emb, "B C H W -> B (H W) C")
|
||||
text_emb = self.text_emb.to(device=text.device, dtype=text.dtype)
|
||||
text_emb = repeat(text_emb, "C -> B L C", B=text.shape[0], L=text.shape[1])
|
||||
emb = torch.concat([image_emb, text_emb], dim=1)
|
||||
return emb
|
||||
|
||||
|
||||
class AAABlock(torch.nn.Module):
|
||||
def __init__(self, dim=1024, num_heads=32):
|
||||
super().__init__()
|
||||
self.norm_attn = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||
self.to_q = torch.nn.Linear(dim, dim)
|
||||
self.to_k = torch.nn.Linear(dim, dim)
|
||||
self.to_v = torch.nn.Linear(dim, dim)
|
||||
self.to_out = torch.nn.Linear(dim, dim)
|
||||
self.norm_mlp = torch.nn.RMSNorm(dim, elementwise_affine=False)
|
||||
self.ff = torch.nn.Sequential(
|
||||
torch.nn.Linear(dim, dim*3),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(dim*3, dim),
|
||||
)
|
||||
self.to_gate = torch.nn.Linear(dim, dim * 2)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def attention(self, emb, pos_emb):
|
||||
emb = self.norm_attn(emb + pos_emb)
|
||||
q, k, v = self.to_q(emb), self.to_k(emb), self.to_v(emb)
|
||||
emb = attention_forward(
|
||||
q, k, v,
|
||||
q_pattern="b s (n d)", k_pattern="b s (n d)", v_pattern="b s (n d)", out_pattern="b s (n d)",
|
||||
dims={"n": self.num_heads},
|
||||
)
|
||||
emb = self.to_out(emb)
|
||||
return emb
|
||||
|
||||
def feed_forward(self, emb, pos_emb):
|
||||
emb = self.norm_mlp(emb + pos_emb)
|
||||
emb = self.ff(emb)
|
||||
return emb
|
||||
|
||||
def forward(self, emb, pos_emb, t_emb):
|
||||
gate_attn, gate_mlp = self.to_gate(t_emb).chunk(2, dim=-1)
|
||||
emb = emb + self.attention(emb, pos_emb) * (1 + gate_attn)
|
||||
emb = emb + self.feed_forward(emb, pos_emb) * (1 + gate_mlp)
|
||||
return emb
|
||||
|
||||
|
||||
class AAADiT(torch.nn.Module):
|
||||
def __init__(self, dim=1024):
|
||||
super().__init__()
|
||||
self.pos_embedder = AAAPositionalEmbedding(dim=dim)
|
||||
self.timestep_embedder = TimestepEmbeddings(256, dim)
|
||||
self.image_embedder = torch.nn.Sequential(torch.nn.Linear(128, dim), torch.nn.LayerNorm(dim))
|
||||
self.text_embedder = torch.nn.Sequential(torch.nn.Linear(1024, dim), torch.nn.LayerNorm(dim))
|
||||
self.blocks = torch.nn.ModuleList([AAABlock(dim) for _ in range(10)])
|
||||
self.proj_out = torch.nn.Linear(dim, 128)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents,
|
||||
prompt_embeds,
|
||||
timestep,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
):
|
||||
pos_emb = self.pos_embedder(latents, prompt_embeds)
|
||||
t_emb = self.timestep_embedder(timestep, dtype=latents.dtype).view(1, 1, -1)
|
||||
image = self.image_embedder(rearrange(latents, "B C H W -> B (H W) C"))
|
||||
text = self.text_embedder(prompt_embeds)
|
||||
emb = torch.concat([image, text], dim=1)
|
||||
for block_id, block in enumerate(self.blocks):
|
||||
emb = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
emb=emb,
|
||||
pos_emb=pos_emb,
|
||||
t_emb=t_emb,
|
||||
)
|
||||
emb = emb[:, :latents.shape[-1] * latents.shape[-2]]
|
||||
emb = self.proj_out(emb)
|
||||
emb = rearrange(emb, "B (H W) C -> B C H W", W=latents.shape[-1])
|
||||
return emb
|
||||
|
||||
|
||||
class AAAImagePipeline(BasePipeline):
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__(
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler("FLUX.2")
|
||||
self.text_encoder: ZImageTextEncoder = None
|
||||
self.dit: AAADiT = None
|
||||
self.vae: Flux2VAE = None
|
||||
self.tokenizer: AutoProcessor = None
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
AAAUnit_PromptEmbedder(),
|
||||
AAAUnit_NoiseInitializer(),
|
||||
AAAUnit_InputImageEmbedder(),
|
||||
]
|
||||
self.model_fn = model_fn_aaa
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
tokenizer_config: ModelConfig = None,
|
||||
vram_limit: float = None,
|
||||
):
|
||||
# Initialize pipeline
|
||||
pipe = AAAImagePipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("z_image_text_encoder")
|
||||
pipe.dit = model_pool.fetch_model("aaa_dit")
|
||||
pipe.vae = model_pool.fetch_model("flux2_vae")
|
||||
if tokenizer_config is not None:
|
||||
tokenizer_config.download_if_necessary()
|
||||
pipe.tokenizer = AutoTokenizer.from_pretrained(tokenizer_config.path)
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
# Prompt
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 1.0,
|
||||
# Image
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
# Shape
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
# Randomness
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
# Steps
|
||||
num_inference_steps: int = 30,
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {"prompt": prompt}
|
||||
inputs_nega = {"negative_prompt": negative_prompt}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"input_image": input_image, "denoising_strength": denoising_strength,
|
||||
"height": height, "width": width,
|
||||
"seed": seed, "rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale,
|
||||
inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
image = self.vae.decode(inputs_shared["latents"])
|
||||
image = self.vae_output_to_image(image)
|
||||
self.load_models_to_device([])
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class AAAUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
output_params=("prompt_embeds",),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
self.hidden_states_layers = (-1,)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, prompt):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
text = pipe.tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = pipe.tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(pipe.device)
|
||||
output = pipe.text_encoder(**inputs, output_hidden_states=True, use_cache=False)
|
||||
prompt_embeds = torch.concat([output.hidden_states[k] for k in self.hidden_states_layers], dim=-1)
|
||||
return {"prompt_embeds": prompt_embeds}
|
||||
|
||||
|
||||
class AAAUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "seed", "rand_device"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, height, width, seed, rand_device):
|
||||
noise = pipe.generate_noise((1, 128, height//16, width//16), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
return {"noise": noise}
|
||||
|
||||
|
||||
class AAAUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "noise"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: AAAImagePipeline, input_image, noise):
|
||||
if input_image is None:
|
||||
return {"latents": noise, "input_latents": None}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
image = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae.encode(image)
|
||||
if pipe.scheduler.training:
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
|
||||
|
||||
def model_fn_aaa(
|
||||
dit: AAADiT,
|
||||
latents=None,
|
||||
prompt_embeds=None,
|
||||
timestep=None,
|
||||
use_gradient_checkpointing=False,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
**kwargs,
|
||||
):
|
||||
model_output = dit(
|
||||
latents,
|
||||
prompt_embeds,
|
||||
timestep,
|
||||
use_gradient_checkpointing=use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
|
||||
)
|
||||
return model_output
|
||||
|
||||
|
||||
class AAATrainingModule(DiffusionTrainingModule):
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
self.pipe = AAAImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device=device,
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen3-0.6B", origin_file_pattern="./"),
|
||||
)
|
||||
self.pipe.dit = AAADiT().to(dtype=torch.bfloat16, device=device)
|
||||
self.pipe.freeze_except(["dit"])
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
def forward(self, data):
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
inputs_nega = {"negative_prompt": ""}
|
||||
inputs_shared = {
|
||||
"input_image": data["image"],
|
||||
"height": data["image"].size[1],
|
||||
"width": data["image"].size[0],
|
||||
"cfg_scale": 1,
|
||||
"use_gradient_checkpointing": False,
|
||||
"use_gradient_checkpointing_offload": False,
|
||||
}
|
||||
for unit in self.pipe.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
|
||||
return loss
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
accelerator = accelerate.Accelerator(gradient_accumulation_steps=1)
|
||||
dataset = UnifiedDataset(
|
||||
base_path="data/images",
|
||||
metadata_path="data/metadata_merged.csv",
|
||||
max_data_items=10000000,
|
||||
data_file_keys=("image",),
|
||||
main_data_operator=UnifiedDataset.default_image_operator(base_path="data/images", height=256, width=256)
|
||||
)
|
||||
model = AAATrainingModule(device=accelerator.device)
|
||||
model_logger = ModelLogger(
|
||||
"models/AAA/v1",
|
||||
remove_prefix_in_ckpt="pipe.dit.",
|
||||
)
|
||||
launch_training_task(
|
||||
accelerator, dataset, model, model_logger,
|
||||
learning_rate=2e-4,
|
||||
num_workers=4,
|
||||
save_steps=50000,
|
||||
num_epochs=999999,
|
||||
)
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
Diffusion 模型通过多步迭代式地去噪(denoise)生成清晰的图像或视频内容,我们从一个数据样本 $x_0$ 的生成过程开始讲起。直观地,在完整的一轮 denoise 过程中,我们从随机高斯噪声 $x_T$ 开始,通过迭代依次得到 $x_{T-1}$、$x_{T-2}$、$x_{T-3}$、$\cdots$,在每一步中逐渐减少噪声含量,最终得到不含噪声的数据样本 $x_0$。
|
||||
|
||||

|
||||
(图)
|
||||
|
||||
这个过程是很直观的,但如果要理解其中的细节,我们就需要回答这几个问题:
|
||||
|
||||
@@ -28,7 +28,7 @@ Diffusion 模型通过多步迭代式地去噪(denoise)生成清晰的图像
|
||||
|
||||
那么在中间的某一步,我们可以直接合成含噪声的数据样本 $x_t=(1-\sigma_t)x_0+\sigma_t x_T$。
|
||||
|
||||

|
||||
(图)
|
||||
|
||||
## 迭代去噪的计算是如何进行的?
|
||||
|
||||
@@ -40,6 +40,8 @@ Diffusion 模型通过多步迭代式地去噪(denoise)生成清晰的图像
|
||||
|
||||
其中,引导条件 $c$ 是新引入的参数,它是由用户输入的,可以是用于描述图像内容的文本,也可以是用于勾勒图像结构的线稿图。
|
||||
|
||||
(图)
|
||||
|
||||
而模型的输出 $\hat \epsilon(x_t,c,t)$,则近似地等于 $x_T-x_0$,也就是整个扩散过程(去噪过程的反向过程)的方向。
|
||||
|
||||
接下来我们分析一步迭代中发生的计算,在时间步 $t$,模型通过计算得到近似的 $x_T-x_0$ 后,我们计算下一步的 $x_{t-1}$:
|
||||
@@ -87,6 +89,8 @@ $$
|
||||
|
||||
训练过程不同于生成过程,如果我们在训练过程中保留多步迭代,那么梯度需经过多步回传,带来的时间和空间复杂度是灾难性的。为了提高计算效率,我们在训练中随机选择某一时间步 $t$ 进行训练。
|
||||
|
||||
(图)
|
||||
|
||||
以下是训练过程的伪代码
|
||||
|
||||
> 从数据集获取数据样本 $x_0$ 和引导条件 $c$
|
||||
@@ -107,7 +111,7 @@ $$
|
||||
|
||||
从理论到实践,还需要填充更多细节。现代 Diffusion 模型架构已经发展成熟,主流的架构沿用了 Latent Diffusion 所提出的“三段式”架构,包括数据编解码器、引导条件编码器、去噪模型三部分。
|
||||
|
||||

|
||||
(图)
|
||||
|
||||
### 数据编解码器
|
||||
|
||||
|
||||
@@ -108,14 +108,7 @@ def test_flux():
|
||||
run_inference("examples/flux/model_training/validate_lora")
|
||||
|
||||
|
||||
def test_z_image():
|
||||
run_inference("examples/z_image/model_inference")
|
||||
run_inference("examples/z_image/model_inference_low_vram")
|
||||
run_train_multi_GPU("examples/z_image/model_training/full")
|
||||
run_inference("examples/z_image/model_training/validate_full")
|
||||
run_train_single_GPU("examples/z_image/model_training/lora")
|
||||
run_inference("examples/z_image/model_training/validate_lora")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_z_image()
|
||||
test_qwen_image()
|
||||
test_flux()
|
||||
test_wan()
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
|
||||
export CPU_AFFINITY_CONF=1
|
||||
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config_zero2offload.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata_kontext.csv \
|
||||
--data_file_keys "image,kontext_images" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-Kontext-dev:flux1-kontext-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-Kontext-dev_full" \
|
||||
--trainable_models "dit" \
|
||||
--extra_inputs "kontext_images" \
|
||||
--use_gradient_checkpointing
|
||||
@@ -1,15 +0,0 @@
|
||||
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
|
||||
export CPU_AFFINITY_CONF=1
|
||||
|
||||
accelerate launch --config_file examples/flux/model_training/full/accelerate_config_zero2offload.yaml examples/flux/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 400 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/*.safetensors,black-forest-labs/FLUX.1-dev:ae.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 1 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.1-dev_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing
|
||||
@@ -1,21 +0,0 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
||||
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
|
||||
image.save("image_FLUX.2-klein-4B.jpg")
|
||||
|
||||
prompt = "change the color of the clothes to red"
|
||||
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4)
|
||||
image.save("image_edit_FLUX.2-klein-4B.jpg")
|
||||
@@ -1,21 +0,0 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
||||
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
|
||||
image.save("image_FLUX.2-klein-9B.jpg")
|
||||
|
||||
prompt = "change the color of the clothes to red"
|
||||
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4)
|
||||
image.save("image_edit_FLUX.2-klein-9B.jpg")
|
||||
@@ -1,21 +0,0 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
||||
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||
image.save("image_FLUX.2-klein-base-4B.jpg")
|
||||
|
||||
prompt = "change the color of the clothes to red"
|
||||
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||
image.save("image_edit_FLUX.2-klein-base-4B.jpg")
|
||||
@@ -1,21 +0,0 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-9B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
||||
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||
image.save("image_FLUX.2-klein-base-9B.jpg")
|
||||
|
||||
prompt = "change the color of the clothes to red"
|
||||
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||
image.save("image_edit_FLUX.2-klein-base-9B.jpg")
|
||||
@@ -1,32 +0,0 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
||||
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
|
||||
image.save("image_FLUX.2-klein-4B.jpg")
|
||||
|
||||
prompt = "change the color of the clothes to red"
|
||||
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4)
|
||||
image.save("image_edit_FLUX.2-klein-4B.jpg")
|
||||
@@ -1,32 +0,0 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
||||
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=4)
|
||||
image.save("image_FLUX.2-klein-9B.jpg")
|
||||
|
||||
prompt = "change the color of the clothes to red"
|
||||
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=4)
|
||||
image.save("image_edit_FLUX.2-klein-9B.jpg")
|
||||
@@ -1,32 +0,0 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
||||
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||
image.save("image_FLUX.2-klein-base-4B.jpg")
|
||||
|
||||
prompt = "change the color of the clothes to red"
|
||||
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||
image.save("image_edit_FLUX.2-klein-base-4B.jpg")
|
||||
@@ -1,32 +0,0 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-9B", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
prompt = "Masterpiece, best quality. Anime-style portrait of a woman in a blue dress, underwater, surrounded by colorful bubbles."
|
||||
image = pipe(prompt, seed=0, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||
image.save("image_FLUX.2-klein-base-9B.jpg")
|
||||
|
||||
prompt = "change the color of the clothes to red"
|
||||
image = pipe(prompt, edit_image=[image], seed=1, rand_device="cuda", num_inference_steps=50, cfg_scale=4)
|
||||
image.save("image_edit_FLUX.2-klein-base-9B.jpg")
|
||||
@@ -1,30 +0,0 @@
|
||||
accelerate launch examples/flux2/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.2-klein-4B_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing
|
||||
|
||||
# Edit
|
||||
# accelerate launch examples/flux2/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
||||
# --data_file_keys "image,edit_image" \
|
||||
# --extra_inputs "edit_image" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
|
||||
# --learning_rate 1e-5 \
|
||||
# --num_epochs 2 \
|
||||
# --remove_prefix_in_ckpt "pipe.dit." \
|
||||
# --output_path "./models/train/FLUX.2-klein-4B_full" \
|
||||
# --trainable_models "dit" \
|
||||
# --use_gradient_checkpointing
|
||||
@@ -1,31 +0,0 @@
|
||||
# This script is tested on 8*A100
|
||||
accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.2-klein-9B_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing
|
||||
|
||||
# Edit
|
||||
# accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
||||
# --data_file_keys "image,edit_image" \
|
||||
# --extra_inputs "edit_image" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
|
||||
# --learning_rate 1e-5 \
|
||||
# --num_epochs 2 \
|
||||
# --remove_prefix_in_ckpt "pipe.dit." \
|
||||
# --output_path "./models/train/FLUX.2-klein-9B_full" \
|
||||
# --trainable_models "dit" \
|
||||
# --use_gradient_checkpointing
|
||||
@@ -1,30 +0,0 @@
|
||||
accelerate launch examples/flux2/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.2-klein-base-4B_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing
|
||||
|
||||
# Edit
|
||||
# accelerate launch examples/flux2/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
||||
# --data_file_keys "image,edit_image" \
|
||||
# --extra_inputs "edit_image" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
|
||||
# --learning_rate 1e-5 \
|
||||
# --num_epochs 2 \
|
||||
# --remove_prefix_in_ckpt "pipe.dit." \
|
||||
# --output_path "./models/train/FLUX.2-klein-base-4B_full" \
|
||||
# --trainable_models "dit" \
|
||||
# --use_gradient_checkpointing
|
||||
@@ -1,31 +0,0 @@
|
||||
# This script is tested on 8*A100
|
||||
accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.2-klein-base-9B_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing
|
||||
|
||||
# Edit
|
||||
# accelerate launch --config_file examples/flux2/model_training/full/accelerate_config.yaml examples/flux2/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
||||
# --data_file_keys "image,edit_image" \
|
||||
# --extra_inputs "edit_image" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
|
||||
# --learning_rate 1e-5 \
|
||||
# --num_epochs 2 \
|
||||
# --remove_prefix_in_ckpt "pipe.dit." \
|
||||
# --output_path "./models/train/FLUX.2-klein-base-9B_full" \
|
||||
# --trainable_models "dit" \
|
||||
# --use_gradient_checkpointing
|
||||
@@ -1,22 +0,0 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 1
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -1,34 +0,0 @@
|
||||
accelerate launch examples/flux2/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.2-klein-4B_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing
|
||||
|
||||
# Edit
|
||||
# accelerate launch examples/flux2/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
||||
# --data_file_keys "image,edit_image" \
|
||||
# --extra_inputs "edit_image" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
|
||||
# --learning_rate 1e-4 \
|
||||
# --num_epochs 5 \
|
||||
# --remove_prefix_in_ckpt "pipe.dit." \
|
||||
# --output_path "./models/train/FLUX.2-klein-4B_lora" \
|
||||
# --lora_base_model "dit" \
|
||||
# --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \
|
||||
# --lora_rank 32 \
|
||||
# --use_gradient_checkpointing
|
||||
@@ -1,34 +0,0 @@
|
||||
accelerate launch examples/flux2/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.2-klein-9B_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing
|
||||
|
||||
# Edit
|
||||
# accelerate launch examples/flux2/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
||||
# --data_file_keys "image,edit_image" \
|
||||
# --extra_inputs "edit_image" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
|
||||
# --learning_rate 1e-4 \
|
||||
# --num_epochs 5 \
|
||||
# --remove_prefix_in_ckpt "pipe.dit." \
|
||||
# --output_path "./models/train/FLUX.2-klein-9B_lora" \
|
||||
# --lora_base_model "dit" \
|
||||
# --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \
|
||||
# --lora_rank 32 \
|
||||
# --use_gradient_checkpointing
|
||||
@@ -1,34 +0,0 @@
|
||||
accelerate launch examples/flux2/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.2-klein-base-4B_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing
|
||||
|
||||
# Edit
|
||||
# accelerate launch examples/flux2/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
||||
# --data_file_keys "image,edit_image" \
|
||||
# --extra_inputs "edit_image" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-4B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-4B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-4B:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --tokenizer_path "black-forest-labs/FLUX.2-klein-4B:tokenizer/" \
|
||||
# --learning_rate 1e-4 \
|
||||
# --num_epochs 5 \
|
||||
# --remove_prefix_in_ckpt "pipe.dit." \
|
||||
# --output_path "./models/train/FLUX.2-klein-base-4B_lora" \
|
||||
# --lora_base_model "dit" \
|
||||
# --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out" \
|
||||
# --lora_rank 32 \
|
||||
# --use_gradient_checkpointing
|
||||
@@ -1,34 +0,0 @@
|
||||
accelerate launch examples/flux2/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/FLUX.2-klein-base-9B_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing
|
||||
|
||||
# Edit
|
||||
# accelerate launch examples/flux2/model_training/train.py \
|
||||
# --dataset_base_path data/example_image_dataset \
|
||||
# --dataset_metadata_path data/example_image_dataset/metadata_qwen_imgae_edit_multi.json \
|
||||
# --data_file_keys "image,edit_image" \
|
||||
# --extra_inputs "edit_image" \
|
||||
# --max_pixels 1048576 \
|
||||
# --dataset_repeat 50 \
|
||||
# --model_id_with_origin_paths "black-forest-labs/FLUX.2-klein-9B:text_encoder/*.safetensors,black-forest-labs/FLUX.2-klein-base-9B:transformer/*.safetensors,black-forest-labs/FLUX.2-klein-9B:vae/diffusion_pytorch_model.safetensors" \
|
||||
# --tokenizer_path "black-forest-labs/FLUX.2-klein-9B:tokenizer/" \
|
||||
# --learning_rate 1e-4 \
|
||||
# --num_epochs 5 \
|
||||
# --remove_prefix_in_ckpt "pipe.dit." \
|
||||
# --output_path "./models/train/FLUX.2-klein-base-9B_lora" \
|
||||
# --lora_base_model "dit" \
|
||||
# --lora_target_modules "to_q,to_k,to_v,to_out.0,add_q_proj,add_k_proj,add_v_proj,to_add_out,linear_in,linear_out,to_qkv_mlp_proj,single_transformer_blocks.0.attn.to_out,single_transformer_blocks.1.attn.to_out,single_transformer_blocks.2.attn.to_out,single_transformer_blocks.3.attn.to_out,single_transformer_blocks.4.attn.to_out,single_transformer_blocks.5.attn.to_out,single_transformer_blocks.6.attn.to_out,single_transformer_blocks.7.attn.to_out,single_transformer_blocks.8.attn.to_out,single_transformer_blocks.9.attn.to_out,single_transformer_blocks.10.attn.to_out,single_transformer_blocks.11.attn.to_out,single_transformer_blocks.12.attn.to_out,single_transformer_blocks.13.attn.to_out,single_transformer_blocks.14.attn.to_out,single_transformer_blocks.15.attn.to_out,single_transformer_blocks.16.attn.to_out,single_transformer_blocks.17.attn.to_out,single_transformer_blocks.18.attn.to_out,single_transformer_blocks.19.attn.to_out,single_transformer_blocks.20.attn.to_out,single_transformer_blocks.21.attn.to_out,single_transformer_blocks.22.attn.to_out,single_transformer_blocks.23.attn.to_out" \
|
||||
# --lora_rank 32 \
|
||||
# --use_gradient_checkpointing
|
||||
@@ -24,7 +24,7 @@ class Flux2ImageTrainingModule(DiffusionTrainingModule):
|
||||
super().__init__()
|
||||
# Load models
|
||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||
tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"))
|
||||
tokenizer_config = ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/") if tokenizer_path is None else ModelConfig(tokenizer_path)
|
||||
self.pipe = Flux2ImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
|
||||
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
from diffsynth.core import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
state_dict = load_state_dict("./models/train/FLUX.2-klein-4B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16)
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
prompt = "a dog"
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
|
||||
image.save("image.jpg")
|
||||
@@ -1,20 +0,0 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
from diffsynth.core import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
state_dict = load_state_dict("./models/train/FLUX.2-klein-9B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16)
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
prompt = "a dog"
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
|
||||
image.save("image.jpg")
|
||||
@@ -1,20 +0,0 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
from diffsynth.core import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-4B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
state_dict = load_state_dict("./models/train/FLUX.2-klein-base-4B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16)
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
prompt = "a dog"
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
|
||||
image.save("image.jpg")
|
||||
@@ -1,20 +0,0 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
from diffsynth.core import load_state_dict
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-base-9B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
state_dict = load_state_dict("./models/train/FLUX.2-klein-base-9B_full/epoch-1.safetensors", torch_dtype=torch.bfloat16)
|
||||
pipe.dit.load_state_dict(state_dict)
|
||||
prompt = "a dog"
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
|
||||
image.save("image.jpg")
|
||||
@@ -1,18 +0,0 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-4B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "./models/train/FLUX.2-klein-4B_lora/epoch-4.safetensors")
|
||||
prompt = "a dog"
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
|
||||
image.save("image.jpg")
|
||||
@@ -1,18 +0,0 @@
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="text_encoder/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="transformer/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-klein-9B", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.dit, "./models/train/FLUX.2-klein-9B_lora/epoch-4.safetensors")
|
||||
prompt = "a dog"
|
||||
image = pipe(prompt=prompt, seed=0, num_inference_steps=40, cfg_scale=4, height=768, width=768)
|
||||
image.save("image.jpg")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user