mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
DiffSynth-Studio 2.0 major update
This commit is contained in:
79
docs/zh/API_Reference/core/attention.md
Normal file
79
docs/zh/API_Reference/core/attention.md
Normal file
@@ -0,0 +1,79 @@
|
||||
# `diffsynth.core.attention`: 注意力机制实现
|
||||
|
||||
`diffsynth.core.attention` 提供了注意力机制实现的路由机制,根据 `Python` 环境中的可用包和[环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation)自动选择高效的注意力机制实现。
|
||||
|
||||
## 注意力机制
|
||||
|
||||
注意力机制是在论文[《Attention Is All You Need》](https://arxiv.org/abs/1706.03762)中提出的模型结构,在原论文中,注意力机制按照如下公式实现:
|
||||
|
||||
$$
|
||||
\text{Attention}(Q, K, V) = \text{Softmax}\left(
|
||||
\frac{QK^T}{\sqrt{d_k}}
|
||||
\right)
|
||||
V.
|
||||
$$
|
||||
|
||||
在 `PyTorch` 中,可以用如下代码实现:
|
||||
```python
|
||||
import torch
|
||||
|
||||
def attention(query, key, value):
|
||||
scale_factor = 1 / query.size(-1)**0.5
|
||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||
return attn_weight @ value
|
||||
|
||||
query = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda")
|
||||
key = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda")
|
||||
value = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda")
|
||||
output_1 = attention(query, key, value)
|
||||
```
|
||||
|
||||
其中 `query`、`key`、`value` 的维度是 $(b, n, s, d)$:
|
||||
* $b$:Batch size
|
||||
* $n$: Attention head 的数量
|
||||
* $s$: 序列长度
|
||||
* $d$: 每个 Attention head 的维数
|
||||
|
||||
这部分计算是不包含任何可训练参数的,现代 transformer 架构的模型会在进行这一计算前后经过 Linear 层,本文讨论的“注意力机制”不包含这些计算,仅包含以上代码的计算。
|
||||
|
||||
## 更高效的实现
|
||||
|
||||
注意到,注意力机制中 Attention Score(公式中的 $\text{Softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)$,代码中的 `attn_weight`)的维度为 $(b, n, s, s)$,其中序列长度 $s$ 通常非常大,这导致计算的时间和空间复杂度达到平方级。以图像生成模型为例,图像的宽度和高度每增加到 2 倍,序列长度增加到 4 倍,计算量和显存需求增加到 16 倍。为了避免高昂的计算成本,需采用更高效的注意力机制实现,包括
|
||||
* Flash Attention 3:[GitHub](https://github.com/Dao-AILab/flash-attention)、[论文](https://arxiv.org/abs/2407.08608)
|
||||
* Flash Attention 2:[GitHub](https://github.com/Dao-AILab/flash-attention)、[论文](https://arxiv.org/abs/2307.08691)
|
||||
* Sage Attention:[GitHub](https://github.com/thu-ml/SageAttention)、[论文](https://arxiv.org/abs/2505.11594)
|
||||
* xFormers:[GitHub](https://github.com/facebookresearch/xformers)、[文档](https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops)
|
||||
* PyTorch:[GitHub](https://github.com/pytorch/pytorch)、[文档](https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
|
||||
|
||||
如需调用除 `PyTorch` 外的其他注意力实现,请按照其 GitHub 页面的指引安装对应的包。`DiffSynth-Studio` 会自动根据 Python 环境中的可用包路由到对应的实现上,也可通过[环境变量](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_attention_implementation)控制。
|
||||
|
||||
```python
|
||||
from diffsynth.core.attention import attention_forward
|
||||
import torch
|
||||
|
||||
def attention(query, key, value):
|
||||
scale_factor = 1 / query.size(-1)**0.5
|
||||
attn_weight = query @ key.transpose(-2, -1) * scale_factor
|
||||
attn_weight = torch.softmax(attn_weight, dim=-1)
|
||||
return attn_weight @ value
|
||||
|
||||
query = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda")
|
||||
key = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda")
|
||||
value = torch.rand(32, 8, 128, 64, dtype=torch.bfloat16, device="cuda")
|
||||
output_1 = attention(query, key, value)
|
||||
output_2 = attention_forward(query, key, value)
|
||||
print((output_1 - output_2).abs().mean())
|
||||
```
|
||||
|
||||
请注意,加速的同时会引入误差,但在大多数情况下误差是可以忽略不计的。
|
||||
|
||||
## 开发者导引
|
||||
|
||||
在为 `DiffSynth-Studio` 接入新模型时,开发者可自行决定是否调用 `diffsynth.core.attention` 中的 `attention_forward`,但我们期望模型能够尽可能优先调用这一模块,以便让新的注意力机制实现能够在这些模型上直接生效。
|
||||
|
||||
## 最佳实践
|
||||
|
||||
**在大多数情况下,我们建议直接使用 `PyTorch` 原生的实现,无需安装任何额外的包。** 虽然其他注意力机制实现可以加速,但加速效果是较为有限的,在少数情况下会出现兼容性和精度不足的问题。
|
||||
|
||||
此外,高效的注意力机制实现会逐步集成到 `PyTorch` 中,`PyTorch` 的 `2.9.0` 版本中的 `scaled_dot_product_attention` 已经集成了 Flash Attention 2。我们仍在 `DiffSynth-Studio` 提供这一接口,是为了让一些激进的加速方案能够快速走向应用,尽管它们在稳定性上还需要时间验证。
|
||||
151
docs/zh/API_Reference/core/data.md
Normal file
151
docs/zh/API_Reference/core/data.md
Normal file
@@ -0,0 +1,151 @@
|
||||
# `diffsynth.core.data`: 数据处理算子与通用数据集
|
||||
|
||||
## 数据处理算子
|
||||
|
||||
### 可用数据处理算子
|
||||
|
||||
`diffsynth.core.data` 提供了一系列数据处理算子,用于进行数据处理,包括:
|
||||
|
||||
* 数据格式转换算子
|
||||
* `ToInt`: 转换为 int 格式
|
||||
* `ToFloat`: 转换为 float 格式
|
||||
* `ToStr`: 转换为 str 格式
|
||||
* `ToList`: 转换为列表格式,以列表包裹此数据
|
||||
* `ToAbsolutePath`: 将相对路径转换为绝对路径
|
||||
* 文件加载算子
|
||||
* `LoadImage`: 读取图片文件
|
||||
* `LoadVideo`: 读取视频文件
|
||||
* `LoadAudio`: 读取音频文件
|
||||
* `LoadGIF`: 读取 GIF 文件
|
||||
* `LoadTorchPickle`: 读取由 [`torch.save`](https://docs.pytorch.org/docs/stable/generated/torch.save.html) 保存的二进制文件【该算子可能导致二进制文件中的代码注入攻击,请谨慎使用!】
|
||||
* 媒体文件处理算子
|
||||
* `ImageCropAndResize`: 对图像进行裁剪和拉伸
|
||||
* Meta 算子
|
||||
* `SequencialProcess`: 将序列中的每个数据路由到一个算子
|
||||
* `RouteByExtensionName`: 按照文件扩展名路由到特定算子
|
||||
* `RouteByType`: 按照数据类型路由到特定算子
|
||||
|
||||
### 算子使用
|
||||
|
||||
数据算子之间以 `>>` 符号连接形成数据处理流水线,例如:
|
||||
|
||||
```python
|
||||
from diffsynth.core.data.operators import *
|
||||
|
||||
data = "image.jpg"
|
||||
data_pipeline = ToAbsolutePath(base_path="/data") >> LoadImage() >> ImageCropAndResize(max_pixels=512*512)
|
||||
data = data_pipeline(data)
|
||||
```
|
||||
|
||||
在经过每个算子后,数据被依次处理
|
||||
|
||||
* `ToAbsolutePath(base_path="/data")`: `"/data/image.jpg"`
|
||||
* `LoadImage()`: `<PIL.Image.Image image mode=RGB size=1024x1024 at 0x7F8E7AAEFC10>`
|
||||
* `ImageCropAndResize(max_pixels=512*512)`: `<PIL.Image.Image image mode=RGB size=512x512 at 0x7F8E7A936F20>`
|
||||
|
||||
我们可以组合出功能完备的数据流水线,例如通用数据集的默认视频数据算子为
|
||||
|
||||
```python
|
||||
RouteByType(operator_map=[
|
||||
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
|
||||
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
|
||||
(("gif",), LoadGIF(
|
||||
num_frames, time_division_factor, time_division_remainder,
|
||||
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||
)),
|
||||
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
|
||||
num_frames, time_division_factor, time_division_remainder,
|
||||
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
|
||||
)),
|
||||
])),
|
||||
])
|
||||
```
|
||||
|
||||
它包含如下逻辑:
|
||||
|
||||
* 如果是 `str` 类型的数据
|
||||
* 如果是 `"jpg", "jpeg", "png", "webp"` 类型文件
|
||||
* 加载这张图片
|
||||
* 裁剪并缩放到特定分辨率
|
||||
* 打包进列表,视为单帧视频
|
||||
* 如果是 `"gif"` 类型文件
|
||||
* 加载 gif 文件内容
|
||||
* 将每一帧裁剪和缩放到特定分辨率
|
||||
* 如果是 `"mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"` 类型文件
|
||||
* 加载 gif 文件内容
|
||||
* 将每一帧裁剪和缩放到特定分辨率
|
||||
* 如果不是 `str` 类型的数据,报错
|
||||
|
||||
## 通用数据集
|
||||
|
||||
`diffsynth.core.data` 提供了统一的数据集实现,数据集需输入以下参数:
|
||||
|
||||
* `base_path`: 根目录,若数据集中包含图片文件的相对路径,则需填入此字段用于加载这些路径指向的文件
|
||||
* `metadata_path`: 元数据目录,记录所有元数据的文件路径,支持 `csv`、`json`、`jsonl` 格式
|
||||
* `repeat`: 数据重复次数,默认为 1,该参数影响一个 epoch 的训练步数
|
||||
* `data_file_keys`: 需进行加载的数据字段名,例如 `(image, edit_image)`
|
||||
* `main_data_operator`: 主加载算子,需通过数据处理算子组装好数据处理流水线
|
||||
* `special_operator_map`: 特殊算子映射,对需要特殊处理的字段构建的算子映射
|
||||
|
||||
### 元数据
|
||||
|
||||
数据集的 `metadata_path` 指向元数据文件,支持 `csv`、`json`、`jsonl` 格式,以下提供了样例
|
||||
|
||||
* `csv` 格式:可读性高、不支持列表数据、内存占用小
|
||||
|
||||
```csv
|
||||
image,prompt
|
||||
image_1.jpg,"a dog"
|
||||
image_2.jpg,"a cat"
|
||||
```
|
||||
|
||||
* `json` 格式:可读性高、支持列表数据、内存占用大
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"image": "image_1.jpg",
|
||||
"prompt": "a dog"
|
||||
},
|
||||
{
|
||||
"image": "image_2.jpg",
|
||||
"prompt": "a cat"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
* `jsonl` 格式:可读性低、支持列表数据、内存占用小
|
||||
|
||||
```json
|
||||
{"image": "image_1.jpg", "prompt": "a dog"}
|
||||
{"image": "image_2.jpg", "prompt": "a cat"}
|
||||
```
|
||||
|
||||
如何选择最佳的元数据格式?
|
||||
|
||||
* 如果数据量大,达到千万级的数据量,由于 `json` 文件解析时需要额外内存,此时不可用,请使用 `csv` 或 `jsonl` 格式
|
||||
* 如果数据集中包含列表数据,例如编辑模型需输入多张图,由于 `csv` 格式无法存储列表格式数据,此时不可用,请使用 `json` 或 `jsonl` 格式
|
||||
|
||||
### 数据加载逻辑
|
||||
|
||||
在没有进行额外设置时,数据集默认输出元数据集中的数据,图片和视频文件的路径会以字符串的格式输出,若要加载这些文件,则需要设置 `data_file_keys`、`main_data_operator`、`special_operator_map`。
|
||||
|
||||
在数据处理流程中,按如下逻辑进行处理:
|
||||
* 如果字段位于 `special_operator_map`,则调用 `special_operator_map` 中的对应算子进行处理
|
||||
* 如果字段不位于 `special_operator_map`
|
||||
* 如果字段位于 `data_file_keys`,则调用 `main_data_operator` 算子进行处理
|
||||
* 如果字段不位于 `data_file_keys`,则不进行处理
|
||||
|
||||
`special_operator_map` 可用于实现特殊的数据处理,例如模型 [Wan-AI/Wan2.2-Animate-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.2-Animate-14B) 中输入的人物面部视频 `animate_face_video` 是以固定分辨率处理的,与输出视频不一致,因此这一字段由专门的算子处理:
|
||||
|
||||
```python
|
||||
special_operator_map={
|
||||
"animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)),
|
||||
}
|
||||
```
|
||||
|
||||
### 其他注意事项
|
||||
|
||||
当数据量过少时,可适当增加 `repeat`,延长单个 epoch 的训练时间,避免频繁保存模型产生较多耗时。
|
||||
|
||||
当数据量 * `repeat` 超过 $10^9$ 时,我们观测到数据集的速度明显变慢,这似乎是 `PyTorch` 的 bug,我们尚不确定新版本的 `PyTorch` 是否已经修复了这一问题。
|
||||
69
docs/zh/API_Reference/core/gradient.md
Normal file
69
docs/zh/API_Reference/core/gradient.md
Normal file
@@ -0,0 +1,69 @@
|
||||
# `diffsynth.core.gradient`: 梯度检查点及其 Offload
|
||||
|
||||
`diffsynth.core.gradient` 中提供了封装好的梯度检查点及其 Offload 版本,用于模型训练。
|
||||
|
||||
## 梯度检查点
|
||||
|
||||
梯度检查点是用于减少训练时显存占用的技术。我们提供一个例子来帮助你理解这一技术,以下是一个简单的模型结构
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class ToyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.activation = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
return self.activation(x)
|
||||
|
||||
model = ToyModel()
|
||||
x = torch.randn((2, 3))
|
||||
y = model(x)
|
||||
```
|
||||
|
||||
在这个模型结构中,输入的参数 $x$ 经过 Sigmoid 激活函数得到输出值 $y=\frac{1}{1+e^{-x}}$。
|
||||
|
||||
在训练过程中,假定我们的损失函数值为 $\mathcal L$,在梯度反响传播时,我们得到 $\frac{\partial \mathcal L}{\partial y}$,此时我们需计算 $\frac{\partial \mathcal L}{\partial x}$,不难发现 $\frac{\partial y}{\partial x}=y(1-y)$,进而有 $\frac{\partial \mathcal L}{\partial x}=\frac{\partial \mathcal L}{\partial y}\frac{\partial y}{\partial x}=\frac{\partial \mathcal L}{\partial y}y(1-y)$。如果在模型前向传播时保存 $y$ 的数值,并在梯度反向传播时直接计算 $y(1-y)$,这将避免复杂的 exp 计算,加快计算速度,但这会导致我们需要额外的显存来存储中间变量 $y$。
|
||||
|
||||
不启用梯度检查点时,训练框架会默认存储所有辅助梯度计算的中间变量,从而达到最佳的计算速度。开启梯度检查点时,中间变量则不会存储,但输入参数 $x$ 仍会存储,减少显存占用,在梯度反向传播时需重新计算这些变量,减慢计算速度。
|
||||
|
||||
## 启用梯度检查点及其 Offload
|
||||
|
||||
`diffsynth.core.gradient` 中的 `gradient_checkpoint_forward` 实现了梯度检查点及其 Offload,可参考以下代码调用:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffsynth.core.gradient import gradient_checkpoint_forward
|
||||
|
||||
class ToyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.activation = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
return self.activation(x)
|
||||
|
||||
model = ToyModel()
|
||||
x = torch.randn((2, 3))
|
||||
y = gradient_checkpoint_forward(
|
||||
model,
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
x=x,
|
||||
)
|
||||
```
|
||||
|
||||
* 当 `use_gradient_checkpointing=False` 且 `use_gradient_checkpointing_offload=False` 时,计算过程与原始计算完全相同,不影响模型的推理和训练,你可以直接将其集成到代码中。
|
||||
* 当 `use_gradient_checkpointing=True` 且 `use_gradient_checkpointing_offload=False` 时,启用梯度检查点。
|
||||
* 当 `use_gradient_checkpointing_offload=True` 时,启用梯度检查点,所有梯度检查点的输入参数存储在内存中,进一步降低显存占用和减慢计算速度。
|
||||
|
||||
## 最佳实践
|
||||
|
||||
> Q: 应当在何处启用梯度检查点?
|
||||
>
|
||||
> A: 对整个模型启用梯度检查点时,计算效率和显存占用并不是最优的,我们需要设置细粒度的梯度检查点,但同时不希望为框架增加过多繁杂的代码。因此我们建议在 `Pipeline` 的 `model_fn` 中实现,例如 `diffsynth/pipelines/qwen_image.py` 中的 `model_fn_qwen_image`,在 Block 层级启用梯度检查点,不需要修改模型结构的任何代码。
|
||||
|
||||
> Q: 什么情况下需要启用梯度检查点?
|
||||
>
|
||||
> A: 随着模型参数量越来越大,梯度检查点已成为必要的训练技术,梯度检查点通常是需要启用的。梯度检查点的 Offload 则仅需在激活值占用显存过大的模型(例如视频生成模型)中启用。
|
||||
141
docs/zh/API_Reference/core/loader.md
Normal file
141
docs/zh/API_Reference/core/loader.md
Normal file
@@ -0,0 +1,141 @@
|
||||
# `diffsynth.core.loader`: 模型下载与加载
|
||||
|
||||
本文档介绍 `diffsynth.core.loader` 中模型下载与加载相关的功能。
|
||||
|
||||
## ModelConfig
|
||||
|
||||
`diffsynth.core.loader` 中的 `ModelConfig` 用于标注模型下载来源、本地路径、显存管理配置等信息。
|
||||
|
||||
### 从远程下载并加载模型
|
||||
|
||||
以模型[DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) 为例,在 `ModelConfig` 中填写 `model_id` 和 `origin_file_pattern` 后即可自动下载模型。默认下载到 `./models` 路径,该路径可通过[环境变量 DIFFSYNTH_MODEL_BASE_PATH](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_model_base_path) 修改。
|
||||
|
||||
默认情况下,即使模型已经下载完毕,程序仍会向远程查询是否有遗漏文件,如果要完全关闭远程请求,请将[环境变量 DIFFSYNTH_SKIP_DOWNLOAD](/docs/zh/Pipeline_Usage/Environment_Variables.md#diffsynth_skip_download) 设置为 `True`。
|
||||
|
||||
```python
|
||||
from diffsynth.core import ModelConfig
|
||||
|
||||
config = ModelConfig(
|
||||
model_id="DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny",
|
||||
origin_file_pattern="model.safetensors",
|
||||
)
|
||||
# Download models
|
||||
config.download_if_necessary()
|
||||
print(config.path)
|
||||
```
|
||||
|
||||
调用 `download_if_necessary` 后,模型会自动下载,并将路径返回到 `config.path` 中。
|
||||
|
||||
### 从本地路径加载模型
|
||||
|
||||
如果从本地路径加载模型,则需要填入 `path`:
|
||||
|
||||
```python
|
||||
from diffsynth.core import ModelConfig
|
||||
|
||||
config = ModelConfig(path="models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors")
|
||||
```
|
||||
|
||||
如果模型包含多个分片文件,以列表的形式输入即可:
|
||||
|
||||
```python
|
||||
from diffsynth.core import ModelConfig
|
||||
|
||||
config = ModelConfig(path=[
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
|
||||
])
|
||||
```
|
||||
|
||||
### 显存管理配置
|
||||
|
||||
`ModelConfig` 也包含了显存管理配置信息,详见[显存管理](/docs/zh/Pipeline_Usage/VRAM_management.md#更多使用方式)。
|
||||
|
||||
## 模型文件加载
|
||||
|
||||
`diffsynth.core.loader` 提供了统一的 `load_state_dict`,用于加载模型文件中的 state dict。
|
||||
|
||||
加载单个模型文件:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_state_dict
|
||||
|
||||
state_dict = load_state_dict("models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors")
|
||||
```
|
||||
|
||||
加载多个模型文件(合并为一个 state dict):
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_state_dict
|
||||
|
||||
state_dict = load_state_dict([
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
|
||||
])
|
||||
```
|
||||
|
||||
## 模型哈希
|
||||
|
||||
模型哈希是用于判断模型类型的,哈希值可通过 `hash_model_file` 获取:
|
||||
|
||||
```python
|
||||
from diffsynth.core import hash_model_file
|
||||
|
||||
print(hash_model_file("models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors"))
|
||||
```
|
||||
|
||||
也可计算多个模型文件的哈希值,等价于合并 state dict 后计算模型哈希值:
|
||||
|
||||
```python
|
||||
from diffsynth.core import hash_model_file
|
||||
|
||||
print(hash_model_file([
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
|
||||
]))
|
||||
```
|
||||
|
||||
模型哈希值只与模型文件中 state dict 的 keys 和 tensor shape 有关,与模型参数的数值、文件保存时间等信息无关。在计算 `.safetensors` 格式文件的模型哈希值时,`hash_model_file` 是几乎瞬间完成的,无需读取模型的参数;但在计算 `.bin`、`.pth`、`.ckpt` 等二进制文件的模型哈希值时,则需要读取全部模型参数,因此**我们不建议开发者继续使用这些格式的文件。**
|
||||
|
||||
通过[编写模型 Config](/docs/zh/Developer_Guide/Integrating_Your_Model.md#step-3-编写模型-config)并将模型哈希值等信息填入 `diffsynth/configs/model_configs.py`,开发者可以让 `DiffSynth-Studio` 自动识别模型类型并加载。
|
||||
|
||||
## 模型加载
|
||||
|
||||
`load_model` 是 `diffsynth.core.loader` 中加载模型的外部入口,它会调用 [skip_model_initialization](/docs/zh/API_Reference/core/vram.md#跳过模型参数初始化) 跳过模型参数初始化。如果启用了 [Disk Offload](/docs/zh/Pipeline_Usage/VRAM_management.md#disk-offload),则调用 [DiskMap](/docs/zh/API_Reference/core/vram.md#state-dict-硬盘映射) 进行惰性加载;如果没有启用 Disk Offload,则调用 [load_state_dict](#模型文件加载) 加载模型参数。如果需要的话,还会调用 [state dict converter](/docs/zh/Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换) 进行模型格式转换。最后调用 `model.eval()` 将其切换到推理模式。
|
||||
|
||||
以下是一个启用了 Disk Offload 的使用案例:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
|
||||
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
|
||||
import torch
|
||||
|
||||
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
|
||||
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
|
||||
|
||||
model = load_model(
|
||||
QwenImageDiT,
|
||||
model_path,
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
vram_config={
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": "disk",
|
||||
"onload_device": "disk",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
},
|
||||
vram_limit=0,
|
||||
)
|
||||
```
|
||||
66
docs/zh/API_Reference/core/vram.md
Normal file
66
docs/zh/API_Reference/core/vram.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# `diffsynth.core.vram`: 显存管理
|
||||
|
||||
本文档介绍 `diffsynth.core.vram` 中的显存管理底层功能,如果你希望将这些功能用于其他的代码库中,可参考本文档。
|
||||
|
||||
## 跳过模型参数初始化
|
||||
|
||||
在 `PyTorch` 中加载模型时,模型的参数默认会占用显存或内存并进行参数初始化,而这些参数会在加载预训练权重后被覆盖掉,这导致了冗余的计算。`PyTorch` 中没有提供接口来跳过这些冗余的计算,我们在 `diffsynth.core.vram` 中提供了 `skip_model_initialization` 用于跳过模型参数初始化。
|
||||
|
||||
默认的模型加载方式:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_state_dict
|
||||
from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet
|
||||
|
||||
model = QwenImageBlockWiseControlNet() # Slow
|
||||
path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors"
|
||||
state_dict = load_state_dict(path, device="cpu")
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
```
|
||||
|
||||
跳过参数初始化的模型加载方式:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_state_dict, skip_model_initialization
|
||||
from diffsynth.models.qwen_image_controlnet import QwenImageBlockWiseControlNet
|
||||
|
||||
with skip_model_initialization():
|
||||
model = QwenImageBlockWiseControlNet() # Fast
|
||||
path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors"
|
||||
state_dict = load_state_dict(path, device="cpu")
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
```
|
||||
|
||||
在 `DiffSynth-Studio` 中,所有预训练模型都遵循这一加载逻辑。开发者在[接入模型](/docs/zh/Developer_Guide/Integrating_Your_Model.md)完毕后即可直接以这种方式快速加载模型。
|
||||
|
||||
## State Dict 硬盘映射
|
||||
|
||||
对于某个模型的预训练权重文件,如果我们只需要读取其中的一组参数,而非全部参数,State Dict 硬盘映射可以加速这一过程。我们在 `diffsynth.core.vram` 中提供了 `DiskMap` 用于按需加载模型参数。
|
||||
|
||||
默认的权重加载方式:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_state_dict
|
||||
|
||||
path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors"
|
||||
state_dict = load_state_dict(path, device="cpu") # Slow
|
||||
print(state_dict["img_in.weight"])
|
||||
```
|
||||
|
||||
使用 `DiskMap` 只加载特定参数:
|
||||
|
||||
```python
|
||||
from diffsynth.core import DiskMap
|
||||
|
||||
path = "models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny/model.safetensors"
|
||||
state_dict = DiskMap(path, device="cpu") # Fast
|
||||
print(state_dict["img_in.weight"])
|
||||
```
|
||||
|
||||
`DiskMap` 是 `DiffSynth-Studio` 中 Disk Offload 的基本组件,开发者在[配置细粒度显存管理方案](/docs/zh/Developer_Guide/Enabling_VRAM_management.md)后即可直接启用 Disk Offload。
|
||||
|
||||
`DiskMap` 是利用 `.safetensors` 文件的特性实现的功能,因此在使用 `.bin`、`.pth`、`.ckpt` 等二进制文件时,模型的参数是全量加载的,这也导致 Disk Offload 不支持这些格式的文件。**我们不建议开发者继续使用这些格式的文件。**
|
||||
|
||||
## 显存管理可替换模块
|
||||
|
||||
在启用 `DiffSynth-Studio` 的显存管理后,模型内部的模块会被替换为 `diffsynth.core.vram.layers` 中的可替换模块,其使用方式详见[细粒度显存管理方案](/docs/zh/Developer_Guide/Enabling_VRAM_management.md#编写细粒度显存管理方案)。
|
||||
Reference in New Issue
Block a user