mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
DiffSynth-Studio 2.0 major update
This commit is contained in:
79
docs/en/API_Reference/core/attention.md
Normal file
79
docs/en/API_Reference/core/attention.md
Normal file
@@ -0,0 +1,79 @@
|
||||
# `diffsynth.core.attention`: Attention Mechanism Implementation
|
||||
|
||||
`diffsynth.core.attention` provides routing mechanisms for attention mechanism implementations, automatically selecting efficient attention implementations based on available packages in the `Python` environment and [environment variables](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation).
|
||||
|
||||
## Attention Mechanism
|
||||
|
||||
The attention mechanism is a model structure proposed in the paper ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762). In the original paper, the attention mechanism is implemented according to the following formula:
|
||||
|
||||
$$
|
||||
\text{Attention}(Q, K, V) = \text{Softmax}\left(
|
||||
\frac{QK^T}{\sqrt{d_k}}
|
||||
\right)
|
||||
V.
|
||||
$$
|
||||
|
||||
In `PyTorch`, it can be implemented with the following code:
|
||||
```python
|
||||
import torch
|
||||
|
||||
def attention(query, key, value):
|
||||
scale_factor = 1 / query.size(-1)**0.5
|
||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||
return attn_weight @ value
|
||||
|
||||
query = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda")
|
||||
key = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda")
|
||||
value = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda")
|
||||
output_1 = attention(query, key, value)
|
||||
```
|
||||
|
||||
The dimensions of `query`, `key`, and `value` are $(b, n, s, d)$:
|
||||
* $b$: Batch size
|
||||
* $n$: Number of attention heads
|
||||
* $s$: Sequence length
|
||||
* $d$: Dimension of each attention head
|
||||
|
||||
This computation does not include any trainable parameters. Modern transformer architectures will pass through Linear layers before and after this computation, but the "attention mechanism" discussed in this article refers only to the computation in the above code, not including these calculations.
|
||||
|
||||
## More Efficient Implementations
|
||||
|
||||
Note that the dimension of the Attention Score in the attention mechanism ( $\text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)$ in the formula, `attn_weight` in the code) is $(b, n, s, s)$, where the sequence length $s$ is typically very large, causing the time and space complexity of computation to reach quadratic level. Taking image generation models as an example, when the width and height of the image increase to 2 times, the sequence length increases to 4 times, and the computational load and memory requirements increase to 16 times. To avoid high computational costs, more efficient attention mechanism implementations are needed, including:
|
||||
* Flash Attention 3: [GitHub](https://github.com/Dao-AILab/flash-attention), [Paper](https://arxiv.org/abs/2407.08608)
|
||||
* Flash Attention 2: [GitHub](https://github.com/Dao-AILab/flash-attention), [Paper](https://arxiv.org/abs/2307.08691)
|
||||
* Sage Attention: [GitHub](https://github.com/thu-ml/SageAttention), [Paper](https://arxiv.org/abs/2505.11594)
|
||||
* xFormers: [GitHub](https://github.com/facebookresearch/xformers), [Documentation](https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops)
|
||||
* PyTorch: [GitHub](https://github.com/pytorch/pytorch), [Documentation](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
|
||||
To call attention implementations other than `PyTorch`, please follow the instructions on their GitHub pages to install the corresponding packages. `DiffSynth-Studio` will automatically route to the corresponding implementation based on available packages in the Python environment, or can be controlled through [environment variables](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation).
|
||||
|
||||
```python
|
||||
from diffsynth.core.attention import attention_forward
|
||||
import torch
|
||||
|
||||
def attention(query, key, value):
|
||||
scale_factor = 1 / query.size(-1)**0.5
|
||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||
return attn_weight @ value
|
||||
|
||||
query = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda")
|
||||
key = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda")
|
||||
value = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda")
|
||||
output_1 = attention(query, key, value)
|
||||
output_2 = attention_forward(query, key, value)
|
||||
print((output_1 - output_2).abs().mean())
|
||||
```
|
||||
|
||||
Please note that acceleration will introduce errors, but in most cases, the error is negligible.
|
||||
|
||||
## Developer Guide
|
||||
|
||||
When integrating new models into `DiffSynth-Studio`, developers can decide whether to call `attention_forward` in `diffsynth.core.attention`, but we expect models to prioritize calling this module as much as possible, so that new attention mechanism implementations can take effect directly on these models.
|
||||
|
||||
## Best Practices
|
||||
|
||||
**In most cases, we recommend directly using the native `PyTorch` implementation without installing any additional packages.** Although other attention mechanism implementations can accelerate, the acceleration effect is relatively limited, and in a few cases, compatibility and precision issues may arise.
|
||||
|
||||
In addition, efficient attention mechanism implementations will gradually be integrated into `PyTorch`. The `scaled_dot_product_attention` in `PyTorch` version 2.9.0 has already integrated Flash Attention 2. We still provide this interface in `DiffSynth-Studio` to allow some aggressive acceleration schemes to quickly move toward application, even though they still need time to be verified for stability.
|
||||
151
docs/en/API_Reference/core/data.md
Normal file
151
docs/en/API_Reference/core/data.md
Normal file
@@ -0,0 +1,151 @@
|
||||
# `diffsynth.core.data`: Data Processing Operators and Universal Dataset
|
||||
|
||||
## Data Processing Operators
|
||||
|
||||
### Available Data Processing Operators
|
||||
|
||||
`diffsynth.core.data` provides a series of data processing operators for data processing, including:
|
||||
|
||||
* Data format conversion operators
|
||||
* `ToInt`: Convert to int format
|
||||
* `ToFloat`: Convert to float format
|
||||
* `ToStr`: Convert to str format
|
||||
* `ToList`: Convert to list format, wrapping this data in a list
|
||||
* `ToAbsolutePath`: Convert relative paths to absolute paths
|
||||
* File loading operators
|
||||
* `LoadImage`: Read image files
|
||||
* `LoadVideo`: Read video files
|
||||
* `LoadAudio`: Read audio files
|
||||
* `LoadGIF`: Read GIF files
|
||||
* `LoadTorchPickle`: Read binary files saved by [`torch.save`](https://docs.pytorch.org/docs/stable/generated/torch.save.html) [This operator may cause code injection attacks in binary files, please use with caution!]
|
||||
* Media file processing operators
|
||||
* `ImageCropAndResize`: Crop and resize images
|
||||
* Meta operators
|
||||
* `SequencialProcess`: Route each data in the sequence to an operator
|
||||
* `RouteByExtensionName`: Route to specific operators by file extension
|
||||
* `RouteByType`: Route to specific operators by data type
|
||||
|
||||
### Operator Usage
|
||||
|
||||
Data operators are connected with the `>>` symbol to form data processing pipelines, for example:
|
||||
|
||||
```python
|
||||
from diffsynth.core.data.operators import *
|
||||
|
||||
data = "image.jpg"
|
||||
data_pipeline = ToAbsolutePath(base_path="/data") >> LoadImage() >> ImageCropAndResize(max_pixels=512*512)
|
||||
data = data_pipeline(data)
|
||||
```
|
||||
|
||||
After passing through each operator, the data is processed in sequence:
|
||||
|
||||
* `ToAbsolutePath(base_path="/data")`: `"/data/image.jpg"`
|
||||
* `LoadImage()`: `<PIL.Image.Image image mode=RGB size=1024x1024 at 0x7F8E7AAEFC10>`
|
||||
* `ImageCropAndResize(max_pixels=512*512)`: `<PIL.Image.Image image mode=RGB size=512x512 at 0x7F8E7A936F20>`
|
||||
|
||||
We can compose functionally complete data pipelines, for example, the default video data operator for the universal dataset is:
|
||||
|
||||
```python
|
||||
RouteByType(operator_map=[
|
||||
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
|
||||
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
|
||||
(("gif",), LoadGIF(
|
||||
num_frames, time_division_factor, time_division_remainder,
|
||||
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||
)),
|
||||
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
|
||||
num_frames, time_division_factor, time_division_remainder,
|
||||
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||
)),
|
||||
])),
|
||||
])
|
||||
```
|
||||
|
||||
It includes the following logic:
|
||||
|
||||
* If the data is of type `str`
|
||||
* If it's a `"jpg", "jpeg", "png", "webp"` type file
|
||||
* Load this image
|
||||
* Crop and scale to a specific resolution
|
||||
* Pack into a list, treating it as a single-frame video
|
||||
* If it's a `"gif"` type file
|
||||
* Load the GIF file content
|
||||
* Crop and scale each frame to a specific resolution
|
||||
* If it's a `"mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"` type file
|
||||
* Load the video file content
|
||||
* Crop and scale each frame to a specific resolution
|
||||
* If the data is not of type `str`, an error is reported
|
||||
|
||||
## Universal Dataset
|
||||
|
||||
`diffsynth.core.data` provides a unified dataset implementation. The dataset requires the following parameters:
|
||||
|
||||
* `base_path`: Root directory. If the dataset contains relative paths to image files, this field needs to be filled in to load the files pointed to by these paths
|
||||
* `metadata_path`: Metadata directory, records the file paths of all metadata, supports `csv`, `json`, `jsonl` formats
|
||||
* `repeat`: Data repetition count, defaults to 1, this parameter affects the number of training steps in an epoch
|
||||
* `data_file_keys`: Data field names that need to be loaded, for example `(image, edit_image)`
|
||||
* `main_data_operator`: Main loading operator, needs to assemble the data processing pipeline through data processing operators
|
||||
* `special_operator_map`: Special operator mapping, operator mappings built for fields that require special processing
|
||||
|
||||
### Metadata
|
||||
|
||||
The dataset's `metadata_path` points to a metadata file, supporting `csv`, `json`, `jsonl` formats. The following provides examples:
|
||||
|
||||
* `csv` format: High readability, does not support list data, small memory footprint
|
||||
|
||||
```csv
|
||||
image,prompt
|
||||
image_1.jpg,"a dog"
|
||||
image_2.jpg,"a cat"
|
||||
```
|
||||
|
||||
* `json` format: High readability, supports list data, large memory footprint
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"image": "image_1.jpg",
|
||||
"prompt": "a dog"
|
||||
},
|
||||
{
|
||||
"image": "image_2.jpg",
|
||||
"prompt": "a cat"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
* `jsonl` format: Low readability, supports list data, small memory footprint
|
||||
|
||||
```json
|
||||
{"image": "image_1.jpg", "prompt": "a dog"}
|
||||
{"image": "image_2.jpg", "prompt": "a cat"}
|
||||
```
|
||||
|
||||
How to choose the best metadata format?
|
||||
|
||||
* If the data volume is large, reaching tens of millions, since `json` file parsing requires additional memory, it's not available. Please use `csv` or `jsonl` format
|
||||
* If the dataset contains list data, such as edit models that require multiple images as input, since `csv` format cannot store list format data, it's not available. Please use `json` or `jsonl` format
|
||||
|
||||
### Data Loading Logic
|
||||
|
||||
When no additional settings are made, the dataset defaults to outputting data from the metadata set. Image and video file paths will be output in string format. To load these files, you need to set `data_file_keys`, `main_data_operator`, and `special_operator_map`.
|
||||
|
||||
In the data processing flow, processing is done according to the following logic:
|
||||
* If the field is in `special_operator_map`, call the corresponding operator in `special_operator_map` for processing
|
||||
* If the field is not in `special_operator_map`
|
||||
* If the field is in `data_file_keys`, call the `main_data_operator` operator for processing
|
||||
* If the field is not in `data_file_keys`, no processing is done
|
||||
|
||||
`special_operator_map` can be used to implement special data processing. For example, in the model [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B), the input character face video `animate_face_video` is processed at a fixed resolution, inconsistent with the output video. Therefore, this field is processed by a dedicated operator:
|
||||
|
||||
```python
|
||||
special_operator_map={
|
||||
"animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)),
|
||||
}
|
||||
```
|
||||
|
||||
### Other Notes
|
||||
|
||||
When the data volume is too small, you can appropriately increase `repeat` to extend the training time of a single epoch, avoiding frequent model saving that generates considerable overhead.
|
||||
|
||||
When data volume * `repeat` exceeds $10^9$, we observe that the dataset speed becomes significantly slower. This seems to be a `PyTorch` bug, and we are not sure if newer versions of `PyTorch` have fixed this issue.
|
||||
69
docs/en/API_Reference/core/gradient.md
Normal file
69
docs/en/API_Reference/core/gradient.md
Normal file
@@ -0,0 +1,69 @@
|
||||
# `diffsynth.core.gradient`: Gradient Checkpointing and Offload
|
||||
|
||||
`diffsynth.core.gradient` provides encapsulated gradient checkpointing and its Offload version for model training.
|
||||
|
||||
## Gradient Checkpointing
|
||||
|
||||
Gradient checkpointing is a technique used to reduce memory usage during training. We provide an example to help you understand this technique. Here is a simple model structure:
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class ToyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.activation = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
return self.activation(x)
|
||||
|
||||
model = ToyModel()
|
||||
x = torch.randn((2, 3))
|
||||
y = model(x)
|
||||
```
|
||||
|
||||
In this model structure, the input parameter $x$ passes through the Sigmoid activation function to obtain the output value $y=\frac{1}{1+e^{-x}}$.
|
||||
|
||||
During the training process, assuming our loss function value is $\mathcal L$, when backpropagating gradients, we obtain $\frac{\partial \mathcal L}{\partial y}$. At this point, we need to calculate $\frac{\partial \mathcal L}{\partial x}$. It's not difficult to find that $\frac{\partial y}{\partial x}=y(1-y)$, and thus $\frac{\partial \mathcal L}{\partial x}=\frac{\partial \mathcal L}{\partial y}\frac{\partial y}{\partial x}=\frac{\partial \mathcal L}{\partial y}y(1-y)$. If we save the value of $y$ during the model's forward propagation and directly compute $y(1-y)$ during gradient backpropagation, this will avoid complex exp computations, speeding up the calculation. However, this requires additional memory to store the intermediate variable $y$.
|
||||
|
||||
When gradient checkpointing is not enabled, the training framework will default to storing all intermediate variables that assist gradient computation, thereby achieving optimal computational speed. When gradient checkpointing is enabled, intermediate variables are not stored, but the input parameter $x$ is still stored, reducing memory usage. During gradient backpropagation, these variables need to be recomputed, slowing down the calculation.
|
||||
|
||||
## Enabling Gradient Checkpointing and Its Offload
|
||||
|
||||
`gradient_checkpoint_forward` in `diffsynth.core.gradient` implements gradient checkpointing and its Offload. Refer to the following code for calling:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.core.gradient import gradient_checkpoint_forward
|
||||
|
||||
class ToyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.activation = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
return self.activation(x)
|
||||
|
||||
model = ToyModel()
|
||||
x = torch.randn((2, 3))
|
||||
y = gradient_checkpoint_forward(
|
||||
model,
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
x=x,
|
||||
)
|
||||
```
|
||||
|
||||
* When `use_gradient_checkpointing=False` and `use_gradient_checkpointing_offload=False`, the computation process is exactly the same as the original computation, not affecting the model's inference and training. You can directly integrate it into your code.
|
||||
* When `use_gradient_checkpointing=True` and `use_gradient_checkpointing_offload=False`, gradient checkpointing is enabled.
|
||||
* When `use_gradient_checkpointing_offload=True`, gradient checkpointing is enabled, and all gradient checkpoint input parameters are stored in memory, further reducing memory usage and slowing down computation.
|
||||
|
||||
## Best Practices
|
||||
|
||||
> Q: Where should gradient checkpointing be enabled?
|
||||
>
|
||||
> A: When enabling gradient checkpointing for the entire model, computational efficiency and memory usage are not optimal. We need to set fine-grained gradient checkpoints, but we don't want to add too much complicated code to the framework. Therefore, we recommend implementing it in the `model_fn` of `Pipeline`, for example, `model_fn_qwen_image` in `diffsynth/pipelines/qwen_image.py`, enabling gradient checkpointing at the Block level without modifying any code in the model structure.
|
||||
|
||||
> Q: When should gradient checkpointing be enabled?
|
||||
>
|
||||
> A: As model parameters become increasingly large, gradient checkpointing has become a necessary training technique. Gradient checkpointing usually needs to be enabled. Gradient checkpointing Offload should only be enabled in models where activation values occupy excessive memory (such as video generation models).
|
||||
141
docs/en/API_Reference/core/loader.md
Normal file
141
docs/en/API_Reference/core/loader.md
Normal file
@@ -0,0 +1,141 @@
|
||||
# `diffsynth.core.loader`: Model Download and Loading
|
||||
|
||||
This document introduces the model download and loading functionalities in `diffsynth.core.loader`.
|
||||
|
||||
## ModelConfig
|
||||
|
||||
`ModelConfig` in `diffsynth.core.loader` is used to annotate model download sources, local paths, VRAM management configurations, and other information.
|
||||
|
||||
### Downloading and Loading Models from Remote Sources
|
||||
|
||||
Taking the model [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) as an example, after filling in `model_id` and `origin_file_pattern` in `ModelConfig`, the model can be automatically downloaded. By default, it downloads to the `./models` path, which can be modified through the [environment variable DIFFSYNTH_MODEL_BASE_PATH](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path).
|
||||
|
||||
By default, even if the model has already been downloaded, the program will still query the remote for any missing files. To completely disable remote requests, set the [environment variable DIFFSYNTH_SKIP_DOWNLOAD](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`.
|
||||
|
||||
```python
|
||||
from diffsynth.core import ModelConfig
|
||||
|
||||
config = ModelConfig(
|
||||
model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny",
|
||||
origin_file_pattern="model.safetensors",
|
||||
)
|
||||
# Download models
|
||||
config.download_if_necessary()
|
||||
print(config.path)
|
||||
```
|
||||
|
||||
After calling `download_if_necessary`, the model will be automatically downloaded, and the path will be returned to `config.path`.
|
||||
|
||||
### Loading Models from Local Paths
|
||||
|
||||
If loading models from local paths, you need to fill in `path`:
|
||||
|
||||
```python
|
||||
from diffsynth.core import ModelConfig
|
||||
|
||||
config = ModelConfig(path="models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors")
|
||||
```
|
||||
|
||||
If the model contains multiple shard files, input them in list form:
|
||||
|
||||
```python
|
||||
from diffsynth.core import ModelConfig
|
||||
|
||||
config = ModelConfig(path=[
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
|
||||
])
|
||||
```
|
||||
|
||||
### VRAM Management Configuration
|
||||
|
||||
`ModelConfig` also contains VRAM management configuration information. See [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md#more-usage-methods) for details.
|
||||
|
||||
## Model File Loading
|
||||
|
||||
`diffsynth.core.loader` provides a unified `load_state_dict` for loading state dicts from model files.
|
||||
|
||||
Loading a single model file:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_state_dict
|
||||
|
||||
state_dict = load_state_dict("models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors")
|
||||
```
|
||||
|
||||
Loading multiple model files (merged into one state dict):
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_state_dict
|
||||
|
||||
state_dict = load_state_dict([
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
|
||||
])
|
||||
```
|
||||
|
||||
## Model Hash
|
||||
|
||||
Model hash is used to determine the model type. The hash value can be obtained through `hash_model_file`:
|
||||
|
||||
```python
|
||||
from diffsynth.core import hash_model_file
|
||||
|
||||
print(hash_model_file("models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors"))
|
||||
```
|
||||
|
||||
The hash value of multiple model files can also be calculated, which is equivalent to calculating the model hash value after merging the state dict:
|
||||
|
||||
```python
|
||||
from diffsynth.core import hash_model_file
|
||||
|
||||
print(hash_model_file([
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
|
||||
]))
|
||||
```
|
||||
|
||||
The model hash value is only related to the keys and tensor shapes in the state dict of the model file, and is unrelated to the numerical values of the model parameters, file saving time, and other information. When calculating the model hash value of `.safetensors` format files, `hash_model_file` is almost instantly completed without reading the model parameters. However, when calculating the model hash value of `.bin`, `.pth`, `.ckpt`, and other binary files, all model parameters need to be read, so **we do not recommend developers to continue using these formats of files.**
|
||||
|
||||
By [writing model Config](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-3-writing-model-config) and filling in model hash value and other information into `diffsynth/configs/model_configs.py`, developers can let `DiffSynth-Studio` automatically identify the model type and load it.
|
||||
|
||||
## Model Loading
|
||||
|
||||
`load_model` is the external entry for loading models in `diffsynth.core.loader`. It will call [skip_model_initialization](/docs/en/API_Reference/core/vram.md#skipping-model-parameter-initialization) to skip model parameter initialization. If [Disk Offload](/docs/en/Pipeline_Usage/VRAM_management.md#disk-offload) is enabled, it calls [DiskMap](/docs/en/API_Reference/core/vram.md#state-dict-disk-mapping) for lazy loading. If Disk Offload is not enabled, it calls [load_state_dict](#model-file-loading) to load model parameters. If necessary, it will also call [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) for model format conversion. Finally, it calls `model.eval()` to switch to inference mode.
|
||||
|
||||
Here is a usage example with Disk Offload enabled:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
|
||||
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
|
||||
import torch
|
||||
|
||||
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
|
||||
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
|
||||
|
||||
model = load_model(
|
||||
QwenImageDiT,
|
||||
model_path,
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
vram_config={
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": "disk",
|
||||
"onload_device": "disk",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
},
|
||||
vram_limit=0,
|
||||
)
|
||||
```
|
||||
66
docs/en/API_Reference/core/vram.md
Normal file
66
docs/en/API_Reference/core/vram.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# `diffsynth.core.vram`: VRAM Management
|
||||
|
||||
This document introduces the underlying VRAM management functionalities in `diffsynth.core.vram`. If you wish to use these functionalities in other codebases, you can refer to this document.
|
||||
|
||||
## Skipping Model Parameter Initialization
|
||||
|
||||
When loading models in `PyTorch`, model parameters default to occupying VRAM or memory and initializing parameters, but these parameters will be overwritten when loading pretrained weights, leading to redundant computations. `PyTorch` does not provide an interface to skip these redundant computations. We provide `skip_model_initialization` in `diffsynth.core.vram` to skip model parameter initialization.
|
||||
|
||||
Default model loading approach:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_state_dict
|
||||
from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet
|
||||
|
||||
model = QwenImageBlockWiseControlNet() # Slow
|
||||
path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors"
|
||||
state_dict = load_state_dict(path, device="cpu")
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
```
|
||||
|
||||
Model loading approach that skips parameter initialization:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_state_dict, skip_model_initialization
|
||||
from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet
|
||||
|
||||
with skip_model_initialization():
|
||||
model = QwenImageBlockWiseControlNet() # Fast
|
||||
path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors"
|
||||
state_dict = load_state_dict(path, device="cpu")
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
```
|
||||
|
||||
In `DiffSynth-Studio`, all pretrained models follow this loading logic. After developers [integrate models](/docs/en/Developer_Guide/Integrating_Your_Model.md), they can directly load models quickly using this approach.
|
||||
|
||||
## State Dict Disk Mapping
|
||||
|
||||
For pretrained weight files of a model, if we only need to read a set of parameters rather than all parameters, State Dict Disk Mapping can accelerate this process. We provide `DiskMap` in `diffsynth.core.vram` for on-demand loading of model parameters.
|
||||
|
||||
Default weight loading approach:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_state_dict
|
||||
|
||||
path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors"
|
||||
state_dict = load_state_dict(path, device="cpu") # Slow
|
||||
print(state_dict["img_in.weight"])
|
||||
```
|
||||
|
||||
Using `DiskMap` to load only specific parameters:
|
||||
|
||||
```python
|
||||
from diffsynth.core import DiskMap
|
||||
|
||||
path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors"
|
||||
state_dict = DiskMap(path, device="cpu") # Fast
|
||||
print(state_dict["img_in.weight"])
|
||||
```
|
||||
|
||||
`DiskMap` is the basic component of Disk Offload in `DiffSynth-Studio`. After developers [configure fine-grained VRAM management schemes](/docs/en/Developer_Guide/Enabling_VRAM_management.md), they can directly enable Disk Offload.
|
||||
|
||||
`DiskMap` is a functionality implemented using the characteristics of `.safetensors` files. Therefore, when using `.bin`, `.pth`, `.ckpt`, and other binary files, model parameters are fully loaded, which causes Disk Offload to not support these formats of files. **We do not recommend developers to continue using these formats of files.**
|
||||
|
||||
## Replacable Modules for VRAM Management
|
||||
|
||||
When `DiffSynth-Studio`'s VRAM management is enabled, the modules inside the model will be replaced with replacable modules in `diffsynth.core.vram.layers`. For usage, see [Fine-grained VRAM Management Scheme](/docs/en/Developer_Guide/Enabling_VRAM_management.md#writing-fine-grained-vram-management-schemes).
|
||||
250
docs/en/Developer_Guide/Building_a_Pipeline.md
Normal file
250
docs/en/Developer_Guide/Building_a_Pipeline.md
Normal file
@@ -0,0 +1,250 @@
|
||||
# Building a Pipeline
|
||||
|
||||
After [integrating the required models for the Pipeline](/docs/en/Developer_Guide/Integrating_Your_Model.md), you also need to build a `Pipeline` for model inference. This document provides a standardized process for building a `Pipeline`. Developers can also refer to existing `Pipeline` implementations for construction.
|
||||
|
||||
The `Pipeline` implementation is located in `diffsynth/pipelines`. Each `Pipeline` contains the following essential key components:
|
||||
|
||||
* `__init__`
|
||||
* `from_pretrained`
|
||||
* `__call__`
|
||||
* `units`
|
||||
* `model_fn`
|
||||
|
||||
## `__init__`
|
||||
|
||||
In `__init__`, the `Pipeline` is initialized. Here is a simple implementation:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from PIL import Image
|
||||
from typing import Union
|
||||
from tqdm import tqdm
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
from ..models.new_models import XXX_Model, YYY_Model, ZZZ_Model
|
||||
|
||||
class NewDiffSynthPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = FlowMatchScheduler()
|
||||
self.text_encoder: XXX_Model = None
|
||||
self.dit: YYY_Model = None
|
||||
self.vae: ZZZ_Model = None
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
NewDiffSynthPipelineUnit_xxx(),
|
||||
...
|
||||
]
|
||||
self.model_fn = model_fn_new
|
||||
```
|
||||
|
||||
This includes the following parts:
|
||||
|
||||
* `scheduler`: Scheduler, used to control the coefficients in the iterative formula during inference, controlling the noise content at each step.
|
||||
* `text_encoder`, `dit`, `vae`: Models. Since [Latent Diffusion](https://arxiv.org/abs/2112.10752) was proposed, this three-stage model architecture has become the mainstream Diffusion model architecture. However, this is not immutable, and any number of models can be added to the `Pipeline`.
|
||||
* `in_iteration_models`: Iteration models. This tuple marks which models will be called during iteration.
|
||||
* `units`: Pre-processing units for model iteration. See [`units`](#units) for details.
|
||||
* `model_fn`: The `forward` function of the denoising model during iteration. See [`model_fn`](#model_fn) for details.
|
||||
|
||||
> Q: Model loading does not occur in `__init__`, why initialize each model as `None` here?
|
||||
>
|
||||
> A: By annotating the type of each model here, the code editor can provide code completion prompts based on each model, facilitating subsequent development.
|
||||
|
||||
## `from_pretrained`
|
||||
|
||||
`from_pretrained` is responsible for loading the required models to make the `Pipeline` callable. Here is a simple implementation:
|
||||
|
||||
```python
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
vram_limit: float = None,
|
||||
):
|
||||
# Initialize pipeline
|
||||
pipe = NewDiffSynthPipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("xxx_text_encoder")
|
||||
pipe.dit = model_pool.fetch_model("yyy_dit")
|
||||
pipe.vae = model_pool.fetch_model("zzz_vae")
|
||||
# If necessary, load tokenizers here.
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
```
|
||||
|
||||
Developers need to implement the logic for fetching models. The corresponding model names are the `"model_name"` in the [model Config filled in during model integration](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-3-writing-model-config).
|
||||
|
||||
Some models also need to load `tokenizer`. Extra `tokenizer_config` parameters can be added to `from_pretrained` as needed, and this part can be implemented after fetching the models.
|
||||
|
||||
## `__call__`
|
||||
|
||||
`__call__` implements the entire generation process of the Pipeline. Below is a common generation process template. Developers can modify it based on their needs.
|
||||
|
||||
```python
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 4.0,
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
height: int = 1328,
|
||||
width: int = 1328,
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
num_inference_steps: int = 30,
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(
|
||||
num_inference_steps,
|
||||
denoising_strength=denoising_strength
|
||||
)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
}
|
||||
inputs_nega = {
|
||||
"negative_prompt": negative_prompt,
|
||||
}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"input_image": input_image,
|
||||
"denoising_strength": denoising_strength,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"seed": seed,
|
||||
"rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep, progress_id=progress_id)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep, progress_id=progress_id)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# Scheduler
|
||||
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
image = self.vae.decode(inputs_shared["latents"], device=self.device)
|
||||
image = self.vae_output_to_image(image)
|
||||
self.load_models_to_device([])
|
||||
|
||||
return image
|
||||
```
|
||||
|
||||
## `units`
|
||||
|
||||
`units` contains all the preprocessing processes, such as: width/height checking, prompt encoding, initial noise generation, etc. In the entire model preprocessing process, data is abstracted into three mutually exclusive parts, stored in corresponding dictionaries:
|
||||
|
||||
* `inputs_shared`: Shared inputs, parameters unrelated to [Classifier-Free Guidance](https://arxiv.org/abs/2207.12598) (CFG for short).
|
||||
* `inputs_posi`: Positive side inputs for Classifier-Free Guidance, containing content related to positive prompts.
|
||||
* `inputs_nega`: Negative side inputs for Classifier-Free Guidance, containing content related to negative prompts.
|
||||
|
||||
Pipeline Unit implementations include three types: direct mode, CFG separation mode, and takeover mode.
|
||||
|
||||
If some calculations are unrelated to CFG, direct mode can be used, for example, Qwen-Image's random noise initialization:
|
||||
|
||||
```python
|
||||
class QwenImageUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "seed", "rand_device"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device):
|
||||
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
return {"noise": noise}
|
||||
```
|
||||
|
||||
If some calculations are related to CFG and need to separately process positive and negative prompts, but the input parameters on both sides are the same, CFG separation mode can be used, for example, Qwen-image's prompt encoding:
|
||||
|
||||
```python
|
||||
class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
input_params=("edit_image",),
|
||||
output_params=("prompt_emb", "prompt_emb_mask"),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
# Do something
|
||||
return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask}
|
||||
```
|
||||
|
||||
If some calculations need global information, takeover mode is required, for example, Qwen-Image's entity partition control:
|
||||
|
||||
```python
|
||||
class QwenImageUnit_EntityControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
input_params=("eligen_entity_prompts", "width", "height", "eligen_enable_on_negative", "cfg_scale"),
|
||||
output_params=("entity_prompt_emb", "entity_masks", "entity_prompt_emb_mask"),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
# Do something
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
```
|
||||
|
||||
The following are the parameter configurations required for Pipeline Unit:
|
||||
|
||||
* `seperate_cfg`: Whether to enable CFG separation mode
|
||||
* `take_over`: Whether to enable takeover mode
|
||||
* `input_params`: Shared input parameters
|
||||
* `output_params`: Output parameters
|
||||
* `input_params_posi`: Positive side input parameters
|
||||
* `input_params_nega`: Negative side input parameters
|
||||
* `onload_model_names`: Names of model components to be called
|
||||
|
||||
When designing `unit`, please try to follow these principles:
|
||||
|
||||
* Default fallback: For optional function `unit` input parameters, the default is `None` rather than `False` or other values. Please provide fallback processing for this default value.
|
||||
* Parameter triggering: Some Adapter models may not be loaded, such as ControlNet. The corresponding `unit` should control triggering based on whether the parameter input is `None` rather than whether the model is loaded. For example, when the user inputs `controlnet_image` but does not load the ControlNet model, the code should give an error rather than ignore these input parameters and continue execution.
|
||||
* Simplicity first: Use direct mode as much as possible, only use takeover mode when the function cannot be implemented.
|
||||
* VRAM efficiency: When calling models in `unit`, please use `pipe.load_models_to_device(self.onload_model_names)` to activate the corresponding models. Do not call other models outside `onload_model_names`. After `unit` calculation is completed, do not manually release VRAM with `pipe.load_models_to_device([])`.
|
||||
|
||||
> Q: Some parameters are not called during the inference process, such as `output_params`. Is it still necessary to configure them?
|
||||
>
|
||||
> A: These parameters will not affect the inference process, but they will affect some experimental features. Therefore, we recommend configuring them properly. For example, "split training" - we can complete the preprocessing offline during training, but some model calculations that require gradient backpropagation cannot be split. These parameters are used to build computational graphs to infer which calculations can be split.
|
||||
|
||||
## `model_fn`
|
||||
|
||||
`model_fn` is the unified `forward` interface during iteration. For models where the open-source ecosystem is not yet formed, you can directly use the denoising model's `forward`, for example:
|
||||
|
||||
```python
|
||||
def model_fn_new(dit=None, latents=None, timestep=None, prompt_emb=None, **kwargs):
|
||||
return dit(latents, prompt_emb, timestep)
|
||||
```
|
||||
|
||||
For models with rich open-source ecosystems, `model_fn` usually contains complex and chaotic cross-model inference. Taking `diffsynth/pipelines/qwen_image.py` as an example, the additional calculations implemented in this function include: entity partition control, three types of ControlNet, Gradient Checkpointing, etc. Developers need to be extra careful when implementing this part to avoid conflicts between module functions.
|
||||
455
docs/en/Developer_Guide/Enabling_VRAM_management.md
Normal file
455
docs/en/Developer_Guide/Enabling_VRAM_management.md
Normal file
@@ -0,0 +1,455 @@
|
||||
# Fine-Grained VRAM Management Scheme
|
||||
|
||||
This document introduces how to write reasonable fine-grained VRAM management schemes for models, and how to use the VRAM management functions in `DiffSynth-Studio` for other external code libraries. Before reading this document, please read the document [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md).
|
||||
|
||||
## How Much VRAM Does a 20B Model Need?
|
||||
|
||||
Taking Qwen-Image's DiT model as an example, this model has reached 20B parameters. The following code will load this model and perform inference, requiring about 40G VRAM. This model obviously cannot run on consumer-grade GPUs with smaller VRAM.
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_model
|
||||
from diffsynth.models.qwen_image_dit import QwenImageDiT
|
||||
from modelscope import snapshot_download
|
||||
import torch
|
||||
|
||||
snapshot_download(
|
||||
model_id="Qwen/Qwen-Image",
|
||||
local_dir="models/Qwen/Qwen-Image",
|
||||
allow_file_pattern="transformer/*"
|
||||
)
|
||||
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
|
||||
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
|
||||
inputs = {
|
||||
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
|
||||
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
}
|
||||
|
||||
model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cuda")
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
```
|
||||
|
||||
## Writing Fine-Grained VRAM Management Scheme
|
||||
|
||||
To write a fine-grained VRAM management scheme, we need to use `print(model)` to observe and analyze the model structure:
|
||||
|
||||
```
|
||||
QwenImageDiT(
|
||||
(pos_embed): QwenEmbedRope()
|
||||
(time_text_embed): TimestepEmbeddings(
|
||||
(time_proj): TemporalTimesteps()
|
||||
(timestep_embedder): DiffusersCompatibleTimestepProj(
|
||||
(linear_1): Linear(in_features=256, out_features=3072, bias=True)
|
||||
(act): SiLU()
|
||||
(linear_2): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
)
|
||||
)
|
||||
(txt_norm): RMSNorm()
|
||||
(img_in): Linear(in_features=64, out_features=3072, bias=True)
|
||||
(txt_in): Linear(in_features=3584, out_features=3072, bias=True)
|
||||
(transformer_blocks): ModuleList(
|
||||
(0-59): 60 x QwenImageTransformerBlock(
|
||||
(img_mod): Sequential(
|
||||
(0): SiLU()
|
||||
(1): Linear(in_features=3072, out_features=18432, bias=True)
|
||||
)
|
||||
(img_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
(attn): QwenDoubleStreamAttention(
|
||||
(to_q): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(to_k): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(to_v): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(norm_q): RMSNorm()
|
||||
(norm_k): RMSNorm()
|
||||
(add_q_proj): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(add_k_proj): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(add_v_proj): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(norm_added_q): RMSNorm()
|
||||
(norm_added_k): RMSNorm()
|
||||
(to_out): Sequential(
|
||||
(0): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
)
|
||||
(to_add_out): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
)
|
||||
(img_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
(img_mlp): QwenFeedForward(
|
||||
(net): ModuleList(
|
||||
(0): ApproximateGELU(
|
||||
(proj): Linear(in_features=3072, out_features=12288, bias=True)
|
||||
)
|
||||
(1): Dropout(p=0.0, inplace=False)
|
||||
(2): Linear(in_features=12288, out_features=3072, bias=True)
|
||||
)
|
||||
)
|
||||
(txt_mod): Sequential(
|
||||
(0): SiLU()
|
||||
(1): Linear(in_features=3072, out_features=18432, bias=True)
|
||||
)
|
||||
(txt_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
(txt_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
(txt_mlp): QwenFeedForward(
|
||||
(net): ModuleList(
|
||||
(0): ApproximateGELU(
|
||||
(proj): Linear(in_features=3072, out_features=12288, bias=True)
|
||||
)
|
||||
(1): Dropout(p=0.0, inplace=False)
|
||||
(2): Linear(in_features=12288, out_features=3072, bias=True)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
(norm_out): AdaLayerNorm(
|
||||
(linear): Linear(in_features=3072, out_features=6144, bias=True)
|
||||
(norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
)
|
||||
(proj_out): Linear(in_features=3072, out_features=64, bias=True)
|
||||
)
|
||||
```
|
||||
|
||||
In VRAM management, we only care about layers containing parameters. In this model structure, `QwenEmbedRope`, `TemporalTimesteps`, `SiLU` and other Layers do not contain parameters. `LayerNorm` also does not contain parameters because `elementwise_affine=False` is set. Layers containing parameters are only `Linear` and `RMSNorm`.
|
||||
|
||||
`diffsynth.core.vram` provides two replacement modules for VRAM management:
|
||||
* `AutoWrappedLinear`: Used to replace `Linear` layers
|
||||
* `AutoWrappedModule`: Used to replace any other layer
|
||||
|
||||
Write a `module_map` to map `Linear` and `RMSNorm` in the model to the corresponding modules:
|
||||
|
||||
```python
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
}
|
||||
```
|
||||
|
||||
In addition, `vram_config` and `vram_limit` are also required, which have been introduced in [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md#more-usage-methods).
|
||||
|
||||
Call `enable_vram_management` to enable VRAM management. Note that the `device` when loading the model is `cpu`, consistent with `offload_device`:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
|
||||
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
|
||||
import torch
|
||||
|
||||
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
|
||||
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
|
||||
inputs = {
|
||||
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
|
||||
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
}
|
||||
|
||||
model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cpu")
|
||||
enable_vram_management(
|
||||
model,
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
},
|
||||
vram_limit=0,
|
||||
)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
```
|
||||
|
||||
The above code only requires 2G VRAM to run the `forward` of a 20B model.
|
||||
|
||||
## Disk Offload
|
||||
|
||||
[Disk Offload](/docs/en/Pipeline_Usage/VRAM_management.md#disk-offload) is a special VRAM management scheme that needs to be enabled during the model loading process, not after the model is loaded. Usually, when the above code can run smoothly, Disk Offload can be directly enabled:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
|
||||
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
|
||||
import torch
|
||||
|
||||
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
|
||||
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
|
||||
inputs = {
|
||||
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
|
||||
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
}
|
||||
|
||||
model = load_model(
|
||||
QwenImageDiT,
|
||||
model_path,
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
vram_config={
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": "disk",
|
||||
"onload_device": "disk",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
},
|
||||
vram_limit=0,
|
||||
)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
```
|
||||
|
||||
Disk Offload is an extremely special VRAM management scheme. It only supports `.safetensors` format files, not binary files such as `.bin`, `.pth`, `.ckpt`, and does not support [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape.
|
||||
|
||||
If there are situations where Disk Offload cannot run normally but non-Disk Offload can run normally, please submit an issue to us on GitHub.
|
||||
|
||||
## Writing Default Configuration
|
||||
|
||||
To make it easier for users to use the VRAM management function, we write the fine-grained VRAM management configuration in `diffsynth/configs/vram_management_module_maps.py`. The configuration information for the above model is:
|
||||
|
||||
```python
|
||||
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
|
||||
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
}
|
||||
```# Fine-Grained VRAM Management Scheme
|
||||
|
||||
This document introduces how to write reasonable fine-grained VRAM management schemes for models, and how to use the VRAM management functions in `DiffSynth-Studio` for other external code libraries. Before reading this document, please read the document [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md).
|
||||
|
||||
## How Much VRAM Does a 20B Model Need?
|
||||
|
||||
Taking Qwen-Image's DiT model as an example, this model has reached 20B parameters. The following code will load this model and perform inference, requiring about 40G VRAM. This model obviously cannot run on consumer-grade GPUs with smaller VRAM.
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_model
|
||||
from diffsynth.models.qwen_image_dit import QwenImageDiT
|
||||
from modelscope import snapshot_download
|
||||
import torch
|
||||
|
||||
snapshot_download(
|
||||
model_id="Qwen/Qwen-Image",
|
||||
local_dir="models/Qwen/Qwen-Image",
|
||||
allow_file_pattern="transformer/*"
|
||||
)
|
||||
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
|
||||
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
|
||||
inputs = {
|
||||
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
|
||||
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
}
|
||||
|
||||
model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cuda")
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
```
|
||||
|
||||
## Writing Fine-Grained VRAM Management Scheme
|
||||
|
||||
To write a fine-grained VRAM management scheme, we need to use `print(model)` to observe and analyze the model structure:
|
||||
|
||||
```
|
||||
QwenImageDiT(
|
||||
(pos_embed): QwenEmbedRope()
|
||||
(time_text_embed): TimestepEmbeddings(
|
||||
(time_proj): TemporalTimesteps()
|
||||
(timestep_embedder): DiffusersCompatibleTimestepProj(
|
||||
(linear_1): Linear(in_features=256, out_features=3072, bias=True)
|
||||
(act): SiLU()
|
||||
(linear_2): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
)
|
||||
)
|
||||
(txt_norm): RMSNorm()
|
||||
(img_in): Linear(in_features=64, out_features=3072, bias=True)
|
||||
(txt_in): Linear(in_features=3584, out_features=3072, bias=True)
|
||||
(transformer_blocks): ModuleList(
|
||||
(0-59): 60 x QwenImageTransformerBlock(
|
||||
(img_mod): Sequential(
|
||||
(0): SiLU()
|
||||
(1): Linear(in_features=3072, out_features=18432, bias=True)
|
||||
)
|
||||
(img_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
(attn): QwenDoubleStreamAttention(
|
||||
(to_q): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(to_k): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(to_v): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(norm_q): RMSNorm()
|
||||
(norm_k): RMSNorm()
|
||||
(add_q_proj): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(add_k_proj): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(add_v_proj): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(norm_added_q): RMSNorm()
|
||||
(norm_added_k): RMSNorm()
|
||||
(to_out): Sequential(
|
||||
(0): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
)
|
||||
(to_add_out): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
)
|
||||
(img_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
(img_mlp): QwenFeedForward(
|
||||
(net): ModuleList(
|
||||
(0): ApproximateGELU(
|
||||
(proj): Linear(in_features=3072, out_features=12288, bias=True)
|
||||
)
|
||||
(1): Dropout(p=0.0, inplace=False)
|
||||
(2): Linear(in_features=12288, out_features=3072, bias=True)
|
||||
)
|
||||
)
|
||||
(txt_mod): Sequential(
|
||||
(0): SiLU()
|
||||
(1): Linear(in_features=3072, out_features=18432, bias=True)
|
||||
)
|
||||
(txt_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
(txt_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
(txt_mlp): QwenFeedForward(
|
||||
(net): ModuleList(
|
||||
(0): ApproximateGELU(
|
||||
(proj): Linear(in_features=3072, out_features=12288, bias=True)
|
||||
)
|
||||
(1): Dropout(p=0.0, inplace=False)
|
||||
(2): Linear(in_features=12288, out_features=3072, bias=True)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
(norm_out): AdaLayerNorm(
|
||||
(linear): Linear(in_features=3072, out_features=6144, bias=True)
|
||||
(norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
)
|
||||
(proj_out): Linear(in_features=3072, out_features=64, bias=True)
|
||||
)
|
||||
```
|
||||
|
||||
In VRAM management, we only care about layers containing parameters. In this model structure, `QwenEmbedRope`, `TemporalTimesteps`, `SiLU` and other Layers do not contain parameters. `LayerNorm` also does not contain parameters because `elementwise_affine=False` is set. Layers containing parameters are only `Linear` and `RMSNorm`.
|
||||
|
||||
`diffsynth.core.vram` provides two replacement modules for VRAM management:
|
||||
* `AutoWrappedLinear`: Used to replace `Linear` layers
|
||||
* `AutoWrappedModule`: Used to replace any other layer
|
||||
|
||||
Write a `module_map` to map `Linear` and `RMSNorm` in the model to the corresponding modules:
|
||||
|
||||
```python
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
}
|
||||
```
|
||||
|
||||
In addition, `vram_config` and `vram_limit` are also required, which have been introduced in [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md#more-usage-methods).
|
||||
|
||||
Call `enable_vram_management` to enable VRAM management. Note that the `device` when loading the model is `cpu`, consistent with `offload_device`:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
|
||||
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
|
||||
import torch
|
||||
|
||||
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
|
||||
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
|
||||
inputs = {
|
||||
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
|
||||
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
}
|
||||
|
||||
model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cpu")
|
||||
enable_vram_management(
|
||||
model,
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
},
|
||||
vram_limit=0,
|
||||
)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
```
|
||||
|
||||
The above code only requires 2G VRAM to run the `forward` of a 20B model.
|
||||
|
||||
## Disk Offload
|
||||
|
||||
[Disk Offload](/docs/en/Pipeline_Usage/VRAM_management.md#disk-offload) is a special VRAM management scheme that needs to be enabled during the model loading process, not after the model is loaded. Usually, when the above code can run smoothly, Disk Offload can be directly enabled:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
|
||||
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
|
||||
import torch
|
||||
|
||||
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
|
||||
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
|
||||
inputs = {
|
||||
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
|
||||
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
}
|
||||
|
||||
model = load_model(
|
||||
QwenImageDiT,
|
||||
model_path,
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
vram_config={
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": "disk",
|
||||
"onload_device": "disk",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
},
|
||||
vram_limit=0,
|
||||
)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
```
|
||||
|
||||
Disk Offload is an extremely special VRAM management scheme. It only supports `.safetensors` format files, not binary files such as `.bin`, `.pth`, `.ckpt`, and does not support [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape.
|
||||
|
||||
If there are situations where Disk Offload cannot run normally but non-Disk Offload can run normally, please submit an issue to us on GitHub.
|
||||
|
||||
## Writing Default Configuration
|
||||
|
||||
To make it easier for users to use the VRAM management function, we write the fine-grained VRAM management configuration in `diffsynth/configs/vram_management_module_maps.py`. The configuration information for the above model is:
|
||||
|
||||
```python
|
||||
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
|
||||
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
}
|
||||
```
|
||||
186
docs/en/Developer_Guide/Integrating_Your_Model.md
Normal file
186
docs/en/Developer_Guide/Integrating_Your_Model.md
Normal file
@@ -0,0 +1,186 @@
|
||||
# Integrating Model Architecture
|
||||
|
||||
This document introduces how to integrate models into the `DiffSynth-Studio` framework for use by modules such as `Pipeline`.
|
||||
|
||||
## Step 1: Integrate Model Architecture Code
|
||||
|
||||
All model architecture implementations in `DiffSynth-Studio` are unified in `diffsynth/models`. Each `.py` code file implements a model architecture, and all models are loaded through `ModelPool` in `diffsynth/models/model_loader.py`. When integrating new model architectures, please create a new `.py` file under this path.
|
||||
|
||||
```shell
|
||||
diffsynth/models/
|
||||
├── general_modules.py
|
||||
├── model_loader.py
|
||||
├── qwen_image_controlnet.py
|
||||
├── qwen_image_dit.py
|
||||
├── qwen_image_text_encoder.py
|
||||
├── qwen_image_vae.py
|
||||
└── ...
|
||||
```
|
||||
|
||||
In most cases, we recommend integrating models in native `PyTorch` code form, with the model architecture class directly inheriting from `torch.nn.Module`, for example:
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class NewDiffSynthModel(torch.nn.Module):
|
||||
def __init__(self, dim=1024):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(dim, dim)
|
||||
self.activation = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
x = self.activation(x)
|
||||
return x
|
||||
```
|
||||
|
||||
If the model architecture implementation contains additional dependencies, we strongly recommend removing them, otherwise this will cause heavy package dependency issues. In our existing models, Qwen-Image's Blockwise ControlNet is integrated in this way. The code is lightweight, please refer to `diffsynth/models/qwen_image_controlnet.py`.
|
||||
|
||||
If the model has been integrated by Huggingface Library ([`transformers`](https://huggingface.co/docs/transformers/main/index), [`diffusers`](https://huggingface.co/docs/diffusers/main/index), etc.), we can integrate the model in a simpler way:
|
||||
|
||||
<details>
|
||||
<summary>Integrating Huggingface Library Style Model Architecture Code</summary>
|
||||
|
||||
The loading method for these models in Huggingface Library is:
|
||||
|
||||
```python
|
||||
from transformers import XXX_Model
|
||||
|
||||
model = XXX_Model.from_pretrained("path_to_your_model")
|
||||
```
|
||||
|
||||
`DiffSynth-Studio` does not support loading models through `from_pretrained` because this conflicts with VRAM management and other functions. Please rewrite the model architecture in the following format:
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class DiffSynth_XXX_Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
from transformers import XXX_Config, XXX_Model
|
||||
config = XXX_Config(**{
|
||||
"architectures": ["XXX_Model"],
|
||||
"other_configs": "Please copy and paste the other configs here.",
|
||||
})
|
||||
self.model = XXX_Model(config)
|
||||
|
||||
def forward(self, x):
|
||||
outputs = self.model(x)
|
||||
return outputs
|
||||
```
|
||||
|
||||
Where `XXX_Config` is the Config class corresponding to the model. For example, the Config class for `Qwen2_5_VLModel` is `Qwen2_5_VLConfig`, which can be found by consulting its source code. The content inside Config can usually be found in the `config.json` file in the model library. `DiffSynth-Studio` will not read the `config.json` file, so the content needs to be copied and pasted into the code.
|
||||
|
||||
In rare cases, version updates of `transformers` and `diffusers` may cause some models to be unable to import. Therefore, if possible, we still recommend using the model integration method in Step 1.1.
|
||||
|
||||
In our existing models, Qwen-Image's Text Encoder is integrated in this way. The code is lightweight, please refer to `diffsynth/models/qwen_image_text_encoder.py`.
|
||||
|
||||
</details>
|
||||
|
||||
## Step 2: Model File Format Conversion
|
||||
|
||||
Due to the variety of model file formats provided by developers in the open-source community, we sometimes need to convert model file formats to form correctly formatted [state dict](https://docs.pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html). This is common in the following situations:
|
||||
|
||||
* Model files built by different code libraries, for example [Wan-AI/Wan2.1-T2V-1.3B](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) and [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B-Diffusers).
|
||||
* Models modified during integration, for example, the Text Encoder of [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) adds a `model.` prefix in `diffsynth/models/qwen_image_text_encoder.py`.
|
||||
* Model files containing multiple models, for example, the VACE Adapter and base DiT model of [Wan-AI/Wan2.1-VACE-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) are mixed and stored in the same set of model files.
|
||||
|
||||
In our development philosophy, we hope to respect the wishes of model authors as much as possible. If we repackage the model files, for example [Comfy-Org/Qwen-Image_ComfyUI](https://www.modelscope.cn/models/Comfy-Org/Qwen-Image_ComfyUI), although we can call the model more conveniently, traffic (model page views and downloads, etc.) will be directed elsewhere, and the original author of the model will also lose the power to delete the model. Therefore, we have added the `diffsynth/utils/state_dict_converters` module to the framework for file format conversion during model loading.
|
||||
|
||||
This part of logic is very simple. Taking Qwen-Image's Text Encoder as an example, only 10 lines of code are needed:
|
||||
|
||||
```python
|
||||
def QwenImageTextEncoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for k in state_dict:
|
||||
v = state_dict[k]
|
||||
if k.startswith("visual."):
|
||||
k = "model." + k
|
||||
elif k.startswith("model."):
|
||||
k = k.replace("model.", "model.language_model.")
|
||||
state_dict_[k] = v
|
||||
return state_dict_
|
||||
```
|
||||
|
||||
## Step 3: Writing Model Config
|
||||
|
||||
Model Config is located in `diffsynth/configs/model_configs.py`, used to identify model types and load them. The following fields need to be filled in:
|
||||
|
||||
* `model_hash`: Model file hash value, which can be obtained through the `hash_model_file` function. This hash value is only related to the keys and tensor shapes in the model file's state dict, and is unrelated to other information in the file.
|
||||
* `model_name`: Model name, used for `Pipeline` to identify the required model. If different structured models play the same role in `Pipeline`, the same `model_name` can be used. When integrating new models, just ensure that `model_name` is different from other existing functional models. The corresponding model is fetched through `model_name` in the `Pipeline`'s `from_pretrained`.
|
||||
* `model_class`: Model architecture import path, pointing to the model architecture class implemented in Step 1, for example `diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder`.
|
||||
* `state_dict_converter`: Optional parameter. If model file format conversion is needed, the import path of the model conversion logic needs to be filled in, for example `diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter`.
|
||||
* `extra_kwargs`: Optional parameter. If additional parameters need to be passed when initializing the model, these parameters need to be filled in. For example, models [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) and [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) both adopt the `QwenImageBlockWiseControlNet` structure in `diffsynth/models/qwen_image_controlnet.py`, but the latter also needs additional configuration `additional_in_dim=4`. Therefore, this configuration information needs to be filled in the `extra_kwargs` field.
|
||||
|
||||
We provide a piece of code to quickly understand how models are loaded through this configuration information:
|
||||
|
||||
```python
|
||||
from diffsynth.core import hash_model_file, load_state_dict, skip_model_initialization
|
||||
from diffsynth.models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from diffsynth.utils.state_dict_converters.qwen_image_text_encoder import QwenImageTextEncoderStateDictConverter
|
||||
import torch
|
||||
|
||||
model_hash = "8004730443f55db63092006dd9f7110e"
|
||||
model_name = "qwen_image_text_encoder"
|
||||
model_class = QwenImageTextEncoder
|
||||
state_dict_converter = QwenImageTextEncoderStateDictConverter
|
||||
extra_kwargs = {}
|
||||
|
||||
model_path = [
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors",
|
||||
]
|
||||
if hash_model_file(model_path) == model_hash:
|
||||
with skip_model_initialization():
|
||||
model = model_class(**extra_kwargs)
|
||||
state_dict = load_state_dict(model_path, torch_dtype=torch.bfloat16, device="cuda")
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
print("Done!")
|
||||
```
|
||||
|
||||
> Q: The logic of the above code looks very simple, why is this part of code in `DiffSynth-Studio` extremely complex?
|
||||
>
|
||||
> A: Because we provide aggressive VRAM management functions that are coupled with the model loading logic, this leads to the complexity of the framework structure. We have tried our best to simplify the interface exposed to developers.
|
||||
|
||||
The `model_hash` in `diffsynth/configs/model_configs.py` is not uniquely existing. Multiple models may exist in the same model file. For this situation, please use multiple model Configs to load each model separately, and write the corresponding `state_dict_converter` to separate the parameters required by each model.
|
||||
|
||||
## Step 4: Verifying Whether the Model Can Be Recognized and Loaded
|
||||
|
||||
After model integration, the following code can be used to verify whether the model can be correctly recognized and loaded. The following code will attempt to load the model into memory:
|
||||
|
||||
```python
|
||||
from diffsynth.models.model_loader import ModelPool
|
||||
|
||||
model_pool = ModelPool()
|
||||
model_pool.auto_load_model(
|
||||
[
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors",
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
If the model can be recognized and loaded, you will see the following output:
|
||||
|
||||
```
|
||||
Loading models from: [
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
|
||||
]
|
||||
Loaded model: {
|
||||
"model_name": "qwen_image_text_encoder",
|
||||
"model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder",
|
||||
"extra_kwargs": null
|
||||
}
|
||||
```
|
||||
|
||||
## Step 5: Writing Model VRAM Management Scheme
|
||||
|
||||
`DiffSynth-Studio` supports complex VRAM management. See [Enabling VRAM Management](/docs/en/Developer_Guide/Enabling_VRAM_management.md) for details.
|
||||
66
docs/en/Developer_Guide/Training_Diffusion_Models.md
Normal file
66
docs/en/Developer_Guide/Training_Diffusion_Models.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# Integrating Model Training
|
||||
|
||||
After [integrating models](/docs/en/Developer_Guide/Integrating_Your_Model.md) and [implementing Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md), the next step is to integrate model training functionality.
|
||||
|
||||
## Training-Inference Consistent Pipeline Modification
|
||||
|
||||
To ensure strict consistency between training and inference processes, we will use most of the inference code during training, but still need to make minor modifications.
|
||||
|
||||
First, add extra logic during inference to switch the image-to-image/video-to-video logic based on the `scheduler` state. Taking Qwen-Image as an example:
|
||||
|
||||
```python
|
||||
class QwenImageUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, input_image, noise, tiled, tile_size, tile_stride):
|
||||
if input_image is None:
|
||||
return {"latents": noise, "input_latents": None}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
if pipe.scheduler.training:
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
```
|
||||
|
||||
Then, enable Gradient Checkpointing in `model_fn`, which will significantly reduce the VRAM required for training at the cost of computational speed. This is not mandatory, but we strongly recommend doing so.
|
||||
|
||||
Taking Qwen-Image as an example, before modification:
|
||||
|
||||
```python
|
||||
text, image = block(
|
||||
image=image,
|
||||
text=text,
|
||||
temb=conditioning,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
```
|
||||
|
||||
After modification:
|
||||
|
||||
```python
|
||||
from ..core import gradient_checkpoint_forward
|
||||
|
||||
text, image = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
image=image,
|
||||
text=text,
|
||||
temb=conditioning,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
```
|
||||
|
||||
## Writing Training Scripts
|
||||
|
||||
`DiffSynth-Studio` does not strictly encapsulate the training framework, but exposes the script content to developers. This approach makes it more convenient to modify training scripts to implement additional functions. Developers can refer to existing training scripts, such as `examples/qwen_image/model_training/train.py`, for modification to adapt to new model training.
|
||||
201
docs/en/Model_Details/FLUX.md
Normal file
201
docs/en/Model_Details/FLUX.md
Normal file
@@ -0,0 +1,201 @@
|
||||
# FLUX
|
||||
|
||||

|
||||
|
||||
FLUX is an image generation model series developed and open-sourced by Black Forest Labs.
|
||||
|
||||
## Installation
|
||||
|
||||
Before using this project for model inference and training, please install DiffSynth-Studio first.
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md).
|
||||
|
||||
## Quick Start
|
||||
|
||||
Run the following code to quickly load the [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) model and perform inference. VRAM management is enabled, and the framework will automatically control model parameter loading based on remaining VRAM. Minimum 8GB VRAM is required to run.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e4m3fn,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config),
|
||||
],
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 1,
|
||||
)
|
||||
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
image = pipe(prompt=prompt, seed=0)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## Model Overview
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Model Lineage</summary>
|
||||
|
||||
```mermaid
|
||||
graph LR;
|
||||
FLUX.1-Series-->black-forest-labs/FLUX.1-dev;
|
||||
FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;
|
||||
FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;
|
||||
black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;
|
||||
FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;
|
||||
FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;
|
||||
FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;
|
||||
black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;
|
||||
black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;
|
||||
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;
|
||||
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;
|
||||
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;
|
||||
black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;
|
||||
black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;
|
||||
Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;
|
||||
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;
|
||||
Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
| Model ID | Extra Parameters | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||
| - | - | - | - | - | - | - | - |
|
||||
| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py) |
|
||||
| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) |
|
||||
| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) |
|
||||
| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) |
|
||||
| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) |
|
||||
| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) |
|
||||
| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) |
|
||||
| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) |
|
||||
| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) |
|
||||
| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - |
|
||||
| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - |
|
||||
| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](/examples/flux/model_inference/Step1X-Edit.py) | [code](/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](/examples/flux/model_training/full/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_lora/Step1X-Edit.py) |
|
||||
| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](/examples/flux/model_inference/FLEX.2-preview.py) | [code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py) |
|
||||
| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](/examples/flux/model_training/full/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_lora/Nexus-Gen.py) |
|
||||
|
||||
Special Training Scripts:
|
||||
|
||||
* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/flux/model_training/special/differential_training/)
|
||||
* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/flux/model_training/special/fp8_training/)
|
||||
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/flux/model_training/special/split_training/)
|
||||
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh)
|
||||
|
||||
## Model Inference
|
||||
|
||||
Models are loaded via `FluxImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models).
|
||||
|
||||
Input parameters for `FluxImagePipeline` inference include:
|
||||
|
||||
* `prompt`: Prompt describing the content appearing in the image.
|
||||
* `negative_prompt`: Negative prompt describing content that should not appear in the image, default value is `""`.
|
||||
* `cfg_scale`: Classifier-free guidance parameter, default value is 1. When set to a value greater than 1, CFG is enabled.
|
||||
* `height`: Image height, must be a multiple of 16.
|
||||
* `width`: Image width, must be a multiple of 16.
|
||||
* `seed`: Random seed. Default is `None`, meaning completely random.
|
||||
* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results.
|
||||
* `num_inference_steps`: Number of inference steps, default value is 30.
|
||||
* `embedded_guidance`: Embedded guidance parameter, default value is 3.5.
|
||||
* `t5_sequence_length`: Sequence length of the T5 text encoder, default is 512.
|
||||
* `tiled`: Whether to enable VAE tiling inference, default is `False`. Setting to `True` can significantly reduce VRAM usage during VAE encoding/decoding stages, producing slight errors and slightly longer inference time.
|
||||
* `tile_size`: Tile size during VAE encoding/decoding stages, default is 128, only effective when `tiled=True`.
|
||||
* `tile_stride`: Tile stride during VAE encoding/decoding stages, default is 64, only effective when `tiled=True`, must be less than or equal to `tile_size`.
|
||||
* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be disabled by setting to `lambda x:x`.
|
||||
* `controlnet_inputs`: ControlNet model inputs, type is `ControlNetInput` list.
|
||||
* `ipadapter_images`: IP-Adapter model input image list.
|
||||
* `ipadapter_scale`: Guidance strength of the IP-Adapter model.
|
||||
* `infinityou_id_image`: InfiniteYou model input image.
|
||||
* `infinityou_guidance`: Guidance strength of the InfiniteYou model.
|
||||
* `kontext_images`: Kontext model input images.
|
||||
* `eligen_entity_prompts`: EliGen partition control prompt list.
|
||||
* `eligen_entity_masks`: EliGen partition control region mask image list.
|
||||
* `eligen_enable_on_negative`: Whether to enable EliGen partition control on the negative side of CFG.
|
||||
* `eligen_enable_inpaint`: Whether to enable EliGen partition control inpainting function.
|
||||
* `lora_encoder_inputs`: LoRA encoder input image list.
|
||||
* `lora_encoder_scale`: Guidance strength of the LoRA encoder.
|
||||
* `step1x_reference_image`: Step1X model reference image.
|
||||
* `flex_inpaint_image`: Flex model image to be inpainted.
|
||||
* `flex_inpaint_mask`: Flex model inpainting mask.
|
||||
* `flex_control_image`: Flex model control image.
|
||||
* `flex_control_strength`: Flex model control strength.
|
||||
* `flex_control_stop`: Flex model control stop timestep.
|
||||
* `nexus_gen_reference_image`: Nexus-Gen model reference image.
|
||||
|
||||
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
|
||||
|
||||
## Model Training
|
||||
|
||||
FLUX series models are uniformly trained through [`examples/flux/model_training/train.py`](/examples/flux/model_training/train.py), and the script parameters include:
|
||||
|
||||
* General Training Parameters
|
||||
* Dataset Basic Configuration
|
||||
* `--dataset_base_path`: Root directory of the dataset.
|
||||
* `--dataset_metadata_path`: Metadata file path of the dataset.
|
||||
* `--dataset_repeat`: Number of times the dataset is repeated in each epoch.
|
||||
* `--dataset_num_workers`: Number of processes for each DataLoader.
|
||||
* `--data_file_keys`: Field names to be loaded from metadata, usually image or video file paths, separated by `,`.
|
||||
* Model Loading Configuration
|
||||
* `--model_paths`: Paths of models to be loaded. JSON format.
|
||||
* `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors"`. Separated by commas.
|
||||
* `--extra_inputs`: Extra input parameters required by the model Pipeline, e.g., `controlnet_inputs` when training ControlNet models, separated by `,`.
|
||||
* `--fp8_models`: Models loaded in FP8 format, consistent with `--model_paths` or `--model_id_with_origin_paths` format. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA).
|
||||
* Training Basic Configuration
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
|
||||
* `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training.
|
||||
* `--weight_decay`: Weight decay size, see [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).
|
||||
* `--task`: Training task, default is `sft`. Some models support more training modes, please refer to the documentation of each specific model.
|
||||
* Output Configuration
|
||||
* `--output_path`: Model saving path.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in the state dict of the model file.
|
||||
* `--save_steps`: Interval of training steps to save the model. If this parameter is left blank, the model is saved once per epoch.
|
||||
* LoRA Configuration
|
||||
* `--lora_base_model`: Which model to add LoRA to.
|
||||
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||
* `--lora_rank`: Rank of LoRA.
|
||||
* `--lora_checkpoint`: Path of the LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint.
|
||||
* `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training.
|
||||
* `--preset_lora_model`: Model that the preset LoRA is merged into, e.g., `dit`.
|
||||
* Gradient Configuration
|
||||
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory.
|
||||
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||
* Image Width/Height Configuration (Applicable to Image Generation and Video Generation Models)
|
||||
* `--height`: Height of image or video. Leave `height` and `width` blank to enable dynamic resolution.
|
||||
* `--width`: Width of image or video. Leave `height` and `width` blank to enable dynamic resolution.
|
||||
* `--max_pixels`: Maximum pixel area of image or video frames. When dynamic resolution is enabled, images with resolution larger than this value will be downscaled, and images with resolution smaller than this value will remain unchanged.
|
||||
* FLUX Specific Parameters
|
||||
* `--tokenizer_1_path`: Path of the CLIP tokenizer, leave blank to automatically download from remote.
|
||||
* `--tokenizer_2_path`: Path of the T5 tokenizer, leave blank to automatically download from remote.
|
||||
* `--align_to_opensource_format`: Whether to align LoRA format to open-source format, only applicable to DiT's LoRA.
|
||||
|
||||
We have built a sample image dataset for your testing. You can download this dataset with the following command:
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
```
|
||||
|
||||
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).
|
||||
138
docs/en/Model_Details/FLUX2.md
Normal file
138
docs/en/Model_Details/FLUX2.md
Normal file
@@ -0,0 +1,138 @@
|
||||
# FLUX.2
|
||||
|
||||
FLUX.2 is an image generation model trained and open-sourced by Black Forest Labs.
|
||||
|
||||
## Installation
|
||||
|
||||
Before using this project for model inference and training, please install DiffSynth-Studio first.
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md).
|
||||
|
||||
## Quick Start
|
||||
|
||||
Run the following code to quickly load the [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) model and perform inference. VRAM management is enabled, and the framework will automatically control model parameter loading based on remaining VRAM. Minimum 10GB VRAM is required to run.
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene."
|
||||
image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Model ID | Inference | Low VRAM Inference | LoRA Training | Validation After LoRA Training |
|
||||
| - | - | - | - | - |
|
||||
| [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) | [code](/examples/flux2/model_inference/FLUX.2-dev.py) | [code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py) | [code](/examples/flux2/model_training/lora/FLUX.2-dev.sh) | [code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py) |
|
||||
|
||||
Special Training Scripts:
|
||||
|
||||
* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/flux/model_training/special/differential_training/)
|
||||
* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/flux/model_training/special/fp8_training/)
|
||||
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/flux/model_training/special/split_training/)
|
||||
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh)
|
||||
|
||||
## Model Inference
|
||||
|
||||
Models are loaded via `Flux2ImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models).
|
||||
|
||||
Input parameters for `Flux2ImagePipeline` inference include:
|
||||
|
||||
* `prompt`: Prompt describing the content appearing in the image.
|
||||
* `negative_prompt`: Negative prompt describing content that should not appear in the image, default value is `""`.
|
||||
* `cfg_scale`: Classifier-free guidance parameter, default value is 1. When set to a value greater than 1, CFG is enabled.
|
||||
* `height`: Image height, must be a multiple of 16.
|
||||
* `width`: Image width, must be a multiple of 16.
|
||||
* `seed`: Random seed. Default is `None`, meaning completely random.
|
||||
* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results.
|
||||
* `num_inference_steps`: Number of inference steps, default value is 30.
|
||||
* `embedded_guidance`: Embedded guidance parameter, default value is 3.5.
|
||||
* `t5_sequence_length`: Sequence length of the T5 text encoder, default is 512.
|
||||
* `tiled`: Whether to enable VAE tiling inference, default is `False`. Setting to `True` can significantly reduce VRAM usage during VAE encoding/decoding stages, producing slight errors and slightly longer inference time.
|
||||
* `tile_size`: Tile size during VAE encoding/decoding stages, default is 128, only effective when `tiled=True`.
|
||||
* `tile_stride`: Tile stride during VAE encoding/decoding stages, default is 64, only effective when `tiled=True`, must be less than or equal to `tile_size`.
|
||||
* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be disabled by setting to `lambda x:x`.
|
||||
|
||||
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
|
||||
|
||||
## Model Training
|
||||
|
||||
FLUX.2 series models are uniformly trained through [`examples/flux2/model_training/train.py`](/examples/flux2/model_training/train.py), and the script parameters include:
|
||||
|
||||
* General Training Parameters
|
||||
* Dataset Basic Configuration
|
||||
* `--dataset_base_path`: Root directory of the dataset.
|
||||
* `--dataset_metadata_path`: Metadata file path of the dataset.
|
||||
* `--dataset_repeat`: Number of times the dataset is repeated in each epoch.
|
||||
* `--dataset_num_workers`: Number of processes for each DataLoader.
|
||||
* `--data_file_keys`: Field names to be loaded from metadata, usually image or video file paths, separated by `,`.
|
||||
* Model Loading Configuration
|
||||
* `--model_paths`: Paths of models to be loaded. JSON format.
|
||||
* `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `"black-forest-labs/FLUX.2-dev:text_encoder/*.safetensors"`. Separated by commas.
|
||||
* `--extra_inputs`: Extra input parameters required by the model Pipeline, e.g., `controlnet_inputs` when training ControlNet models, separated by `,`.
|
||||
* `--fp8_models`: Models loaded in FP8 format, consistent with `--model_paths` or `--model_id_with_origin_paths` format. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA).
|
||||
* Training Basic Configuration
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
|
||||
* `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training.
|
||||
* `--weight_decay`: Weight decay size, see [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).
|
||||
* `--task`: Training task, default is `sft`. Some models support more training modes, please refer to the documentation of each specific model.
|
||||
* Output Configuration
|
||||
* `--output_path`: Model saving path.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in the state dict of the model file.
|
||||
* `--save_steps`: Interval of training steps to save the model. If this parameter is left blank, the model is saved once per epoch.
|
||||
* LoRA Configuration
|
||||
* `--lora_base_model`: Which model to add LoRA to.
|
||||
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||
* `--lora_rank`: Rank of LoRA.
|
||||
* `--lora_checkpoint`: Path of the LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint.
|
||||
* `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training.
|
||||
* `--preset_lora_model`: Model that the preset LoRA is merged into, e.g., `dit`.
|
||||
* Gradient Configuration
|
||||
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory.
|
||||
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||
* Image Width/Height Configuration (Applicable to Image Generation and Video Generation Models)
|
||||
* `--height`: Height of image or video. Leave `height` and `width` blank to enable dynamic resolution.
|
||||
* `--width`: Width of image or video. Leave `height` and `width` blank to enable dynamic resolution.
|
||||
* `--max_pixels`: Maximum pixel area of image or video frames. When dynamic resolution is enabled, images with resolution larger than this value will be downscaled, and images with resolution smaller than this value will remain unchanged.
|
||||
* FLUX.2 Specific Parameters
|
||||
* `--tokenizer_path`: Path of the tokenizer, applicable to text-to-image models, leave blank to automatically download from remote.
|
||||
|
||||
We have built a sample image dataset for your testing. You can download this dataset with the following command:
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
```
|
||||
|
||||
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).
|
||||
291
docs/en/Model_Details/Overview.md
Normal file
291
docs/en/Model_Details/Overview.md
Normal file
@@ -0,0 +1,291 @@
|
||||
# Model Directory
|
||||
|
||||
## Qwen-Image
|
||||
|
||||
Documentation: [./Qwen-Image.md](/docs/en/Model_Details/Qwen-Image.md)
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Effect Preview</summary>
|
||||
|
||||

|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Quick Start</summary>
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(
|
||||
prompt, seed=0, num_inference_steps=40,
|
||||
# edit_image=Image.open("xxx.jpg").resize((1328, 1328)) # For Qwen-Image-Edit
|
||||
)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Model Lineage</summary>
|
||||
|
||||
```mermaid
|
||||
graph LR;
|
||||
Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;
|
||||
Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;
|
||||
Qwen/Qwen-Image-->EliGen-Series;
|
||||
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;
|
||||
DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;
|
||||
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;
|
||||
Qwen/Qwen-Image-->Distill-Series;
|
||||
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;
|
||||
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;
|
||||
Qwen/Qwen-Image-->ControlNet-Series;
|
||||
ControlNet-Series-->Blockwise-ControlNet-Series;
|
||||
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;
|
||||
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;
|
||||
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;
|
||||
ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;
|
||||
Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||
| - | - | - | - | - | - | - |
|
||||
| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](/examples/qwen_image/model_inference/Qwen-Image.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |
|
||||
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
|
||||
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - |
|
||||
|
||||
## FLUX Series
|
||||
|
||||
Documentation: [./FLUX.md](/docs/en/Model_Details/FLUX.md)
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Effect Preview</summary>
|
||||
|
||||

|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Quick Start</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image = pipe(prompt="a cat", seed=0)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Model Lineage</summary>
|
||||
|
||||
```mermaid
|
||||
graph LR;
|
||||
FLUX.1-Series-->black-forest-labs/FLUX.1-dev;
|
||||
FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;
|
||||
FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;
|
||||
black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;
|
||||
FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;
|
||||
FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;
|
||||
FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;
|
||||
black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;
|
||||
black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;
|
||||
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;
|
||||
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;
|
||||
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;
|
||||
black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;
|
||||
black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;
|
||||
Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;
|
||||
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;
|
||||
Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
| Model ID | Extra Parameters | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||
| - | - | - | - | - | - | - | - |
|
||||
| [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py) |
|
||||
| [black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) | | [code](/examples/flux/model_inference/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py) |
|
||||
| [black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev) | `kontext_images` | [code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py) | [code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py) |
|
||||
| [alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py) |
|
||||
| [InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py) |
|
||||
| [jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler) | `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py) |
|
||||
| [InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter) | `ipadapter_images`, `ipadapter_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py) |
|
||||
| [ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou) | `infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs` | [code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py) | [code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py) |
|
||||
| [DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) | `eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint` | [code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py) | - | - | [code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh) | [code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py) |
|
||||
| [DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev) | `lora_encoder_inputs`, `lora_encoder_scale` | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py) | [code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh) | [code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py) | - | - |
|
||||
| [DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev) | | [code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py) | - | - | - | - | - |
|
||||
| [stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit) | `step1x_reference_image` | [code](/examples/flux/model_inference/Step1X-Edit.py) | [code](/examples/flux/model_inference_low_vram/Step1X-Edit.py) | [code](/examples/flux/model_training/full/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_full/Step1X-Edit.py) | [code](/examples/flux/model_training/lora/Step1X-Edit.sh) | [code](/examples/flux/model_training/validate_lora/Step1X-Edit.py) |
|
||||
| [ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview) | `flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop` | [code](/examples/flux/model_inference/FLEX.2-preview.py) | [code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py) | [code](/examples/flux/model_training/full/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_full/FLEX.2-preview.py) | [code](/examples/flux/model_training/lora/FLEX.2-preview.sh) | [code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py) |
|
||||
| [DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2) | `nexus_gen_reference_image` | [code](/examples/flux/model_inference/Nexus-Gen-Editing.py) | [code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py) | [code](/examples/flux/model_training/full/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_full/Nexus-Gen.py) | [code](/examples/flux/model_training/lora/Nexus-Gen.sh) | [code](/examples/flux/model_training/validate_lora/Nexus-Gen.py) |
|
||||
|
||||
## Wan Series
|
||||
|
||||
Documentation: [./Wan.md](/docs/en/Model_Details/Wan.md)
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Effect Preview</summary>
|
||||
|
||||
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Quick Start</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.utils.data import save_video
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
)
|
||||
save_video(video, "video.mp4", fps=15, quality=5)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Model Lineage</summary>
|
||||
|
||||
```mermaid
|
||||
graph LR;
|
||||
Wan-Series-->Wan2.1-Series;
|
||||
Wan-Series-->Wan2.2-Series;
|
||||
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;
|
||||
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;
|
||||
Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;
|
||||
Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;
|
||||
iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;
|
||||
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;
|
||||
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;
|
||||
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;
|
||||
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;
|
||||
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;
|
||||
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;
|
||||
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;
|
||||
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;
|
||||
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;
|
||||
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;
|
||||
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;
|
||||
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;
|
||||
Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;
|
||||
Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;
|
||||
Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;
|
||||
Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;
|
||||
Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;
|
||||
Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;
|
||||
Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;
|
||||
Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;
|
||||
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;
|
||||
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;
|
||||
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
| Model ID | Extra Parameters | Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||
| - | - | - | - | - | - | - |
|
||||
| [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py) |
|
||||
| [Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py) |
|
||||
| [Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py) |
|
||||
| [Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py) |
|
||||
| [Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py) |
|
||||
| [iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py) |
|
||||
| [Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py) |
|
||||
| [Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py) |
|
||||
| [PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py) |
|
||||
| [PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py) |
|
||||
| [PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py) |
|
||||
| [PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py) |
|
||||
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py) |
|
||||
| [PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py) |
|
||||
| [PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py) |
|
||||
| [PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py) |
|
||||
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) |
|
||||
| [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py) |
|
||||
| [DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1) | `motion_bucket_id` | [code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py) |
|
||||
| [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) | | [code](/examples/wanvideo/model_inference/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/full/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py) |
|
||||
| [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) | `longcat_video` | [code](/examples/wanvideo/model_inference/LongCat-Video.py) | [code](/examples/wanvideo/model_training/full/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py) | [code](/examples/wanvideo/model_training/lora/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py) |
|
||||
| [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) | `vap_video`, `vap_prompt` | [code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py) |
|
||||
| [Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | | [code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py) |
|
||||
| [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py) |
|
||||
| [Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py) |
|
||||
| [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) | `input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video` | [code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py) |
|
||||
| [Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B) | `input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video` | [code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py) |
|
||||
| [PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py) |
|
||||
| [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) |
|
||||
| [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) |
|
||||
| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) |
|
||||
|
||||
* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/wanvideo/model_training/special/fp8_training/)
|
||||
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/wanvideo/model_training/special/split_training/)
|
||||
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/wanvideo/model_training/special/direct_distill/)
|
||||
191
docs/en/Model_Details/Qwen-Image.md
Normal file
191
docs/en/Model_Details/Qwen-Image.md
Normal file
@@ -0,0 +1,191 @@
|
||||
# Qwen-Image
|
||||
|
||||

|
||||
|
||||
Qwen-Image is an image generation model trained and open-sourced by the Tongyi Lab Qwen Team of Alibaba.
|
||||
|
||||
## Installation
|
||||
|
||||
Before using this project for model inference and training, please install DiffSynth-Studio first.
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md).
|
||||
|
||||
## Quick Start
|
||||
|
||||
Run the following code to quickly load the [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) model and perform inference. VRAM management is enabled, and the framework will automatically control model parameter loading based on remaining VRAM. Minimum 8GB VRAM is required to run.
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## Model Overview
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Model Lineage</summary>
|
||||
|
||||
```mermaid
|
||||
graph LR;
|
||||
Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;
|
||||
Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;
|
||||
Qwen/Qwen-Image-->EliGen-Series;
|
||||
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;
|
||||
DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;
|
||||
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;
|
||||
Qwen/Qwen-Image-->Distill-Series;
|
||||
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;
|
||||
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;
|
||||
Qwen/Qwen-Image-->ControlNet-Series;
|
||||
ControlNet-Series-->Blockwise-ControlNet-Series;
|
||||
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;
|
||||
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;
|
||||
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;
|
||||
ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;
|
||||
Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||
| - | - | - | - | - | - | - |
|
||||
| [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) | [code](/examples/qwen_image/model_inference/Qwen-Image.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py) |
|
||||
| [Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py) |
|
||||
| [Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster) | [code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA) | [code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) | [code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py) | [code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union) | [code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py) | - | - | [code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh) | [code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py) |
|
||||
| [DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix) | [code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py) | [code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py) | - | - | - | - |
|
||||
|
||||
Special Training Scripts:
|
||||
|
||||
* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/qwen_image/model_training/special/differential_training/)
|
||||
* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/qwen_image/model_training/special/fp8_training/)
|
||||
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/qwen_image/model_training/special/split_training/)
|
||||
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)
|
||||
|
||||
## Model Inference
|
||||
|
||||
Models are loaded via `QwenImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models).
|
||||
|
||||
Input parameters for `QwenImagePipeline` inference include:
|
||||
|
||||
* `prompt`: Prompt describing the content appearing in the image.
|
||||
* `negative_prompt`: Negative prompt describing content that should not appear in the image, default value is `""`.
|
||||
* `cfg_scale`: Classifier-free guidance parameter, default value is 4. When set to 1, it no longer takes effect.
|
||||
* `input_image`: Input image for image-to-image generation, used in conjunction with `denoising_strength`.
|
||||
* `denoising_strength`: Denoising strength, range is 0~1, default value is 1. When the value approaches 0, the generated image is similar to the input image; when the value approaches 1, the generated image differs more from the input image. When `input_image` parameter is not provided, do not set this to a non-1 value.
|
||||
* `inpaint_mask`: Image inpainting mask image.
|
||||
* `inpaint_blur_size`: Edge blur width for image inpainting.
|
||||
* `inpaint_blur_sigma`: Edge blur strength for image inpainting.
|
||||
* `height`: Image height, must be a multiple of 16.
|
||||
* `width`: Image width, must be a multiple of 16.
|
||||
* `seed`: Random seed. Default is `None`, meaning completely random.
|
||||
* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results.
|
||||
* `num_inference_steps`: Number of inference steps, default value is 30.
|
||||
* `exponential_shift_mu`: Fixed parameter used in sampling timesteps. Leave blank to sample based on image width and height.
|
||||
* `blockwise_controlnet_inputs`: Blockwise ControlNet model inputs.
|
||||
* `eligen_entity_prompts`: EliGen partition control prompts.
|
||||
* `eligen_entity_masks`: EliGen partition control region mask images.
|
||||
* `eligen_enable_on_negative`: Whether to enable EliGen partition control on the negative side of CFG.
|
||||
* `edit_image`: Edit model images to be edited, supports multiple images.
|
||||
* `edit_image_auto_resize`: Whether to automatically scale edit images.
|
||||
* `edit_rope_interpolation`: Whether to enable ROPE interpolation on low-resolution edit images.
|
||||
* `context_image`: In-Context Control input image.
|
||||
* `tiled`: Whether to enable VAE tiling inference, default is `False`. Setting to `True` can significantly reduce VRAM usage during VAE encoding/decoding stages, producing slight errors and slightly longer inference time.
|
||||
* `tile_size`: Tile size during VAE encoding/decoding stages, default is 128, only effective when `tiled=True`.
|
||||
* `tile_stride`: Tile stride during VAE encoding/decoding stages, default is 64, only effective when `tiled=True`, must be less than or equal to `tile_size`.
|
||||
* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be disabled by setting to `lambda x:x`.
|
||||
|
||||
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
|
||||
|
||||
## Model Training
|
||||
|
||||
Qwen-Image series models are uniformly trained through [`examples/qwen_image/model_training/train.py`](/examples/qwen_image/model_training/train.py), and the script parameters include:
|
||||
|
||||
* General Training Parameters
|
||||
* Dataset Basic Configuration
|
||||
* `--dataset_base_path`: Root directory of the dataset.
|
||||
* `--dataset_metadata_path`: Metadata file path of the dataset.
|
||||
* `--dataset_repeat`: Number of times the dataset is repeated in each epoch.
|
||||
* `--dataset_num_workers`: Number of processes for each DataLoader.
|
||||
* `--data_file_keys`: Field names to be loaded from metadata, usually image or video file paths, separated by `,`.
|
||||
* Model Loading Configuration
|
||||
* `--model_paths`: Paths of models to be loaded. JSON format.
|
||||
* `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors"`. Separated by commas.
|
||||
* `--extra_inputs`: Extra input parameters required by the model Pipeline, e.g., extra parameters `edit_image` when training image editing model Qwen-Image-Edit, separated by `,`.
|
||||
* `--fp8_models`: Models loaded in FP8 format, consistent with `--model_paths` or `--model_id_with_origin_paths` format. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA).
|
||||
* Training Basic Configuration
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
|
||||
* `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training.
|
||||
* `--weight_decay`: Weight decay size, see [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).
|
||||
* `--task`: Training task, default is `sft`. Some models support more training modes, please refer to the documentation of each specific model.
|
||||
* Output Configuration
|
||||
* `--output_path`: Model saving path.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in the state dict of the model file.
|
||||
* `--save_steps`: Interval of training steps to save the model. If this parameter is left blank, the model is saved once per epoch.
|
||||
* LoRA Configuration
|
||||
* `--lora_base_model`: Which model to add LoRA to.
|
||||
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||
* `--lora_rank`: Rank of LoRA.
|
||||
* `--lora_checkpoint`: Path of the LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint.
|
||||
* `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training.
|
||||
* `--preset_lora_model`: Model that the preset LoRA is merged into, e.g., `dit`.
|
||||
* Gradient Configuration
|
||||
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory.
|
||||
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||
* Image Width/Height Configuration (Applicable to Image Generation and Video Generation Models)
|
||||
* `--height`: Height of image or video. Leave `height` and `width` blank to enable dynamic resolution.
|
||||
* `--width`: Width of image or video. Leave `height` and `width` blank to enable dynamic resolution.
|
||||
* `--max_pixels`: Maximum pixel area of image or video frames. When dynamic resolution is enabled, images with resolution larger than this value will be downscaled, and images with resolution smaller than this value will remain unchanged.
|
||||
* Qwen-Image Specific Parameters
|
||||
* `--tokenizer_path`: Path of the tokenizer, applicable to text-to-image models, leave blank to automatically download from remote.
|
||||
* `--processor_path`: Path of the processor, applicable to image editing models, leave blank to automatically download from remote.
|
||||
|
||||
We have built a sample image dataset for your testing. You can download this dataset with the following command:
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
```
|
||||
|
||||
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).
|
||||
252
docs/en/Model_Details/Wan.md
Normal file
252
docs/en/Model_Details/Wan.md
Normal file
@@ -0,0 +1,252 @@
|
||||
# Wan
|
||||
|
||||
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||
|
||||
Wan is a video generation model series developed by the Tongyi Wanxiang Team of Alibaba Tongyi Lab.
|
||||
|
||||
## Installation
|
||||
|
||||
Before using this project for model inference and training, please install DiffSynth-Studio first.
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md).
|
||||
|
||||
## Quick Start
|
||||
|
||||
Run the following code to quickly load the [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) model and perform inference. VRAM management is enabled, and the framework will automatically control model parameter loading based on remaining VRAM. Minimum 8GB VRAM is required to run.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
|
||||
)
|
||||
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
)
|
||||
save_video(video, "video.mp4", fps=15, quality=5)
|
||||
```
|
||||
|
||||
## Model Overview
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Model Lineage</summary>
|
||||
|
||||
```mermaid
|
||||
graph LR;
|
||||
Wan-Series-->Wan2.1-Series;
|
||||
Wan-Series-->Wan2.2-Series;
|
||||
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;
|
||||
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;
|
||||
Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;
|
||||
Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;
|
||||
iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;
|
||||
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;
|
||||
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;
|
||||
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;
|
||||
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;
|
||||
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;
|
||||
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;
|
||||
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;
|
||||
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;
|
||||
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;
|
||||
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;
|
||||
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;
|
||||
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;
|
||||
Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;
|
||||
Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;
|
||||
Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;
|
||||
Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;
|
||||
Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;
|
||||
Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;
|
||||
Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;
|
||||
Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;
|
||||
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;
|
||||
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;
|
||||
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
| Model ID | Extra Parameters | Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||
| - | - | - | - | - | - | - |
|
||||
| [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py) |
|
||||
| [Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B) | | [code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py) |
|
||||
| [Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py) |
|
||||
| [Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py) |
|
||||
| [Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py) |
|
||||
| [iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py) |
|
||||
| [Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py) |
|
||||
| [Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py) |
|
||||
| [PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py) |
|
||||
| [PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py) |
|
||||
| [PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py) |
|
||||
| [PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control) | `control_video` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py) |
|
||||
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py) |
|
||||
| [PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py) |
|
||||
| [PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py) |
|
||||
| [PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py) |
|
||||
| [PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py) |
|
||||
| [PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py) |
|
||||
| [DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1) | `motion_bucket_id` | [code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py) | [code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py) |
|
||||
| [krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video) | | [code](/examples/wanvideo/model_inference/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/full/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py) | [code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh) | [code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py) |
|
||||
| [meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video) | `longcat_video` | [code](/examples/wanvideo/model_inference/LongCat-Video.py) | [code](/examples/wanvideo/model_training/full/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py) | [code](/examples/wanvideo/model_training/lora/LongCat-Video.sh) | [code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py) |
|
||||
| [ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B) | `vap_video`, `vap_prompt` | [code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py) | [code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py) |
|
||||
| [Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B) | | [code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py) |
|
||||
| [Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py) |
|
||||
| [Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B) | `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py) |
|
||||
| [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) | `input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video` | [code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py) |
|
||||
| [Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B) | `input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video` | [code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py) |
|
||||
| [PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B) | `vace_control_video`, `vace_reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py) |
|
||||
| [PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP) | `input_image`, `end_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py) |
|
||||
| [PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control) | `control_video`, `reference_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py) |
|
||||
| [PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera) | `control_camera_video`, `input_image` | [code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py) | [code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh) | [code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py) |
|
||||
|
||||
* FP8 Precision Training: [doc](/docs/en/Training/FP8_Precision.md), [code](/examples/wanvideo/model_training/special/fp8_training/)
|
||||
* Two-stage Split Training: [doc](/docs/en/Training/Split_Training.md), [code](/examples/wanvideo/model_training/special/split_training/)
|
||||
* End-to-end Direct Distillation: [doc](/docs/en/Training/Direct_Distill.md), [code](/examples/wanvideo/model_training/special/direct_distill/)
|
||||
|
||||
## Model Inference
|
||||
|
||||
Models are loaded via `WanVideoPipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models).
|
||||
|
||||
Input parameters for `WanVideoPipeline` inference include:
|
||||
|
||||
* `prompt`: Prompt describing the content appearing in the video.
|
||||
* `negative_prompt`: Negative prompt describing content that should not appear in the video, default value is `""`.
|
||||
* `cfg_scale`: Classifier-free guidance parameter, default value is 5. When set to 1, it no longer takes effect.
|
||||
* `input_image`: Input image for image-to-video generation, used in conjunction with `denoising_strength`.
|
||||
* `end_image`: End image for first-and-last frame video generation.
|
||||
* `input_video`: Input video for video-to-video generation, used in conjunction with `denoising_strength`.
|
||||
* `denoising_strength`: Denoising strength, range is 0~1, default value is 1. When the value approaches 0, the generated video is similar to the input video; when the value approaches 1, the generated video differs more from the input video.
|
||||
* `control_video`: Control video for controlling the video generation process.
|
||||
* `reference_image`: Reference image for maintaining consistency of certain features in the generated video.
|
||||
* `camera_control_direction`: Camera control direction, optional values are `"Left"`, `"Right"`, `"Up"`, `"Down"`, `"LeftUp"`, `"LeftDown"`, `"RightUp"`, `"RightDown"`.
|
||||
* `camera_control_speed`: Camera control speed, default value is 1/54.
|
||||
* `vace_video`: VACE control video.
|
||||
* `vace_video_mask`: VACE control video mask.
|
||||
* `vace_reference_image`: VACE reference image.
|
||||
* `vace_scale`: VACE control strength, default value is 1.0.
|
||||
* `animate_pose_video`: `animate` model pose video.
|
||||
* `animate_face_video`: `animate` model face video.
|
||||
* `animate_inpaint_video`: `animate` model local editing video.
|
||||
* `animate_mask_video`: `animate` model mask video.
|
||||
* `vap_video`: `video-as-prompt` input video.
|
||||
* `vap_prompt`: `video-as-prompt` text description.
|
||||
* `negative_vap_prompt`: `video-as-prompt` negative text description.
|
||||
* `input_audio`: Input audio for speech-to-video generation.
|
||||
* `audio_embeds`: Audio embedding vectors.
|
||||
* `audio_sample_rate`: Audio sampling rate, default value is 16000.
|
||||
* `s2v_pose_video`: S2V model pose video.
|
||||
* `motion_video`: S2V model motion video.
|
||||
* `height`: Video height, must be a multiple of 16.
|
||||
* `width`: Video width, must be a multiple of 16.
|
||||
* `num_frames`: Number of video frames, default value is 81, must be a multiple of 4 + 1.
|
||||
* `seed`: Random seed. Default is `None`, meaning completely random.
|
||||
* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results.
|
||||
* `num_inference_steps`: Number of inference steps, default value is 50.
|
||||
* `motion_bucket_id`: Motion control parameter, the larger the value, the greater the motion amplitude.
|
||||
* `longcat_video`: LongCat input video.
|
||||
* `tiled`: Whether to enable VAE tiling inference, default is `True`. Setting to `True` can significantly reduce VRAM usage during VAE encoding/decoding stages, producing slight errors and slightly longer inference time.
|
||||
* `tile_size`: Tile size during VAE encoding/decoding stages, default is `(30, 52)`, only effective when `tiled=True`.
|
||||
* `tile_stride`: Tile stride during VAE encoding/decoding stages, default is `(15, 26)`, only effective when `tiled=True`, must be less than or equal to `tile_size`.
|
||||
* `switch_DiT_boundary`: Time boundary for switching DiT models, default value is 0.875.
|
||||
* `sigma_shift`: Timestep offset parameter, default value is 5.0.
|
||||
* `sliding_window_size`: Sliding window size.
|
||||
* `sliding_window_stride`: Sliding window stride.
|
||||
* `tea_cache_l1_thresh`: L1 threshold for TeaCache.
|
||||
* `tea_cache_model_id`: Model ID used by TeaCache.
|
||||
* `progress_bar_cmd`: Progress bar, default is `tqdm.tqdm`. Can be disabled by setting to `lambda x:x`.
|
||||
|
||||
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
|
||||
|
||||
## Model Training
|
||||
|
||||
Wan series models are uniformly trained through [`examples/wanvideo/model_training/train.py`](/examples/wanvideo/model_training/train.py), and the script parameters include:
|
||||
|
||||
* General Training Parameters
|
||||
* Dataset Basic Configuration
|
||||
* `--dataset_base_path`: Root directory of the dataset.
|
||||
* `--dataset_metadata_path`: Metadata file path of the dataset.
|
||||
* `--dataset_repeat`: Number of times the dataset is repeated in each epoch.
|
||||
* `--dataset_num_workers`: Number of processes for each DataLoader.
|
||||
* `--data_file_keys`: Field names to be loaded from metadata, usually image or video file paths, separated by `,`.
|
||||
* Model Loading Configuration
|
||||
* `--model_paths`: Paths of models to be loaded. JSON format.
|
||||
* `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `"Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors"`. Separated by commas.
|
||||
* `--extra_inputs`: Extra input parameters required by the model Pipeline, e.g., extra parameters when training image editing models, separated by `,`.
|
||||
* `--fp8_models`: Models loaded in FP8 format, consistent with `--model_paths` or `--model_id_with_origin_paths` format. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA).
|
||||
* Training Basic Configuration
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
|
||||
* `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training.
|
||||
* `--weight_decay`: Weight decay size, see [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).
|
||||
* `--task`: Training task, default is `sft`. Some models support more training modes, please refer to the documentation of each specific model.
|
||||
* Output Configuration
|
||||
* `--output_path`: Model saving path.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in the state dict of the model file.
|
||||
* `--save_steps`: Interval of training steps to save the model. If this parameter is left blank, the model is saved once per epoch.
|
||||
* LoRA Configuration
|
||||
* `--lora_base_model`: Which model to add LoRA to.
|
||||
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||
* `--lora_rank`: Rank of LoRA.
|
||||
* `--lora_checkpoint`: Path of the LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint.
|
||||
* `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training.
|
||||
* `--preset_lora_model`: Model that the preset LoRA is merged into, e.g., `dit`.
|
||||
* Gradient Configuration
|
||||
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory.
|
||||
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||
* Video Width/Height Configuration
|
||||
* `--height`: Height of the video. Leave `height` and `width` blank to enable dynamic resolution.
|
||||
* `--width`: Width of the video. Leave `height` and `width` blank to enable dynamic resolution.
|
||||
* `--max_pixels`: Maximum pixel area of video frames. When dynamic resolution is enabled, video frames with resolution larger than this value will be downscaled, and video frames with resolution smaller than this value will remain unchanged.
|
||||
* `--num_frames`: Number of frames in the video.
|
||||
* Wan Series Specific Parameters
|
||||
* `--tokenizer_path`: Path of the tokenizer, applicable to text-to-video models, leave blank to automatically download from remote.
|
||||
* `--audio_processor_path`: Path of the audio processor, applicable to speech-to-video models, leave blank to automatically download from remote.
|
||||
|
||||
We have built a sample video dataset for your testing. You can download this dataset with the following command:
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_dataset
|
||||
```
|
||||
|
||||
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).
|
||||
141
docs/en/Model_Details/Z-Image.md
Normal file
141
docs/en/Model_Details/Z-Image.md
Normal file
@@ -0,0 +1,141 @@
|
||||
# Z-Image
|
||||
|
||||
Z-Image is an image generation model trained and open-sourced by the Multimodal Interaction Team of Alibaba Tongyi Lab.
|
||||
|
||||
## Installation
|
||||
|
||||
Before using this project for model inference and training, please install DiffSynth-Studio first.
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
For more information about installation, please refer to [Install Dependencies](/docs/en/Pipeline_Usage/Setup.md).
|
||||
|
||||
## Quick Start
|
||||
|
||||
Run the following code to quickly load the [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) model and perform inference. FP8 precision quantization causes noticeable image quality degradation, so it is not recommended to enable any quantization on the Z-Image Turbo model. Only CPU Offload is recommended, minimum 8GB VRAM is required to run.
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = ZImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
|
||||
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## Model Overview
|
||||
|
||||
| Model ID | Inference | Low VRAM Inference | Full Training | Validation After Full Training | LoRA Training | Validation After LoRA Training |
|
||||
| - | - | - | - | - | - | - |
|
||||
| [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) | [code](/examples/z_image/model_inference/Z-Image-Turbo.py) | [code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py) | [code](/examples/z_image/model_training/full/Z-Image-Turbo.sh) | [code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py) | [code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh) | [code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py) |
|
||||
|
||||
Special Training Scripts:
|
||||
|
||||
* Differential LoRA Training: [doc](/docs/en/Training/Differential_LoRA.md), [code](/examples/z_image/model_training/special/differential_training/)
|
||||
* Trajectory Imitation Distillation Training (Experimental Feature): [code](/examples/z_image/model_training/special/trajectory_imitation/)
|
||||
|
||||
## Model Inference
|
||||
|
||||
Models are loaded via `ZImagePipeline.from_pretrained`, see [Loading Models](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models).
|
||||
|
||||
Input parameters for `ZImagePipeline` inference include:
|
||||
|
||||
* `prompt`: Prompt describing the content appearing in the image.
|
||||
* `negative_prompt`: Negative prompt describing content that should not appear in the image, default value is `""`.
|
||||
* `cfg_scale`: Classifier-free guidance parameter, default value is 1.
|
||||
* `input_image`: Input image for image-to-image generation, used in conjunction with `denoising_strength`.
|
||||
* `denoising_strength`: Denoising strength, range is 0~1, default value is 1. When the value approaches 0, the generated image is similar to the input image; when the value approaches 1, the generated image differs more from the input image. When `input_image` parameter is not provided, do not set this to a non-1 value.
|
||||
* `height`: Image height, must be a multiple of 16.
|
||||
* `width`: Image width, must be a multiple of 16.
|
||||
* `seed`: Random seed. Default is `None`, meaning completely random.
|
||||
* `rand_device`: Computing device for generating random Gaussian noise matrix, default is `"cpu"`. When set to `cuda`, different GPUs will produce different generation results.
|
||||
* `num_inference_steps`: Number of inference steps, default value is 8.
|
||||
|
||||
If VRAM is insufficient, please enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md). We provide recommended low VRAM configurations for each model in the example code, see the table in the "Model Overview" section above.
|
||||
|
||||
## Model Training
|
||||
|
||||
Z-Image series models are uniformly trained through [`examples/z_image/model_training/train.py`](/examples/z_image/model_training/train.py), and the script parameters include:
|
||||
|
||||
* General Training Parameters
|
||||
* Dataset Basic Configuration
|
||||
* `--dataset_base_path`: Root directory of the dataset.
|
||||
* `--dataset_metadata_path`: Metadata file path of the dataset.
|
||||
* `--dataset_repeat`: Number of times the dataset is repeated in each epoch.
|
||||
* `--dataset_num_workers`: Number of processes for each DataLoader.
|
||||
* `--data_file_keys`: Field names to be loaded from metadata, usually image or video file paths, separated by `,`.
|
||||
* Model Loading Configuration
|
||||
* `--model_paths`: Paths of models to be loaded. JSON format.
|
||||
* `--model_id_with_origin_paths`: Model IDs with original paths, e.g., `"Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors"`. Separated by commas.
|
||||
* `--extra_inputs`: Extra input parameters required by the model Pipeline, e.g., extra parameters when training image editing models, separated by `,`.
|
||||
* `--fp8_models`: Models loaded in FP8 format, consistent with `--model_paths` or `--model_id_with_origin_paths` format. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA).
|
||||
* Training Basic Configuration
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--trainable_models`: Trainable models, e.g., `dit`, `vae`, `text_encoder`.
|
||||
* `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training.
|
||||
* `--weight_decay`: Weight decay size, see [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html).
|
||||
* `--task`: Training task, default is `sft`. Some models support more training modes, please refer to the documentation of each specific model.
|
||||
* Output Configuration
|
||||
* `--output_path`: Model saving path.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in the state dict of the model file.
|
||||
* `--save_steps`: Interval of training steps to save the model. If this parameter is left blank, the model is saved once per epoch.
|
||||
* LoRA Configuration
|
||||
* `--lora_base_model`: Which model to add LoRA to.
|
||||
* `--lora_target_modules`: Which layers to add LoRA to.
|
||||
* `--lora_rank`: Rank of LoRA.
|
||||
* `--lora_checkpoint`: Path of the LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint.
|
||||
* `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training.
|
||||
* `--preset_lora_model`: Model that the preset LoRA is merged into, e.g., `dit`.
|
||||
* Gradient Configuration
|
||||
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory.
|
||||
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||
* Image Width/Height Configuration (Applicable to Image Generation and Video Generation Models)
|
||||
* `--height`: Height of image or video. Leave `height` and `width` blank to enable dynamic resolution.
|
||||
* `--width`: Width of image or video. Leave `height` and `width` blank to enable dynamic resolution.
|
||||
* `--max_pixels`: Maximum pixel area of image or video frames. When dynamic resolution is enabled, images with resolution larger than this value will be downscaled, and images with resolution smaller than this value will remain unchanged.
|
||||
* Z-Image Specific Parameters
|
||||
* `--tokenizer_path`: Path of the tokenizer, applicable to text-to-image models, leave blank to automatically download from remote.
|
||||
|
||||
We have built a sample image dataset for your testing. You can download this dataset with the following command:
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
```
|
||||
|
||||
We have written recommended training scripts for each model, please refer to the table in the "Model Overview" section above. For how to write model training scripts, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md); for more advanced training algorithms, please refer to [Training Framework Detailed Explanation](/docs/Training/).
|
||||
|
||||
Training Tips:
|
||||
|
||||
* [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) is a distilled acceleration model. Therefore, direct training will quickly cause the model to lose its acceleration capability. The effect of inference with "acceleration configuration" (`num_inference_steps=8`, `cfg_scale=1`) becomes worse, while the effect of inference with "no acceleration configuration" (`num_inference_steps=30`, `cfg_scale=2`) becomes better. The following training and inference schemes can be adopted:
|
||||
* Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + No Acceleration Configuration Inference
|
||||
* Differential LoRA Training ([code](/examples/z_image/model_training/special/differential_training/)) + Acceleration Configuration Inference
|
||||
* An additional LoRA needs to be loaded in differential LoRA training, e.g., [ostris/zimage_turbo_training_adapter](https://www.modelscope.cn/models/ostris/zimage_turbo_training_adapter)
|
||||
* Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Trajectory Imitation Distillation Training ([code](/examples/z_image/model_training/special/trajectory_imitation/)) + Acceleration Configuration Inference
|
||||
* Standard SFT Training ([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + Load Distillation Acceleration LoRA During Inference ([model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillFix)) + Acceleration Configuration Inference
|
||||
39
docs/en/Pipeline_Usage/Environment_Variables.md
Normal file
39
docs/en/Pipeline_Usage/Environment_Variables.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# Environment Variables
|
||||
|
||||
`DiffSynth-Studio` can control some settings through environment variables.
|
||||
|
||||
In `Python` code, you can set environment variables using `os.environ`. Please note that environment variables must be set before `import diffsynth`.
|
||||
|
||||
```python
|
||||
import os
|
||||
os.environ["DIFFSYNTH_MODEL_BASE_PATH"] = "./path_to_my_models"
|
||||
import diffsynth
|
||||
```
|
||||
|
||||
On Linux operating systems, you can also temporarily set environment variables from the command line:
|
||||
|
||||
```shell
|
||||
DIFFSYNTH_MODEL_BASE_PATH="./path_to_my_models" python xxx.py
|
||||
```
|
||||
|
||||
Below are the environment variables supported by `DiffSynth-Studio`.
|
||||
|
||||
## `DIFFSYNTH_SKIP_DOWNLOAD`
|
||||
|
||||
Whether to skip model downloads. Can be set to `True`, `true`, `False`, `false`. If `skip_download` is not set in `ModelConfig`, this environment variable will determine whether to skip model downloads.
|
||||
|
||||
## `DIFFSYNTH_MODEL_BASE_PATH`
|
||||
|
||||
Model download root directory. Can be set to any local path. If `local_model_path` is not set in `ModelConfig`, model files will be downloaded to the path pointed to by this environment variable. If neither is set, model files will be downloaded to `./models`.
|
||||
|
||||
## `DIFFSYNTH_ATTENTION_IMPLEMENTATION`
|
||||
|
||||
Attention mechanism implementation method. Can be set to `flash_attention_3`, `flash_attention_2`, `sage_attention`, `xformers`, or `torch`. See [`./core/attention.md`](/docs/en/API_Reference/core/attention.md) for details.
|
||||
|
||||
## `DIFFSYNTH_DISK_MAP_BUFFER_SIZE`
|
||||
|
||||
Buffer size in disk mapping. Default is 1B (1000000000). Larger values occupy more memory but result in faster speeds.
|
||||
|
||||
## `DIFFSYNTH_DOWNLOAD_SOURCE`
|
||||
|
||||
Remote model download source. Can be set to `modelscope` or `huggingface` to control the source of model downloads. Default value is `modelscope`.
|
||||
105
docs/en/Pipeline_Usage/Model_Inference.md
Normal file
105
docs/en/Pipeline_Usage/Model_Inference.md
Normal file
@@ -0,0 +1,105 @@
|
||||
# Model Inference
|
||||
|
||||
This document uses the Qwen-Image model as an example to introduce how to use `DiffSynth-Studio` for model inference.
|
||||
|
||||
## Loading Models
|
||||
|
||||
Models are loaded through `from_pretrained`:
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
```
|
||||
|
||||
Where `torch_dtype` and `device` are computation precision and computation device (not model precision and device). `model_configs` can be configured in multiple ways for model paths. For how models are loaded internally in this project, please refer to [`diffsynth.core.loader`](/docs/en/API_Reference/core/loader.md).
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Download and load models from remote sources</summary>
|
||||
|
||||
> `DiffSynth-Studio` downloads and loads models from [ModelScope](https://www.modelscope.cn/) by default. You need to fill in `model_id` and `origin_file_pattern`, for example:
|
||||
>
|
||||
> ```python
|
||||
> ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
> ```
|
||||
>
|
||||
> Model files are downloaded to the `./models` path by default, which can be modified through [environment variable DIFFSYNTH_MODEL_BASE_PATH](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path).
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Load models from local file paths</summary>
|
||||
|
||||
> Fill in `path`, for example:
|
||||
>
|
||||
> ```python
|
||||
> ModelConfig(path="models/xxx.safetensors")
|
||||
> ```
|
||||
>
|
||||
> For models loaded from multiple files, use a list, for example:
|
||||
>
|
||||
> ```python
|
||||
> ModelConfig(path=[
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors",
|
||||
> ])
|
||||
> ```
|
||||
|
||||
</details>
|
||||
|
||||
By default, even after models have been downloaded, the program will still query remotely for missing files. To completely disable remote requests, set [environment variable DIFFSYNTH_SKIP_DOWNLOAD](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`.
|
||||
|
||||
```shell
|
||||
import os
|
||||
os.environ["DIFFSYNTH_SKIP_DOWNLOAD"] = "True"
|
||||
import diffsynth
|
||||
```
|
||||
|
||||
To download models from [HuggingFace](https://huggingface.co/), set [environment variable DIFFSYNTH_DOWNLOAD_SOURCE](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source) to `huggingface`.
|
||||
|
||||
```shell
|
||||
import os
|
||||
os.environ["DIFFSYNTH_DOWNLOAD_SOURCE"] = "huggingface"
|
||||
import diffsynth
|
||||
```
|
||||
|
||||
## Starting Inference
|
||||
|
||||
Input a prompt to start the inference process and generate an image.
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal."
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
Each model `Pipeline` has different input parameters. Please refer to the documentation for each model.
|
||||
|
||||
If the model parameters are too large, causing insufficient VRAM, please enable [VRAM management](/docs/en/Pipeline_Usage/VRAM_management.md).
|
||||
247
docs/en/Pipeline_Usage/Model_Training.md
Normal file
247
docs/en/Pipeline_Usage/Model_Training.md
Normal file
@@ -0,0 +1,247 @@
|
||||
# Model Training
|
||||
|
||||
This document introduces how to use `DiffSynth-Studio` for model training.
|
||||
|
||||
## Script Parameters
|
||||
|
||||
Training scripts typically include the following parameters:
|
||||
|
||||
* Dataset base configuration
|
||||
* `--dataset_base_path`: Root directory of the dataset.
|
||||
* `--dataset_metadata_path`: Metadata file path of the dataset.
|
||||
* `--dataset_repeat`: Number of times the dataset is repeated in each epoch.
|
||||
* `--dataset_num_workers`: Number of processes for each Dataloader.
|
||||
* `--data_file_keys`: Field names that need to be loaded from metadata, usually image or video file paths, separated by `,`.
|
||||
* Model loading configuration
|
||||
* `--model_paths`: Paths of models to be loaded. JSON format.
|
||||
* `--model_id_with_origin_paths`: Model IDs with original paths, for example `"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors"`. Separated by commas.
|
||||
* `--extra_inputs`: Extra input parameters required by the model Pipeline, for example, training image editing model Qwen-Image-Edit requires extra parameter `edit_image`, separated by `,`.
|
||||
* `--fp8_models`: Models loaded in FP8 format, consistent with the format of `--model_paths` or `--model_id_with_origin_paths`. Currently only supports models whose parameters are not updated by gradients (no gradient backpropagation, or gradients only update their LoRA).
|
||||
* Training base configuration
|
||||
* `--learning_rate`: Learning rate.
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--trainable_models`: Trainable models, for example `dit`, `vae`, `text_encoder`.
|
||||
* `--find_unused_parameters`: Whether there are unused parameters in DDP training. Some models contain redundant parameters that do not participate in gradient calculation, and this setting needs to be enabled to avoid errors in multi-GPU training.
|
||||
* `--weight_decay`: Weight decay size. See [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html) for details.
|
||||
* `--task`: Training task, default is `sft`. Some models support more training modes. Please refer to the documentation for each specific model.
|
||||
* Output configuration
|
||||
* `--output_path`: Model save path.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefixes in the state dict of model files.
|
||||
* `--save_steps`: Interval of training steps for saving models. If this parameter is left blank, the model will be saved once per epoch.
|
||||
* LoRA configuration
|
||||
* `--lora_base_model`: Which model LoRA is added to.
|
||||
* `--lora_target_modules`: Which layers LoRA is added to.
|
||||
* `--lora_rank`: Rank of LoRA.
|
||||
* `--lora_checkpoint`: Path of LoRA checkpoint. If this path is provided, LoRA will be loaded from this checkpoint.
|
||||
* `--preset_lora_path`: Preset LoRA checkpoint path. If this path is provided, this LoRA will be loaded in the form of being merged into the base model. This parameter is used for LoRA differential training.
|
||||
* `--preset_lora_model`: Model that preset LoRA is merged into, for example `dit`.
|
||||
* Gradient configuration
|
||||
* `--use_gradient_checkpointing`: Whether to enable gradient checkpointing.
|
||||
* `--use_gradient_checkpointing_offload`: Whether to offload gradient checkpointing to memory.
|
||||
* `--gradient_accumulation_steps`: Number of gradient accumulation steps.
|
||||
* Image dimension configuration (applicable to image generation models and video generation models)
|
||||
* `--height`: Height of images or videos. Leave `height` and `width` blank to enable dynamic resolution.
|
||||
* `--width`: Width of images or videos. Leave `height` and `width` blank to enable dynamic resolution.
|
||||
* `--max_pixels`: Maximum pixel area of images or video frames. When dynamic resolution is enabled, images with resolution larger than this value will be scaled down, and images with resolution smaller than this value will remain unchanged.
|
||||
|
||||
Some models' training scripts also contain additional parameters. See the documentation for each model for details.
|
||||
|
||||
## Preparing Datasets
|
||||
|
||||
`DiffSynth-Studio` adopts a universal dataset format. The dataset contains a series of data files (images, videos, etc.) and annotated metadata files. We recommend organizing dataset files as follows:
|
||||
|
||||
```
|
||||
data/example_image_dataset/
|
||||
├── metadata.csv
|
||||
├── image_1.jpg
|
||||
└── image_2.jpg
|
||||
```
|
||||
|
||||
Where `image_1.jpg`, `image_2.jpg` are training image data, and `metadata.csv` is the metadata list, for example:
|
||||
|
||||
```
|
||||
image,prompt
|
||||
image_1.jpg,"a dog"
|
||||
image_2.jpg,"a cat"
|
||||
```
|
||||
|
||||
We have built sample datasets for your testing. To understand how the universal dataset architecture is implemented, please refer to [`diffsynth.core.data`](/docs/en/API_Reference/core/data.md).
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Sample Image Dataset</summary>
|
||||
|
||||
> ```shell
|
||||
> modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
> ```
|
||||
>
|
||||
> Applicable to training of image generation models such as Qwen-Image and FLUX.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Sample Video Dataset</summary>
|
||||
|
||||
> ```shell
|
||||
> modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_dataset
|
||||
> ```
|
||||
>
|
||||
> Applicable to training of video generation models such as Wan.
|
||||
|
||||
</details>
|
||||
|
||||
## Loading Models
|
||||
|
||||
Similar to [model loading during inference](/docs/en/Pipeline_Usage/Model_Inference.md#loading-models), we support multiple ways to configure model paths, and the two methods can be mixed.
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Download and load models from remote sources</summary>
|
||||
|
||||
> If we load models during inference through the following settings:
|
||||
>
|
||||
> ```python
|
||||
> model_configs=[
|
||||
> ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
> ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
> ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
> ]
|
||||
> ```
|
||||
>
|
||||
> Then during training, fill in the following parameters to load the corresponding models:
|
||||
>
|
||||
> ```shell
|
||||
> --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors"
|
||||
> ```
|
||||
>
|
||||
> Model files are downloaded to the `./models` path by default, which can be modified through [environment variable DIFFSYNTH_MODEL_BASE_PATH](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path).
|
||||
>
|
||||
> By default, even after models have been downloaded, the program will still query remotely for missing files. To completely disable remote requests, set [environment variable DIFFSYNTH_SKIP_DOWNLOAD](/docs/en/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) to `True`.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Load models from local file paths</summary>
|
||||
|
||||
> If loading models from local files during inference, for example:
|
||||
>
|
||||
> ```python
|
||||
> model_configs=[
|
||||
> ModelConfig([
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors"
|
||||
> ]),
|
||||
> ModelConfig([
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
|
||||
> ]),
|
||||
> ModelConfig("models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors")
|
||||
> ]
|
||||
> ```
|
||||
>
|
||||
> Then during training, set to:
|
||||
>
|
||||
> ```shell
|
||||
> --model_paths '[
|
||||
> [
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors"
|
||||
> ],
|
||||
> [
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
|
||||
> ],
|
||||
> "models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors"
|
||||
> ]' \
|
||||
> ```
|
||||
>
|
||||
> Note that `--model_paths` is in JSON format, and extra `,` cannot appear in it, otherwise it cannot be parsed normally.
|
||||
|
||||
</details>
|
||||
|
||||
## Setting Trainable Modules
|
||||
|
||||
The training framework supports training of any model. Taking Qwen-Image as an example, to fully train the DiT model, set to:
|
||||
|
||||
```shell
|
||||
--trainable_models "dit"
|
||||
```
|
||||
|
||||
To train LoRA of the DiT model, set to:
|
||||
|
||||
```shell
|
||||
--lora_base_model dit --lora_target_modules "to_q,to_k,to_v" --lora_rank 32
|
||||
```
|
||||
|
||||
We hope to leave enough room for technical exploration, so the framework supports training any number of modules simultaneously. For example, to train the text encoder, controlnet, and LoRA of the DiT simultaneously:
|
||||
|
||||
```shell
|
||||
--trainable_models "text_encoder,controlnet" --lora_base_model dit --lora_target_modules "to_q,to_k,to_v" --lora_rank 32
|
||||
```
|
||||
|
||||
Additionally, since the training script loads multiple modules (text encoder, dit, vae, etc.), prefixes need to be removed when saving model files. For example, when fully training the DiT part or training the LoRA model of the DiT part, please set `--remove_prefix_in_ckpt pipe.dit.`. If multiple modules are trained simultaneously, developers need to write code to split the state dict in the model file after training is completed.
|
||||
|
||||
## Starting the Training Program
|
||||
|
||||
The training framework is built on [`accelerate`](https://huggingface.co/docs/accelerate/index). Training commands are written in the following format:
|
||||
|
||||
```shell
|
||||
accelerate launch xxx/train.py \
|
||||
--xxx yyy \
|
||||
--xxxx yyyy
|
||||
```
|
||||
|
||||
We have written preset training scripts for each model. See the documentation for each model for details.
|
||||
|
||||
By default, `accelerate` will train according to the configuration in `~/.cache/huggingface/accelerate/default_config.yaml`. Use `accelerate config` to configure interactively in the terminal, including multi-GPU training, [`DeepSpeed`](https://www.deepspeed.ai/), etc.
|
||||
|
||||
We provide recommended `accelerate` configuration files for some models, which can be set through `--config_file`. For example, full training of the Qwen-Image model:
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
```
|
||||
|
||||
## Training Considerations
|
||||
|
||||
* In addition to the `csv` format, dataset metadata also supports `json` and `jsonl` formats. For how to choose the best metadata format, please refer to [/docs/en/API_Reference/core/data.md#metadata](/docs/en/API_Reference/core/data.md#metadata)
|
||||
* Training effectiveness is usually strongly correlated with training steps and weakly correlated with epoch count. Therefore, we recommend using the `--save_steps` parameter to save model files at training step intervals.
|
||||
* When data volume * `dataset_repeat` exceeds $10^9$, we observed that the dataset speed becomes significantly slower, which seems to be a `PyTorch` bug. We are not sure if newer versions of `PyTorch` have fixed this issue.
|
||||
* For learning rate `--learning_rate`, it is recommended to set to `1e-4` in LoRA training and `1e-5` in full training.
|
||||
* The training framework does not support batch size > 1. The reasons are complex. See [Q&A: Why doesn't the training framework support batch size > 1?](/docs/en/QA.md#why-doesnt-the-training-framework-support-batch-size--1)
|
||||
* Some models contain redundant parameters. For example, the text encoding part of the last layer of Qwen-Image's DiT part. When training these models, `--find_unused_parameters` needs to be set to avoid errors in multi-GPU training. For compatibility with community models, we do not intend to remove these redundant parameters.
|
||||
* The loss function value of Diffusion models has little relationship with actual effects. Therefore, we do not record loss function values during training. We recommend setting `--num_epochs` to a sufficiently large value, testing while training, and manually closing the training program after the effect converges.
|
||||
* `--use_gradient_checkpointing` is usually enabled unless GPU VRAM is sufficient; `--use_gradient_checkpointing_offload` is enabled as needed. See [`diffsynth.core.gradient`](/docs/en/API_Reference/core/gradient.md) for details.
|
||||
21
docs/en/Pipeline_Usage/Setup.md
Normal file
21
docs/en/Pipeline_Usage/Setup.md
Normal file
@@ -0,0 +1,21 @@
|
||||
# Installing Dependencies
|
||||
|
||||
Install from source (recommended):
|
||||
|
||||
```
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Install from PyPI (there may be delays in version updates; for latest features, install from source):
|
||||
|
||||
```
|
||||
pip install diffsynth
|
||||
```
|
||||
|
||||
If you encounter issues during installation, they may be caused by upstream dependency packages. Please refer to the documentation for these packages:
|
||||
|
||||
* [torch](https://pytorch.org/get-started/locally/)
|
||||
* [sentencepiece](https://github.com/google/sentencepiece)
|
||||
* [cmake](https://cmake.org)
|
||||
206
docs/en/Pipeline_Usage/VRAM_management.md
Normal file
206
docs/en/Pipeline_Usage/VRAM_management.md
Normal file
@@ -0,0 +1,206 @@
|
||||
# VRAM Management
|
||||
|
||||
VRAM management is a distinctive feature of `DiffSynth-Studio` that enables GPUs with low VRAM to run inference with large parameter models. This document uses Qwen-Image as an example to introduce how to use the VRAM management solution.
|
||||
|
||||
## Basic Inference
|
||||
|
||||
The following code does not enable any VRAM management, occupying 56G VRAM as a reference.
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal."
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## CPU Offload
|
||||
|
||||
Since the model `Pipeline` consists of multiple components that are not called simultaneously, we can move some components to memory when they are not needed for computation, reducing VRAM usage. The following code implements this logic, occupying 40G VRAM.
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal."
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## FP8 Quantization
|
||||
|
||||
Building upon CPU Offload, we further enable FP8 quantization to reduce VRAM requirements. The following code allows model parameters to be stored in VRAM with FP8 precision and temporarily converted to BF16 precision for computation during inference, occupying 21G VRAM. However, this quantization scheme has minor image quality degradation issues.
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e4m3fn,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal."
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
> Q: Why temporarily convert to BF16 precision during inference instead of computing with FP8 precision?
|
||||
>
|
||||
> A: Native FP8 computation is only supported on Hopper architecture GPUs (such as H20) and has significant computational errors. We currently do not enable FP8 precision computation. The current FP8 quantization only reduces VRAM usage but does not improve computation speed.
|
||||
|
||||
## Dynamic VRAM Management
|
||||
|
||||
In CPU Offload, we control model components. In fact, we support Layer-level Offload, splitting a model into multiple Layers, keeping some resident in VRAM and storing others in memory for on-demand transfer to VRAM for computation. This feature requires model developers to provide detailed VRAM management solutions for each model. Related configurations are in `diffsynth/configs/vram_management_module_maps.py`.
|
||||
|
||||
By adding the `vram_limit` parameter to the `Pipeline`, the framework can automatically sense the remaining VRAM of the device and decide how to split the model between VRAM and memory. The smaller the `vram_limit`, the less VRAM occupied, but slower the speed.
|
||||
* When `vram_limit=None`, the default state, the framework assumes unlimited VRAM and dynamic VRAM management is disabled
|
||||
* When `vram_limit=10`, the framework will limit the model after VRAM usage exceeds 10G, moving the excess parts to memory storage
|
||||
* When `vram_limit=0`, the framework will do its best to reduce VRAM usage, storing all model parameters in memory and transferring them to VRAM for computation only when necessary
|
||||
|
||||
When VRAM is insufficient to run model inference, the framework will attempt to exceed the `vram_limit` restriction to keep the model inference running. Therefore, the VRAM management framework cannot always guarantee that VRAM usage will be less than `vram_limit`. We recommend setting it to slightly less than the actual available VRAM. For example, when GPU VRAM is 16G, set it to `vram_limit=15.5`. In `PyTorch`, you can use `torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3)` to get the GPU's VRAM.
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e4m3fn,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
prompt = "Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal."
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## Disk Offload
|
||||
|
||||
In more extreme cases, when memory is also insufficient to store the entire model, the Disk Offload feature allows lazy loading of model parameters, meaning each Layer of the model only reads the corresponding parameters from disk when the forward function is called. When enabling this feature, we recommend using high-speed SSD drives.
|
||||
|
||||
Disk Offload is a very special VRAM management solution that only supports `.safetensors` format files, not `.bin`, `.pth`, `.ckpt`, or other binary files, and does not support [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape.
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": "disk",
|
||||
"onload_device": "disk",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=10,
|
||||
)
|
||||
prompt = "Exquisite portrait, underwater girl, blue dress flowing, hair floating, translucent light, bubbles surrounding, peaceful face, intricate details, dreamy and ethereal."
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## More Usage Methods
|
||||
|
||||
Information in `vram_config` can be filled in manually, for example, Disk Offload without FP8 quantization:
|
||||
|
||||
```python
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": "disk",
|
||||
"onload_device": "disk",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
```
|
||||
|
||||
Specifically, the VRAM management module divides model Layers into the following four states:
|
||||
|
||||
* Offload: This model will not be called in the short term. This state is controlled by switching `Pipeline`
|
||||
* Onload: This model will be called at any time soon. This state is controlled by switching `Pipeline`
|
||||
* Preparing: Intermediate state between Onload and Computation. A temporary storage state when VRAM allows. This state is controlled by the VRAM management mechanism and enters this state if and only if [vram_limit is set to unlimited] or [vram_limit is set and there is spare VRAM]
|
||||
* Computation: The model is being computed. This state is controlled by the VRAM management mechanism and is temporarily entered only during `forward`
|
||||
|
||||
If you are a model developer and want to control the VRAM management granularity of a specific model, please refer to [../Developer_Guide/Enabling_VRAM_management.md](/docs/en/Developer_Guide/Enabling_VRAM_management.md).
|
||||
|
||||
## Best Practices
|
||||
|
||||
* Sufficient VRAM -> Use [Basic Inference](#basic-inference)
|
||||
* Insufficient VRAM
|
||||
* Sufficient memory -> Use [Dynamic VRAM Management](#dynamic-vram-management)
|
||||
* Insufficient memory -> Use [Disk Offload](#disk-offload)
|
||||
28
docs/en/QA.md
Normal file
28
docs/en/QA.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# Frequently Asked Questions
|
||||
|
||||
## Why doesn't the training framework support batch size > 1?
|
||||
|
||||
* **Larger batch sizes no longer achieve significant acceleration**: Due to acceleration technologies such as flash attention that have fully improved GPU utilization, larger batch sizes will only bring greater VRAM usage without significant acceleration. The experience with small models like Stable Diffusion 1.5 is no longer applicable to the latest large models.
|
||||
* **Larger batch sizes can be achieved through other solutions**: Multi-GPU training and Gradient Accumulation can both mathematically equivalently achieve larger batch sizes.
|
||||
* **Larger batch sizes contradict the framework's general design**: We hope to build a general training framework. Many models cannot accommodate larger batch sizes, such as text encodings of different lengths and images of different resolutions, which cannot be merged into larger batches.
|
||||
|
||||
## Why aren't redundant parameters removed from certain models?
|
||||
|
||||
In some models, redundant parameters exist. For example, in Qwen-Image's DiT model, the text portion of the last layer does not participate in any calculations. This is a minor bug left by the model developers. Setting it as trainable directly will also cause errors in multi-GPU training.
|
||||
|
||||
To maintain compatibility with other models in the open-source community, we have decided to retain these parameters. These redundant parameters can avoid errors in multi-GPU training through the `--find_unused_parameters` parameter.
|
||||
|
||||
## Why does FP8 quantization show no acceleration effect?
|
||||
|
||||
Native FP8 computation relies on Hopper architecture GPUs and has significant precision errors. It is currently immature technology, so this project does not support native FP8 computation.
|
||||
|
||||
FP8 computation in VRAM management refers to storing model parameters in memory or VRAM with FP8 precision and temporarily converting them to other precisions when needed for computation. Therefore, it can only reduce VRAM usage without acceleration effects.
|
||||
|
||||
## Why doesn't the training framework support native FP8 precision training?
|
||||
|
||||
Even with suitable hardware conditions, we currently have no plans to support native FP8 precision training.
|
||||
|
||||
* The main challenge of native FP8 precision training is precision overflow caused by gradient explosion. To ensure training stability, the model structure needs to be redesigned accordingly. However, no model developers are willing to do so at present.
|
||||
* Additionally, models trained with native FP8 precision can only be computed with BF16 precision during inference without Hopper architecture GPUs, theoretically resulting in generation quality inferior to FP8.
|
||||
|
||||
Therefore, native FP8 precision training technology is extremely immature. We will observe the technological developments in the open-source community.
|
||||
88
docs/en/README.md
Normal file
88
docs/en/README.md
Normal file
@@ -0,0 +1,88 @@
|
||||
# DiffSynth-Studio Documentation
|
||||
|
||||
Welcome to the magical world of Diffusion models! `DiffSynth-Studio` is an open-source Diffusion model engine developed and maintained by the [ModelScope Community](https://www.modelscope.cn/). We aim to build a universal Diffusion model framework that fosters technological innovation through framework construction, aggregates the power of the open-source community, and explores the boundaries of generative model technology!
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Documentation Reading Guide</summary>
|
||||
|
||||
```mermaid
|
||||
graph LR;
|
||||
I_want_to_use_models_for_inference_and_training-->sec1[Section 1: Getting Started];
|
||||
I_want_to_use_models_for_inference_and_training-->sec2[Section 2: Model Details];
|
||||
I_want_to_use_models_for_inference_and_training-->sec3[Section 3: Training Framework];
|
||||
I_want_to_develop_based_on_this_framework-->sec3[Section 3: Training Framework];
|
||||
I_want_to_develop_based_on_this_framework-->sec4[Section 4: Model Integration];
|
||||
I_want_to_develop_based_on_this_framework-->sec5[Section 5: API Reference];
|
||||
I_want_to_explore_new_technologies_based_on_this_project-->sec4[Section 4: Model Integration];
|
||||
I_want_to_explore_new_technologies_based_on_this_project-->sec5[Section 5: API Reference];
|
||||
I_want_to_explore_new_technologies_based_on_this_project-->sec6[Section 6: Academic Guide];
|
||||
I_encountered_a_problem-->sec7[Section 7: Frequently Asked Questions];
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Section 1: Getting Started
|
||||
|
||||
This section introduces the basic usage of `DiffSynth-Studio`, including how to enable VRAM management for inference on GPUs with extremely low VRAM, and how to train various base models, LoRAs, ControlNets, and other models.
|
||||
|
||||
* [Installation Dependencies](/docs/en/Pipeline_Usage/Setup.md)
|
||||
* [Model Inference](/docs/en/Pipeline_Usage/Model_Inference.md)
|
||||
* [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md)
|
||||
* [Model Training](/docs/en/Pipeline_Usage/Model_Training.md)
|
||||
* [Environment Variables](/docs/en/Pipeline_Usage/Environment_Variables.md)
|
||||
|
||||
## Section 2: Model Details
|
||||
|
||||
This section introduces the Diffusion models supported by `DiffSynth-Studio`. Some model pipelines feature special functionalities such as controllable generation and parallel acceleration.
|
||||
|
||||
* [FLUX.1](/docs/en/Model_Details/FLUX.md)
|
||||
* [Wan](/docs/en/Model_Details/Wan.md)
|
||||
* [Qwen-Image](/docs/en/Model_Details/Qwen-Image.md)
|
||||
* [FLUX.2](/docs/en/Model_Details/FLUX2.md)
|
||||
* [Z-Image](/docs/en/Model_Details/Z-Image.md)
|
||||
|
||||
## Section 3: Training Framework
|
||||
|
||||
This section introduces the design philosophy of the training framework in `DiffSynth-Studio`, helping developers understand the principles of Diffusion model training algorithms.
|
||||
|
||||
* [Basic Principles of Diffusion Models](/docs/en/Training/Understanding_Diffusion_models.md)
|
||||
* [Standard Supervised Training](/docs/en/Training/Supervised_Fine_Tuning.md)
|
||||
* [Enabling FP8 Precision in Training](/docs/en/Training/FP8_Precision.md)
|
||||
* [End-to-End Distillation Accelerated Training](/docs/en/Training/Direct_Distill.md)
|
||||
* [Two-Stage Split Training](/docs/en/Training/Split_Training.md)
|
||||
* [Differential LoRA Training](/docs/en/Training/Differential_LoRA.md)
|
||||
|
||||
## Section 4: Model Integration
|
||||
|
||||
This section introduces how to integrate models into `DiffSynth-Studio` to utilize the framework's basic functions, helping developers provide support for new models in this project or perform inference and training of private models.
|
||||
|
||||
* [Integrating Model Architecture](/docs/en/Developer_Guide/Integrating_Your_Model.md)
|
||||
* [Building a Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md)
|
||||
* [Enabling Fine-Grained VRAM Management](/docs/en/Developer_Guide/Enabling_VRAM_management.md)
|
||||
* [Model Training Integration](/docs/en/Developer_Guide/Training_Diffusion_Models.md)
|
||||
|
||||
## Section 5: API Reference
|
||||
|
||||
This section introduces the independent core module `diffsynth.core` in `DiffSynth-Studio`, explaining how internal functions are designed and operate. Developers can use these functional modules in other codebase developments if needed.
|
||||
|
||||
* [`diffsynth.core.attention`](/docs/en/API_Reference/core/attention.md): Attention mechanism implementation
|
||||
* [`diffsynth.core.data`](/docs/en/API_Reference/core/data.md): Data processing operators and general datasets
|
||||
* [`diffsynth.core.gradient`](/docs/en/API_Reference/core/gradient.md): Gradient checkpointing
|
||||
* [`diffsynth.core.loader`](/docs/en/API_Reference/core/loader.md): Model download and loading
|
||||
* [`diffsynth.core.vram`](/docs/en/API_Reference/core/vram.md): VRAM management
|
||||
|
||||
## Section 6: Academic Guide
|
||||
|
||||
This section introduces how to use `DiffSynth-Studio` to train new models, helping researchers explore new model technologies.
|
||||
|
||||
* Training models from scratch 【coming soon】
|
||||
* Inference improvement techniques 【coming soon】
|
||||
* Designing controllable generation models 【coming soon】
|
||||
* Creating new training paradigms 【coming soon】
|
||||
|
||||
## Section 7: Frequently Asked Questions
|
||||
|
||||
This section summarizes common developer questions. If you encounter issues during usage or development, please refer to this section. If you still cannot resolve the problem, please submit an issue on GitHub.
|
||||
|
||||
* [Frequently Asked Questions](/docs/en/QA.md)
|
||||
38
docs/en/Training/Differential_LoRA.md
Normal file
38
docs/en/Training/Differential_LoRA.md
Normal file
@@ -0,0 +1,38 @@
|
||||
# Differential LoRA Training
|
||||
|
||||
Differential LoRA training is a special form of LoRA training designed to enable models to learn differences between images.
|
||||
|
||||
## Training Approach
|
||||
|
||||
We were unable to identify the original proposer of differential LoRA training, as this technique has been circulating in the open-source community for a long time.
|
||||
|
||||
Assume we have two similar-content images: Image 1 and Image 2. For example, both images contain a car, but Image 1 has fewer details while Image 2 has more details. In differential LoRA training, we perform two-step training:
|
||||
|
||||
* Train LoRA 1 using Image 1 as training data with [standard supervised training](/docs/en/Training/Supervised_Fine_Tuning.md)
|
||||
* Train LoRA 2 using Image 2 as training data, after integrating LoRA 1 into the base model, with [standard supervised training](/docs/en/Training/Supervised_Fine_Tuning.md)
|
||||
|
||||
In the first training step, since there is only one training image, the LoRA model easily overfits. Therefore, after training, LoRA 1 will cause the model to generate Image 1 without hesitation, regardless of the random seed. In the second training step, the LoRA model overfits again. Thus, after training, with the combined effect of LoRA 1 and LoRA 2, the model will generate Image 2 without hesitation. In short:
|
||||
|
||||
* LoRA 1 = Generate Image 1
|
||||
* LoRA 1 + LoRA 2 = Generate Image 2
|
||||
|
||||
At this point, discarding LoRA 1 and using only LoRA 2, the model will understand the difference between Image 1 and Image 2, making the generated content tend toward "less like Image 1, more like Image 2."
|
||||
|
||||
Single training data can ensure the model overfits to the training data, but lacks stability. To improve stability, we can train with multiple image pairs and average the trained LoRA 2 models to obtain a more stable LoRA.
|
||||
|
||||
Using this training approach, some functionally unique LoRA models can be trained. For example, using ugly and beautiful image pairs to train LoRAs that enhance image aesthetics; using low-detail and high-detail image pairs to train LoRAs that increase image detail.
|
||||
|
||||
## Model Effects
|
||||
|
||||
We have trained several aesthetic enhancement LoRAs using differential LoRA training techniques. You can visit the corresponding model pages to view the generation effects.
|
||||
|
||||
* [DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1)
|
||||
* [DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)
|
||||
|
||||
## Using Differential LoRA Training in the Training Framework
|
||||
|
||||
The first step of training is identical to ordinary LoRA training. In the second step's training command, fill in the path of the first step's LoRA model file through the `--preset_lora_path` parameter, and set `--preset_lora_model` to the same parameters as `lora_base_model` to load LoRA 1 into the base model.
|
||||
|
||||
## Framework Design Concept
|
||||
|
||||
In the training framework, the model pointed to by `--preset_lora_path` is loaded in the `switch_pipe_to_training_mode` of `DiffusionTrainingModule`.
|
||||
97
docs/en/Training/Direct_Distill.md
Normal file
97
docs/en/Training/Direct_Distill.md
Normal file
@@ -0,0 +1,97 @@
|
||||
# End-to-End Distillation Accelerated Training
|
||||
|
||||
## Distillation Accelerated Training
|
||||
|
||||
The inference process of Diffusion models typically requires multi-step iterations, which improves generation quality but also makes the generation process slow. Through distillation accelerated training, the number of steps required to generate clear content can be reduced. The essence of distillation accelerated training technology is to align the generation effects of a small number of steps with those of a large number of steps.
|
||||
|
||||
There are diverse methods for distillation accelerated training, such as:
|
||||
|
||||
* Adversarial training ADD (Adversarial Diffusion Distillation)
|
||||
* Paper: https://arxiv.org/abs/2311.17042
|
||||
* Model: [stabilityai/sdxl-turbo](https://modelscope.cn/models/stabilityai/sdxl-turbo)
|
||||
* Progressive training Hyper-SD
|
||||
* Paper: https://arxiv.org/abs/2404.13686
|
||||
* Model: [ByteDance/Hyper-SD](https://www.modelscope.cn/models/ByteDance/Hyper-SD)
|
||||
|
||||
## Direct Distillation
|
||||
|
||||
At the framework level, supporting these distillation accelerated training schemes is extremely difficult. In the design of the training framework, we need to ensure that the training scheme meets the following conditions:
|
||||
|
||||
* Generality: The training scheme applies to most Diffusion models supported within the framework, rather than only working for a specific model, which is a basic requirement for code framework construction.
|
||||
* Stability: The training scheme must ensure stable training effects without requiring manual fine-tuning of parameters. Adversarial training in ADD cannot guarantee stability.
|
||||
* Simplicity: The training scheme does not introduce additional complex modules. According to Occam's Razor principle, complex solutions may introduce potential risks. The Human Feedback Learning in Hyper-SD makes the training process overly complex.
|
||||
|
||||
Therefore, in the training framework of `DiffSynth-Studio`, we designed an end-to-end distillation accelerated training scheme, which we call Direct Distillation. The pseudocode for the training process is as follows:
|
||||
|
||||
```
|
||||
seed = xxx
|
||||
with torch.no_grad():
|
||||
image_1 = pipe(prompt, steps=50, seed=seed, cfg=4)
|
||||
image_2 = pipe(prompt, steps=4, seed=seed, cfg=1)
|
||||
loss = torch.nn.functional.mse_loss(image_1, image_2)
|
||||
```
|
||||
|
||||
Yes, it's a very end-to-end training scheme that produces immediate results with minimal training.
|
||||
|
||||
## Models Trained with Direct Distillation
|
||||
|
||||
We trained two models based on Qwen-Image using this scheme:
|
||||
|
||||
* [DiffSynth-Studio/Qwen-Image-Distill-Full](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full): Full distillation training
|
||||
* [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA): LoRA distillation training
|
||||
|
||||
Click on the model links to go to the model pages and view the model effects.
|
||||
|
||||
## Using Distillation Accelerated Training in the Training Framework
|
||||
|
||||
First, you need to generate training data. Please refer to the [Model Inference](/docs/en/Pipeline_Usage/Model_Inference.md) section to write inference code and generate training data with a sufficient number of inference steps.
|
||||
|
||||
Taking Qwen-Image as an example, the following code can generate an image:
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
Then, we compile the necessary information into [metadata files](/docs/en/API_Reference/core/data.md#metadata):
|
||||
|
||||
```csv
|
||||
image,prompt,seed,rand_device,num_inference_steps,cfg_scale
|
||||
distill_qwen/image.jpg,"精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。",0,cpu,4,1
|
||||
```
|
||||
|
||||
This sample dataset can be downloaded directly:
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
```
|
||||
|
||||
Then start LoRA distillation accelerated training:
|
||||
|
||||
```shell
|
||||
bash examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh
|
||||
```
|
||||
|
||||
Please note that in the [training script parameters](/docs/en/Pipeline_Usage/Model_Training.md#script-parameters), the image resolution setting for the dataset should avoid triggering scaling processing. When setting `--height` and `--width` to enable fixed resolution, all training data must be generated with exactly the same width and height. When setting `--max_pixels` to enable dynamic resolution, the value of `--max_pixels` must be greater than or equal to the pixel area of any training image.
|
||||
|
||||
## Framework Design Concept
|
||||
|
||||
Compared to [Standard Supervised Training](/docs/en/Training/Supervised_Fine_Tuning.md), Direct Distillation only differs in the training loss function. The loss function for Direct Distillation is `DirectDistillLoss` in `diffsynth.diffusion.loss`.
|
||||
|
||||
## Future Work
|
||||
|
||||
Direct Distillation is a highly general acceleration scheme, but it may not be the best-performing scheme. Therefore, we have not yet published this technology in paper form. We hope to leave this problem to the academic and open-source communities to solve together, and we look forward to developers providing more complete general training schemes.
|
||||
20
docs/en/Training/FP8_Precision.md
Normal file
20
docs/en/Training/FP8_Precision.md
Normal file
@@ -0,0 +1,20 @@
|
||||
# Enabling FP8 Precision in Training
|
||||
|
||||
Although `DiffSynth-Studio` supports [VRAM management](/docs/en/Pipeline_Usage/VRAM_management.md) in model inference, most of the techniques for reducing VRAM usage are not suitable for training. Offloading would cause extremely slow training processes.
|
||||
|
||||
FP8 precision is the only VRAM management strategy that can be enabled during training. However, this framework currently does not support native FP8 precision training. For reasons, see [Q&A: Why doesn't the training framework support native FP8 precision training?](/docs/en/QA.md#why-doesnt-the-training-framework-support-native-fp8-precision-training). It only supports storing models whose parameters are not updated by gradients (models that do not require gradient backpropagation, or whose gradients only update their LoRA) in FP8 precision.
|
||||
|
||||
## Enabling FP8
|
||||
|
||||
In our provided training scripts, you can quickly set models to be stored in FP8 precision through the `--fp8_models` parameter. Taking Qwen-Image LoRA training as an example, we provide a script for enabling FP8 training located at [`/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh`](/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh). After training is completed, you can verify the training results with the script [`/examples/qwen_image/model_training/special/fp8_training/validate.py`](/examples/qwen_image/model_training/special/fp8_training/validate.py).
|
||||
|
||||
Please note that this FP8 VRAM management strategy does not support gradient updates. When a model is set to be trainable, FP8 precision cannot be enabled for that model. Models that support FP8 include two types:
|
||||
|
||||
* Parameters are not trainable, such as VAE models
|
||||
* Gradients do not update their parameters, such as DiT models in LoRA training
|
||||
|
||||
Experimental verification shows that LoRA training with FP8 enabled does not cause significant image quality degradation. However, theoretical errors do exist. If you encounter training results inferior to BF16 precision training when using this feature, please provide feedback through GitHub issues.
|
||||
|
||||
## Training Framework Design Concept
|
||||
|
||||
The training framework completely reuses the inference VRAM management, and only parses VRAM management configurations through `parse_model_configs` in `DiffusionTrainingModule` during training.
|
||||
97
docs/en/Training/Split_Training.md
Normal file
97
docs/en/Training/Split_Training.md
Normal file
@@ -0,0 +1,97 @@
|
||||
# Two-Stage Split Training
|
||||
|
||||
This document introduces split training, which can automatically divide the training process into two stages, reducing VRAM usage while accelerating training speed.
|
||||
|
||||
(Split training is an experimental feature that has not yet undergone large-scale validation. If you encounter any issues while using it, please submit an issue on GitHub.)
|
||||
|
||||
## Split Training
|
||||
|
||||
In the training process of most models, a large amount of computation occurs in "preprocessing," i.e., "computations unrelated to the denoising model," including VAE encoding, text encoding, etc. When the corresponding model parameters are fixed, the results of these computations are repetitive. For each data sample, the computational results are identical across multiple epochs. Therefore, we provide a "split training" feature that can automatically analyze and split the training process.
|
||||
|
||||
For standard supervised training of ordinary text-to-image models, the splitting process is straightforward. It only requires splitting the computation of all [`Pipeline Units`](/docs/en/Developer_Guide/Building_a_Pipeline.md#units) into the first stage, storing the computational results to disk, and then reading these results from disk in the second stage for subsequent computations. However, if gradient backpropagation is required during preprocessing, the situation becomes extremely complex. To address this, we introduced a computational graph splitting algorithm to analyze how to split the computation.
|
||||
|
||||
## Computational Graph Splitting Algorithm
|
||||
|
||||
> (We will supplement the detailed specifics of the computational graph splitting algorithm in future document updates)
|
||||
|
||||
## Using Split Training
|
||||
|
||||
Split training already supports [Standard Supervised Training](/docs/en/Training/Supervised_Fine_Tuning.md) and [Direct Distillation Training](/docs/en/Training/Direct_Distill.md). The `--task` parameter in the training command controls this. Taking LoRA training of the Qwen-Image model as an example, the pre-split training command is:
|
||||
|
||||
```shell
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--find_unused_parameters
|
||||
```
|
||||
|
||||
After splitting, in the first stage, make the following modifications:
|
||||
|
||||
* Change `--dataset_repeat` to 1 to avoid redundant computation
|
||||
* Change `--output_path` to the path where the first-stage computation results are saved
|
||||
* Add the additional parameter `--task "sft:data_process"`
|
||||
* Remove the DiT model from `--model_id_with_origin_paths`
|
||||
|
||||
```shell
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 1 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-LoRA-splited-cache" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--find_unused_parameters \
|
||||
--task "sft:data_process"
|
||||
```
|
||||
|
||||
In the second stage, make the following modifications:
|
||||
|
||||
* Change `--dataset_base_path` to the `--output_path` of the first stage
|
||||
* Remove `--dataset_metadata_path`
|
||||
* Add the additional parameter `--task "sft:train"`
|
||||
* Remove the Text Encoder and VAE models from `--model_id_with_origin_paths`
|
||||
|
||||
```shell
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path "./models/train/Qwen-Image-LoRA-splited-cache" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-LoRA-splited" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--find_unused_parameters \
|
||||
--task "sft:train"
|
||||
```
|
||||
|
||||
We provide sample training scripts and validation scripts located at `examples/qwen_image/model_training/special/split_training`.
|
||||
|
||||
## Training Framework Design Concept
|
||||
|
||||
The training framework splits the computational units in the `Pipeline` through the `split_pipeline_units` method of `DiffusionTrainingModule`.
|
||||
129
docs/en/Training/Supervised_Fine_Tuning.md
Normal file
129
docs/en/Training/Supervised_Fine_Tuning.md
Normal file
@@ -0,0 +1,129 @@
|
||||
# Standard Supervised Training
|
||||
|
||||
After understanding the [Basic Principles of Diffusion Models](/docs/en/Training/Understanding_Diffusion_models.md), this document introduces how the framework implements Diffusion model training. This document explains the framework's principles to help developers write new training code. If you want to use our provided default training functions, please refer to [Model Training](/docs/en/Pipeline_Usage/Model_Training.md).
|
||||
|
||||
Recalling the model training pseudocode from earlier, when we actually write code, the situation becomes extremely complex. Some models require additional guidance conditions and preprocessing, such as ControlNet; some models require cross-computation with the denoising model, such as VACE; some models require Gradient Checkpointing due to excessive VRAM demands, such as Qwen-Image's DiT.
|
||||
|
||||
To achieve strict consistency between inference and training, we abstractly encapsulate components like `Pipeline`, reusing inference code extensively during training. Please refer to [Integrating Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md) to understand the design of `Pipeline` components. Next, we'll introduce how the training framework utilizes `Pipeline` components to build training algorithms.
|
||||
|
||||
## Framework Design Concept
|
||||
|
||||
The training module is encapsulated on top of the `Pipeline`, inheriting `DiffusionTrainingModule` from `diffsynth.diffusion.training_module`. We need to provide the necessary `__init__` and `forward` methods for the training module. Taking Qwen-Image's LoRA training as an example, we provide a simple script containing only basic training functions in `examples/qwen_image/model_training/special/simple/train.py` to help developers understand the design concept of the training module.
|
||||
|
||||
```python
|
||||
class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
def __init__(self, device):
|
||||
# Initialize models here.
|
||||
pass
|
||||
|
||||
def forward(self, data):
|
||||
# Compute loss here.
|
||||
return loss
|
||||
```
|
||||
|
||||
### `__init__`
|
||||
|
||||
In `__init__`, model initialization is required. First load the model, then switch it to training mode.
|
||||
|
||||
```python
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
# Load the pipeline
|
||||
self.pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device=device,
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
# Switch to training mode
|
||||
self.switch_pipe_to_training_mode(
|
||||
self.pipe,
|
||||
lora_base_model="dit",
|
||||
lora_target_modules="to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj",
|
||||
lora_rank=32,
|
||||
)
|
||||
```
|
||||
|
||||
The logic for loading models is basically consistent with inference, supporting loading models from remote and local paths. See [Model Inference](/docs/en/Pipeline_Usage/Model_Inference.md) for details, but please note not to enable [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md).
|
||||
|
||||
`switch_pipe_to_training_mode` can switch the model to training mode. See `switch_pipe_to_training_mode` for details.
|
||||
|
||||
### `forward`
|
||||
|
||||
In `forward`, the loss function value needs to be calculated. First perform preprocessing, then compute the loss function through the `Pipeline`'s [`model_fn`](/docs/en/Developer_Guide/Building_a_Pipeline.md#model_fn).
|
||||
|
||||
```python
|
||||
def forward(self, data):
|
||||
# Preprocess
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
inputs_nega = {"negative_prompt": ""}
|
||||
inputs_shared = {
|
||||
# Assume you are using this pipeline for inference,
|
||||
# please fill in the input parameters.
|
||||
"input_image": data["image"],
|
||||
"height": data["image"].size[1],
|
||||
"width": data["image"].size[0],
|
||||
# Please do not modify the following parameters
|
||||
# unless you clearly know what this will cause.
|
||||
"cfg_scale": 1,
|
||||
"rand_device": self.pipe.device,
|
||||
"use_gradient_checkpointing": True,
|
||||
"use_gradient_checkpointing_offload": False,
|
||||
}
|
||||
for unit in self.pipe.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||
# Loss
|
||||
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
|
||||
return loss
|
||||
```
|
||||
|
||||
The preprocessing process is consistent with the inference phase. Developers only need to assume they are using the `Pipeline` for inference and fill in the input parameters.
|
||||
|
||||
The loss function calculation reuses `FlowMatchSFTLoss` from `diffsynth.diffusion.loss`.
|
||||
|
||||
### Starting Training
|
||||
|
||||
The training framework requires other modules, including:
|
||||
|
||||
* accelerator: Training launcher provided by `accelerate`, see [`accelerate`](https://huggingface.co/docs/accelerate/index) for details
|
||||
* dataset: Generic dataset, see [`diffsynth.core.data`](/docs/en/API_Reference/core/data.md) for details
|
||||
* model_logger: Model logger, see `diffsynth.diffusion.logger` for details
|
||||
|
||||
```python
|
||||
if __name__ == "__main__":
|
||||
accelerator = accelerate.Accelerator(
|
||||
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=True)],
|
||||
)
|
||||
dataset = UnifiedDataset(
|
||||
base_path="data/example_image_dataset",
|
||||
metadata_path="data/example_image_dataset/metadata.csv",
|
||||
repeat=50,
|
||||
data_file_keys="image",
|
||||
main_data_operator=UnifiedDataset.default_image_operator(
|
||||
base_path="data/example_image_dataset",
|
||||
height=512,
|
||||
width=512,
|
||||
height_division_factor=16,
|
||||
width_division_factor=16,
|
||||
)
|
||||
)
|
||||
model = QwenImageTrainingModule(accelerator.device)
|
||||
model_logger = ModelLogger(
|
||||
output_path="models/toy_model",
|
||||
remove_prefix_in_ckpt="pipe.dit.",
|
||||
)
|
||||
launch_training_task(
|
||||
accelerator, dataset, model, model_logger,
|
||||
learning_rate=1e-5, num_epochs=1,
|
||||
)
|
||||
```
|
||||
|
||||
Assembling all the above code results in `examples/qwen_image/model_training/special/simple/train.py`. Use the following command to start training:
|
||||
|
||||
```
|
||||
accelerate launch examples/qwen_image/model_training/special/simple/train.py
|
||||
```
|
||||
145
docs/en/Training/Understanding_Diffusion_models.md
Normal file
145
docs/en/Training/Understanding_Diffusion_models.md
Normal file
@@ -0,0 +1,145 @@
|
||||
# Basic Principles of Diffusion Models
|
||||
|
||||
This document introduces the basic principles of Diffusion models to help you understand how the training framework is constructed. To make these complex mathematical theories easier for readers to understand, we have reconstructed the theoretical framework of Diffusion models, abandoning complex stochastic differential equations and presenting them in a more concise and understandable form.
|
||||
|
||||
## Introduction
|
||||
|
||||
Diffusion models generate clear images or video content through iterative denoising. We start by explaining the generation process of a data sample $x_0$. Intuitively, in a complete round of denoising, we start from random Gaussian noise $x_T$ and iteratively obtain $x_{T-1}$, $x_{T-2}$, $x_{T-3}$, $\cdots$, gradually reducing the noise content at each step until we finally obtain the noise-free data sample $x_0$.
|
||||
|
||||
(Figure)
|
||||
|
||||
This process is intuitive, but to understand the details, we need to answer several questions:
|
||||
|
||||
* How is the noise content at each step defined?
|
||||
* How is the iterative denoising computation performed?
|
||||
* How to train such Diffusion models?
|
||||
* What is the architecture of modern Diffusion models?
|
||||
* How does this project encapsulate and implement model training?
|
||||
|
||||
## How is the noise content at each step defined?
|
||||
|
||||
In the theoretical system of Diffusion models, the noise content is determined by a series of parameters $\sigma_T$, $\sigma_{T-1}$, $\sigma_{T-2}$, $\cdots$, $\sigma_0$. Where:
|
||||
|
||||
* $\sigma_T=1$, corresponding to $x_T$ as pure Gaussian noise
|
||||
* $\sigma_T>\sigma_{T-1}>\sigma_{T-2}>\cdots>x_0$, the noise content gradually decreases during iteration
|
||||
* $\sigma_0=0$, corresponding to $x_0$ as a data sample without any noise
|
||||
|
||||
As for the intermediate values $\sigma_{T-1}$, $\sigma_{T-2}$, $\cdots$, $\sigma_1$, they are not fixed and only need to satisfy the decreasing condition.
|
||||
|
||||
At an intermediate step, we can directly synthesize noisy data samples $x_t=(1-\sigma_t)x_0+\sigma_t x_T$.
|
||||
|
||||
(Figure)
|
||||
|
||||
## How is the iterative denoising computation performed?
|
||||
|
||||
Before understanding the iterative denoising computation, we need to clarify what the input and output of the denoising model are. We abstract the model as a symbol $\hat \epsilon$, whose input typically consists of three parts:
|
||||
|
||||
* Time step $t$, the model needs to understand which stage of the denoising process it is currently in
|
||||
* Noisy data sample $x_t$, the model needs to understand what data to denoise
|
||||
* Guidance condition $c$, the model needs to understand what kind of data sample to generate through denoising
|
||||
|
||||
Among these, the guidance condition $c$ is a newly introduced parameter that is input by the user. It can be text describing the image content or a sketch outlining the image structure.
|
||||
|
||||
(Figure)
|
||||
|
||||
The model's output $\hat \epsilon(x_t,c,t)$ approximately equals $x_T-x_0$, which is the direction of the entire diffusion process (the reverse process of denoising).
|
||||
|
||||
Next, we analyze the computation occurring in one iteration. At time step $t$, after the model computes an approximation of $x_T-x_0$, we calculate the next $x_{t-1}$:
|
||||
|
||||
$$
|
||||
\begin{aligned}
|
||||
x_{t-1}&=x_t + (\sigma_{t-1} - \sigma_t) \cdot \hat \epsilon(x_t,c,t)\\
|
||||
&\approx x_t + (\sigma_{t-1} - \sigma_t) \cdot (x_T-x_0)\\
|
||||
&=(1-\sigma_t)x_0+\sigma_t x_T + (\sigma_{t-1} - \sigma_t) \cdot (x_T-x_0)\\
|
||||
&=(1-\sigma_{t-1})x_0+\sigma_{t-1}x_T
|
||||
\end{aligned}
|
||||
$$
|
||||
|
||||
Perfect! It perfectly matches the noise content definition at time step $t-1$.
|
||||
|
||||
> (This part might be a bit difficult to understand. Don't worry; it's recommended to skip this part on first reading without affecting the rest of the document.)
|
||||
>
|
||||
> After completing this somewhat complex formula derivation, let's consider a question: why should the model's output approximately equal $x_T-x_0$? Can it be set to other values?
|
||||
>
|
||||
> Actually, Diffusion models rely on two definitions to form a complete theory. From the above formulas, we can extract these two definitions and derive the iterative formula:
|
||||
>
|
||||
> * Data definition: $x_t=(1-\sigma_t)x_0+\sigma_t x_T$
|
||||
> * Model definition: $\hat \epsilon(x_t,c,t)=x_T-x_0$
|
||||
> * Derived iterative formula: $x_{t-1}=x_t + (\sigma_{t-1} - \sigma_t) \cdot \hat \epsilon(x_t,c,t)$
|
||||
>
|
||||
> These three mathematical formulas are complete. For example, in the previous derivation, substituting the data definition and model definition into the iterative formula yields $x_{t-1}$ that matches the data definition.
|
||||
>
|
||||
> These are two definitions built on Flow Matching theory, but Diffusion models can also be implemented with other definitions. For example, early models based on DDPM (Denoising Diffusion Probabilistic Models) have their two definitions and derived iterative formulas as:
|
||||
>
|
||||
> * Data definition: $x_t=\sqrt{\alpha_t}x_0+\sqrt{1-\alpha_t}x_T$
|
||||
> * Model definition: $\hat \epsilon(x_t,c,t)=x_T$
|
||||
> * Derived iterative formula: $x_{t-1}=\sqrt{\alpha_{t-1}}\left(\frac{x_t-\sqrt{1-\alpha_t}\hat \epsilon(x_t,c,t)}{\sqrt{\sigma_t}}\right)+\sqrt{1-\alpha_{t-1}}\hat \epsilon(x_t,c,t)$
|
||||
>
|
||||
> More generally, we describe the derivation process of the iterative formula using matrices. For any data definition and model definition:
|
||||
>
|
||||
> * Data definition: $x_t=C_T(x_0,x_T)^T$
|
||||
> * Model definition: $\hat \epsilon(x_t,c,t)=C_T^{[\epsilon]}(x_0,x_T)^T$
|
||||
> * Derived iterative formula: $x_{t-1}=C_{t-1}(C_t,C_t^{[\epsilon]})^{-T}(x_t,\hat \epsilon(x_t,c,t))^T$
|
||||
>
|
||||
> Where $C_t$ and $C_t^{[\epsilon]}$ are $1\times 2$ coefficient matrices. It's not difficult to see that when constructing the two definitions, the matrix $(C_t,C_t^{[\epsilon]})^T$ must be invertible.
|
||||
>
|
||||
> Although Flow Matching and DDPM have been widely verified by numerous pre-trained models, this doesn't mean they are optimal solutions. We encourage developers to design new Diffusion model theories for better training results.
|
||||
|
||||
## How to train such Diffusion models?
|
||||
|
||||
After understanding the iterative denoising process, we next consider how to train such Diffusion models.
|
||||
|
||||
The training process differs from the generation process. If we retain multi-step iterations during training, the gradient would need to backpropagate through multiple steps, bringing catastrophic time and space complexity. To improve computational efficiency, we randomly select a time step $t$ for training.
|
||||
|
||||
(Figure)
|
||||
|
||||
The following is pseudocode for the training process:
|
||||
|
||||
> Obtain data sample $x_0$ and guidance condition $c$ from the dataset
|
||||
>
|
||||
> Randomly sample time step $t\in(0,T]$
|
||||
>
|
||||
> Randomly sample Gaussian noise $x_T\in \mathcal N(O,I)$
|
||||
>
|
||||
> $x_t=(1-\sigma_t)x_0+\sigma_t x_T$
|
||||
>
|
||||
> $\hat \epsilon(x_t,c,t)$
|
||||
>
|
||||
> Loss function $\mathcal L=||\hat \epsilon(x_t,c,t)-(x_T-x_0)||_2^2$
|
||||
>
|
||||
> Backpropagate gradients and update model parameters
|
||||
|
||||
## What is the architecture of modern Diffusion models?
|
||||
|
||||
From theory to practice, more details need to be filled in. Modern Diffusion model architectures have matured, with mainstream architectures following the "three-stage" architecture proposed by Latent Diffusion, including data encoder-decoder, guidance condition encoder, and denoising model.
|
||||
|
||||
(Figure)
|
||||
|
||||
### Data Encoder-Decoder
|
||||
|
||||
In the previous text, we consistently referred to $x_0$ as a "data sample" rather than an image or video because modern Diffusion models typically don't process images or videos directly. Instead, they use an Encoder-Decoder architecture model, usually a VAE (Variational Auto-Encoders) model, to encode images or videos into Embedding tensors, obtaining $x_0$.
|
||||
|
||||
After data is encoded by the encoder and then decoded by the decoder, the reconstructed content is approximately consistent with the original, with minor errors. So why process on the encoded Embedding tensor instead of directly on images or videos? The main reasons are twofold:
|
||||
|
||||
* Encoding compresses the data simultaneously, reducing computational load during processing.
|
||||
* Encoded data distribution is more similar to Gaussian distribution, making it easier for denoising models to model the data.
|
||||
|
||||
During generation, the encoder part doesn't participate in computation. After iteration completes, the decoder part decodes $x_0$ to obtain clear images or videos. During training, the decoder part doesn't participate in computation; only the encoder is used to compute $x_0$.
|
||||
|
||||
### Guidance Condition Encoder
|
||||
|
||||
User-input guidance conditions $c$ can be complex and diverse, requiring specialized encoder models to process them into Embedding tensors. According to the type of guidance condition, we classify guidance condition encoders into the following categories:
|
||||
|
||||
* Text type, such as CLIP, Qwen-VL
|
||||
* Image type, such as ControlNet, IP-Adapter
|
||||
* Video type, such as VAE
|
||||
|
||||
> The model $\hat \epsilon$ mentioned in the previous text refers to the entirety of all guidance condition encoders and the denoising model. We list guidance condition encoders separately because these models are typically frozen during Diffusion training, and their output values are independent of time step $t$, allowing guidance condition encoder computations to be performed offline.
|
||||
|
||||
### Denoising Model
|
||||
|
||||
The denoising model is the true essence of Diffusion models, with diverse model structures such as UNet and DiT. Model developers can freely innovate on these structures.
|
||||
|
||||
## How does this project encapsulate and implement model training?
|
||||
|
||||
Please read the next document: [Standard Supervised Training](/docs/en/Training/Supervised_Fine_Tuning.md)
|
||||
79
docs/zh/API_Reference/core/attention.md
Normal file
79
docs/zh/API_Reference/core/attention.md
Normal file
@@ -0,0 +1,79 @@
|
||||
# `diffsynth.core.attention`: 注意力机制实现
|
||||
|
||||
`diffsynth.core.attention` 提供了注意力机制实现的路由机制,根据 `Python` 环境中的可用包和[环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation)自动选择高效的注意力机制实现。
|
||||
|
||||
## 注意力机制
|
||||
|
||||
注意力机制是在论文[《Attention Is All You Need》](https://arxiv.org/abs/1706.03762)中提出的模型结构,在原论文中,注意力机制按照如下公式实现:
|
||||
|
||||
$$
|
||||
\text{Attention}(Q, K, V) = \text{Softmax}\left(
|
||||
\frac{QK^T}{\sqrt{d_k}}
|
||||
\right)
|
||||
V.
|
||||
$$
|
||||
|
||||
在 `PyTorch` 中,可以用如下代码实现:
|
||||
```python
|
||||
import torch
|
||||
|
||||
def attention(query, key, value):
|
||||
scale_factor = 1 / query.size(-1)**0.5
|
||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||
return attn_weight @ value
|
||||
|
||||
query = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda")
|
||||
key = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda")
|
||||
value = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda")
|
||||
output_1 = attention(query, key, value)
|
||||
```
|
||||
|
||||
其中 `query`、`key`、`value` 的维度是 $(b, n, s, d)$:
|
||||
* $b$:Batch size
|
||||
* $n$: Attention head 的数量
|
||||
* $s$: 序列长度
|
||||
* $d$: 每个 Attention head 的维数
|
||||
|
||||
这部分计算是不包含任何可训练参数的,现代 transformer 架构的模型会在进行这一计算前后经过 Linear 层,本文讨论的“注意力机制”不包含这些计算,仅包含以上代码的计算。
|
||||
|
||||
## 更高效的实现
|
||||
|
||||
注意到,注意力机制中 Attention Score(公式中的 $\text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)$,代码中的 `attn_weight`)的维度为 $(b, n, s, s)$,其中序列长度 $s$ 通常非常大,这导致计算的时间和空间复杂度达到平方级。以图像生成模型为例,图像的宽度和高度每增加到 2 倍,序列长度增加到 4 倍,计算量和显存需求增加到 16 倍。为了避免高昂的计算成本,需采用更高效的注意力机制实现,包括
|
||||
* Flash Attention 3:[GitHub](https://github.com/Dao-AILab/flash-attention)、[论文](https://arxiv.org/abs/2407.08608)
|
||||
* Flash Attention 2:[GitHub](https://github.com/Dao-AILab/flash-attention)、[论文](https://arxiv.org/abs/2307.08691)
|
||||
* Sage Attention:[GitHub](https://github.com/thu-ml/SageAttention)、[论文](https://arxiv.org/abs/2505.11594)
|
||||
* xFormers:[GitHub](https://github.com/facebookresearch/xformers)、[文档](https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops)
|
||||
* PyTorch:[GitHub](https://github.com/pytorch/pytorch)、[文档](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
|
||||
如需调用除 `PyTorch` 外的其他注意力实现,请按照其 GitHub 页面的指引安装对应的包。`DiffSynth-Studio` 会自动根据 Python 环境中的可用包路由到对应的实现上,也可通过[环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation)控制。
|
||||
|
||||
```python
|
||||
from diffsynth.core.attention import attention_forward
|
||||
import torch
|
||||
|
||||
def attention(query, key, value):
|
||||
scale_factor = 1 / query.size(-1)**0.5
|
||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||
return attn_weight @ value
|
||||
|
||||
query = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda")
|
||||
key = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda")
|
||||
value = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda")
|
||||
output_1 = attention(query, key, value)
|
||||
output_2 = attention_forward(query, key, value)
|
||||
print((output_1 - output_2).abs().mean())
|
||||
```
|
||||
|
||||
请注意,加速的同时会引入误差,但在大多数情况下误差是可以忽略不计的。
|
||||
|
||||
## 开发者导引
|
||||
|
||||
在为 `DiffSynth-Studio` 接入新模型时,开发者可自行决定是否调用 `diffsynth.core.attention` 中的 `attention_forward`,但我们期望模型能够尽可能优先调用这一模块,以便让新的注意力机制实现能够在这些模型上直接生效。
|
||||
|
||||
## 最佳实践
|
||||
|
||||
**在大多数情况下,我们建议直接使用 `PyTorch` 原生的实现,无需安装任何额外的包。** 虽然其他注意力机制实现可以加速,但加速效果是较为有限的,在少数情况下会出现兼容性和精度不足的问题。
|
||||
|
||||
此外,高效的注意力机制实现会逐步集成到 `PyTorch` 中,`PyTorch` 的 `2.9.0` 版本中的 `scaled_dot_product_attention` 已经集成了 Flash Attention 2。我们仍在 `DiffSynth-Studio` 提供这一接口,是为了让一些激进的加速方案能够快速走向应用,尽管它们在稳定性上还需要时间验证。
|
||||
151
docs/zh/API_Reference/core/data.md
Normal file
151
docs/zh/API_Reference/core/data.md
Normal file
@@ -0,0 +1,151 @@
|
||||
# `diffsynth.core.data`: 数据处理算子与通用数据集
|
||||
|
||||
## 数据处理算子
|
||||
|
||||
### 可用数据处理算子
|
||||
|
||||
`diffsynth.core.data` 提供了一系列数据处理算子,用于进行数据处理,包括:
|
||||
|
||||
* 数据格式转换算子
|
||||
* `ToInt`: 转换为 int 格式
|
||||
* `ToFloat`: 转换为 float 格式
|
||||
* `ToStr`: 转换为 str 格式
|
||||
* `ToList`: 转换为列表格式,以列表包裹此数据
|
||||
* `ToAbsolutePath`: 将相对路径转换为绝对路径
|
||||
* 文件加载算子
|
||||
* `LoadImage`: 读取图片文件
|
||||
* `LoadVideo`: 读取视频文件
|
||||
* `LoadAudio`: 读取音频文件
|
||||
* `LoadGIF`: 读取 GIF 文件
|
||||
* `LoadTorchPickle`: 读取由 [`torch.save`](https://docs.pytorch.org/docs/stable/generated/torch.save.html) 保存的二进制文件【该算子可能导致二进制文件中的代码注入攻击,请谨慎使用!】
|
||||
* 媒体文件处理算子
|
||||
* `ImageCropAndResize`: 对图像进行裁剪和拉伸
|
||||
* Meta 算子
|
||||
* `SequencialProcess`: 将序列中的每个数据路由到一个算子
|
||||
* `RouteByExtensionName`: 按照文件扩展名路由到特定算子
|
||||
* `RouteByType`: 按照数据类型路由到特定算子
|
||||
|
||||
### 算子使用
|
||||
|
||||
数据算子之间以 `>>` 符号连接形成数据处理流水线,例如:
|
||||
|
||||
```python
|
||||
from diffsynth.core.data.operators import *
|
||||
|
||||
data = "image.jpg"
|
||||
data_pipeline = ToAbsolutePath(base_path="/data") >> LoadImage() >> ImageCropAndResize(max_pixels=512*512)
|
||||
data = data_pipeline(data)
|
||||
```
|
||||
|
||||
在经过每个算子后,数据被依次处理
|
||||
|
||||
* `ToAbsolutePath(base_path="/data")`: `"/data/image.jpg"`
|
||||
* `LoadImage()`: `<PIL.Image.Image image mode=RGB size=1024x1024 at 0x7F8E7AAEFC10>`
|
||||
* `ImageCropAndResize(max_pixels=512*512)`: `<PIL.Image.Image image mode=RGB size=512x512 at 0x7F8E7A936F20>`
|
||||
|
||||
我们可以组合出功能完备的数据流水线,例如通用数据集的默认视频数据算子为
|
||||
|
||||
```python
|
||||
RouteByType(operator_map=[
|
||||
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
|
||||
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
|
||||
(("gif",), LoadGIF(
|
||||
num_frames, time_division_factor, time_division_remainder,
|
||||
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||
)),
|
||||
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
|
||||
num_frames, time_division_factor, time_division_remainder,
|
||||
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||
)),
|
||||
])),
|
||||
])
|
||||
```
|
||||
|
||||
它包含如下逻辑:
|
||||
|
||||
* 如果是 `str` 类型的数据
|
||||
* 如果是 `"jpg", "jpeg", "png", "webp"` 类型文件
|
||||
* 加载这张图片
|
||||
* 裁剪并缩放到特定分辨率
|
||||
* 打包进列表,视为单帧视频
|
||||
* 如果是 `"gif"` 类型文件
|
||||
* 加载 gif 文件内容
|
||||
* 将每一帧裁剪和缩放到特定分辨率
|
||||
* 如果是 `"mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"` 类型文件
|
||||
* 加载 gif 文件内容
|
||||
* 将每一帧裁剪和缩放到特定分辨率
|
||||
* 如果不是 `str` 类型的数据,报错
|
||||
|
||||
## 通用数据集
|
||||
|
||||
`diffsynth.core.data` 提供了统一的数据集实现,数据集需输入以下参数:
|
||||
|
||||
* `base_path`: 根目录,若数据集中包含图片文件的相对路径,则需填入此字段用于加载这些路径指向的文件
|
||||
* `metadata_path`: 元数据目录,记录所有元数据的文件路径,支持 `csv`、`json`、`jsonl` 格式
|
||||
* `repeat`: 数据重复次数,默认为 1,该参数影响一个 epoch 的训练步数
|
||||
* `data_file_keys`: 需进行加载的数据字段名,例如 `(image, edit_image)`
|
||||
* `main_data_operator`: 主加载算子,需通过数据处理算子组装好数据处理流水线
|
||||
* `special_operator_map`: 特殊算子映射,对需要特殊处理的字段构建的算子映射
|
||||
|
||||
### 元数据
|
||||
|
||||
数据集的 `metadata_path` 指向元数据文件,支持 `csv`、`json`、`jsonl` 格式,以下提供了样例
|
||||
|
||||
* `csv` 格式:可读性高、不支持列表数据、内存占用小
|
||||
|
||||
```csv
|
||||
image,prompt
|
||||
image_1.jpg,"a dog"
|
||||
image_2.jpg,"a cat"
|
||||
```
|
||||
|
||||
* `json` 格式:可读性高、支持列表数据、内存占用大
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"image": "image_1.jpg",
|
||||
"prompt": "a dog"
|
||||
},
|
||||
{
|
||||
"image": "image_2.jpg",
|
||||
"prompt": "a cat"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
* `jsonl` 格式:可读性低、支持列表数据、内存占用小
|
||||
|
||||
```json
|
||||
{"image": "image_1.jpg", "prompt": "a dog"}
|
||||
{"image": "image_2.jpg", "prompt": "a cat"}
|
||||
```
|
||||
|
||||
如何选择最佳的元数据格式?
|
||||
|
||||
* 如果数据量大,达到千万级的数据量,由于 `json` 文件解析时需要额外内存,此时不可用,请使用 `csv` 或 `jsonl` 格式
|
||||
* 如果数据集中包含列表数据,例如编辑模型需输入多张图,由于 `csv` 格式无法存储列表格式数据,此时不可用,请使用 `json` 或 `jsonl` 格式
|
||||
|
||||
### 数据加载逻辑
|
||||
|
||||
在没有进行额外设置时,数据集默认输出元数据集中的数据,图片和视频文件的路径会以字符串的格式输出,若要加载这些文件,则需要设置 `data_file_keys`、`main_data_operator`、`special_operator_map`。
|
||||
|
||||
在数据处理流程中,按如下逻辑进行处理:
|
||||
* 如果字段位于 `special_operator_map`,则调用 `special_operator_map` 中的对应算子进行处理
|
||||
* 如果字段不位于 `special_operator_map`
|
||||
* 如果字段位于 `data_file_keys`,则调用 `main_data_operator` 算子进行处理
|
||||
* 如果字段不位于 `data_file_keys`,则不进行处理
|
||||
|
||||
`special_operator_map` 可用于实现特殊的数据处理,例如模型 [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) 中输入的人物面部视频 `animate_face_video` 是以固定分辨率处理的,与输出视频不一致,因此这一字段由专门的算子处理:
|
||||
|
||||
```python
|
||||
special_operator_map={
|
||||
"animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)),
|
||||
}
|
||||
```
|
||||
|
||||
### 其他注意事项
|
||||
|
||||
当数据量过少时,可适当增加 `repeat`,延长单个 epoch 的训练时间,避免频繁保存模型产生较多耗时。
|
||||
|
||||
当数据量 * `repeat` 超过 $10^9$ 时,我们观测到数据集的速度明显变慢,这似乎是 `PyTorch` 的 bug,我们尚不确定新版本的 `PyTorch` 是否已经修复了这一问题。
|
||||
69
docs/zh/API_Reference/core/gradient.md
Normal file
69
docs/zh/API_Reference/core/gradient.md
Normal file
@@ -0,0 +1,69 @@
|
||||
# `diffsynth.core.gradient`: 梯度检查点及其 Offload
|
||||
|
||||
`diffsynth.core.gradient` 中提供了封装好的梯度检查点及其 Offload 版本,用于模型训练。
|
||||
|
||||
## 梯度检查点
|
||||
|
||||
梯度检查点是用于减少训练时显存占用的技术。我们提供一个例子来帮助你理解这一技术,以下是一个简单的模型结构
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class ToyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.activation = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
return self.activation(x)
|
||||
|
||||
model = ToyModel()
|
||||
x = torch.randn((2, 3))
|
||||
y = model(x)
|
||||
```
|
||||
|
||||
在这个模型结构中,输入的参数 $x$ 经过 Sigmoid 激活函数得到输出值 $y=\frac{1}{1+e^{-x}}$。
|
||||
|
||||
在训练过程中,假定我们的损失函数值为 $\mathcal L$,在梯度反响传播时,我们得到 $\frac{\partial \mathcal L}{\partial y}$,此时我们需计算 $\frac{\partial \mathcal L}{\partial x}$,不难发现 $\frac{\partial y}{\partial x}=y(1-y)$,进而有 $\frac{\partial \mathcal L}{\partial x}=\frac{\partial \mathcal L}{\partial y}\frac{\partial y}{\partial x}=\frac{\partial \mathcal L}{\partial y}y(1-y)$。如果在模型前向传播时保存 $y$ 的数值,并在梯度反向传播时直接计算 $y(1-y)$,这将避免复杂的 exp 计算,加快计算速度,但这会导致我们需要额外的显存来存储中间变量 $y$。
|
||||
|
||||
不启用梯度检查点时,训练框架会默认存储所有辅助梯度计算的中间变量,从而达到最佳的计算速度。开启梯度检查点时,中间变量则不会存储,但输入参数 $x$ 仍会存储,减少显存占用,在梯度反向传播时需重新计算这些变量,减慢计算速度。
|
||||
|
||||
## 启用梯度检查点及其 Offload
|
||||
|
||||
`diffsynth.core.gradient` 中的 `gradient_checkpoint_forward` 实现了梯度检查点及其 Offload,可参考以下代码调用:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.core.gradient import gradient_checkpoint_forward
|
||||
|
||||
class ToyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.activation = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
return self.activation(x)
|
||||
|
||||
model = ToyModel()
|
||||
x = torch.randn((2, 3))
|
||||
y = gradient_checkpoint_forward(
|
||||
model,
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
x=x,
|
||||
)
|
||||
```
|
||||
|
||||
* 当 `use_gradient_checkpointing=False` 且 `use_gradient_checkpointing_offload=False` 时,计算过程与原始计算完全相同,不影响模型的推理和训练,你可以直接将其集成到代码中。
|
||||
* 当 `use_gradient_checkpointing=True` 且 `use_gradient_checkpointing_offload=False` 时,启用梯度检查点。
|
||||
* 当 `use_gradient_checkpointing_offload=True` 时,启用梯度检查点,所有梯度检查点的输入参数存储在内存中,进一步降低显存占用和减慢计算速度。
|
||||
|
||||
## 最佳实践
|
||||
|
||||
> Q: 应当在何处启用梯度检查点?
|
||||
>
|
||||
> A: 对整个模型启用梯度检查点时,计算效率和显存占用并不是最优的,我们需要设置细粒度的梯度检查点,但同时不希望为框架增加过多繁杂的代码。因此我们建议在 `Pipeline` 的 `model_fn` 中实现,例如 `diffsynth/pipelines/qwen_image.py` 中的 `model_fn_qwen_image`,在 Block 层级启用梯度检查点,不需要修改模型结构的任何代码。
|
||||
|
||||
> Q: 什么情况下需要启用梯度检查点?
|
||||
>
|
||||
> A: 随着模型参数量越来越大,梯度检查点已成为必要的训练技术,梯度检查点通常是需要启用的。梯度检查点的 Offload 则仅需在激活值占用显存过大的模型(例如视频生成模型)中启用。
|
||||
141
docs/zh/API_Reference/core/loader.md
Normal file
141
docs/zh/API_Reference/core/loader.md
Normal file
@@ -0,0 +1,141 @@
|
||||
# `diffsynth.core.loader`: 模型下载与加载
|
||||
|
||||
本文档介绍 `diffsynth.core.loader` 中模型下载与加载相关的功能。
|
||||
|
||||
## ModelConfig
|
||||
|
||||
`diffsynth.core.loader` 中的 `ModelConfig` 用于标注模型下载来源、本地路径、显存管理配置等信息。
|
||||
|
||||
### 从远程下载并加载模型
|
||||
|
||||
以模型[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) 为例,在 `ModelConfig` 中填写 `model_id` 和 `origin_file_pattern` 后即可自动下载模型。默认下载到 `./models` 路径,该路径可通过[环境变量 DIFFSYNTH_MODEL_BASE_PATH](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path) 修改。
|
||||
|
||||
默认情况下,即使模型已经下载完毕,程序仍会向远程查询是否有遗漏文件,如果要完全关闭远程请求,请将[环境变量 DIFFSYNTH_SKIP_DOWNLOAD](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) 设置为 `True`。
|
||||
|
||||
```python
|
||||
from diffsynth.core import ModelConfig
|
||||
|
||||
config = ModelConfig(
|
||||
model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny",
|
||||
origin_file_pattern="model.safetensors",
|
||||
)
|
||||
# Download models
|
||||
config.download_if_necessary()
|
||||
print(config.path)
|
||||
```
|
||||
|
||||
调用 `download_if_necessary` 后,模型会自动下载,并将路径返回到 `config.path` 中。
|
||||
|
||||
### 从本地路径加载模型
|
||||
|
||||
如果从本地路径加载模型,则需要填入 `path`:
|
||||
|
||||
```python
|
||||
from diffsynth.core import ModelConfig
|
||||
|
||||
config = ModelConfig(path="models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors")
|
||||
```
|
||||
|
||||
如果模型包含多个分片文件,以列表的形式输入即可:
|
||||
|
||||
```python
|
||||
from diffsynth.core import ModelConfig
|
||||
|
||||
config = ModelConfig(path=[
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
|
||||
])
|
||||
```
|
||||
|
||||
### 显存管理配置
|
||||
|
||||
`ModelConfig` 也包含了显存管理配置信息,详见[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md#更多使用方式)。
|
||||
|
||||
## 模型文件加载
|
||||
|
||||
`diffsynth.core.loader` 提供了统一的 `load_state_dict`,用于加载模型文件中的 state dict。
|
||||
|
||||
加载单个模型文件:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_state_dict
|
||||
|
||||
state_dict = load_state_dict("models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors")
|
||||
```
|
||||
|
||||
加载多个模型文件(合并为一个 state dict):
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_state_dict
|
||||
|
||||
state_dict = load_state_dict([
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
|
||||
])
|
||||
```
|
||||
|
||||
## 模型哈希
|
||||
|
||||
模型哈希是用于判断模型类型的,哈希值可通过 `hash_model_file` 获取:
|
||||
|
||||
```python
|
||||
from diffsynth.core import hash_model_file
|
||||
|
||||
print(hash_model_file("models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors"))
|
||||
```
|
||||
|
||||
也可计算多个模型文件的哈希值,等价于合并 state dict 后计算模型哈希值:
|
||||
|
||||
```python
|
||||
from diffsynth.core import hash_model_file
|
||||
|
||||
print(hash_model_file([
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
|
||||
]))
|
||||
```
|
||||
|
||||
模型哈希值只与模型文件中 state dict 的 keys 和 tensor shape 有关,与模型参数的数值、文件保存时间等信息无关。在计算 `.safetensors` 格式文件的模型哈希值时,`hash_model_file` 是几乎瞬间完成的,无需读取模型的参数;但在计算 `.bin`、`.pth`、`.ckpt` 等二进制文件的模型哈希值时,则需要读取全部模型参数,因此**我们不建议开发者继续使用这些格式的文件。**
|
||||
|
||||
通过[编写模型 Config](/docs/zh/Developer_Guide/Integrating_Your_Model.md#step-3-编写模型-config)并将模型哈希值等信息填入 `diffsynth/configs/model_configs.py`,开发者可以让 `DiffSynth-Studio` 自动识别模型类型并加载。
|
||||
|
||||
## 模型加载
|
||||
|
||||
`load_model` 是 `diffsynth.core.loader` 中加载模型的外部入口,它会调用 [skip_model_initialization](/docs/zh/API_Reference/core/vram.md#跳过模型参数初始化) 跳过模型参数初始化。如果启用了 [Disk Offload](/docs/zh/Pipeline_Usage/VRAM_management.md#disk-offload),则调用 [DiskMap](/docs/zh/API_Reference/core/vram.md#state-dict-硬盘映射) 进行惰性加载;如果没有启用 Disk Offload,则调用 [load_state_dict](#模型文件加载) 加载模型参数。如果需要的话,还会调用 [state dict converter](/docs/zh/Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换) 进行模型格式转换。最后调用 `model.eval()` 将其切换到推理模式。
|
||||
|
||||
以下是一个启用了 Disk Offload 的使用案例:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
|
||||
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
|
||||
import torch
|
||||
|
||||
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
|
||||
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
|
||||
|
||||
model = load_model(
|
||||
QwenImageDiT,
|
||||
model_path,
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
vram_config={
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": "disk",
|
||||
"onload_device": "disk",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
},
|
||||
vram_limit=0,
|
||||
)
|
||||
```
|
||||
66
docs/zh/API_Reference/core/vram.md
Normal file
66
docs/zh/API_Reference/core/vram.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# `diffsynth.core.vram`: 显存管理
|
||||
|
||||
本文档介绍 `diffsynth.core.vram` 中的显存管理底层功能,如果你希望将这些功能用于其他的代码库中,可参考本文档。
|
||||
|
||||
## 跳过模型参数初始化
|
||||
|
||||
在 `PyTorch` 中加载模型时,模型的参数默认会占用显存或内存并进行参数初始化,而这些参数会在加载预训练权重后被覆盖掉,这导致了冗余的计算。`PyTorch` 中没有提供接口来跳过这些冗余的计算,我们在 `diffsynth.core.vram` 中提供了 `skip_model_initialization` 用于跳过模型参数初始化。
|
||||
|
||||
默认的模型加载方式:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_state_dict
|
||||
from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet
|
||||
|
||||
model = QwenImageBlockWiseControlNet() # Slow
|
||||
path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors"
|
||||
state_dict = load_state_dict(path, device="cpu")
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
```
|
||||
|
||||
跳过参数初始化的模型加载方式:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_state_dict, skip_model_initialization
|
||||
from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet
|
||||
|
||||
with skip_model_initialization():
|
||||
model = QwenImageBlockWiseControlNet() # Fast
|
||||
path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors"
|
||||
state_dict = load_state_dict(path, device="cpu")
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
```
|
||||
|
||||
在 `DiffSynth-Studio` 中,所有预训练模型都遵循这一加载逻辑。开发者在[接入模型](/docs/zh/Developer_Guide/Integrating_Your_Model.md)完毕后即可直接以这种方式快速加载模型。
|
||||
|
||||
## State Dict 硬盘映射
|
||||
|
||||
对于某个模型的预训练权重文件,如果我们只需要读取其中的一组参数,而非全部参数,State Dict 硬盘映射可以加速这一过程。我们在 `diffsynth.core.vram` 中提供了 `DiskMap` 用于按需加载模型参数。
|
||||
|
||||
默认的权重加载方式:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_state_dict
|
||||
|
||||
path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors"
|
||||
state_dict = load_state_dict(path, device="cpu") # Slow
|
||||
print(state_dict["img_in.weight"])
|
||||
```
|
||||
|
||||
使用 `DiskMap` 只加载特定参数:
|
||||
|
||||
```python
|
||||
from diffsynth.core import DiskMap
|
||||
|
||||
path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors"
|
||||
state_dict = DiskMap(path, device="cpu") # Fast
|
||||
print(state_dict["img_in.weight"])
|
||||
```
|
||||
|
||||
`DiskMap` 是 `DiffSynth-Studio` 中 Disk Offload 的基本组件,开发者在[配置细粒度显存管理方案](/docs/zh/Developer_Guide/Enabling_VRAM_management.md)后即可直接启用 Disk Offload。
|
||||
|
||||
`DiskMap` 是利用 `.safetensors` 文件的特性实现的功能,因此在使用 `.bin`、`.pth`、`.ckpt` 等二进制文件时,模型的参数是全量加载的,这也导致 Disk Offload 不支持这些格式的文件。**我们不建议开发者继续使用这些格式的文件。**
|
||||
|
||||
## 显存管理可替换模块
|
||||
|
||||
在启用 `DiffSynth-Studio` 的显存管理后,模型内部的模块会被替换为 `diffsynth.core.vram.layers` 中的可替换模块,其使用方式详见[细粒度显存管理方案](/docs/zh/Developer_Guide/Enabling_VRAM_management.md#编写细粒度显存管理方案)。
|
||||
250
docs/zh/Developer_Guide/Building_a_Pipeline.md
Normal file
250
docs/zh/Developer_Guide/Building_a_Pipeline.md
Normal file
@@ -0,0 +1,250 @@
|
||||
# 接入 Pipeline
|
||||
|
||||
在[将 Pipeline 所需的模型接入](/docs/zh/Developer_Guide/Integrating_Your_Model.md)之后,还需构建 `Pipeline` 用于模型推理,本文档提供 `Pipeline` 构建的标准化流程,开发者也可参考现有的 `Pipeline` 进行构建。
|
||||
|
||||
`Pipeline` 的实现位于 `diffsynth/pipelines`,每个 `Pipeline` 包含以下必要的关键组件:
|
||||
|
||||
* `__init__`
|
||||
* `from_pretrained`
|
||||
* `__call__`
|
||||
* `units`
|
||||
* `model_fn`
|
||||
|
||||
## `__init__`
|
||||
|
||||
在 `__init__` 中,`Pipeline` 进行初始化,以下是一个简易的实现:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from PIL import Image
|
||||
from typing import Union
|
||||
from tqdm import tqdm
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
from ..models.new_models import XXX_Model, YYY_Model, ZZZ_Model
|
||||
|
||||
class NewDiffSynthPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = FlowMatchScheduler()
|
||||
self.text_encoder: XXX_Model = None
|
||||
self.dit: YYY_Model = None
|
||||
self.vae: ZZZ_Model = None
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
NewDiffSynthPipelineUnit_xxx(),
|
||||
...
|
||||
]
|
||||
self.model_fn = model_fn_new
|
||||
```
|
||||
|
||||
其中包括以下几部分
|
||||
|
||||
* `scheduler`: 调度器,用于控制推理的迭代公式中的系数,控制每一步的噪声含量。
|
||||
* `text_encoder`、`dit`、`vae`: 模型,自 [Latent Diffusion](https://arxiv.org/abs/2112.10752) 被提出以来,这种三段式模型架构已成为主流的 Diffusion 模型架构,但这并不是一成不变的,`Pipeline` 中可添加任意多个模型。
|
||||
* `in_iteration_models`: 迭代中模型,这个元组标注了在迭代中会调用哪些模型。
|
||||
* `units`: 模型迭代的前处理单元,详见[`units`](#units)。
|
||||
* `model_fn`: 迭代中去噪模型的 `forward` 函数,详见[`model_fn`](#model_fn)。
|
||||
|
||||
> Q: 模型加载并不发生在 `__init__`,为什么这里仍要将每个模型初始化为 `None`?
|
||||
>
|
||||
> A: 在这里标注每个模型的类型后,代码编辑器就可以根据每个模型提供代码补全提示,便于后续的开发。
|
||||
|
||||
## `from_pretrained`
|
||||
|
||||
`from_pretrained` 负责加载所需的模型,让 `Pipeline` 变成可调用的状态。以下是一个简易的实现:
|
||||
|
||||
```python
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
vram_limit: float = None,
|
||||
):
|
||||
# Initialize pipeline
|
||||
pipe = NewDiffSynthPipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("xxx_text_encoder")
|
||||
pipe.dit = model_pool.fetch_model("yyy_dit")
|
||||
pipe.vae = model_pool.fetch_model("zzz_vae")
|
||||
# If necessary, load tokenizers here.
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
```
|
||||
|
||||
开发者需要实现其中获取模型的逻辑,对应的模型名称即为[模型接入时填写的模型 Config](/docs/zh/Developer_Guide/Integrating_Your_Model.md#step-3-编写模型-config) 中的 `"model_name"`。
|
||||
|
||||
部分模型还需要加载 `tokenizer`,可根据需要在 `from_pretrained` 上添加额外的 `tokenizer_config` 参数并在获取模型后实现这部分。
|
||||
|
||||
## `__call__`
|
||||
|
||||
`__call__` 实现了整个 Pipeline 的生成过程,以下是常见的生成过程模板,开发者可根据需要在此基础上修改。
|
||||
|
||||
```python
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 4.0,
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
height: int = 1328,
|
||||
width: int = 1328,
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
num_inference_steps: int = 30,
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(
|
||||
num_inference_steps,
|
||||
denoising_strength=denoising_strength
|
||||
)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
}
|
||||
inputs_nega = {
|
||||
"negative_prompt": negative_prompt,
|
||||
}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"input_image": input_image,
|
||||
"denoising_strength": denoising_strength,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"seed": seed,
|
||||
"rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep, progress_id=progress_id)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep, progress_id=progress_id)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# Scheduler
|
||||
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
image = self.vae.decode(inputs_shared["latents"], device=self.device)
|
||||
image = self.vae_output_to_image(image)
|
||||
self.load_models_to_device([])
|
||||
|
||||
return image
|
||||
```
|
||||
|
||||
## `units`
|
||||
|
||||
`units` 包含了所有的前处理过程,例如:宽高检查、提示词编码、初始噪声生成等。在整个模型前处理过程中,数据被抽象为了互斥的三部分,分别存储在对应的字典中:
|
||||
|
||||
* `inputs_shard`: 共享输入,与 [Classifier-Free Guidance](https://arxiv.org/abs/2207.12598)(简称 CFG)无关的参数。
|
||||
* `inputs_posi`: Classifier-Free Guidance 的 Positive 侧输入,包含与正向提示词相关的内容。
|
||||
* `inputs_nega`: Classifier-Free Guidance 的 Negative 侧输入,包含与负向提示词相关的内容。
|
||||
|
||||
Pipeline Unit 的实现包括三种:直接模式、CFG 分离模式、接管模式。
|
||||
|
||||
如果某些计算与 CFG 无关,可采用直接模式,例如 Qwen-Image 的随机噪声初始化:
|
||||
|
||||
```python
|
||||
class QwenImageUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "seed", "rand_device"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device):
|
||||
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
return {"noise": noise}
|
||||
```
|
||||
|
||||
如果某些计算与 CFG 有关,需分别处理正向和负向提示词,但两侧的输入参数是相同的,可采用 CFG 分离模式,例如 Qwen-image 的提示词编码:
|
||||
|
||||
```python
|
||||
class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
input_params=("edit_image",),
|
||||
output_params=("prompt_emb", "prompt_emb_mask"),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
# Do something
|
||||
return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask}
|
||||
```
|
||||
|
||||
如果某些计算需要全局的信息,则需要接管模式,例如 Qwen-Image 的实体分区控制:
|
||||
|
||||
```python
|
||||
class QwenImageUnit_EntityControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
input_params=("eligen_entity_prompts", "width", "height", "eligen_enable_on_negative", "cfg_scale"),
|
||||
output_params=("entity_prompt_emb", "entity_masks", "entity_prompt_emb_mask"),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
# Do something
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
```
|
||||
|
||||
以下是 Pipeline Unit 所需的参数配置:
|
||||
|
||||
* `seperate_cfg`: 是否启用 CFG 分离模式
|
||||
* `take_over`: 是否启用接管模式
|
||||
* `input_params`: 共享输入参数
|
||||
* `output_params`: 输出参数
|
||||
* `input_params_posi`: Positive 侧输入参数
|
||||
* `input_params_nega`: Negative 侧输入参数
|
||||
* `onload_model_names`: 需调用的模型组件名
|
||||
|
||||
在设计 `unit` 时请尽量按照以下原则进行:
|
||||
|
||||
* 缺省兜底:可选功能的 `unit` 输入参数默认为 `None`,而不是 `False` 或其他数值,请对此默认值进行兜底处理。
|
||||
* 参数触发:部分 Adapter 模型可能是未被加载的,例如 ControlNet,对应的 `unit` 应当以参数输入是否为 `None` 来控制触发,而不是以模型是否被加载来控制触发。例如当用户输入了 `controlnet_image` 但没有加载 ControlNet 模型时,代码应当给出报错,而不是忽略这些输入参数继续执行。
|
||||
* 简洁优先:尽可能使用直接模式,仅当功能无法实现时,使用接管模式。
|
||||
* 显存高效:在 `unit` 中调用模型时,请使用 `pipe.load_models_to_device(self.onload_model_names)` 激活对应的模型,请不要调用 `onload_model_names` 之外的其他模型,`unit` 计算完成后,请不要使用 `pipe.load_models_to_device([])` 手动释放显存。
|
||||
|
||||
> Q: 部分参数并未在推理过程中调用,例如 `output_params`,是否仍有必要配置?
|
||||
>
|
||||
> A: 这些参数不会影响推理过程,但会影响一些实验性功能,因此我们建议将其配置好。例如“拆分训练”,我们可以将训练中的前处理离线完成,但部分需要梯度回传的模型计算无法拆分,这些参数用于构建计算图从而推断哪些计算是可以拆分的。
|
||||
|
||||
## `model_fn`
|
||||
|
||||
`model_fn` 是迭代中的统一 `forward` 接口,对于开源模型生态尚未形成的模型,直接沿用去噪模型的 `forward` 即可,例如:
|
||||
|
||||
```python
|
||||
def model_fn_new(dit=None, latents=None, timestep=None, prompt_emb=None, **kwargs):
|
||||
return dit(latents, prompt_emb, timestep)
|
||||
```
|
||||
|
||||
对于开源生态丰富的模型,`model_fn` 通常包含复杂且混乱的跨模型推理,以 `diffsynth/pipelines/qwen_image.py` 为例,这个函数中实现的额外计算包括:实体分区控制、三种 ControlNet、Gradient Checkpointing 等,开发者在实现这一部分时要格外小心,避免模块功能之间的冲突。
|
||||
228
docs/zh/Developer_Guide/Enabling_VRAM_management.md
Normal file
228
docs/zh/Developer_Guide/Enabling_VRAM_management.md
Normal file
@@ -0,0 +1,228 @@
|
||||
# 细粒度显存管理方案
|
||||
|
||||
本文档介绍如何为模型编写合理的细粒度显存管理方案,以及如何将 `DiffSynth-Studio` 中的显存管理功能用于外部的其他代码库,在阅读本文档前,请先阅读文档[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)。
|
||||
|
||||
## 20B 模型需要多少显存?
|
||||
|
||||
以 Qwen-Image 的 DiT 模型为例,这一模型的参数量达到了 20B,以下代码会加载这一模型并进行推理,需要约 40G 显存,这个模型在显存较小的消费级 GPU 上显然是无法运行的。
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_model
|
||||
from diffsynth.models.qwen_image_dit import QwenImageDiT
|
||||
from modelscope import snapshot_download
|
||||
import torch
|
||||
|
||||
snapshot_download(
|
||||
model_id="Qwen/Qwen-Image",
|
||||
local_dir="models/Qwen/Qwen-Image",
|
||||
allow_file_pattern="transformer/*"
|
||||
)
|
||||
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
|
||||
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
|
||||
inputs = {
|
||||
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
|
||||
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
}
|
||||
|
||||
model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cuda")
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
```
|
||||
|
||||
## 编写细粒度显存管理方案
|
||||
|
||||
为了编写细粒度的显存管理方案,我们需用 `print(model)` 观察和分析模型结构:
|
||||
|
||||
```
|
||||
QwenImageDiT(
|
||||
(pos_embed): QwenEmbedRope()
|
||||
(time_text_embed): TimestepEmbeddings(
|
||||
(time_proj): TemporalTimesteps()
|
||||
(timestep_embedder): DiffusersCompatibleTimestepProj(
|
||||
(linear_1): Linear(in_features=256, out_features=3072, bias=True)
|
||||
(act): SiLU()
|
||||
(linear_2): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
)
|
||||
)
|
||||
(txt_norm): RMSNorm()
|
||||
(img_in): Linear(in_features=64, out_features=3072, bias=True)
|
||||
(txt_in): Linear(in_features=3584, out_features=3072, bias=True)
|
||||
(transformer_blocks): ModuleList(
|
||||
(0-59): 60 x QwenImageTransformerBlock(
|
||||
(img_mod): Sequential(
|
||||
(0): SiLU()
|
||||
(1): Linear(in_features=3072, out_features=18432, bias=True)
|
||||
)
|
||||
(img_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
(attn): QwenDoubleStreamAttention(
|
||||
(to_q): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(to_k): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(to_v): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(norm_q): RMSNorm()
|
||||
(norm_k): RMSNorm()
|
||||
(add_q_proj): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(add_k_proj): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(add_v_proj): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(norm_added_q): RMSNorm()
|
||||
(norm_added_k): RMSNorm()
|
||||
(to_out): Sequential(
|
||||
(0): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
)
|
||||
(to_add_out): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
)
|
||||
(img_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
(img_mlp): QwenFeedForward(
|
||||
(net): ModuleList(
|
||||
(0): ApproximateGELU(
|
||||
(proj): Linear(in_features=3072, out_features=12288, bias=True)
|
||||
)
|
||||
(1): Dropout(p=0.0, inplace=False)
|
||||
(2): Linear(in_features=12288, out_features=3072, bias=True)
|
||||
)
|
||||
)
|
||||
(txt_mod): Sequential(
|
||||
(0): SiLU()
|
||||
(1): Linear(in_features=3072, out_features=18432, bias=True)
|
||||
)
|
||||
(txt_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
(txt_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
(txt_mlp): QwenFeedForward(
|
||||
(net): ModuleList(
|
||||
(0): ApproximateGELU(
|
||||
(proj): Linear(in_features=3072, out_features=12288, bias=True)
|
||||
)
|
||||
(1): Dropout(p=0.0, inplace=False)
|
||||
(2): Linear(in_features=12288, out_features=3072, bias=True)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
(norm_out): AdaLayerNorm(
|
||||
(linear): Linear(in_features=3072, out_features=6144, bias=True)
|
||||
(norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
)
|
||||
(proj_out): Linear(in_features=3072, out_features=64, bias=True)
|
||||
)
|
||||
```
|
||||
|
||||
在显存管理中,我们只关心包含参数的 Layer。在这个模型结构中,`QwenEmbedRope`、`TemporalTimesteps`、`SiLU` 等 Layer 都是不包含参数的,`LayerNorm` 也因为设置了 `elementwise_affine=False` 不包含参数。包含参数的 Layer 只有 `Linear` 和 `RMSNorm`。
|
||||
|
||||
`diffsynth.core.vram` 中提供了两个用于替换的模块用于显存管理:
|
||||
* `AutoWrappedLinear`: 用于替换 `Linear` 层
|
||||
* `AutoWrappedModule`: 用于替换其他任意层
|
||||
|
||||
编写一个 `module_map`,将模型中的 `Linear` 和 `RMSNorm` 映射到对应的模块上:
|
||||
|
||||
```python
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
}
|
||||
```
|
||||
|
||||
此外,还需要提供 `vram_config` 与 `vram_limit`,这两个参数在[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md#更多使用方式)中已有介绍。
|
||||
|
||||
调用 `enable_vram_management` 即可启用显存管理,注意此时模型加载时的 `device` 为 `cpu`,与 `offload_device` 一致:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
|
||||
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
|
||||
import torch
|
||||
|
||||
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
|
||||
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
|
||||
inputs = {
|
||||
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
|
||||
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
}
|
||||
|
||||
model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cpu")
|
||||
enable_vram_management(
|
||||
model,
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
},
|
||||
vram_limit=0,
|
||||
)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
```
|
||||
|
||||
以上代码只需要 2G 显存就可以运行 20B 模型的 `forward`。
|
||||
|
||||
## Disk Offload
|
||||
|
||||
[Disk Offload](/docs/zh/Pipeline_Usage/VRAM_management.md#disk-offload) 是特殊的显存管理方案,需在模型加载过程中启用,而非模型加载完毕后。通常,在以上代码能够顺利运行的前提下,Disk Offload 可以直接启用:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
|
||||
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
|
||||
import torch
|
||||
|
||||
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
|
||||
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
|
||||
inputs = {
|
||||
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
|
||||
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
}
|
||||
|
||||
model = load_model(
|
||||
QwenImageDiT,
|
||||
model_path,
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
vram_config={
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": "disk",
|
||||
"onload_device": "disk",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
},
|
||||
vram_limit=0,
|
||||
)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
```
|
||||
|
||||
Disk Offload 是极为特殊的显存管理方案,只支持 `.safetensors` 格式文件,不支持 `.bin`、`.pth`、`.ckpt` 等二进制文件,不支持带 Tensor reshape 的 [state dict converter](/docs/zh/Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换)。
|
||||
|
||||
如果出现非 Disk Offload 能正常运行但 Disk Offload 不能正常运行的情况,请在 GitHub 上给我们提 issue。
|
||||
|
||||
## 写入默认配置
|
||||
|
||||
为了让用户能够更方便地使用显存管理功能,我们将细粒度显存管理的配置写在 `diffsynth/configs/vram_management_module_maps.py` 中,上述模型的配置信息为:
|
||||
|
||||
```python
|
||||
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
|
||||
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
}
|
||||
```
|
||||
186
docs/zh/Developer_Guide/Integrating_Your_Model.md
Normal file
186
docs/zh/Developer_Guide/Integrating_Your_Model.md
Normal file
@@ -0,0 +1,186 @@
|
||||
# 接入模型结构
|
||||
|
||||
本文档介绍如何将模型接入到 `DiffSynth-Studio` 框架中,供 `Pipeline` 等模块调用。
|
||||
|
||||
## Step 1: 集成模型结构代码
|
||||
|
||||
`DiffSynth-Studio` 中的所有模型结构实现统一在 `diffsynth/models` 中,每个 `.py` 代码文件分别实现一个模型结构,所有模型通过 `diffsynth/models/model_loader.py` 中的 `ModelPool` 来加载。在接入新的模型结构时,请在这个路径下建立新的 `.py` 文件。
|
||||
|
||||
```shell
|
||||
diffsynth/models/
|
||||
├── general_modules.py
|
||||
├── model_loader.py
|
||||
├── qwen_image_controlnet.py
|
||||
├── qwen_image_dit.py
|
||||
├── qwen_image_text_encoder.py
|
||||
├── qwen_image_vae.py
|
||||
└── ...
|
||||
```
|
||||
|
||||
在大多数情况下,我们建议用 `PyTorch` 原生代码的形式集成模型,让模型结构类直接继承 `torch.nn.Module`,例如:
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class NewDiffSynthModel(torch.nn.Module):
|
||||
def __init__(self, dim=1024):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(dim, dim)
|
||||
self.activation = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
x = self.activation(x)
|
||||
return x
|
||||
```
|
||||
|
||||
如果模型结构的实现中包含额外的依赖,我们强烈建议将其删除,否则这会导致沉重的包依赖问题。在我们现有的模型中,Qwen-Image 的 Blockwise ControlNet 是以这种方式集成的,其代码很轻量,请参考 `diffsynth/models/qwen_image_controlnet.py`。
|
||||
|
||||
如果模型已被 Huggingface Library ([`transformers`](https://huggingface.co/docs/transformers/main/index)、[`diffusers`](https://huggingface.co/docs/diffusers/main/index) 等)集成,我们能够以更简单的方式集成模型:
|
||||
|
||||
<details>
|
||||
<summary>集成 Huggingface Library 风格模型结构代码</summary>
|
||||
|
||||
这类模型在 Huggingface Library 中的加载方式为:
|
||||
|
||||
```python
|
||||
from transformers import XXX_Model
|
||||
|
||||
model = XXX_Model.from_pretrained("path_to_your_model")
|
||||
```
|
||||
|
||||
`DiffSynth-Studio` 不支持通过 `from_pretrained` 加载模型,因为这与显存管理等功能是冲突的,请将模型结构改写成以下格式:
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class DiffSynth_XXX_Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
from transformers import XXX_Config, XXX_Model
|
||||
config = XXX_Config(**{
|
||||
"architectures": ["XXX_Model"],
|
||||
"other_configs": "Please copy and paste the other configs here.",
|
||||
})
|
||||
self.model = XXX_Model(config)
|
||||
|
||||
def forward(self, x):
|
||||
outputs = self.model(x)
|
||||
return outputs
|
||||
```
|
||||
|
||||
其中 `XXX_Config` 为模型对应的 Config 类,例如 `Qwen2_5_VLModel` 的 Config 类为 `Qwen2_5_VLConfig`,可通过查阅其源代码找到。Config 内部的内容通常可以在模型库中的 `config.json` 中找到,`DiffSynth-Studio` 不会读取 `config.json` 文件,因此需要将其中的内容复制粘贴到代码中。
|
||||
|
||||
在少数情况下,`transformers` 和 `diffusers` 的版本更新会导致部分的模型无法导入,因此如果可能的话,我们仍建议使用 Step 1.1 中的模型集成方式。
|
||||
|
||||
在我们现有的模型中,Qwen-Image 的 Text Encoder 是以这种方式集成的,其代码很轻量,请参考 `diffsynth/models/qwen_image_text_encoder.py`。
|
||||
|
||||
</details>
|
||||
|
||||
## Step 2: 模型文件格式转换
|
||||
|
||||
由于开源社区中开发者提供的模型文件格式多种多样,因此我们有时需对模型文件格式进行转换,从而形成格式正确的 [state dict](https://docs.pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html),常见于以下几种情况:
|
||||
|
||||
* 模型文件由不同代码库构建,例如 [Wan-AI/Wan2.1-T2V-1.3B](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) 和 [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B-Diffusers)。
|
||||
* 模型在接入中做了修改,例如 [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) 的 Text Encoder 在 `diffsynth/models/qwen_image_text_encoder.py` 中增加了 `model.` 前缀。
|
||||
* 模型文件包含多个模型,例如 [Wan-AI/Wan2.1-VACE-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) 的 VACE Adapter 和基础 DiT 模型混合存储在同一组模型文件中。
|
||||
|
||||
在我们的开发理念中,我们希望尽可能尊重模型原作者的意愿。如果对模型文件进行重新封装,例如 [Comfy-Org/Qwen-Image_ComfyUI](https://www.modelscope.cn/models/Comfy-Org/Qwen-Image_ComfyUI),虽然我们可以更方便地调用模型,但流量(模型页面浏览量和下载量等)会被引向他处,模型的原作者也会失去删除模型的权力。因此,我们在框架中增加了 `diffsynth/utils/state_dict_converters` 这一模块,用于在模型加载过程中进行文件格式转换。
|
||||
|
||||
这部分逻辑是非常简单的,以 Qwen-Image 的 Text Encoder 为例,只需要 10 行代码即可:
|
||||
|
||||
```python
|
||||
def QwenImageTextEncoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for k in state_dict:
|
||||
v = state_dict[k]
|
||||
if k.startswith("visual."):
|
||||
k = "model." + k
|
||||
elif k.startswith("model."):
|
||||
k = k.replace("model.", "model.language_model.")
|
||||
state_dict_[k] = v
|
||||
return state_dict_
|
||||
```
|
||||
|
||||
## Step 3: 编写模型 Config
|
||||
|
||||
模型 Config 位于 `diffsynth/configs/model_configs.py`,用于识别模型类型并进行加载。需填入以下字段:
|
||||
|
||||
* `model_hash`:模型文件哈希值,可通过 `hash_model_file` 函数获取,此哈希值仅与模型文件中 state dict 的 keys 和张量 shape 有关,与文件中的其他信息无关。
|
||||
* `model_name`: 模型名称,用于给 `Pipeline` 识别所需模型。如果不同结构的模型在 `Pipeline` 中发挥的作用相同,则可以使用相同的 `model_name`。在接入新模型时,只需保证 `model_name` 与现有的其他功能模型不同即可。在 `Pipeline` 的 `from_pretrained` 中通过 `model_name` 获取对应的模型。
|
||||
* `model_class`: 模型结构导入路径,指向在 Step 1 中实现的模型结构类,例如 `diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder`。
|
||||
* `state_dict_converter`: 可选参数,如需进行模型文件格式转换,则需填入模型转换逻辑的导入路径,例如 `diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter`。
|
||||
* `extra_kwargs`: 可选参数,如果模型初始化时需传入额外参数,则需要填入这些参数,例如模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) 与 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) 都采用了 `diffsynth/models/qwen_image_controlnet.py` 中的 `QwenImageBlockWiseControlNet` 结构,但后者还需额外的配置 `additional_in_dim=4`,因此这部分配置信息需填入 `extra_kwargs` 字段。
|
||||
|
||||
我们提供了一份代码,以便快速理解模型是如何通过这些配置信息加载的:
|
||||
|
||||
```python
|
||||
from diffsynth.core import hash_model_file, load_state_dict, skip_model_initialization
|
||||
from diffsynth.models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from diffsynth.utils.state_dict_converters.qwen_image_text_encoder import QwenImageTextEncoderStateDictConverter
|
||||
import torch
|
||||
|
||||
model_hash = "8004730443f55db63092006dd9f7110e"
|
||||
model_name = "qwen_image_text_encoder"
|
||||
model_class = QwenImageTextEncoder
|
||||
state_dict_converter = QwenImageTextEncoderStateDictConverter
|
||||
extra_kwargs = {}
|
||||
|
||||
model_path = [
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors",
|
||||
]
|
||||
if hash_model_file(model_path) == model_hash:
|
||||
with skip_model_initialization():
|
||||
model = model_class(**extra_kwargs)
|
||||
state_dict = load_state_dict(model_path, torch_dtype=torch.bfloat16, device="cuda")
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
print("Done!")
|
||||
```
|
||||
|
||||
> Q: 上述代码的逻辑看起来很简单,为什么 `DiffSynth-Studio` 中的这部分代码极为复杂?
|
||||
>
|
||||
> A: 因为我们提供了激进的显存管理功能,与模型加载逻辑耦合,这导致框架结构的复杂性,我们已尽可能简化暴露给开发者的接口。
|
||||
|
||||
`diffsynth/configs/model_configs.py` 中的 `model_hash` 不是唯一存在的,同一模型文件中可能存在多个模型。对于这种情况,请使用多个模型 Config 分别加载每个模型,编写相应的 `state_dict_converter` 分离每个模型所需的参数。
|
||||
|
||||
## Step 4: 检验模型是否能被识别和加载
|
||||
|
||||
模型接入之后,可通过以下代码验证模型是否能够被正确识别和加载,以下代码会试图将模型加载到内存中:
|
||||
|
||||
```python
|
||||
from diffsynth.models.model_loader import ModelPool
|
||||
|
||||
model_pool = ModelPool()
|
||||
model_pool.auto_load_model(
|
||||
[
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors",
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
如果模型能够被识别和加载,则会看到以下输出内容:
|
||||
|
||||
```
|
||||
Loading models from: [
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
|
||||
]
|
||||
Loaded model: {
|
||||
"model_name": "qwen_image_text_encoder",
|
||||
"model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder",
|
||||
"extra_kwargs": null
|
||||
}
|
||||
```
|
||||
|
||||
## Step 5: 编写模型显存管理方案
|
||||
|
||||
`DiffSynth-Studio` 支持复杂的显存管理,详见[启用显存管理](/docs/zh/Developer_Guide/Enabling_VRAM_management.md)。
|
||||
66
docs/zh/Developer_Guide/Training_Diffusion_Models.md
Normal file
66
docs/zh/Developer_Guide/Training_Diffusion_Models.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# 接入模型训练
|
||||
|
||||
在[接入模型](/docs/zh/Developer_Guide/Integrating_Your_Model.md)并[实现 Pipeline](/docs/zh/Developer_Guide/Building_a_Pipeline.md)后,接下来接入模型训练功能。
|
||||
|
||||
## 训推一致的 Pipeline 改造
|
||||
|
||||
为了保证训练和推理过程严格的一致性,我们会在训练过程中沿用大部分推理代码,但仍需作出少量改造。
|
||||
|
||||
首先,在推理过程中添加额外的逻辑,让图生图/视频生视频逻辑根据 `scheduler` 状态进行切换。以 Qwen-Image 为例:
|
||||
|
||||
```python
|
||||
class QwenImageUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, input_image, noise, tiled, tile_size, tile_stride):
|
||||
if input_image is None:
|
||||
return {"latents": noise, "input_latents": None}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
if pipe.scheduler.training:
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
```
|
||||
|
||||
然后,在 `model_fn` 中启用 Gradient Checkpointing,这将以计算速度为代价,大幅度减少训练所需的显存。这并不是必需的,但我们强烈建议这么做。
|
||||
|
||||
以 Qwen-Image 为例,修改前:
|
||||
|
||||
```python
|
||||
text, image = block(
|
||||
image=image,
|
||||
text=text,
|
||||
temb=conditioning,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
```
|
||||
|
||||
修改后:
|
||||
|
||||
```python
|
||||
from ..core import gradient_checkpoint_forward
|
||||
|
||||
text, image = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
image=image,
|
||||
text=text,
|
||||
temb=conditioning,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
```
|
||||
|
||||
## 编写训练脚本
|
||||
|
||||
`DiffSynth-Studio` 没有对训练框架做严格的封装,而是将脚本内容暴露给开发者,这种方式可以更方便地对训练脚本进行修改,实现额外的功能。开发者可参考现有的训练脚本,例如 `examples/qwen_image/model_training/train.py` 进行修改,从而适配新的模型训练。
|
||||
201
docs/zh/Model_Details/FLUX.md
Normal file
201
docs/zh/Model_Details/FLUX.md
Normal file
@@ -0,0 +1,201 @@
|
||||
# FLUX
|
||||
|
||||

|
||||
|
||||
FLUX 是由 Black Forest Labs 开发并开源的图像生成模型系列。
|
||||
|
||||
## 安装
|
||||
|
||||
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
更多关于安装的信息,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)。
|
||||
|
||||
## 快速开始
|
||||
|
||||
运行以下代码可以快速加载 [black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e4m3fn,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors", **vram_config),
|
||||
],
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 1,
|
||||
)
|
||||
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
image = pipe(prompt=prompt, seed=0)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## 模型总览
|
||||
|
||||
<details>
|
||||
|
||||
<summary>模型血缘</summary>
|
||||
|
||||
```mermaid
|
||||
graph LR;
|
||||
FLUX.1-Series-->black-forest-labs/FLUX.1-dev;
|
||||
FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;
|
||||
FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;
|
||||
black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;
|
||||
FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;
|
||||
FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;
|
||||
FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;
|
||||
black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;
|
||||
black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;
|
||||
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;
|
||||
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;
|
||||
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;
|
||||
black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;
|
||||
black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;
|
||||
Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;
|
||||
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;
|
||||
Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|-|
|
||||
|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
|
||||
|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
|
||||
|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
||||
|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|
||||
|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|
||||
|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|
||||
|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|
||||
|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|
||||
|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|
||||
|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|
||||
|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|
||||
|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)|
|
||||
|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
|
||||
|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)|
|
||||
|
||||
特殊训练脚本:
|
||||
|
||||
* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)、[code](/examples/flux/model_training/special/differential_training/)
|
||||
* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)、[code](/examples/flux/model_training/special/fp8_training/)
|
||||
* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/flux/model_training/special/split_training/)
|
||||
* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh)
|
||||
|
||||
## 模型推理
|
||||
|
||||
模型通过 `FluxImagePipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。
|
||||
|
||||
`FluxImagePipeline` 推理的输入参数包括:
|
||||
|
||||
* `prompt`: 提示词,描述画面中出现的内容。
|
||||
* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。
|
||||
* `cfg_scale`: Classifier-free guidance 的参数,默认值为 1,当设置为大于 1 的值时启用 CFG。
|
||||
* `height`: 图像高度,需保证高度为 16 的倍数。
|
||||
* `width`: 图像宽度,需保证宽度为 16 的倍数。
|
||||
* `seed`: 随机种子。默认为 `None`,即完全随机。
|
||||
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
|
||||
* `num_inference_steps`: 推理次数,默认值为 30。
|
||||
* `embedded_guidance`: 嵌入式引导参数,默认值为 3.5。
|
||||
* `t5_sequence_length`: T5 文本编码器的序列长度,默认为 512。
|
||||
* `tiled`: 是否启用 VAE 分块推理,默认为 `False`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用,会产生少许误差,以及少量推理时间延长。
|
||||
* `tile_size`: VAE 编解码阶段的分块大小,默认为 128,仅在 `tiled=True` 时生效。
|
||||
* `tile_stride`: VAE 编解码阶段的分块步长,默认为 64,仅在 `tiled=True` 时生效,需保证其数值小于或等于 `tile_size`。
|
||||
* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。
|
||||
* `controlnet_inputs`: ControlNet 模型的输入,类型为 `ControlNetInput` 列表。
|
||||
* `ipadapter_images`: IP-Adapter 模型的输入图像列表。
|
||||
* `ipadapter_scale`: IP-Adapter 模型的引导强度。
|
||||
* `infinityou_id_image`: InfiniteYou 模型的输入图像。
|
||||
* `infinityou_guidance`: InfiniteYou 模型的引导强度。
|
||||
* `kontext_images`: Kontext 模型的输入图像。
|
||||
* `eligen_entity_prompts`: EliGen 分区控制的提示词列表。
|
||||
* `eligen_entity_masks`: EliGen 分区控制的区域遮罩图像列表。
|
||||
* `eligen_enable_on_negative`: 是否在 CFG 的负向一侧启用 EliGen 分区控制。
|
||||
* `eligen_enable_inpaint`: 是否启用 EliGen 分区控制的局部重绘功能。
|
||||
* `lora_encoder_inputs`: LoRA 编码器的输入图像列表。
|
||||
* `lora_encoder_scale`: LoRA 编码器的引导强度。
|
||||
* `step1x_reference_image`: Step1X 模型的参考图像。
|
||||
* `flex_inpaint_image`: Flex 模型的待修复图像。
|
||||
* `flex_inpaint_mask`: Flex 模型的修复遮罩。
|
||||
* `flex_control_image`: Flex 模型的控制图像。
|
||||
* `flex_control_strength`: Flex 模型的控制强度。
|
||||
* `flex_control_stop`: Flex 模型的控制停止时间步。
|
||||
* `nexus_gen_reference_image`: Nexus-Gen 模型的参考图像。
|
||||
|
||||
如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。
|
||||
|
||||
## 模型训练
|
||||
|
||||
FLUX 系列模型统一通过 [`examples/flux/model_training/train.py`](/examples/flux/model_training/train.py) 进行训练,脚本的参数包括:
|
||||
|
||||
* 通用训练参数
|
||||
* 数据集基础配置
|
||||
* `--dataset_base_path`: 数据集的根目录。
|
||||
* `--dataset_metadata_path`: 数据集的元数据文件路径。
|
||||
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
|
||||
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
|
||||
* 模型加载配置
|
||||
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 `"black-forest-labs/FLUX.1-dev:flux1-dev.safetensors"`。用逗号分隔。
|
||||
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,例如训练 ControlNet 模型时需要额外参数 `controlnet_inputs`,以 `,` 分隔。
|
||||
* `--fp8_models`:以 FP8 格式加载的模型,格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致,目前仅支持参数不被梯度更新的模型(不需要梯度回传,或梯度仅更新其 LoRA)。
|
||||
* 训练基础配置
|
||||
* `--learning_rate`: 学习率。
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。
|
||||
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数,少数模型包含不参与梯度计算的冗余参数,需开启这一设置避免在多 GPU 训练中报错。
|
||||
* `--weight_decay`:权重衰减大小,详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。
|
||||
* `--task`: 训练任务,默认为 `sft`,部分模型支持更多训练模式,请参考每个特定模型的文档。
|
||||
* 输出配置
|
||||
* `--output_path`: 模型保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
|
||||
* `--save_steps`: 保存模型的训练步数间隔,若此参数留空,则每个 epoch 保存一次。
|
||||
* LoRA 配置
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
* `--lora_target_modules`: LoRA 添加到哪些层上。
|
||||
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||
* `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。
|
||||
* `--preset_lora_path`: 预置 LoRA 检查点路径,如果提供此路径,这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。
|
||||
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。
|
||||
* 梯度配置
|
||||
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||
* 图像宽高配置(适用于图像生成模型和视频生成模型)
|
||||
* `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||
* `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||
* `--max_pixels`: 图像或视频帧的最大像素面积,当启用动态分辨率时,分辨率大于这个数值的图片都会被缩小,分辨率小于这个数值的图片保持不变。
|
||||
* FLUX 专有参数
|
||||
* `--tokenizer_1_path`: CLIP tokenizer 的路径,留空则自动从远程下载。
|
||||
* `--tokenizer_2_path`: T5 tokenizer 的路径,留空则自动从远程下载。
|
||||
* `--align_to_opensource_format`: 是否将 LoRA 格式对齐到开源格式,仅适用于 DiT 的 LoRA。
|
||||
|
||||
我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
```
|
||||
|
||||
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。
|
||||
138
docs/zh/Model_Details/FLUX2.md
Normal file
138
docs/zh/Model_Details/FLUX2.md
Normal file
@@ -0,0 +1,138 @@
|
||||
# FLUX.2
|
||||
|
||||
FLUX.2 是由 Black Forest Labs 训练并开源的图像生成模型。
|
||||
|
||||
## 安装
|
||||
|
||||
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
更多关于安装的信息,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)。
|
||||
|
||||
## 快速开始
|
||||
|
||||
运行以下代码可以快速加载 [black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 10G 显存即可运行。
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.flux2_image import Flux2ImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = Flux2ImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="black-forest-labs/FLUX.2-dev", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
prompt = "High resolution. A dreamy underwater portrait of a serene young woman in a flowing blue dress. Her hair floats softly around her face, strands delicately suspended in the water. Clear, shimmering light filters through, casting gentle highlights, while tiny bubbles rise around her. Her expression is calm, her features finely detailed—creating a tranquil, ethereal scene."
|
||||
image = pipe(prompt, seed=42, rand_device="cuda", num_inference_steps=50)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## 模型总览
|
||||
|
||||
|模型 ID|推理|低显存推理|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|
|
||||
|[black-forest-labs/FLUX.2-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.2-dev)|[code](/examples/flux2/model_inference/FLUX.2-dev.py)|[code](/examples/flux2/model_inference_low_vram/FLUX.2-dev.py)|[code](/examples/flux2/model_training/lora/FLUX.2-dev.sh)|[code](/examples/flux2/model_training/validate_lora/FLUX.2-dev.py)|
|
||||
|
||||
特殊训练脚本:
|
||||
|
||||
* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)、[code](/examples/flux/model_training/special/differential_training/)
|
||||
* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)、[code](/examples/flux/model_training/special/fp8_training/)
|
||||
* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/flux/model_training/special/split_training/)
|
||||
* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/flux/model_training/lora/FLUX.1-dev-Distill-LoRA.sh)
|
||||
|
||||
## 模型推理
|
||||
|
||||
模型通过 `Flux2ImagePipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。
|
||||
|
||||
`Flux2ImagePipeline` 推理的输入参数包括:
|
||||
|
||||
* `prompt`: 提示词,描述画面中出现的内容。
|
||||
* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。
|
||||
* `cfg_scale`: Classifier-free guidance 的参数,默认值为 1,当设置为大于 1 的值时启用 CFG。
|
||||
* `height`: 图像高度,需保证高度为 16 的倍数。
|
||||
* `width`: 图像宽度,需保证宽度为 16 的倍数。
|
||||
* `seed`: 随机种子。默认为 `None`,即完全随机。
|
||||
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
|
||||
* `num_inference_steps`: 推理次数,默认值为 30。
|
||||
* `embedded_guidance`: 嵌入式引导参数,默认值为 3.5。
|
||||
* `t5_sequence_length`: T5 文本编码器的序列长度,默认为 512。
|
||||
* `tiled`: 是否启用 VAE 分块推理,默认为 `False`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用,会产生少许误差,以及少量推理时间延长。
|
||||
* `tile_size`: VAE 编解码阶段的分块大小,默认为 128,仅在 `tiled=True` 时生效。
|
||||
* `tile_stride`: VAE 编解码阶段的分块步长,默认为 64,仅在 `tiled=True` 时生效,需保证其数值小于或等于 `tile_size`。
|
||||
* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。
|
||||
|
||||
如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。
|
||||
|
||||
## 模型训练
|
||||
|
||||
FLUX.2 系列模型统一通过 [`examples/flux2/model_training/train.py`](/examples/flux2/model_training/train.py) 进行训练,脚本的参数包括:
|
||||
|
||||
* 通用训练参数
|
||||
* 数据集基础配置
|
||||
* `--dataset_base_path`: 数据集的根目录。
|
||||
* `--dataset_metadata_path`: 数据集的元数据文件路径。
|
||||
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
|
||||
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
|
||||
* 模型加载配置
|
||||
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 `"black-forest-labs/FLUX.2-dev:text_encoder/*.safetensors"`。用逗号分隔。
|
||||
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,例如训练 ControlNet 模型时需要额外参数 `controlnet_inputs`,以 `,` 分隔。
|
||||
* `--fp8_models`:以 FP8 格式加载的模型,格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致,目前仅支持参数不被梯度更新的模型(不需要梯度回传,或梯度仅更新其 LoRA)。
|
||||
* 训练基础配置
|
||||
* `--learning_rate`: 学习率。
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。
|
||||
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数,少数模型包含不参与梯度计算的冗余参数,需开启这一设置避免在多 GPU 训练中报错。
|
||||
* `--weight_decay`:权重衰减大小,详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。
|
||||
* `--task`: 训练任务,默认为 `sft`,部分模型支持更多训练模式,请参考每个特定模型的文档。
|
||||
* 输出配置
|
||||
* `--output_path`: 模型保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
|
||||
* `--save_steps`: 保存模型的训练步数间隔,若此参数留空,则每个 epoch 保存一次。
|
||||
* LoRA 配置
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
* `--lora_target_modules`: LoRA 添加到哪些层上。
|
||||
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||
* `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。
|
||||
* `--preset_lora_path`: 预置 LoRA 检查点路径,如果提供此路径,这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。
|
||||
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。
|
||||
* 梯度配置
|
||||
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||
* 图像宽高配置(适用于图像生成模型和视频生成模型)
|
||||
* `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||
* `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||
* `--max_pixels`: 图像或视频帧的最大像素面积,当启用动态分辨率时,分辨率大于这个数值的图片都会被缩小,分辨率小于这个数值的图片保持不变。
|
||||
* FLUX.2 专有参数
|
||||
* `--tokenizer_path`: tokenizer 的路径,适用于文生图模型,留空则自动从远程下载。
|
||||
|
||||
我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
```
|
||||
|
||||
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。
|
||||
288
docs/zh/Model_Details/Overview.md
Normal file
288
docs/zh/Model_Details/Overview.md
Normal file
@@ -0,0 +1,288 @@
|
||||
# 模型目录
|
||||
|
||||
## Qwen-Image
|
||||
|
||||
文档:[./Qwen-Image.md](/docs/zh/Model_Details/Qwen-Image.md)
|
||||
|
||||
<details>
|
||||
|
||||
<summary>效果一览</summary>
|
||||
|
||||

|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>快速开始</summary>
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
from PIL import Image
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(
|
||||
prompt, seed=0, num_inference_steps=40,
|
||||
# edit_image=Image.open("xxx.jpg").resize((1328, 1328)) # For Qwen-Image-Edit
|
||||
)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>模型血缘</summary>
|
||||
|
||||
```mermaid
|
||||
graph LR;
|
||||
Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;
|
||||
Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;
|
||||
Qwen/Qwen-Image-->EliGen-Series;
|
||||
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;
|
||||
DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;
|
||||
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;
|
||||
Qwen/Qwen-Image-->Distill-Series;
|
||||
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;
|
||||
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;
|
||||
Qwen/Qwen-Image-->ControlNet-Series;
|
||||
ControlNet-Series-->Blockwise-ControlNet-Series;
|
||||
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;
|
||||
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;
|
||||
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;
|
||||
ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;
|
||||
Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
||||
|
||||
## FLUX 系列
|
||||
|
||||
文档:[./FLUX.md](/docs/zh/Model_Details/FLUX.md)
|
||||
|
||||
<details>
|
||||
|
||||
<summary>效果一览</summary>
|
||||
|
||||

|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>快速开始</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.pipelines.flux_image import FluxImagePipeline, ModelConfig
|
||||
|
||||
pipe = FluxImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="flux1-dev.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="text_encoder_2/*.safetensors"),
|
||||
ModelConfig(model_id="black-forest-labs/FLUX.1-dev", origin_file_pattern="ae.safetensors"),
|
||||
],
|
||||
)
|
||||
|
||||
image = pipe(prompt="a cat", seed=0)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>模型血缘</summary>
|
||||
|
||||
```mermaid
|
||||
graph LR;
|
||||
FLUX.1-Series-->black-forest-labs/FLUX.1-dev;
|
||||
FLUX.1-Series-->black-forest-labs/FLUX.1-Krea-dev;
|
||||
FLUX.1-Series-->black-forest-labs/FLUX.1-Kontext-dev;
|
||||
black-forest-labs/FLUX.1-dev-->FLUX.1-dev-ControlNet-Series;
|
||||
FLUX.1-dev-ControlNet-Series-->alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta;
|
||||
FLUX.1-dev-ControlNet-Series-->InstantX/FLUX.1-dev-Controlnet-Union-alpha;
|
||||
FLUX.1-dev-ControlNet-Series-->jasperai/Flux.1-dev-Controlnet-Upscaler;
|
||||
black-forest-labs/FLUX.1-dev-->InstantX/FLUX.1-dev-IP-Adapter;
|
||||
black-forest-labs/FLUX.1-dev-->ByteDance/InfiniteYou;
|
||||
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Eligen;
|
||||
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev;
|
||||
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev;
|
||||
black-forest-labs/FLUX.1-dev-->ostris/Flex.2-preview;
|
||||
black-forest-labs/FLUX.1-dev-->stepfun-ai/Step1X-Edit;
|
||||
Qwen/Qwen2.5-VL-7B-Instruct-->stepfun-ai/Step1X-Edit;
|
||||
black-forest-labs/FLUX.1-dev-->DiffSynth-Studio/Nexus-GenV2;
|
||||
Qwen/Qwen2.5-VL-7B-Instruct-->DiffSynth-Studio/Nexus-GenV2;
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|模型 ID|额外参数|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|-|
|
||||
|[black-forest-labs/FLUX.1-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev.py)|
|
||||
|[black-forest-labs/FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev)||[code](/examples/flux/model_inference/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Krea-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Krea-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Krea-dev.py)|
|
||||
|[black-forest-labs/FLUX.1-Kontext-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Kontext-dev)|`kontext_images`|[code](/examples/flux/model_inference/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/full/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-Kontext-dev.py)|[code](/examples/flux/model_training/lora/FLUX.1-Kontext-dev.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-Kontext-dev.py)|
|
||||
|[alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta](https://www.modelscope.cn/models/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Inpainting-Beta.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Inpainting-Beta.py)|
|
||||
|[InstantX/FLUX.1-dev-Controlnet-Union-alpha](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-Controlnet-Union-alpha)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Union-alpha.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Union-alpha.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Union-alpha.py)|
|
||||
|[jasperai/Flux.1-dev-Controlnet-Upscaler](https://www.modelscope.cn/models/jasperai/Flux.1-dev-Controlnet-Upscaler)|`controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-Controlnet-Upscaler.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-Controlnet-Upscaler.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-Controlnet-Upscaler.py)|
|
||||
|[InstantX/FLUX.1-dev-IP-Adapter](https://www.modelscope.cn/models/InstantX/FLUX.1-dev-IP-Adapter)|`ipadapter_images`, `ipadapter_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-IP-Adapter.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-IP-Adapter.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-IP-Adapter.py)|
|
||||
|[ByteDance/InfiniteYou](https://www.modelscope.cn/models/ByteDance/InfiniteYou)|`infinityou_id_image`, `infinityou_guidance`, `controlnet_inputs`|[code](/examples/flux/model_inference/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-InfiniteYou.py)|[code](/examples/flux/model_training/lora/FLUX.1-dev-InfiniteYou.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-InfiniteYou.py)|
|
||||
|[DiffSynth-Studio/Eligen](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen)|`eligen_entity_prompts`, `eligen_entity_masks`, `eligen_enable_on_negative`, `eligen_enable_inpaint`|[code](/examples/flux/model_inference/FLUX.1-dev-EliGen.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-EliGen.py)|-|-|[code](/examples/flux/model_training/lora/FLUX.1-dev-EliGen.sh)|[code](/examples/flux/model_training/validate_lora/FLUX.1-dev-EliGen.py)|
|
||||
|[DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev](https://www.modelscope.cn/models/DiffSynth-Studio/LoRA-Encoder-FLUX.1-Dev)|`lora_encoder_inputs`, `lora_encoder_scale`|[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_inference_low_vram/FLUX.1-dev-LoRA-Encoder.py)|[code](/examples/flux/model_training/full/FLUX.1-dev-LoRA-Encoder.sh)|[code](/examples/flux/model_training/validate_full/FLUX.1-dev-LoRA-Encoder.py)|-|-|
|
||||
|[DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev](https://modelscope.cn/models/DiffSynth-Studio/LoRAFusion-preview-FLUX.1-dev)||[code](/examples/flux/model_inference/FLUX.1-dev-LoRA-Fusion.py)|-|-|-|-|-|
|
||||
|[stepfun-ai/Step1X-Edit](https://www.modelscope.cn/models/stepfun-ai/Step1X-Edit)|`step1x_reference_image`|[code](/examples/flux/model_inference/Step1X-Edit.py)|[code](/examples/flux/model_inference_low_vram/Step1X-Edit.py)|[code](/examples/flux/model_training/full/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_full/Step1X-Edit.py)|[code](/examples/flux/model_training/lora/Step1X-Edit.sh)|[code](/examples/flux/model_training/validate_lora/Step1X-Edit.py)|
|
||||
|[ostris/Flex.2-preview](https://www.modelscope.cn/models/ostris/Flex.2-preview)|`flex_inpaint_image`, `flex_inpaint_mask`, `flex_control_image`, `flex_control_strength`, `flex_control_stop`|[code](/examples/flux/model_inference/FLEX.2-preview.py)|[code](/examples/flux/model_inference_low_vram/FLEX.2-preview.py)|[code](/examples/flux/model_training/full/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_full/FLEX.2-preview.py)|[code](/examples/flux/model_training/lora/FLEX.2-preview.sh)|[code](/examples/flux/model_training/validate_lora/FLEX.2-preview.py)|
|
||||
|[DiffSynth-Studio/Nexus-GenV2](https://www.modelscope.cn/models/DiffSynth-Studio/Nexus-GenV2)|`nexus_gen_reference_image`|[code](/examples/flux/model_inference/Nexus-Gen-Editing.py)|[code](/examples/flux/model_inference_low_vram/Nexus-Gen-Editing.py)|[code](/examples/flux/model_training/full/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_full/Nexus-Gen.py)|[code](/examples/flux/model_training/lora/Nexus-Gen.sh)|[code](/examples/flux/model_training/validate_lora/Nexus-Gen.py)|
|
||||
|
||||
## Wan 系列
|
||||
|
||||
文档:[./Wan.md](/docs/zh/Model_Details/Wan.md)
|
||||
|
||||
<details>
|
||||
|
||||
<summary>效果一览</summary>
|
||||
|
||||
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>快速开始</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.utils.data import save_video
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth"),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth"),
|
||||
],
|
||||
)
|
||||
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
)
|
||||
save_video(video, "video1.mp4", fps=15, quality=5)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>模型血缘</summary>
|
||||
|
||||
```mermaid
|
||||
graph LR;
|
||||
Wan-Series-->Wan2.1-Series;
|
||||
Wan-Series-->Wan2.2-Series;
|
||||
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;
|
||||
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;
|
||||
Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;
|
||||
Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;
|
||||
iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;
|
||||
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;
|
||||
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;
|
||||
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;
|
||||
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;
|
||||
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;
|
||||
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;
|
||||
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;
|
||||
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;
|
||||
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;
|
||||
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;
|
||||
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;
|
||||
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;
|
||||
Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;
|
||||
Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;
|
||||
Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video;
|
||||
Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;
|
||||
Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;
|
||||
Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;
|
||||
Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;
|
||||
Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;
|
||||
Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;
|
||||
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;
|
||||
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;
|
||||
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|
||||
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|
||||
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/examples/wanmodel_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|
||||
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|
||||
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|
||||
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|
||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|
||||
|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||
191
docs/zh/Model_Details/Qwen-Image.md
Normal file
191
docs/zh/Model_Details/Qwen-Image.md
Normal file
@@ -0,0 +1,191 @@
|
||||
# Qwen-Image
|
||||
|
||||

|
||||
|
||||
Qwen-Image 是由阿里巴巴通义实验室通义千问团队训练并开源的图像生成模型。
|
||||
|
||||
## 安装
|
||||
|
||||
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
更多关于安装的信息,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)。
|
||||
|
||||
## 快速开始
|
||||
|
||||
运行以下代码可以快速加载 [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## 模型总览
|
||||
|
||||
<details>
|
||||
|
||||
<summary>模型血缘</summary>
|
||||
|
||||
```mermaid
|
||||
graph LR;
|
||||
Qwen/Qwen-Image-->Qwen/Qwen-Image-Edit;
|
||||
Qwen/Qwen-Image-Edit-->Qwen/Qwen-Image-Edit-2509;
|
||||
Qwen/Qwen-Image-->EliGen-Series;
|
||||
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen;
|
||||
DiffSynth-Studio/Qwen-Image-EliGen-->DiffSynth-Studio/Qwen-Image-EliGen-V2;
|
||||
EliGen-Series-->DiffSynth-Studio/Qwen-Image-EliGen-Poster;
|
||||
Qwen/Qwen-Image-->Distill-Series;
|
||||
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-Full;
|
||||
Distill-Series-->DiffSynth-Studio/Qwen-Image-Distill-LoRA;
|
||||
Qwen/Qwen-Image-->ControlNet-Series;
|
||||
ControlNet-Series-->Blockwise-ControlNet-Series;
|
||||
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny;
|
||||
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth;
|
||||
Blockwise-ControlNet-Series-->DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint;
|
||||
ControlNet-Series-->DiffSynth-Studio/Qwen-Image-In-Context-Control-Union;
|
||||
Qwen/Qwen-Image-->DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix;
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](/examples/qwen_image/model_inference/Qwen-Image.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|
||||
|[Qwen/Qwen-Image-Edit](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit.py)|
|
||||
|[Qwen/Qwen-Image-Edit-2509](https://www.modelscope.cn/models/Qwen/Qwen-Image-Edit-2509)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Edit-2509.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Edit-2509.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Edit-2509.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-V2](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-V2)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-V2.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-V2.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-EliGen-Poster](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-EliGen-Poster)|[code](/examples/qwen_image/model_inference/Qwen-Image-EliGen-Poster.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-EliGen-Poster.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-EliGen-Poster.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-EliGen-Poster.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA)|[code](/examples/qwen_image/model_inference/Qwen-Image-Distill-LoRA.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Distill-LoRA.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-LoRA.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Canny.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Canny.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Canny.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Depth)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Depth.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Depth.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Depth.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint)|[code](/examples/qwen_image/model_inference/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/full/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_full/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|[code](/examples/qwen_image/model_training/lora/Qwen-Image-Blockwise-ControlNet-Inpaint.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-Blockwise-ControlNet-Inpaint.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-In-Context-Control-Union](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-In-Context-Control-Union)|[code](/examples/qwen_image/model_inference/Qwen-Image-In-Context-Control-Union.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-In-Context-Control-Union.py)|-|-|[code](/examples/qwen_image/model_training/lora/Qwen-Image-In-Context-Control-Union.sh)|[code](/examples/qwen_image/model_training/validate_lora/Qwen-Image-In-Context-Control-Union.py)|
|
||||
|[DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Edit-Lowres-Fix)|[code](/examples/qwen_image/model_inference/Qwen-Image-Edit-Lowres-Fix.py)|[code](/examples/qwen_image/model_inference_low_vram/Qwen-Image-Edit-Lowres-Fix.py)|-|-|-|-|
|
||||
|
||||
特殊训练脚本:
|
||||
|
||||
* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)、[code](/examples/qwen_image/model_training/special/differential_training/)
|
||||
* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)、[code](/examples/qwen_image/model_training/special/fp8_training/)
|
||||
* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/qwen_image/model_training/special/split_training/)
|
||||
* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh)
|
||||
|
||||
## 模型推理
|
||||
|
||||
模型通过 `QwenImagePipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。
|
||||
|
||||
`QwenImagePipeline` 推理的输入参数包括:
|
||||
|
||||
* `prompt`: 提示词,描述画面中出现的内容。
|
||||
* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。
|
||||
* `cfg_scale`: Classifier-free guidance 的参数,默认值为 4,当设置为 1 时不再生效。
|
||||
* `input_image`: 输入图像,用于图生图,该参数与 `denoising_strength` 配合使用。
|
||||
* `denoising_strength`: 去噪强度,范围是 0~1,默认值为 1,当数值接近 0 时,生成图像与输入图像相似;当数值接近 1 时,生成图像与输入图像相差更大。在不输入 `input_image` 参数时,请不要将其设置为非 1 的数值。
|
||||
* `inpaint_mask`: 图像局部重绘的遮罩图像。
|
||||
* `inpaint_blur_size`: 图像局部重绘的边缘柔化宽度。
|
||||
* `inpaint_blur_sigma`: 图像局部重绘的边缘柔化强度。
|
||||
* `height`: 图像高度,需保证高度为 16 的倍数。
|
||||
* `width`: 图像宽度,需保证宽度为 16 的倍数。
|
||||
* `seed`: 随机种子。默认为 `None`,即完全随机。
|
||||
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
|
||||
* `num_inference_steps`: 推理次数,默认值为 30。
|
||||
* `exponential_shift_mu`: 在采样时间步时采用的固定参数,留空则根据图像宽高进行采样。
|
||||
* `blockwise_controlnet_inputs`: Blockwise ControlNet 模型的输入。
|
||||
* `eligen_entity_prompts`: EliGen 分区控制的提示词。
|
||||
* `eligen_entity_masks`: EliGen 分区控制的区域遮罩图像。
|
||||
* `eligen_enable_on_negative`: 是否在 CFG 的负向一侧启用 EliGen 分区控制。
|
||||
* `edit_image`: 编辑模型的待编辑图像,支持多张图像。
|
||||
* `edit_image_auto_resize`: 是否自动缩放待编辑图像。
|
||||
* `edit_rope_interpolation`: 是否在低分辨率编辑图像上启用 ROPE 插值。
|
||||
* `context_image`: In-Context Control 的输入图像。
|
||||
* `tiled`: 是否启用 VAE 分块推理,默认为 `False`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用,会产生少许误差,以及少量推理时间延长。
|
||||
* `tile_size`: VAE 编解码阶段的分块大小,默认为 128,仅在 `tiled=True` 时生效。
|
||||
* `tile_stride`: VAE 编解码阶段的分块步长,默认为 64,仅在 `tiled=True` 时生效,需保证其数值小于或等于 `tile_size`。
|
||||
* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。
|
||||
|
||||
如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文“模型总览”中的表格。
|
||||
|
||||
## 模型训练
|
||||
|
||||
Qwen-Image 系列模型统一通过 [`examples/qwen_image/model_training/train.py`](/examples/qwen_image/model_training/train.py) 进行训练,脚本的参数包括:
|
||||
|
||||
* 通用训练参数
|
||||
* 数据集基础配置
|
||||
* `--dataset_base_path`: 数据集的根目录。
|
||||
* `--dataset_metadata_path`: 数据集的元数据文件路径。
|
||||
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
|
||||
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
|
||||
* 模型加载配置
|
||||
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 `"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors"`。用逗号分隔。
|
||||
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,例如训练图像编辑模型 Qwen-Image-Edit 时需要额外参数 `edit_image`,以 `,` 分隔。
|
||||
* `--fp8_models`:以 FP8 格式加载的模型,格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致,目前仅支持参数不被梯度更新的模型(不需要梯度回传,或梯度仅更新其 LoRA)。
|
||||
* 训练基础配置
|
||||
* `--learning_rate`: 学习率。
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。
|
||||
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数,少数模型包含不参与梯度计算的冗余参数,需开启这一设置避免在多 GPU 训练中报错。
|
||||
* `--weight_decay`:权重衰减大小,详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。
|
||||
* `--task`: 训练任务,默认为 `sft`,部分模型支持更多训练模式,请参考每个特定模型的文档。
|
||||
* 输出配置
|
||||
* `--output_path`: 模型保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
|
||||
* `--save_steps`: 保存模型的训练步数间隔,若此参数留空,则每个 epoch 保存一次。
|
||||
* LoRA 配置
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
* `--lora_target_modules`: LoRA 添加到哪些层上。
|
||||
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||
* `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。
|
||||
* `--preset_lora_path`: 预置 LoRA 检查点路径,如果提供此路径,这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。
|
||||
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。
|
||||
* 梯度配置
|
||||
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||
* 图像宽高配置(适用于图像生成模型和视频生成模型)
|
||||
* `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||
* `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||
* `--max_pixels`: 图像或视频帧的最大像素面积,当启用动态分辨率时,分辨率大于这个数值的图片都会被缩小,分辨率小于这个数值的图片保持不变。
|
||||
* Qwen-Image 专有参数
|
||||
* `--tokenizer_path`: tokenizer 的路径,适用于文生图模型,留空则自动从远程下载。
|
||||
* `--processor_path`: processor 的路径,适用于图像编辑模型,留空则自动从远程下载。
|
||||
|
||||
我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
```
|
||||
|
||||
我们为每个模型编写了推荐的训练脚本,请参考前文“模型总览”中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。
|
||||
253
docs/zh/Model_Details/Wan.md
Normal file
253
docs/zh/Model_Details/Wan.md
Normal file
@@ -0,0 +1,253 @@
|
||||
# Wan
|
||||
|
||||
https://github.com/user-attachments/assets/1d66ae74-3b02-40a9-acc3-ea95fc039314
|
||||
|
||||
Wan 是由阿里巴巴通义实验室通义万相团队开发的视频生成模型系列。
|
||||
|
||||
## 安装
|
||||
|
||||
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
更多关于安装的信息,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)。
|
||||
|
||||
## 快速开始
|
||||
|
||||
运行以下代码可以快速加载 [Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) 模型并进行推理。显存管理已启动,框架会自动根据剩余显存控制模型参数的加载,最低 8G 显存即可运行。
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.utils.data import save_video, VideoData
|
||||
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = WanVideoPipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="diffusion_pytorch_model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", **vram_config),
|
||||
ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="Wan2.1_VAE.pth", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 2,
|
||||
)
|
||||
|
||||
video = pipe(
|
||||
prompt="纪实摄影风格画面,一只活泼的小狗在绿茵茵的草地上迅速奔跑。小狗毛色棕黄,两只耳朵立起,神情专注而欢快。阳光洒在它身上,使得毛发看上去格外柔软而闪亮。背景是一片开阔的草地,偶尔点缀着几朵野花,远处隐约可见蓝天和几片白云。透视感鲜明,捕捉小狗奔跑时的动感和四周草地的生机。中景侧面移动视角。",
|
||||
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
seed=0, tiled=True,
|
||||
)
|
||||
save_video(video, "video.mp4", fps=15, quality=5)
|
||||
```
|
||||
|
||||
## 模型总览
|
||||
|
||||
<details>
|
||||
|
||||
<summary>模型血缘</summary>
|
||||
|
||||
```mermaid
|
||||
graph LR;
|
||||
Wan-Series-->Wan2.1-Series;
|
||||
Wan-Series-->Wan2.2-Series;
|
||||
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-1.3B;
|
||||
Wan2.1-Series-->Wan-AI/Wan2.1-T2V-14B;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-I2V-14B-480P;
|
||||
Wan-AI/Wan2.1-I2V-14B-480P-->Wan-AI/Wan2.1-I2V-14B-720P;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-FLF2V-14B-720P;
|
||||
Wan-AI/Wan2.1-T2V-1.3B-->iic/VACE-Wan2.1-1.3B-Preview;
|
||||
iic/VACE-Wan2.1-1.3B-Preview-->Wan-AI/Wan2.1-VACE-1.3B;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.1-VACE-14B;
|
||||
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-1.3B-Series;
|
||||
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-InP;
|
||||
Wan2.1-Fun-1.3B-Series-->PAI/Wan2.1-Fun-1.3B-Control;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-14B-Series;
|
||||
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-InP;
|
||||
Wan2.1-Fun-14B-Series-->PAI/Wan2.1-Fun-14B-Control;
|
||||
Wan-AI/Wan2.1-T2V-1.3B-->Wan2.1-Fun-V1.1-1.3B-Series;
|
||||
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control;
|
||||
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-InP;
|
||||
Wan2.1-Fun-V1.1-1.3B-Series-->PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan2.1-Fun-V1.1-14B-Series;
|
||||
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control;
|
||||
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-InP;
|
||||
Wan2.1-Fun-V1.1-14B-Series-->PAI/Wan2.1-Fun-V1.1-14B-Control-Camera;
|
||||
Wan-AI/Wan2.1-T2V-1.3B-->DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1;
|
||||
Wan-AI/Wan2.1-T2V-14B-->krea/krea-realtime-video;
|
||||
Wan-AI/Wan2.1-T2V-14B-->meituan-longcat/LongCat-Video;
|
||||
Wan-AI/Wan2.1-I2V-14B-720P-->ByteDance/Video-As-Prompt-Wan2.1-14B;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-Animate-14B;
|
||||
Wan-AI/Wan2.1-T2V-14B-->Wan-AI/Wan2.2-S2V-14B;
|
||||
Wan2.2-Series-->Wan-AI/Wan2.2-T2V-A14B;
|
||||
Wan2.2-Series-->Wan-AI/Wan2.2-I2V-A14B;
|
||||
Wan2.2-Series-->Wan-AI/Wan2.2-TI2V-5B;
|
||||
Wan-AI/Wan2.2-T2V-A14B-->Wan2.2-Fun-Series;
|
||||
Wan2.2-Fun-Series-->PAI/Wan2.2-VACE-Fun-A14B;
|
||||
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-InP;
|
||||
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control;
|
||||
Wan2.2-Fun-Series-->PAI/Wan2.2-Fun-A14B-Control-Camera;
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|模型 ID|额外参数|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Wan-AI/Wan2.1-T2V-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-T2V-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-T2V-14B)||[code](/examples/wanvideo/model_inference/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-T2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-T2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-T2V-14B.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-480P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-480P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-480P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-480P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-480P.py)|
|
||||
|[Wan-AI/Wan2.1-I2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-I2V-14B-720P)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-I2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-I2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-I2V-14B-720P.py)|
|
||||
|[Wan-AI/Wan2.1-FLF2V-14B-720P](https://modelscope.cn/models/Wan-AI/Wan2.1-FLF2V-14B-720P)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-FLF2V-14B-720P.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-FLF2V-14B-720P.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-FLF2V-14B-720P.py)|
|
||||
|[iic/VACE-Wan2.1-1.3B-Preview](https://modelscope.cn/models/iic/VACE-Wan2.1-1.3B-Preview)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B-Preview.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B-Preview.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B-Preview.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-1.3B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-1.3B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-1.3B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-1.3B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-1.3B.py)|
|
||||
|[Wan-AI/Wan2.1-VACE-14B](https://modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-VACE-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-VACE-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-VACE-14B.py)|
|
||||
|[PAI/Wan2.1-Fun-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-1.3B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-1.3B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-14B-Control)|`control_video`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-14B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-InP](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-InP.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-1.3B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-1.3B-Control-Camera.py)|
|
||||
|[PAI/Wan2.1-Fun-V1.1-14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.1-Fun-V1.1-14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-Fun-V1.1-14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-Fun-V1.1-14B-Control-Camera.py)|
|
||||
|[DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1](https://modelscope.cn/models/DiffSynth-Studio/Wan2.1-1.3b-speedcontrol-v1)|`motion_bucket_id`|[code](/examples/wanvideo/model_inference/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/full/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.1-1.3b-speedcontrol-v1.py)|[code](/examples/wanvideo/model_training/lora/Wan2.1-1.3b-speedcontrol-v1.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.1-1.3b-speedcontrol-v1.py)|
|
||||
|[krea/krea-realtime-video](https://www.modelscope.cn/models/krea/krea-realtime-video)||[code](/examples/wanvideo/model_inference/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/full/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_full/krea-realtime-video.py)|[code](/examples/wanvideo/model_training/lora/krea-realtime-video.sh)|[code](/examples/wanvideo/model_training/validate_lora/krea-realtime-video.py)|
|
||||
|[meituan-longcat/LongCat-Video](https://www.modelscope.cn/models/meituan-longcat/LongCat-Video)|`longcat_video`|[code](/examples/wanvideo/model_inference/LongCat-Video.py)|[code](/examples/wanvideo/model_training/full/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_full/LongCat-Video.py)|[code](/examples/wanvideo/model_training/lora/LongCat-Video.sh)|[code](/examples/wanvideo/model_training/validate_lora/LongCat-Video.py)|
|
||||
|[ByteDance/Video-As-Prompt-Wan2.1-14B](https://modelscope.cn/models/ByteDance/Video-As-Prompt-Wan2.1-14B)|`vap_video`, `vap_prompt`|[code](/examples/wanvideo/model_inference/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/full/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Video-As-Prompt-Wan2.1-14B.py)|[code](/examples/wanvideo/model_training/lora/Video-As-Prompt-Wan2.1-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Video-As-Prompt-Wan2.1-14B.py)|
|
||||
|[Wan-AI/Wan2.2-T2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-T2V-A14B)||[code](/examples/wanvideo/model_inference/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-T2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-T2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-T2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-I2V-A14B](https://modelscope.cn/models/Wan-AI/Wan2.2-I2V-A14B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-I2V-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-I2V-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-I2V-A14B.py)|
|
||||
|[Wan-AI/Wan2.2-TI2V-5B](https://modelscope.cn/models/Wan-AI/Wan2.2-TI2V-5B)|`input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-TI2V-5B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-TI2V-5B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-TI2V-5B.py)|
|
||||
|[Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B)|`input_image`, `animate_pose_video`, `animate_face_video`, `animate_inpaint_video`, `animate_mask_video`|[code](/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Animate-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py)|
|
||||
|[Wan-AI/Wan2.2-S2V-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-S2V-14B)|`input_image`, `input_audio`, `audio_sample_rate`, `s2v_pose_video`|[code](/examples/wanvideo/model_inference/Wan2.2-S2V-14B_multi_clips.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-S2V-14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-S2V-14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-S2V-14B.py)|
|
||||
|[PAI/Wan2.2-VACE-Fun-A14B](https://www.modelscope.cn/models/PAI/Wan2.2-VACE-Fun-A14B)|`vace_control_video`, `vace_reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-VACE-Fun-A14B.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-VACE-Fun-A14B.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-VACE-Fun-A14B.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-InP](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-InP)|`input_image`, `end_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-InP.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-InP.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-InP.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control)|`control_video`, `reference_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control.py)|
|
||||
|[PAI/Wan2.2-Fun-A14B-Control-Camera](https://modelscope.cn/models/PAI/Wan2.2-Fun-A14B-Control-Camera)|`control_camera_video`, `input_image`|[code](/examples/wanvideo/model_inference/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/full/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_full/Wan2.2-Fun-A14B-Control-Camera.py)|[code](/examples/wanvideo/model_training/lora/Wan2.2-Fun-A14B-Control-Camera.sh)|[code](/examples/wanvideo/model_training/validate_lora/Wan2.2-Fun-A14B-Control-Camera.py)|
|
||||
|
||||
* FP8 精度训练:[doc](/docs/zh/Training/FP8_Precision.md)、[code](/examples/wanvideo/model_training/special/fp8_training/)
|
||||
* 两阶段拆分训练:[doc](/docs/zh/Training/Split_Training.md)、[code](/examples/wanvideo/model_training/special/split_training/)
|
||||
* 端到端直接蒸馏:[doc](/docs/zh/Training/Direct_Distill.md)、[code](/examples/wanvideo/model_training/special/direct_distill/)
|
||||
|
||||
## 模型推理
|
||||
|
||||
模型通过 `WanVideoPipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。
|
||||
|
||||
`WanVideoPipeline` 推理的输入参数包括:
|
||||
|
||||
* `prompt`: 提示词,描述视频中出现的内容。
|
||||
* `negative_prompt`: 负向提示词,描述视频中不应该出现的内容,默认值为 `""`。
|
||||
* `cfg_scale`: Classifier-free guidance 的参数,默认值为 5,当设置为 1 时不再生效。
|
||||
* `input_image`: 输入图像,用于图生视频,该参数与 `denoising_strength` 配合使用。
|
||||
* `end_image`: 结束图像,用于首尾帧生成视频。
|
||||
* `input_video`: 输入视频,用于视频到视频生成,该参数与 `denoising_strength` 配合使用。
|
||||
* `denoising_strength`: 去噪强度,范围是 0~1,默认值为 1,当数值接近 0 时,生成视频与输入视频相似;当数值接近 1 时,生成视频与输入视频相差更大。
|
||||
* `control_video`: 控制视频,用于控制视频生成过程。
|
||||
* `reference_image`: 参考图像,用于保持生成视频中某些特征的一致性。
|
||||
* `camera_control_direction`: 相机控制方向,可选值为 `"Left"`, `"Right"`, `"Up"`, `"Down"`, `"LeftUp"`, `"LeftDown"`, `"RightUp"`, `"RightDown"`。
|
||||
* `camera_control_speed`: 相机控制速度,默认值为 1/54。
|
||||
* `vace_video`: VACE 控制视频。
|
||||
* `vace_video_mask`: VACE 控制视频遮罩。
|
||||
* `vace_reference_image`: VACE 参考图像。
|
||||
* `vace_scale`: VACE 控制强度,默认值为 1.0。
|
||||
* `animate_pose_video`: `animate` 模型姿态视频。
|
||||
* `animate_face_video`: `animate` 模型面部视频。
|
||||
* `animate_inpaint_video`: `animate` 模型局部编辑视频。
|
||||
* `animate_mask_video`: `animate` 模型遮罩视频。
|
||||
* `vap_video`: `video-as-prompt` 的输入视频。
|
||||
* `vap_prompt`: `video-as-prompt` 的文本描述。
|
||||
* `negative_vap_prompt`: `video-as-prompt` 的负向文本描述。
|
||||
* `input_audio`: 输入音频,用于语音到视频生成。
|
||||
* `audio_embeds`: 音频嵌入向量。
|
||||
* `audio_sample_rate`: 音频采样率,默认值为 16000。
|
||||
* `s2v_pose_video`: S2V 模型的姿态视频。
|
||||
* `motion_video`: S2V 模型的运动视频。
|
||||
* `height`: 视频高度,需保证高度为 16 的倍数。
|
||||
* `width`: 视频宽度,需保证宽度为 16 的倍数。
|
||||
* `num_frames`: 视频帧数,默认值为 81,需保证为 4 的倍数 + 1。
|
||||
* `seed`: 随机种子。默认为 `None`,即完全随机。
|
||||
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
|
||||
* `num_inference_steps`: 推理次数,默认值为 50。
|
||||
* `motion_bucket_id`: 运动控制参数,数值越大,运动幅度越大。
|
||||
* `longcat_video`: LongCat 输入视频。
|
||||
* `tiled`: 是否启用 VAE 分块推理,默认为 `True`。设置为 `True` 时可显著减少 VAE 编解码阶段的显存占用,会产生少许误差,以及少量推理时间延长。
|
||||
* `tile_size`: VAE 编解码阶段的分块大小,默认为 `(30, 52)`,仅在 `tiled=True` 时生效。
|
||||
* `tile_stride`: VAE 编解码阶段的分块步长,默认为 `(15, 26)`,仅在 `tiled=True` 时生效,需保证其数值小于或等于 `tile_size`。
|
||||
* `switch_DiT_boundary`: 切换DiT模型的时间边界,默认值为 0.875。
|
||||
* `sigma_shift`: 时间步偏移参数,默认值为 5.0。
|
||||
* `sliding_window_size`: 滑动窗口大小。
|
||||
* `sliding_window_stride`: 滑动窗口步长。
|
||||
* `tea_cache_l1_thresh`: TeaCache 的 L1 阈值。
|
||||
* `tea_cache_model_id`: TeaCache 使用的模型 ID。
|
||||
* `progress_bar_cmd`: 进度条,默认为 `tqdm.tqdm`。可通过设置为 `lambda x:x` 来屏蔽进度条。
|
||||
|
||||
如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。
|
||||
|
||||
## 模型训练
|
||||
|
||||
Wan 系列模型统一通过 [`examples/wanvideo/model_training/train.py`](/examples/wanvideo/model_training/train.py) 进行训练,脚本的参数包括:
|
||||
|
||||
* 通用训练参数
|
||||
* 数据集基础配置
|
||||
* `--dataset_base_path`: 数据集的根目录。
|
||||
* `--dataset_metadata_path`: 数据集的元数据文件路径。
|
||||
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
|
||||
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
|
||||
* 模型加载配置
|
||||
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 `"Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors"`。用逗号分隔。
|
||||
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,例如训练图像编辑模型时需要额外参数,以 `,` 分隔。
|
||||
* `--fp8_models`:以 FP8 格式加载的模型,格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致,目前仅支持参数不被梯度更新的模型(不需要梯度回传,或梯度仅更新其 LoRA)。
|
||||
* 训练基础配置
|
||||
* `--learning_rate`: 学习率。
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。
|
||||
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数,少数模型包含不参与梯度计算的冗余参数,需开启这一设置避免在多 GPU 训练中报错。
|
||||
* `--weight_decay`:权重衰减大小,详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。
|
||||
* `--task`: 训练任务,默认为 `sft`,部分模型支持更多训练模式,请参考每个特定模型的文档。
|
||||
* 输出配置
|
||||
* `--output_path`: 模型保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
|
||||
* `--save_steps`: 保存模型的训练步数间隔,若此参数留空,则每个 epoch 保存一次。
|
||||
* LoRA 配置
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
* `--lora_target_modules`: LoRA 添加到哪些层上。
|
||||
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||
* `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。
|
||||
* `--preset_lora_path`: 预置 LoRA 检查点路径,如果提供此路径,这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。
|
||||
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。
|
||||
* 梯度配置
|
||||
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||
* 视频宽高配置
|
||||
* `--height`: 视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||
* `--width`: 视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||
* `--max_pixels`: 视频帧的最大像素面积,当启用动态分辨率时,分辨率大于这个数值的视频帧都会被缩小,分辨率小于这个数值的视频帧保持不变。
|
||||
* `--num_frames`: 视频的帧数。
|
||||
* Wan 系列专有参数
|
||||
* `--tokenizer_path`: tokenizer 的路径,适用于文生视频模型,留空则自动从远程下载。
|
||||
* `--audio_processor_path`: 音频处理器的路径,适用于语音到视频模型,留空则自动从远程下载。
|
||||
|
||||
我们构建了一个样例视频数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_dataset
|
||||
```
|
||||
|
||||
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。
|
||||
141
docs/zh/Model_Details/Z-Image.md
Normal file
141
docs/zh/Model_Details/Z-Image.md
Normal file
@@ -0,0 +1,141 @@
|
||||
# Z-Image
|
||||
|
||||
Z-Image 是由阿里巴巴通义实验室多模态交互团队训练并开源的图像生成模型。
|
||||
|
||||
## 安装
|
||||
|
||||
在使用本项目进行模型推理和训练前,请先安装 DiffSynth-Studio。
|
||||
|
||||
```shell
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
更多关于安装的信息,请参考[安装依赖](/docs/zh/Pipeline_Usage/Setup.md)。
|
||||
|
||||
## 快速开始
|
||||
|
||||
运行以下代码可以快速加载 [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) 模型并进行推理。FP8 精度量化会导致明显的图像质量劣化,因此不建议在 Z-Image Turbo 模型上开启任何量化,仅建议开启 CPU Offload,最低 8G 显存即可运行。
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.z_image import ZImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = ZImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="transformer/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="text_encoder/*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Tongyi-MAI/Z-Image-Turbo", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
prompt = "Young Chinese woman in red Hanfu, intricate embroidery. Impeccable makeup, red floral forehead pattern. Elaborate high bun, golden phoenix headdress, red flowers, beads. Holds round folding fan with lady, trees, bird. Neon lightning-bolt lamp (⚡️), bright yellow glow, above extended left palm. Soft-lit outdoor night background, silhouetted tiered pagoda (西安大雁塔), blurred colorful distant lights."
|
||||
image = pipe(prompt=prompt, seed=42, rand_device="cuda")
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## 模型总览
|
||||
|
||||
|模型 ID|推理|低显存推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|
||||
|-|-|-|-|-|-|-|
|
||||
|[Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo)|[code](/examples/z_image/model_inference/Z-Image-Turbo.py)|[code](/examples/z_image/model_inference_low_vram/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/full/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_full/Z-Image-Turbo.py)|[code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)|[code](/examples/z_image/model_training/validate_lora/Z-Image-Turbo.py)|
|
||||
|
||||
特殊训练脚本:
|
||||
|
||||
* 差分 LoRA 训练:[doc](/docs/zh/Training/Differential_LoRA.md)、[code](/examples/z_image/model_training/special/differential_training/)
|
||||
* 轨迹模仿蒸馏训练(实验性功能):[code](/examples/z_image/model_training/special/trajectory_imitation/)
|
||||
|
||||
## 模型推理
|
||||
|
||||
模型通过 `ZImagePipeline.from_pretrained` 加载,详见[加载模型](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型)。
|
||||
|
||||
`ZImagePipeline` 推理的输入参数包括:
|
||||
|
||||
* `prompt`: 提示词,描述画面中出现的内容。
|
||||
* `negative_prompt`: 负向提示词,描述画面中不应该出现的内容,默认值为 `""`。
|
||||
* `cfg_scale`: Classifier-free guidance 的参数,默认值为 1。
|
||||
* `input_image`: 输入图像,用于图生图,该参数与 `denoising_strength` 配合使用。
|
||||
* `denoising_strength`: 去噪强度,范围是 0~1,默认值为 1,当数值接近 0 时,生成图像与输入图像相似;当数值接近 1 时,生成图像与输入图像相差更大。在不输入 `input_image` 参数时,请不要将其设置为非 1 的数值。
|
||||
* `height`: 图像高度,需保证高度为 16 的倍数。
|
||||
* `width`: 图像宽度,需保证宽度为 16 的倍数。
|
||||
* `seed`: 随机种子。默认为 `None`,即完全随机。
|
||||
* `rand_device`: 生成随机高斯噪声矩阵的计算设备,默认为 `"cpu"`。当设置为 `cuda` 时,在不同 GPU 上会导致不同的生成结果。
|
||||
* `num_inference_steps`: 推理次数,默认值为 8。
|
||||
|
||||
如果显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),我们在示例代码中提供了每个模型推荐的低显存配置,详见前文"模型总览"中的表格。
|
||||
|
||||
## 模型训练
|
||||
|
||||
Z-Image 系列模型统一通过 [`examples/z_image/model_training/train.py`](/examples/z_image/model_training/train.py) 进行训练,脚本的参数包括:
|
||||
|
||||
* 通用训练参数
|
||||
* 数据集基础配置
|
||||
* `--dataset_base_path`: 数据集的根目录。
|
||||
* `--dataset_metadata_path`: 数据集的元数据文件路径。
|
||||
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
|
||||
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
|
||||
* 模型加载配置
|
||||
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 `"Tongyi-MAI/Z-Image-Turbo:transformer/*.safetensors"`。用逗号分隔。
|
||||
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,例如训练图像编辑模型时需要额外参数,以 `,` 分隔。
|
||||
* `--fp8_models`:以 FP8 格式加载的模型,格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致,目前仅支持参数不被梯度更新的模型(不需要梯度回传,或梯度仅更新其 LoRA)。
|
||||
* 训练基础配置
|
||||
* `--learning_rate`: 学习率。
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。
|
||||
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数,少数模型包含不参与梯度计算的冗余参数,需开启这一设置避免在多 GPU 训练中报错。
|
||||
* `--weight_decay`:权重衰减大小,详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。
|
||||
* `--task`: 训练任务,默认为 `sft`,部分模型支持更多训练模式,请参考每个特定模型的文档。
|
||||
* 输出配置
|
||||
* `--output_path`: 模型保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
|
||||
* `--save_steps`: 保存模型的训练步数间隔,若此参数留空,则每个 epoch 保存一次。
|
||||
* LoRA 配置
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
* `--lora_target_modules`: LoRA 添加到哪些层上。
|
||||
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||
* `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。
|
||||
* `--preset_lora_path`: 预置 LoRA 检查点路径,如果提供此路径,这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。
|
||||
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。
|
||||
* 梯度配置
|
||||
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||
* 图像宽高配置(适用于图像生成模型和视频生成模型)
|
||||
* `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||
* `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||
* `--max_pixels`: 图像或视频帧的最大像素面积,当启用动态分辨率时,分辨率大于这个数值的图片都会被缩小,分辨率小于这个数值的图片保持不变。
|
||||
* Z-Image 专有参数
|
||||
* `--tokenizer_path`: tokenizer 的路径,适用于文生图模型,留空则自动从远程下载。
|
||||
|
||||
我们构建了一个样例图像数据集,以方便您进行测试,通过以下命令可以下载这个数据集:
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
```
|
||||
|
||||
我们为每个模型编写了推荐的训练脚本,请参考前文"模型总览"中的表格。关于如何编写模型训练脚本,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md);更多高阶训练算法,请参考[训练框架详解](/docs/Training/)。
|
||||
|
||||
训练技巧:
|
||||
|
||||
* [Tongyi-MAI/Z-Image-Turbo](https://www.modelscope.cn/models/Tongyi-MAI/Z-Image-Turbo) 是一个蒸馏加速的模型,因此直接训练将会迅速让模型失去加速能力,以“加速配置”(`num_inference_steps=8`,`cfg_scale=1`)推理的效果变差,以“无加速配置”(`num_inference_steps=30`,`cfg_scale=2`)推理的效果变好。可采用以下方案训练和推理:
|
||||
* 标准 SFT 训练([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh)) + 无加速配置推理
|
||||
* 差分 LoRA 训练([code](/examples/z_image/model_training/special/differential_training/)) + 加速配置推理
|
||||
* 差分 LoRA 训练中需加载一个额外的 LoRA,例如 [ostris/zimage_turbo_training_adapter](https://www.modelscope.cn/models/ostris/zimage_turbo_training_adapter)
|
||||
* 标准 SFT 训练([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh))+ 轨迹模仿蒸馏训练([code](/examples/z_image/model_training/special/trajectory_imitation/))+ 加速配置推理
|
||||
* 标准 SFT 训练([code](/examples/z_image/model_training/lora/Z-Image-Turbo.sh))+ 推理时加载蒸馏加速 LoRA([model](https://www.modelscope.cn/models/DiffSynth-Studio/Z-Image-Turbo-DistillFix)) + 加速配置推理
|
||||
39
docs/zh/Pipeline_Usage/Environment_Variables.md
Normal file
39
docs/zh/Pipeline_Usage/Environment_Variables.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# 环境变量
|
||||
|
||||
`DiffSynth-Studio` 可通过环境变量控制一些设置。
|
||||
|
||||
在 `Python` 代码中,可以使用 `os.environ` 设置环境变量。请注意,环境变量需在 `import diffsynth` 前设置。
|
||||
|
||||
```python
|
||||
import os
|
||||
os.environ["DIFFSYNTH_MODEL_BASE_PATH"] = "./path_to_my_models"
|
||||
import diffsynth
|
||||
```
|
||||
|
||||
在 Linux 操作系统上,也可在命令行临时设置环境变量:
|
||||
|
||||
```shell
|
||||
DIFFSYNTH_MODEL_BASE_PATH="./path_to_my_models" python xxx.py
|
||||
```
|
||||
|
||||
以下是 `DiffSynth-Studio` 所支持的环境变量。
|
||||
|
||||
## `DIFFSYNTH_SKIP_DOWNLOAD`
|
||||
|
||||
是否跳过模型下载。可设置为 `True`、`true`、`False`、`false`,若 `ModelConfig` 中没有设置 `skip_download`,则会根据这一环境变量决定是否跳过模型下载。
|
||||
|
||||
## `DIFFSYNTH_MODEL_BASE_PATH`
|
||||
|
||||
模型下载根目录。可设置为任意本地路径,若 `ModelConfig` 中没有设置 `local_model_path`,则会将模型文件下载到这一环境变量指向的路径。若两者都未设置,则会将模型文件下载到 `./models`。
|
||||
|
||||
## `DIFFSYNTH_ATTENTION_IMPLEMENTATION`
|
||||
|
||||
注意力机制实现的方式,可以设置为 `flash_attention_3`、`flash_attention_2`、`sage_attention`、`xformers`、`torch`。详见 [`./core/attention.md`](/docs/zh/API_Reference/core/attention.md).
|
||||
|
||||
## `DIFFSYNTH_DISK_MAP_BUFFER_SIZE`
|
||||
|
||||
硬盘直连中的 Buffer 大小,默认是 1B(1000000000),数值越大,占用内存越大,速度越快。
|
||||
|
||||
## `DIFFSYNTH_DOWNLOAD_SOURCE`
|
||||
|
||||
远程模型下载源,可设置为 `modelscope` 或 `huggingface`,控制模型下载的来源,默认值为 `modelscope`。
|
||||
105
docs/zh/Pipeline_Usage/Model_Inference.md
Normal file
105
docs/zh/Pipeline_Usage/Model_Inference.md
Normal file
@@ -0,0 +1,105 @@
|
||||
# 模型推理
|
||||
|
||||
本文档以 Qwen-Image 模型为例,介绍如何使用 `DiffSynth-Studio` 进行模型推理。
|
||||
|
||||
## 加载模型
|
||||
|
||||
模型通过 `from_pretrained` 加载:
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
```
|
||||
|
||||
其中 `torch_dtype` 和 `device` 是计算精度和计算设备(不是模型的精度和设备)。`model_configs` 可通过多种方式配置模型路径,关于本项目内部是如何加载模型的,请参考 [`diffsynth.core.loader`](/docs/zh/API_Reference/core/loader.md)。
|
||||
|
||||
<details>
|
||||
|
||||
<summary>从远程下载模型并加载</summary>
|
||||
|
||||
> `DiffSynth-Studio` 默认从[魔搭社区](https://www.modelscope.cn/)下载并加载模型,需填写 `model_id` 和 `origin_file_pattern`,例如
|
||||
>
|
||||
> ```python
|
||||
> ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
> ```
|
||||
>
|
||||
> 模型文件默认下载到 `./models` 路径,该路径可通过[环境变量 DIFFSYNTH_MODEL_BASE_PATH](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path) 修改。
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>从本地文件路径加载模型</summary>
|
||||
|
||||
> 填写 `path`,例如
|
||||
>
|
||||
> ```python
|
||||
> ModelConfig(path="models/xxx.safetensors")
|
||||
> ```
|
||||
>
|
||||
> 对于从多个文件加载的模型,使用列表即可,例如
|
||||
>
|
||||
> ```python
|
||||
> ModelConfig(path=[
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors",
|
||||
> ])
|
||||
> ```
|
||||
|
||||
</details>
|
||||
|
||||
默认情况下,即使模型已经下载完毕,程序仍会向远程查询是否有遗漏文件,如果要完全关闭远程请求,请将[环境变量 DIFFSYNTH_SKIP_DOWNLOAD](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) 设置为 `True`。
|
||||
|
||||
```shell
|
||||
import os
|
||||
os.environ["DIFFSYNTH_SKIP_DOWNLOAD"] = "True"
|
||||
import diffsynth
|
||||
```
|
||||
|
||||
如需从 [HuggingFace](https://huggingface.co/) 下载模型,请将[环境变量 DIFFSYNTH_DOWNLOAD_SOURCE](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_download_source) 设置为 `huggingface`。
|
||||
|
||||
```shell
|
||||
import os
|
||||
os.environ["DIFFSYNTH_DOWNLOAD_SOURCE"] = "huggingface"
|
||||
import diffsynth
|
||||
```
|
||||
|
||||
## 启动推理
|
||||
|
||||
输入提示词,即可启动推理过程,生成一张图片。
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
每个模型 `Pipeline` 的输入参数不同,请参考各模型的文档。
|
||||
|
||||
如果模型参数量太大,导致显存不足,请开启[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)。
|
||||
245
docs/zh/Pipeline_Usage/Model_Training.md
Normal file
245
docs/zh/Pipeline_Usage/Model_Training.md
Normal file
@@ -0,0 +1,245 @@
|
||||
# 模型训练
|
||||
|
||||
本文档介绍如何使用 `DiffSynth-Studio` 进行模型训练。
|
||||
|
||||
## 脚本参数
|
||||
|
||||
训练脚本通常包含以下参数:
|
||||
|
||||
* 数据集基础配置
|
||||
* `--dataset_base_path`: 数据集的根目录。
|
||||
* `--dataset_metadata_path`: 数据集的元数据文件路径。
|
||||
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
|
||||
* `--data_file_keys`: 元数据中需要加载的字段名称,通常是图像或视频文件的路径,以 `,` 分隔。
|
||||
* 模型加载配置
|
||||
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 `"Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors"`。用逗号分隔。
|
||||
* `--extra_inputs`: 模型 Pipeline 所需的额外输入参数,例如训练图像编辑模型 Qwen-Image-Edit 时需要额外参数 `edit_image`,以 `,` 分隔。
|
||||
* `--fp8_models`:以 FP8 格式加载的模型,格式与 `--model_paths` 或 `--model_id_with_origin_paths` 一致,目前仅支持参数不被梯度更新的模型(不需要梯度回传,或梯度仅更新其 LoRA)。
|
||||
* 训练基础配置
|
||||
* `--learning_rate`: 学习率。
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--trainable_models`: 可训练的模型,例如 `dit`、`vae`、`text_encoder`。
|
||||
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数,少数模型包含不参与梯度计算的冗余参数,需开启这一设置避免在多 GPU 训练中报错。
|
||||
* `--weight_decay`:权重衰减大小,详见 [torch.optim.AdamW](https://docs.pytorch.org/docs/stable/generated/torch.optim.AdamW.html)。
|
||||
* `--task`: 训练任务,默认为 `sft`,部分模型支持更多训练模式,请参考每个特定模型的文档。
|
||||
* 输出配置
|
||||
* `--output_path`: 模型保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在模型文件的 state dict 中移除前缀。
|
||||
* `--save_steps`: 保存模型的训练步数间隔,若此参数留空,则每个 epoch 保存一次。
|
||||
* LoRA 配置
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
* `--lora_target_modules`: LoRA 添加到哪些层上。
|
||||
* `--lora_rank`: LoRA 的秩(Rank)。
|
||||
* `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径,LoRA 将从此检查点加载。
|
||||
* `--preset_lora_path`: 预置 LoRA 检查点路径,如果提供此路径,这一 LoRA 将会以融入基础模型的形式加载。此参数用于 LoRA 差分训练。
|
||||
* `--preset_lora_model`: 预置 LoRA 融入的模型,例如 `dit`。
|
||||
* 梯度配置
|
||||
* `--use_gradient_checkpointing`: 是否启用 gradient checkpointing。
|
||||
* `--use_gradient_checkpointing_offload`: 是否将 gradient checkpointing 卸载到内存中。
|
||||
* `--gradient_accumulation_steps`: 梯度累积步数。
|
||||
* 图像宽高配置(适用于图像生成模型和视频生成模型)
|
||||
* `--height`: 图像或视频的高度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||
* `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。
|
||||
* `--max_pixels`: 图像或视频帧的最大像素面积,当启用动态分辨率时,分辨率大于这个数值的图片都会被缩小,分辨率小于这个数值的图片保持不变。
|
||||
|
||||
部分模型的训练脚本还包含额外的参数,详见各模型的文档。
|
||||
|
||||
## 准备数据集
|
||||
|
||||
`DiffSynth-Studio` 采用通用数据集格式,数据集包含一系列数据文件(图像、视频等),以及标注元数据的文件,我们建议您这样组织数据集文件:
|
||||
|
||||
```
|
||||
data/example_image_dataset/
|
||||
├── metadata.csv
|
||||
├── image_1.jpg
|
||||
└── image_2.jpg
|
||||
```
|
||||
|
||||
其中 `image_1.jpg`、`image_2.jpg` 为训练用图像数据,`metadata.csv` 为元数据列表,例如
|
||||
|
||||
```
|
||||
image,prompt
|
||||
image_1.jpg,"a dog"
|
||||
image_2.jpg,"a cat"
|
||||
```
|
||||
|
||||
我们构建了样例数据集,以方便您进行测试。了解通用数据集架构是如何实现的,请参考 [`diffsynth.core.data`](/docs/zh/API_Reference/core/data.md)。
|
||||
|
||||
<details>
|
||||
|
||||
<summary>样例图像数据集</summary>
|
||||
|
||||
> ```shell
|
||||
> modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
> ```
|
||||
>
|
||||
> 适用于 Qwen-Image、FLUX 等图像生成模型的训练。
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>样例视频数据集</summary>
|
||||
|
||||
> ```shell
|
||||
> modelscope download --dataset DiffSynth-Studio/example_video_dataset --local_dir ./data/example_video_dataset
|
||||
> ```
|
||||
>
|
||||
> 适用于 Wan 等视频生成模型的训练。
|
||||
|
||||
</details>
|
||||
|
||||
## 加载模型
|
||||
|
||||
类似于[推理时的模型加载](/docs/zh/Pipeline_Usage/Model_Inference.md#加载模型),我们支持多种方式配置模型路径,两种方式是可以混用的。
|
||||
|
||||
<details>
|
||||
|
||||
<summary>从远程下载模型并加载</summary>
|
||||
|
||||
> 如果在推理时我们通过以下设置加载模型
|
||||
>
|
||||
> ```python
|
||||
> model_configs=[
|
||||
> ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
> ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
> ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
> ]
|
||||
> ```
|
||||
>
|
||||
> 那么在训练时,填入以下参数即可加载对应的模型。
|
||||
>
|
||||
> ```shell
|
||||
> --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors"
|
||||
> ```
|
||||
>
|
||||
> 模型文件默认下载到 `./models` 路径,该路径可通过[环境变量 DIFFSYNTH_MODEL_BASE_PATH](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path) 修改。
|
||||
>
|
||||
> 默认情况下,即使模型已经下载完毕,程序仍会向远程查询是否有遗漏文件,如果要完全关闭远程请求,请将[环境变量 DIFFSYNTH_SKIP_DOWNLOAD](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) 设置为 `True`。
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
<summary>从本地文件路径加载模型</summary>
|
||||
|
||||
> 如果从本地文件加载模型,例如推理时
|
||||
>
|
||||
> ```python
|
||||
> model_configs=[
|
||||
> ModelConfig([
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors"
|
||||
> ]),
|
||||
> ModelConfig([
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
|
||||
> ]),
|
||||
> ModelConfig("models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors")
|
||||
> ]
|
||||
> ```
|
||||
>
|
||||
> 那么训练时需设置为
|
||||
>
|
||||
> ```shell
|
||||
> --model_paths '[
|
||||
> [
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00001-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00002-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00003-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00004-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00005-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00006-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00007-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00008-of-00009.safetensors",
|
||||
> "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model-00009-of-00009.safetensors"
|
||||
> ],
|
||||
> [
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
> "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
|
||||
> ],
|
||||
> "models/Qwen/Qwen-Image/vae/diffusion_pytorch_model.safetensors"
|
||||
> ]' \
|
||||
> ```
|
||||
>
|
||||
> 请注意,`--model_paths` 是 json 格式,其中不能出现多余的 `,`,否则无法被正常解析。
|
||||
|
||||
</details>
|
||||
|
||||
## 设置可训练模块
|
||||
|
||||
训练框架支持任意模型的训练,以 Qwen-Image 为例,若全量训练其中的 DiT 模型,则需设置为
|
||||
|
||||
```shell
|
||||
--trainable_models "dit"
|
||||
```
|
||||
|
||||
若训练 DiT 模型的 LoRA,则需设置
|
||||
|
||||
```shell
|
||||
--lora_base_model dit --lora_target_modules "to_q,to_k,to_v" --lora_rank 32
|
||||
```
|
||||
|
||||
我们希望给技术探索留下足够的发挥空间,因此框架支持同时训练任意多个模块,例如同时训练 text encoder、controlnet,以及 DiT 的 LoRA:
|
||||
|
||||
```shell
|
||||
--trainable_models "text_encoder,controlnet" --lora_base_model dit --lora_target_modules "to_q,to_k,to_v" --lora_rank 32
|
||||
```
|
||||
|
||||
此外,由于训练脚本中加载了多个模块(text encoder、dit、vae 等),保存模型文件时需要移除前缀,例如在全量训练 DiT 部分或者训练 DiT 部分的 LoRA 模型时,请设置 `--remove_prefix_in_ckpt pipe.dit.`。如果多个模块同时训练,则需开发者在训练完成后自行编写代码拆分模型文件中的 state dict。
|
||||
|
||||
## 启动训练程序
|
||||
|
||||
训练框架基于 [`accelerate`](https://huggingface.co/docs/accelerate/index) 构建,训练命令按照如下格式编写:
|
||||
|
||||
```shell
|
||||
accelerate launch xxx/train.py \
|
||||
--xxx yyy \
|
||||
--xxxx yyyy
|
||||
```
|
||||
|
||||
我们为每个模型编写了预置的训练脚本,详见各模型的文档。
|
||||
|
||||
默认情况下,`accelerate` 会按照 `~/.cache/huggingface/accelerate/default_config.yaml` 的配置进行训练,使用 `accelerate config` 可在终端交互式地配置,包括多 GPU 训练、[`DeepSpeed`](https://www.deepspeed.ai/) 等。
|
||||
|
||||
我们为部分模型提供了推荐的 `accelerate` 配置文件,可通过 `--config_file` 设置,例如 Qwen-Image 模型的全量训练:
|
||||
|
||||
```shell
|
||||
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-5 \
|
||||
--num_epochs 2 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image_full" \
|
||||
--trainable_models "dit" \
|
||||
--use_gradient_checkpointing \
|
||||
--find_unused_parameters
|
||||
```
|
||||
|
||||
## 训练注意事项
|
||||
|
||||
* 数据集的元数据除 `csv` 格式外,还支持 `json`、`jsonl` 格式,关于如何选择最佳的元数据格式,请参考[](/docs/zh/API_Reference/core/data.md#元数据)
|
||||
* 通常训练效果与训练步数强相关,与 epoch 数量弱相关,因此我们更推荐使用参数 `--save_steps` 按训练步数间隔来保存模型文件。
|
||||
* 当数据量 * `dataset_repeat` 超过 $10^9$ 时,我们观测到数据集的速度明显变慢,这似乎是 `PyTorch` 的 bug,我们尚不确定新版本的 `PyTorch` 是否已经修复了这一问题。
|
||||
* 学习率 `--learning_rate` 在 LoRA 训练中建议设置为 `1e-4`,在全量训练中建议设置为 `1e-5`。
|
||||
* 训练框架不支持 batch size > 1,原因是复杂的,详见 [Q&A: 为什么训练框架不支持 batch size > 1?](/docs/zh/QA.md#为什么训练框架不支持-batch-size--1)
|
||||
* 少数模型包含冗余参数,例如 Qwen-Image 的 DiT 部分最后一层的文本编码部分,在训练这些模型时,需设置 `--find_unused_parameters` 避免在多 GPU 训练中报错。出于对开源社区模型兼容性的考虑,我们不打算删除这些冗余参数。
|
||||
* Diffusion 模型的损失函数值与实际效果的关系不大,因此我们在训练过程中不会记录损失函数值。我们建议把 `--num_epochs` 设置为足够大的数值,边训边测,直至效果收敛后手动关闭训练程序。
|
||||
* `--use_gradient_checkpointing` 通常是开启的,除非 GPU 显存足够;`--use_gradient_checkpointing_offload` 则按需开启,详见 [`diffsynth.core.gradient`](/docs/zh/API_Reference/core/gradient.md)。
|
||||
21
docs/zh/Pipeline_Usage/Setup.md
Normal file
21
docs/zh/Pipeline_Usage/Setup.md
Normal file
@@ -0,0 +1,21 @@
|
||||
# 安装依赖
|
||||
|
||||
从源码安装(推荐):
|
||||
|
||||
```
|
||||
git clone https://github.com/modelscope/DiffSynth-Studio.git
|
||||
cd DiffSynth-Studio
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
从 pypi 安装(存在版本更新延迟,如需使用最新功能,请从源码安装)
|
||||
|
||||
```
|
||||
pip install diffsynth
|
||||
```
|
||||
|
||||
如果在安装过程中遇到问题,可能是由上游依赖包导致的,请参考这些包的文档:
|
||||
|
||||
* [torch](https://pytorch.org/get-started/locally/)
|
||||
* [sentencepiece](https://github.com/google/sentencepiece)
|
||||
* [cmake](https://cmake.org)
|
||||
206
docs/zh/Pipeline_Usage/VRAM_management.md
Normal file
206
docs/zh/Pipeline_Usage/VRAM_management.md
Normal file
@@ -0,0 +1,206 @@
|
||||
# 显存管理
|
||||
|
||||
显存管理是 `DiffSynth-Studio` 的特色功能,能够让低显存的 GPU 能够运行参数量巨大的模型推理。本文档以 Qwen-Image 为例,介绍显存管理方案的使用。
|
||||
|
||||
## 基础推理
|
||||
|
||||
以下代码中没有启用任何显存管理,显存占用 56G,作为参考。
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## CPU Offload
|
||||
|
||||
由于模型 `Pipeline` 包括多个组件,这些组件并非同时调用的,因此我们可以在某些组件不需要参与计算时将其移至内存,减少显存占用,以下代码可以实现这一逻辑,显存占用 40G。
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## FP8 量化
|
||||
|
||||
在 CPU Offload 的基础上,我们进一步启用 FP8 量化来减少显存需求,以下代码可以令模型参数以 FP8 精度存储在显存中,并在推理时临时转为 BF16 精度计算,显存占用 21G。但这种量化方案有微小的图像质量下降问题。
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e4m3fn,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": "cuda",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
> Q: 为什么要在推理时临时转为 BF16 精度,而不是以 FP8 精度计算?
|
||||
>
|
||||
> A: FP8 的原生计算仅在 Hopper 架构的 GPU(例如 H20)支持,且计算误差很大,我们目前暂不开放 FP8 精度计算。目前的 FP8 量化仅能减少显存占用,不会提高计算速度。
|
||||
|
||||
## 动态显存管理
|
||||
|
||||
在 CPU Offload 中,我们对模型组件进行控制,事实上,我们支持做到 Layer 级别的 Offload,将一个模型拆分为多个 Layer,令一部分常驻显存,令一部分存储在内存中按需移至显存计算。这一功能需要模型开发者针对每个模型提供详细的显存管理方案,相关配置在 `diffsynth/configs/vram_management_module_maps.py` 中。
|
||||
|
||||
通过在 `Pipeline` 中增加 `vram_limit` 参数,框架可以自动感知设备的剩余显存并决定如何拆分模型到显存和内存中。`vram_limit` 越小,占用显存越少,速度越慢。
|
||||
* `vram_limit=None` 时,即默认状态,框架认为显存无限,动态显存管理是不启用的
|
||||
* `vram_limit=10` 时,框架会在显存占用超过 10G 之后限制模型,将超出的部分移至内存中存储。
|
||||
* `vram_limit=0` 时,框架会尽全力减少显存占用,所有模型参数都存储在内存中,仅在必要时移至显存计算
|
||||
|
||||
在显存不足以运行模型推理的情况下,框架会试图超出 `vram_limit` 的限制从而让模型推理运行下去,因此显存管理框架并不能总是保证占用的显存小于 `vram_limit`,我们建议将其设置为略小于实际可用显存的数值,例如 GPU 显存为 16G 时,设置为 `vram_limit=15.5`。`PyTorch` 中可用 `torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3)` 获取 GPU 的显存。
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": torch.float8_e4m3fn,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.float8_e4m3fn,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5,
|
||||
)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## Disk Offload
|
||||
|
||||
在更为极端的情况下,当内存也不足以存储整个模型时,Disk Offload 功能可以让模型参数惰性加载,即,模型中的每个 Layer 仅在调用 forward 时才会从硬盘中读取相应的参数。启用这一功能时,我们建议使用高速的 SSD 硬盘。
|
||||
|
||||
Disk Offload 是极为特殊的显存管理方案,只支持 `.safetensors` 格式文件,不支持 `.bin`、`.pth`、`.ckpt` 等二进制文件,不支持带 Tensor reshape 的 [state dict converter](/docs/zh/Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换)。
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": "disk",
|
||||
"onload_device": "disk",
|
||||
"preparing_dtype": torch.float8_e4m3fn,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
vram_limit=10,
|
||||
)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
## 更多使用方式
|
||||
|
||||
`vram_config` 中的信息可自行填写,例如不开 FP8 量化的 Disk Offload:
|
||||
|
||||
```python
|
||||
vram_config = {
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": "disk",
|
||||
"onload_device": "disk",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
}
|
||||
```
|
||||
|
||||
具体地,显存管理模块会将模型的 Layer 分为以下四种状态:
|
||||
|
||||
* Offload:短期内不调用这个模型,这个状态由 `Pipeline` 控制切换
|
||||
* Onload:接下来随时要调用这个模型,这个状态由 `Pipeline` 控制切换
|
||||
* Preparing:Onload 和 Computation 的中间状态,在显存允许的前提下的暂存状态,这个状态由显存管理机制控制切换,当且仅当【vram_limit 设置为无限制】或【vram_limit 已设置且有空余显存】时会进入这一状态
|
||||
* Computation:模型正在计算过程中,这个状态由显存管理机制控制切换,仅在 `forward` 中临时进入
|
||||
|
||||
如果你是模型开发者,希望自行控制某个模型的显存管理粒度,请参考[../Developer_Guide/Enabling_VRAM_management.md](/docs/zh/Developer_Guide/Enabling_VRAM_management.md)。
|
||||
|
||||
## 最佳实践
|
||||
|
||||
* 显存足够 -> 使用[基础推理](#基础推理)
|
||||
* 显存不足
|
||||
* 内存足够 -> 使用[动态显存管理](#动态显存管理)
|
||||
* 内存不足 -> 使用[Disk Offload](#disk-offload)
|
||||
28
docs/zh/QA.md
Normal file
28
docs/zh/QA.md
Normal file
@@ -0,0 +1,28 @@
|
||||
# 常见问题
|
||||
|
||||
## 为什么训练框架不支持 batch size > 1?
|
||||
|
||||
* **更大的 batch size 已无法实现显著加速**:由于 flash attention 等加速技术已经充分提高了 GPU 的利用率,因此更大的 batch size 只会带来更大的显存占用,无法带来显著加速。在 Stable Diffusion 1.5 这类小模型上的经验已不再适用于最新的大模型。
|
||||
* **更大的 batch size 可以用其他方案实现**:多 GPU 训练和 Gradient Accumulation 都可以在数学意义上等价地实现更大的 batch size。
|
||||
* **更大的 batch size 与框架的通用性设计相悖**:我们希望构建通用的训练框架,大量模型无法适配更大的 batch size,例如不同长度的文本编码、不同分辨率的图像等,都是无法合并为更大的 batch 的。
|
||||
|
||||
## 为什么不删除某些模型中的冗余参数?
|
||||
|
||||
在部分模型中,模型存在冗余参数,例如 Qwen-Image 的 DiT 模型最后一层的文本部分,这部分参数不会参与任何计算,这是模型开发者留下的小 bug。直接将其设置为可训练时还会在多 GPU 训练中出现报错。
|
||||
|
||||
为了与开源社区中其他模型保持兼容性,我们决定保留这些参数。这些冗余参数在多 GPU 训练中可以通过 `--find_unused_parameters` 参数避免报错。
|
||||
|
||||
## 为什么 FP8 量化没有任何加速效果?
|
||||
|
||||
原生 FP8 计算需要依赖 Hopper 架构的 GPU,同时在计算精度上有较大误差,目前仍然是不成熟的技术,因此本项目不支持原生 FP8 计算。
|
||||
|
||||
显存管理中的 FP8 计算是指将模型参数以 FP8 精度存储在内存或显存中,在需要计算时临时转换为其他精度,因此仅能减少显存占用,没有加速效果。
|
||||
|
||||
## 为什么训练框架不支持原生 FP8 精度训练?
|
||||
|
||||
即使硬件条件允许,我们目前也没有任何支持原生 FP8 精度训练的规划。
|
||||
|
||||
* 目前原生 FP8 精度训练的主要挑战是梯度爆炸导致的精度溢出,为了保证训练的稳定性,需针对性地重新设计模型结构,然而目前还没有任何模型开发者愿意这么做。
|
||||
* 此外,使用原生 FP8 精度训练的模型,在推理时若没有 Hopper 架构 GPU,则只能以 BF16 精度进行计算,理论上其生成效果反而不如 FP8。
|
||||
|
||||
因此,原生 FP8 精度训练技术是极不成熟的,我们静观开源社区的技术发展。
|
||||
88
docs/zh/README.md
Normal file
88
docs/zh/README.md
Normal file
@@ -0,0 +1,88 @@
|
||||
# DiffSynth-Studio 文档
|
||||
|
||||
欢迎来到 Diffusion 模型的魔法世界!`DiffSynth-Studio` 是由[魔搭社区](https://www.modelscope.cn/)团队开发和维护的开源 Diffusion 模型引擎。我们期望构建一个通用的 Diffusion 模型框架,以框架建设孵化技术创新,凝聚开源社区的力量,探索生成式模型技术的边界!
|
||||
|
||||
<details>
|
||||
|
||||
<summary>文档阅读导引</summary>
|
||||
|
||||
```mermaid
|
||||
graph LR;
|
||||
我想要使用模型进行推理和训练-->sec1[Section 1: 上手使用];
|
||||
我想要使用模型进行推理和训练-->sec2[Section 2: 模型详解];
|
||||
我想要使用模型进行推理和训练-->sec3[Section 3: 训练框架];
|
||||
我想要基于此框架进行二次开发-->sec3[Section 3: 训练框架];
|
||||
我想要基于此框架进行二次开发-->sec4[Section 4: 模型接入];
|
||||
我想要基于此框架进行二次开发-->sec5[Section 5: API 参考];
|
||||
我想要基于本项目探索新的技术-->sec4[Section 4: 模型接入];
|
||||
我想要基于本项目探索新的技术-->sec5[Section 5: API 参考];
|
||||
我想要基于本项目探索新的技术-->sec6[Section 6: 学术导引];
|
||||
我遇到了问题-->sec7[Section 7: 常见问题];
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Section 1: 上手使用
|
||||
|
||||
本节介绍 `DiffSynth-Studio` 的基本使用方式,包括如何启用显存管理从而在极低显存的 GPU 上进行推理,以及如何训练任意基础模型、LoRA、ControlNet 等模型。
|
||||
|
||||
* [安装依赖](/docs/zh/Pipeline_Usage/Setup.md)
|
||||
* [模型推理](/docs/zh/Pipeline_Usage/Model_Inference.md)
|
||||
* [显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)
|
||||
* [模型训练](/docs/zh/Pipeline_Usage/Model_Training.md)
|
||||
* [环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md)
|
||||
|
||||
## Section 2: 模型详解
|
||||
|
||||
本节介绍 `DiffSynth-Studio` 所支持的 Diffusion 模型,部分模型 Pipeline 具备可控生成、并行加速等特色功能。
|
||||
|
||||
* [FLUX.1](/docs/zh/Model_Details/FLUX.md)
|
||||
* [Wan](/docs/zh/Model_Details/Wan.md)
|
||||
* [Qwen-Image](/docs/zh/Model_Details/Qwen-Image.md)
|
||||
* [FLUX.2](/docs/zh/Model_Details/FLUX2.md)
|
||||
* [Z-Image](/docs/zh/Model_Details/Z-Image.md)
|
||||
|
||||
## Section 3: 训练框架
|
||||
|
||||
本节介绍 `DiffSynth-Studio` 中训练框架的设计思路,帮助开发者理解 Diffusion 模型训练算法的原理。
|
||||
|
||||
* [Diffusion 模型基本原理](/docs/zh/Training/Understanding_Diffusion_models.md)
|
||||
* [标准监督训练](/docs/zh/Training/Supervised_Fine_Tuning.md)
|
||||
* [在训练中启用 FP8 精度](/docs/zh/Training/FP8_Precision.md)
|
||||
* [端到端的蒸馏加速训练](/docs/zh/Training/Direct_Distill.md)
|
||||
* [两阶段拆分训练](/docs/zh/Training/Split_Training.md)
|
||||
* [差分 LoRA 训练](/docs/zh/Training/Differential_LoRA.md)
|
||||
|
||||
## Section 4: 模型接入
|
||||
|
||||
本节介绍如何将模型接入 `DiffSynth-Studio` 从而使用框架基础功能,帮助开发者为本项目提供新模型的支持,或进行私有化模型的推理和训练。
|
||||
|
||||
* [接入模型结构](/docs/zh/Developer_Guide/Integrating_Your_Model.md)
|
||||
* [接入 Pipeline](/docs/zh/Developer_Guide/Building_a_Pipeline.md)
|
||||
* [接入细粒度显存管理](/docs/zh/Developer_Guide/Enabling_VRAM_management.md)
|
||||
* [接入模型训练](/docs/zh/Developer_Guide/Training_Diffusion_Models.md)
|
||||
|
||||
## Section 5: API 参考
|
||||
|
||||
本节介绍 `DiffSynth-Studio` 中的独立核心模块 `diffsynth.core`,介绍内部的功能是如何设计和运作的,开发者如有需要,可将其中的功能模块用于其他代码库的开发中。
|
||||
|
||||
* [`diffsynth.core.attention`](/docs/zh/API_Reference/core/attention.md): 注意力机制实现
|
||||
* [`diffsynth.core.data`](/docs/zh/API_Reference/core/data.md): 数据处理算子与通用数据集
|
||||
* [`diffsynth.core.gradient`](/docs/zh/API_Reference/core/gradient.md): 梯度检查点
|
||||
* [`diffsynth.core.loader`](/docs/zh/API_Reference/core/loader.md): 模型下载与加载
|
||||
* [`diffsynth.core.vram`](/docs/zh/API_Reference/core/vram.md): 显存管理
|
||||
|
||||
## Section 6: 学术导引
|
||||
|
||||
本节介绍如何利用 `DiffSynth-Studio` 训练新的模型,帮助科研工作者探索新的模型技术。
|
||||
|
||||
* 从零开始训练模型【coming soon】
|
||||
* 推理改进优化技术【coming soon】
|
||||
* 设计可控生成模型【coming soon】
|
||||
* 创建新的训练范式【coming soon】
|
||||
|
||||
## Section 7: 常见问题
|
||||
|
||||
本节总结了开发者常见的问题,如果你在使用和开发中遇到了问题,请参考本节内容,如果仍无法解决,请到 GitHub 上给我们提 issue。
|
||||
|
||||
* [常见问题](/docs/zh/QA.md)
|
||||
38
docs/zh/Training/Differential_LoRA.md
Normal file
38
docs/zh/Training/Differential_LoRA.md
Normal file
@@ -0,0 +1,38 @@
|
||||
# 差分 LoRA 训练
|
||||
|
||||
差分 LoRA 训练是一种特殊的 LoRA 训练方式,旨在让模型学习图像之间的差异。
|
||||
|
||||
## 训练方案
|
||||
|
||||
我们未能找到差分 LoRA 训练最早由谁提出,这一技术已经在开源社区中流传甚久。
|
||||
|
||||
假设我们有两张内容相似的图像:图 1 和图 2。例如两张图中分别有一辆车,但图 1 中画面细节更少,图 2 中画面细节更多。在差分 LoRA 训练中,我们进行两步训练:
|
||||
|
||||
* 以图 1 为训练数据,以[标准监督训练](/docs/zh/Training/Supervised_Fine_Tuning.md)的方式,训练 LoRA 1
|
||||
* 以图 2 为训练数据,将 LoRA 1 融入基础模型后,以[标准监督训练](/docs/zh/Training/Supervised_Fine_Tuning.md)的方式,训练 LoRA 2
|
||||
|
||||
在第一步训练中,由于训练数据仅有一张图,LoRA 模型很容易过拟合,因此训练完成后,LoRA 1 会让模型毫不犹豫地生成图 1,无论随机种子是什么。在第二步训练中,LoRA 模型再次过拟合,因此训练完成后,在 LoRA 1 和 LoRA 2 的共同作用下,模型会毫不犹豫地生成图 2。简言之:
|
||||
|
||||
* LoRA 1 = 生成图 1
|
||||
* LoRA 1 + LoRA 2 = 生成图 2
|
||||
|
||||
此时丢弃 LoRA 1,只使用 LoRA 2,模型将会理解图 1 和图 2 的差异,使生成的内容倾向于“更不像图1,更像图 2”。
|
||||
|
||||
单一训练数据可以保证模型能够过拟合到训练数据上,但稳定性不足。为了提高稳定性,我们可以用多个图像对(image pairs)进行训练,并将训练出的 LoRA 2 进行平均,得到效果更稳定的 LoRA。
|
||||
|
||||
用这一训练方案,可以训练出一些功能奇特的 LoRA 模型。例如,使用丑陋的和漂亮的图像对,训练提升图像美感的 LoRA;使用细节少的和细节丰富的图像对,训练增加图像细节的 LoRA。
|
||||
|
||||
## 模型效果
|
||||
|
||||
我们用差分 LoRA 训练技术训练了几个美学提升 LoRA,可前往对应的模型页面查看生成效果。
|
||||
|
||||
* [DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-LoRA-ArtAug-v1)
|
||||
* [DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1](https://modelscope.cn/models/DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1)
|
||||
|
||||
## 在训练框架中使用差分 LoRA 训练
|
||||
|
||||
第一步的训练与普通 LoRA 训练没有任何差异,在第二步的训练命令中,通过 `--preset_lora_path` 参数填入第一步的 LoRA 模型文件路径,并将 `--preset_lora_model` 设置为与 `lora_base_model` 相同的参数,即可将 LoRA 1 加载到基础模型中。
|
||||
|
||||
## 框架设计思路
|
||||
|
||||
在训练框架中,`--preset_lora_path` 指向的模型在 `DiffusionTrainingModule` 的 `switch_pipe_to_training_mode` 中完成加载。
|
||||
97
docs/zh/Training/Direct_Distill.md
Normal file
97
docs/zh/Training/Direct_Distill.md
Normal file
@@ -0,0 +1,97 @@
|
||||
# 端到端的蒸馏加速训练
|
||||
|
||||
## 蒸馏加速训练
|
||||
|
||||
Diffusion 模型的推理过程通常需要多步迭代,在提升生成效果的同时也让生成过程变得缓慢。通过蒸馏加速训练,可以减少生成清晰内容所需的步数。蒸馏加速训练技术的本质训练目标是让少量步数的生成效果与大量步数的生成效果对齐。
|
||||
|
||||
蒸馏加速训练的方法是多样的,例如
|
||||
|
||||
* 对抗式训练 ADD(Adversarial Diffusion Distillation)
|
||||
* 论文:https://arxiv.org/abs/2311.17042
|
||||
* 模型:[stabilityai/sdxl-turbo](https://modelscope.cn/models/stabilityai/sdxl-turbo)
|
||||
* 渐进式训练 Hyper-SD
|
||||
* 论文:https://arxiv.org/abs/2404.13686
|
||||
* 模型:[ByteDance/Hyper-SD](https://www.modelscope.cn/models/ByteDance/Hyper-SD)
|
||||
|
||||
## 直接蒸馏
|
||||
|
||||
在训练框架层面,支持这类蒸馏加速训练方案是极其困难的。在训练框架的设计中,我们需要保证训练方案满足以下条件:
|
||||
|
||||
* 通用性:训练方案适用于大多数框架内支持的 Diffusion 模型,而非只能对某个特定模型生效,这是代码框架建设的基本要求。
|
||||
* 稳定性:训练方案需保证训练效果稳定,不需要人工进行精细的参数调整,ADD 中的对抗式训练则无法保证稳定性。
|
||||
* 简洁性:训练方案不会引入额外的复杂模块,根据奥卡姆剃刀([Occam's Razor](https://en.wikipedia.org/wiki/Occam%27s_razor))原理,复杂解决方案可能引入潜在风险,Hyper-SD 中的 Human Feedback Learning 让训练过程变得过于复杂。
|
||||
|
||||
因此,在 `DiffSynth-Studio` 的训练框架中,我们设计了一个端到端的蒸馏加速训练方案,我们称为直接蒸馏(Direct Distill),其训练过程的伪代码如下:
|
||||
|
||||
```
|
||||
seed = xxx
|
||||
with torch.no_grad():
|
||||
image_1 = pipe(prompt, steps=50, seed=seed, cfg=4)
|
||||
image_2 = pipe(prompt, steps=4, seed=seed, cfg=1)
|
||||
loss = torch.nn.functional.mse_loss(image_1, image_2)
|
||||
```
|
||||
|
||||
是的,非常端到端的训练方案,稍加训练就可以有立竿见影的效果。
|
||||
|
||||
## 直接蒸馏训练的模型
|
||||
|
||||
我们用这个方案基于 Qwen-Image 训练了两个模型:
|
||||
|
||||
* [DiffSynth-Studio/Qwen-Image-Distill-Full](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full): 全量蒸馏训练
|
||||
* [DiffSynth-Studio/Qwen-Image-Distill-LoRA](https://modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-LoRA): LoRA 蒸馏训练
|
||||
|
||||
点击模型链接即可前往模型页面查看模型效果。
|
||||
|
||||
## 在训练框架中使用蒸馏加速训练
|
||||
|
||||
首先,需要生成训练数据,请参考[模型推理](/docs/zh/Pipeline_Usage/Model_Inference.md)部分编写推理代码,以足够多的推理步数生成训练数据。
|
||||
|
||||
以 Qwen-Image 为例,以下代码可以生成一张图片:
|
||||
|
||||
```python
|
||||
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
|
||||
image = pipe(prompt, seed=0, num_inference_steps=40)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
然后,我们把必要的信息编写成[元数据文件](/docs/zh/API_Reference/core/data.md#元数据):
|
||||
|
||||
```csv
|
||||
image,prompt,seed,rand_device,num_inference_steps,cfg_scale
|
||||
distill_qwen/image.jpg,"精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。",0,cpu,4,1
|
||||
```
|
||||
|
||||
这个样例数据集可以直接下载:
|
||||
|
||||
```shell
|
||||
modelscope download --dataset DiffSynth-Studio/example_image_dataset --local_dir ./data/example_image_dataset
|
||||
```
|
||||
|
||||
然后开始 LoRA 蒸馏加速训练:
|
||||
|
||||
```shell
|
||||
bash examples/qwen_image/model_training/lora/Qwen-Image-Distill-LoRA.sh
|
||||
```
|
||||
|
||||
请注意,在[训练脚本参数](/docs/zh/Pipeline_Usage/Model_Training.md#脚本参数)中,数据集的图像分辨率设置要避免触发缩放处理。当设定 `--height` 和 `--width` 以启用固定分辨率时,所有训练数据必须是以完全一致的宽高生成的;当设定 `--max_pixels` 以启用动态分辨率时,`--max_pixels` 的数值必须大于或等于任一训练图像的像素面积。
|
||||
|
||||
## 训练框架设计思路
|
||||
|
||||
直接蒸馏与[标准监督训练](/docs/zh/Training/Supervised_Fine_Tuning.md)相比,仅训练的损失函数不同,直接蒸馏的损失函数是 `diffsynth.diffusion.loss` 中的 `DirectDistillLoss`。
|
||||
|
||||
## 未来工作
|
||||
|
||||
直接蒸馏是通用性很强的加速方案,但未必是效果最好的方案,所以我们暂未把这一技术以论文的形式发布。我们希望把这个问题交给学术界和开源社区共同解决,期待开发者能够给出更完善的通用训练方案。
|
||||
20
docs/zh/Training/FP8_Precision.md
Normal file
20
docs/zh/Training/FP8_Precision.md
Normal file
@@ -0,0 +1,20 @@
|
||||
# 在训练中启用 FP8 精度
|
||||
|
||||
尽管 `DiffSynth-Studio` 在模型推理中支持[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md),但其中的大部分减少显存占用的技术不适合用于训练中,Offload 会导致极为缓慢的训练过程。
|
||||
|
||||
FP8 精度是唯一可在训练过程中启用的显存管理策略,但本框架目前不支持原生 FP8 精度训练,原因详见 [Q&A: 为什么训练框架不支持原生 FP8 精度训练?](/docs/zh/QA.md#为什么训练框架不支持原生-fp8-精度训练),仅支持将参数不被梯度更新的模型(不需要梯度回传,或梯度仅更新其 LoRA)以 FP8 精度进行存储。
|
||||
|
||||
## 启用 FP8
|
||||
|
||||
在我们提供的训练脚本中,通过参数 `--fp8_models` 即可快速设置以 FP8 精度存储的模型。以 Qwen-Image 的 LoRA 训练为例,我们提供了启用 FP8 训练的脚本,位于 [`/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh`](/examples/qwen_image/model_training/special/fp8_training/Qwen-Image-LoRA.sh)。训练完成后,可通过脚本 [`/examples/qwen_image/model_training/special/fp8_training/validate.py`](/examples/qwen_image/model_training/special/fp8_training/validate.py) 验证训练效果。
|
||||
|
||||
请注意,这种 FP8 显存管理策略不支持梯度更新,当某个模型被设置为可训练时,不能为这个模型开启 FP8 精度,支持开启 FP8 的模型包括两类:
|
||||
|
||||
* 参数不可训练,例如 VAE 模型
|
||||
* 梯度不更新其参数,例如 LoRA 训练中的 DiT 模型
|
||||
|
||||
经实验验证,开启 FP8 后的 LoRA 训练效果没有明显的图像质量下降,但理论上误差是确实存在的,如果在使用本功能时遇到训练效果不如 BF16 精度训练的问题,请通过 GitHub issue 给我们提供反馈。
|
||||
|
||||
## 训练框架设计思路
|
||||
|
||||
训练框架完全沿用推理的显存管理,在训练中仅通过 `DiffusionTrainingModule` 中的 `parse_model_configs` 解析显存管理配置。
|
||||
97
docs/zh/Training/Split_Training.md
Normal file
97
docs/zh/Training/Split_Training.md
Normal file
@@ -0,0 +1,97 @@
|
||||
# 两阶段拆分训练
|
||||
|
||||
本文档介绍拆分训练,能够自动将训练过程拆分为两阶段进行,减少显存占用,同时加快训练速度。
|
||||
|
||||
(拆分训练是实验性特性,尚未进行大规模验证,如果在使用中出现问题,请在 GitHub 上给我们提 issue。)
|
||||
|
||||
## 拆分训练
|
||||
|
||||
在大部分模型的训练过程中,大量计算发生在“前处理”中,即“与去噪模型无关的计算”,包括 VAE 编码、文本编码等。当对应的模型参数固定时,这部分计算的结果是重复的,在多个 epoch 中每个数据样本的计算结果完全相同,因此我们提供了“拆分训练”功能,该功能可以自动分析并拆分训练过程。
|
||||
|
||||
对于普通文生图模型的标准监督训练,拆分过程是非常简单的,只需要把所有 [`Pipeline Units`](/docs/zh/Developer_Guide/Building_a_Pipeline.md#units) 的计算拆分到第一阶段,将计算结果存储到硬盘中,然后在第二阶段从硬盘中读取这些结果并进行后续计算即可。但如果前处理过程中需要梯度回传,情况就变得极其复杂,为此,我们引入了一个计算图拆分算法用于分析如何拆分计算。
|
||||
|
||||
## 计算图拆分算法
|
||||
|
||||
> (我们会在后续的文档更新中补充计算图拆分算法的详细细节)
|
||||
|
||||
## 使用拆分训练
|
||||
|
||||
拆分训练已支持[标准监督训练](/docs/zh/Training/Supervised_Fine_Tuning.md)和[直接蒸馏训练](/docs/zh/Training/Direct_Distill.md),在训练命令中通过 `--task` 参数控制,以 Qwen-Image 模型的 LoRA 训练为例,拆分前的训练命令为:
|
||||
|
||||
```shell
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image_lora" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--find_unused_parameters
|
||||
```
|
||||
|
||||
拆分后,在第一阶段中,做如下修改:
|
||||
|
||||
* 将 `--dataset_repeat` 改为 1,避免重复计算
|
||||
* 将 `--output_path` 改为第一阶段计算结果保存的路径
|
||||
* 添加额外参数 `--task "sft:data_process"`
|
||||
* 删除 `--model_id_with_origin_paths` 中的 DiT 模型
|
||||
|
||||
```shell
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path data/example_image_dataset \
|
||||
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 1 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-LoRA-splited-cache" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--find_unused_parameters \
|
||||
--task "sft:data_process"
|
||||
```
|
||||
|
||||
在第二阶段,做如下修改:
|
||||
|
||||
* 将 `--dataset_base_path` 改为第一阶段的 `--output_path`
|
||||
* 删除 `--dataset_metadata_path`
|
||||
* 添加额外参数 `--task "sft:train"`
|
||||
* 删除 `--model_id_with_origin_paths` 中的 Text Encoder 和 VAE 模型
|
||||
|
||||
```shell
|
||||
accelerate launch examples/qwen_image/model_training/train.py \
|
||||
--dataset_base_path "./models/train/Qwen-Image-LoRA-splited-cache" \
|
||||
--max_pixels 1048576 \
|
||||
--dataset_repeat 50 \
|
||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \
|
||||
--learning_rate 1e-4 \
|
||||
--num_epochs 5 \
|
||||
--remove_prefix_in_ckpt "pipe.dit." \
|
||||
--output_path "./models/train/Qwen-Image-LoRA-splited" \
|
||||
--lora_base_model "dit" \
|
||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing \
|
||||
--dataset_num_workers 8 \
|
||||
--find_unused_parameters \
|
||||
--task "sft:train"
|
||||
```
|
||||
|
||||
我们提供了样例训练脚本和验证脚本,位于 `examples/qwen_image/model_training/special/split_training`。
|
||||
|
||||
## 训练框架设计思路
|
||||
|
||||
训练框架通过 `DiffusionTrainingModule` 的 `split_pipeline_units` 方法拆分 `Pipeline` 中的计算单元。
|
||||
129
docs/zh/Training/Supervised_Fine_Tuning.md
Normal file
129
docs/zh/Training/Supervised_Fine_Tuning.md
Normal file
@@ -0,0 +1,129 @@
|
||||
# 标准监督训练
|
||||
|
||||
在理解 [Diffusion 模型基本原理](/docs/zh/Training/Understanding_Diffusion_models.md)之后,本文档介绍框架如何实现 Diffusion 模型的训练。本文档介绍框架的原理,帮助开发者编写新的训练代码,如需使用我们提供的默认训练功能,请参考[模型训练](/docs/zh/Pipeline_Usage/Model_Training.md)。
|
||||
|
||||
回顾前文中的模型训练伪代码,当我们实际编写代码时,情况会变得极为复杂。部分模型需要输入额外的引导条件并进行预处理,例如 ControlNet;部分模型需要与去噪模型进行交叉式的计算,例如 VACE;部分模型因显存需求过大,需要开启 Gradient Checkpointing,例如 Qwen-Image 的 DiT。
|
||||
|
||||
为了实现严格的推理和训练一致性,我们对 `Pipeline` 等组件进行了抽象封装,在训练过程中大量复用推理代码。请参考[接入 Pipeline](/docs/zh/Developer_Guide/Building_a_Pipeline.md) 了解 `Pipeline` 组件的设计。接下来我们介绍训练框架如何利用 `Pipeline` 组件构建训练算法。
|
||||
|
||||
## 框架设计思路
|
||||
|
||||
训练模块在 `Pipeline` 上层进行封装,继承 `diffsynth.diffusion.training_module` 中的 `DiffusionTrainingModule`,我们需为训练模块提供必要的 `__init__` 和 `forward` 方法。我们以 Qwen-Image 的 LoRA 训练为例,在 `examples/qwen_image/model_training/special/simple/train.py` 中提供了仅包含基础训练功能的简易脚本,帮助开发者理解训练模块的设计思路。
|
||||
|
||||
```python
|
||||
class QwenImageTrainingModule(DiffusionTrainingModule):
|
||||
def __init__(self, device):
|
||||
# Initialize models here.
|
||||
pass
|
||||
|
||||
def forward(self, data):
|
||||
# Compute loss here.
|
||||
return loss
|
||||
```
|
||||
|
||||
### `__init__`
|
||||
|
||||
在 `__init__` 中需进行模型的初始化,先加载模型,然后将其切换到训练模式。
|
||||
|
||||
```python
|
||||
def __init__(self, device):
|
||||
super().__init__()
|
||||
# Load the pipeline
|
||||
self.pipe = QwenImagePipeline.from_pretrained(
|
||||
torch_dtype=torch.bfloat16,
|
||||
device=device,
|
||||
model_configs=[
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
|
||||
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
# Switch to training mode
|
||||
self.switch_pipe_to_training_mode(
|
||||
self.pipe,
|
||||
lora_base_model="dit",
|
||||
lora_target_modules="to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj",
|
||||
lora_rank=32,
|
||||
)
|
||||
```
|
||||
|
||||
加载模型的逻辑与推理时基本一致,支持从远程和本地路径加载模型,详见[模型推理](/docs/zh/Pipeline_Usage/Model_Inference.md),但请注意不要启用[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md)。
|
||||
|
||||
`switch_pipe_to_training_mode` 可以将模型切换到训练模式,详见 `switch_pipe_to_training_mode`。
|
||||
|
||||
### `forward`
|
||||
|
||||
在 `forward` 中需计算损失函数值,先进行前处理,然后经过 `Pipeline` 的 [`model_fn`](/docs/zh/Developer_Guide/Building_a_Pipeline.md#model_fn) 计算损失函数。
|
||||
|
||||
```python
|
||||
def forward(self, data):
|
||||
# Preprocess
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
inputs_nega = {"negative_prompt": ""}
|
||||
inputs_shared = {
|
||||
# Assume you are using this pipeline for inference,
|
||||
# please fill in the input parameters.
|
||||
"input_image": data["image"],
|
||||
"height": data["image"].size[1],
|
||||
"width": data["image"].size[0],
|
||||
# Please do not modify the following parameters
|
||||
# unless you clearly know what this will cause.
|
||||
"cfg_scale": 1,
|
||||
"rand_device": self.pipe.device,
|
||||
"use_gradient_checkpointing": True,
|
||||
"use_gradient_checkpointing_offload": False,
|
||||
}
|
||||
for unit in self.pipe.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.pipe.unit_runner(unit, self.pipe, inputs_shared, inputs_posi, inputs_nega)
|
||||
# Loss
|
||||
loss = FlowMatchSFTLoss(self.pipe, **inputs_shared, **inputs_posi)
|
||||
return loss
|
||||
```
|
||||
|
||||
前处理过程与推理阶段一致,开发者只需假定在使用 `Pipeline` 进行推理,将输入参数填入即可。
|
||||
|
||||
损失函数的计算沿用 `diffsynth.diffusion.loss` 中的 `FlowMatchSFTLoss`。
|
||||
|
||||
### 开始训练
|
||||
|
||||
训练框架还需其他模块,包括:
|
||||
|
||||
* accelerator: `accelerate` 提供的训练启动器,详见 [`accelerate`](https://huggingface.co/docs/accelerate/index)
|
||||
* dataset: 通用数据集,详见 [`diffsynth.core.data`](/docs/zh/API_Reference/core/data.md)
|
||||
* model_logger: 模型记录器,详见 `diffsynth.diffusion.logger`
|
||||
|
||||
```python
|
||||
if __name__ == "__main__":
|
||||
accelerator = accelerate.Accelerator(
|
||||
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=True)],
|
||||
)
|
||||
dataset = UnifiedDataset(
|
||||
base_path="data/example_image_dataset",
|
||||
metadata_path="data/example_image_dataset/metadata.csv",
|
||||
repeat=50,
|
||||
data_file_keys="image",
|
||||
main_data_operator=UnifiedDataset.default_image_operator(
|
||||
base_path="data/example_image_dataset",
|
||||
height=512,
|
||||
width=512,
|
||||
height_division_factor=16,
|
||||
width_division_factor=16,
|
||||
)
|
||||
)
|
||||
model = QwenImageTrainingModule(accelerator.device)
|
||||
model_logger = ModelLogger(
|
||||
output_path="models/toy_model",
|
||||
remove_prefix_in_ckpt="pipe.dit.",
|
||||
)
|
||||
launch_training_task(
|
||||
accelerator, dataset, model, model_logger,
|
||||
learning_rate=1e-5, num_epochs=1,
|
||||
)
|
||||
```
|
||||
|
||||
将以上所有代码组装,得到 `examples/qwen_image/model_training/special/simple/train.py`。使用以下命令即可启动训练:
|
||||
|
||||
```
|
||||
accelerate launch examples/qwen_image/model_training/special/simple/train.py
|
||||
```
|
||||
143
docs/zh/Training/Understanding_Diffusion_models.md
Normal file
143
docs/zh/Training/Understanding_Diffusion_models.md
Normal file
@@ -0,0 +1,143 @@
|
||||
# Diffusion 模型基本原理
|
||||
|
||||
本文介绍 Diffusion 模型的基本原理,帮助你理解训练框架是如何构建的。为了让读者更轻松地理解这些复杂的数学理论,我们重构了 Diffusion 模型的理论框架,抛弃了复杂的随机微分方程,用一种更简洁易懂的形式进行介绍。
|
||||
|
||||
## 引言
|
||||
|
||||
Diffusion 模型通过多步迭代式地去噪(denoise)生成清晰的图像或视频内容,我们从一个数据样本 $x_0$ 的生成过程开始讲起。直观地,在完整的一轮 denoise 过程中,我们从随机高斯噪声 $x_T$ 开始,通过迭代依次得到 $x_{T-1}$、$x_{T-2}$、$x_{T-3}$、$\cdots$,在每一步中逐渐减少噪声含量,最终得到不含噪声的数据样本 $x_0$。
|
||||
|
||||
(图)
|
||||
|
||||
这个过程是很直观的,但如果要理解其中的细节,我们就需要回答这几个问题:
|
||||
|
||||
* 每一步的噪声含量是如何定义的?
|
||||
* 迭代去噪的计算是如何进行的?
|
||||
* 如何训练这样的 Diffusion 模型?
|
||||
* 现代 Diffusion 模型的架构是什么样的?
|
||||
* 本项目如何封装和实现模型训练?
|
||||
|
||||
## 每一步的噪声含量是如何定义的?
|
||||
|
||||
在 Diffusion 模型的理论体系中,噪声的含量是由一系列参数 $\sigma_T$、$\sigma_{T-1}$、$\sigma_{T-2}$、$\cdots$、$\sigma_0$ 决定的。其中
|
||||
|
||||
* $\sigma_T=1$,对应的 $x_T$ 为纯粹的高斯噪声
|
||||
* $\sigma_T>\sigma_{T-1}>\sigma_{T-2}>\cdots>x_0$,在迭代过程中噪声含量逐渐减小
|
||||
* $\sigma_0=0$,对应的 $x_0$ 为不含任何噪声的数据样本
|
||||
|
||||
至于中间 $\sigma_{T-1}$、$\sigma_{T-2}$、$\cdots$、$\sigma_1$ 的数值,则不是固定的,满足递减的条件即可。
|
||||
|
||||
那么在中间的某一步,我们可以直接合成含噪声的数据样本 $x_t=(1-\sigma_t)x_0+\sigma_t x_T$。
|
||||
|
||||
(图)
|
||||
|
||||
## 迭代去噪的计算是如何进行的?
|
||||
|
||||
在理解迭代去噪的计算前,我们要先搞清楚,去噪模型的输入和输出是什么。我们把模型抽象成一个符号 $\hat \epsilon$,它的输入通常包含三部分
|
||||
|
||||
* 时间步 $t$,模型需要理解当前处于去噪过程的哪个阶段
|
||||
* 含噪声的数据样本 $x_t$,模型需要理解要对什么数据进行去噪
|
||||
* 引导条件 $c$,模型需要理解要通过去噪生成什么样的数据样本
|
||||
|
||||
其中,引导条件 $c$ 是新引入的参数,它是由用户输入的,可以是用于描述图像内容的文本,也可以是用于勾勒图像结构的线稿图。
|
||||
|
||||
(图)
|
||||
|
||||
而模型的输出 $\hat \epsilon(x_t,c,t)$,则近似地等于 $x_T-x_0$,也就是整个扩散过程(去噪过程的反向过程)的方向。
|
||||
|
||||
接下来我们分析一步迭代中发生的计算,在时间步 $t$,模型通过计算得到近似的 $x_T-x_0$ 后,我们计算下一步的 $x_{t-1}$:
|
||||
$$
|
||||
\begin{aligned}
|
||||
x_{t-1}&=x_t + (\sigma_{t-1} - \sigma_t) \cdot \hat \epsilon(x_t,c,t)\\
|
||||
&\approx x_t + (\sigma_{t-1} - \sigma_t) \cdot (x_T-x_0)\\
|
||||
&=(1-\sigma_t)x_0+\sigma_t x_T + (\sigma_{t-1} - \sigma_t) \cdot (x_T-x_0)\\
|
||||
&=(1-\sigma_{t-1})x_0+\sigma_{t-1}x_T
|
||||
\end{aligned}
|
||||
$$
|
||||
完美!与时间步 $t-1$ 时的噪声含量定义完美契合。
|
||||
|
||||
> (这部分可能有点难懂,请不必担心,首次阅读本文时建议跳过这部分,不影响后文的阅读。)
|
||||
>
|
||||
> 完成了这段有点复杂的公式推导后,我们思考一个问题,为什么模型的输出要近似地等于 $x_T-x_0$ 呢?可以设定成其他值吗?
|
||||
>
|
||||
> 实际上,Diffusion 模型依赖两个定义形成完备的理论。在以上的公式中,我们可以提炼出这两个定义,并导出迭代公式:
|
||||
>
|
||||
> * 数据定义:$x_t=(1-\sigma_t)x_0+\sigma_t x_T$
|
||||
> * 模型定义:$\hat \epsilon(x_t,c,t)=x_T-x_0$
|
||||
> * 导出迭代公式:$x_{t-1}=x_t + (\sigma_{t-1} - \sigma_t) \cdot \hat \epsilon(x_t,c,t)$
|
||||
>
|
||||
> 这三个数学公式是完备的,例如在刚才的推导中,我们把数据定义和模型定义代入迭代公式,可以得到与数据定义吻合的 $x_{t-1}$。
|
||||
>
|
||||
> 这是基于 Flow Matching 理论构建的两个定义,但 Diffusion 模型也可用其他的两个定义来实现,例如早期基于 DDPM(Denoising Diffusion Probabilistic Models)的模型,其两个定义及导出的迭代公式为:
|
||||
>
|
||||
> * 数据定义:$x_t=\sqrt{\alpha_t}x_0+\sqrt{1-\alpha_t}x_T$
|
||||
> * 模型定义:$\hat \epsilon(x_t,c,t)=x_T$
|
||||
> * 导出迭代公式:$x_{t-1}=\sqrt{\alpha_{t-1}}\left(\frac{x_t-\sqrt{1-\alpha_t}\hat \epsilon(x_t,c,t)}{\sqrt{\sigma_t}}\right)+\sqrt{1-\alpha_{t-1}}\hat \epsilon(x_t,c,t)$
|
||||
>
|
||||
> 更一般地,我们用矩阵描述迭代公式的导出过程,对于任意数据定义和模型定义,有:
|
||||
>
|
||||
> * 数据定义:$x_t=C_T(x_0,x_T)^T$
|
||||
> * 模型定义:$\hat \epsilon(x_t,c,t)=C_T^{[\epsilon]}(x_0,x_T)^T$
|
||||
> * 导出迭代公式:$x_{t-1}=C_{t-1}(C_t,C_t^{[\epsilon]})^{-T}(x_t,\hat \epsilon(x_t,c,t))^T$
|
||||
>
|
||||
> 其中,$C_t$、$C_t^{[\epsilon]}$ 是 $1\times 2$ 的系数矩阵,不难发现,在构造两个定义时,需保证矩阵 $(C_t,C_t^{[\epsilon]})^T$ 是可逆的。
|
||||
>
|
||||
> 尽管 Flow Matching 与 DDPM 已被大量预训练模型广泛验证过,但这并不代表这是最优的方案,我们鼓励开发者设计新的 Diffusion 模型理论实现更好的训练效果。
|
||||
|
||||
## 如何训练这样的 Diffusion 模型?
|
||||
|
||||
搞清楚迭代去噪的过程之后,接下来我们考虑如何训练这样的 Diffusion 模型。
|
||||
|
||||
训练过程不同于生成过程,如果我们在训练过程中保留多步迭代,那么梯度需经过多步回传,带来的时间和空间复杂度是灾难性的。为了提高计算效率,我们在训练中随机选择某一时间步 $t$ 进行训练。
|
||||
|
||||
(图)
|
||||
|
||||
以下是训练过程的伪代码
|
||||
|
||||
> 从数据集获取数据样本 $x_0$ 和引导条件 $c$
|
||||
>
|
||||
> 随机采样时间步 $t\in(0,T]$
|
||||
>
|
||||
> 随机采样高斯噪声 $x_T\in \mathcal N(O,I)$
|
||||
>
|
||||
> $x_t=(1-\sigma_t)x_0+\sigma_t x_T$
|
||||
>
|
||||
> $\hat \epsilon(x_t,c,t)$
|
||||
>
|
||||
> 损失函数 $\mathcal L=||\hat \epsilon(x_t,c,t)-(x_T-x_0)||_2^2$
|
||||
>
|
||||
> 梯度回传并更新模型参数
|
||||
|
||||
## 现代 Diffusion 模型的架构是什么样的?
|
||||
|
||||
从理论到实践,还需要填充更多细节。现代 Diffusion 模型架构已经发展成熟,主流的架构沿用了 Latent Diffusion 所提出的“三段式”架构,包括数据编解码器、引导条件编码器、去噪模型三部分。
|
||||
|
||||
(图)
|
||||
|
||||
### 数据编解码器
|
||||
|
||||
在前文中,我们一直将 $x_0$ 称为“数据样本”,而不是图像或视频,这是因为现代 Diffusion 模型通常不会直接在图像或视频上进行处理,而是用编码器(Encoder)-解码器(Decoder)架构的模型,通常是 VAE(Variational Auto-Encoders)模型,将图像或视频编码为 Embedding 张量,得到 $x_0$。
|
||||
|
||||
数据经过编码器编码后,再经过解码器解码,重建后的内容与原来近似地一致,会有少量误差。那么,为什么要在编码后的 Embedding 张量上处理,而不是在图像或视频上直接处理呢?主要原因有亮点:
|
||||
|
||||
* 编码的同时对数据进行了压缩,编码后处理的计算量更小。
|
||||
* 编码后的数据分布与高斯分布更相似,更容易用去噪模型对数据进行建模。
|
||||
|
||||
在生成过程中,编码器部分不参与计算,迭代完成后,用解码器部分解码 $x_0$ 即可得到清晰的图像或视频。在训练过程中,解码器部分不参与计算,仅编码器用于计算 $x_0$。
|
||||
|
||||
### 引导条件编码器
|
||||
|
||||
用户输入的引导条件 $c$ 可能是复杂多样的,需要由专门的编码器模型将其处理成 Embedding 张量。按照引导条件的类型,我们把引导条件编码器分为以下几类:
|
||||
|
||||
* 文本类型,例如 CLIP、Qwen-VL
|
||||
* 图像类型,例如 ControlNet、IP-Adapter
|
||||
* 视频类型,例如 VAE
|
||||
|
||||
> 前文中的模型 $\hat \epsilon$ 指代此处的所有引导条件编码器和去噪模型这一整体,我们把引导条件编码器单独拆分列出,因为这类模型在 Diffusion 训练中通常是冻结的,且输出值与时间步 $t$ 无关,因此引导条件编码器的计算可以离线进行。
|
||||
|
||||
### 去噪模型
|
||||
|
||||
去噪模型是 Diffusion 模型真正的本体,其模型结构多种多样,例如 UNet、DiT,模型开发者可在此结构上自由发挥。
|
||||
|
||||
## 本项目如何封装和实现模型训练?
|
||||
|
||||
请阅读下一文档:[标准监督训练](/docs/zh/Training/Supervised_Fine_Tuning.md)
|
||||
Reference in New Issue
Block a user