Files
DiffSynth-Studio/docs/Developer_Guide/Enabling_VRAM_management.md
2025-11-05 20:37:11 +08:00

229 lines
9.1 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 启用显存管理
本文档介绍如何为模型编写合理的细粒度显存管理方案,以及如何将 `DiffSynth-Studio` 中的显存管理功能用于外部的其他代码库,在阅读本文档前,请先阅读文档[显存管理](../Pipeline_Usage/VRAM_management.md)。
## 20B 模型需要多少显存?
以 Qwen-Image 的 DiT 模型为例,这一模型的参数量达到了 20B以下代码会加载这一模型并进行推理需要约 40G 显存,这个模型在显存较小的消费级 GPU 上显然是无法运行的。
```python
from diffsynth.core import load_model
from diffsynth.models.qwen_image_dit import QwenImageDiT
from modelscope import snapshot_download
import torch
snapshot_download(
model_id="Qwen/Qwen-Image",
local_dir="models/Qwen/Qwen-Image",
allow_file_pattern="transformer/*"
)
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
inputs = {
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
"height": 1024,
"width": 1024,
}
model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cuda")
with torch.no_grad():
output = model(**inputs)
```
## 细粒度显存管理方案
为了编写细粒度的显存管理方案,我们需用 `print(model)` 观察和分析模型结构:
```
QwenImageDiT(
(pos_embed): QwenEmbedRope()
(time_text_embed): TimestepEmbeddings(
(time_proj): TemporalTimesteps()
(timestep_embedder): DiffusersCompatibleTimestepProj(
(linear_1): Linear(in_features=256, out_features=3072, bias=True)
(act): SiLU()
(linear_2): Linear(in_features=3072, out_features=3072, bias=True)
)
)
(txt_norm): RMSNorm()
(img_in): Linear(in_features=64, out_features=3072, bias=True)
(txt_in): Linear(in_features=3584, out_features=3072, bias=True)
(transformer_blocks): ModuleList(
(0-59): 60 x QwenImageTransformerBlock(
(img_mod): Sequential(
(0): SiLU()
(1): Linear(in_features=3072, out_features=18432, bias=True)
)
(img_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
(attn): QwenDoubleStreamAttention(
(to_q): Linear(in_features=3072, out_features=3072, bias=True)
(to_k): Linear(in_features=3072, out_features=3072, bias=True)
(to_v): Linear(in_features=3072, out_features=3072, bias=True)
(norm_q): RMSNorm()
(norm_k): RMSNorm()
(add_q_proj): Linear(in_features=3072, out_features=3072, bias=True)
(add_k_proj): Linear(in_features=3072, out_features=3072, bias=True)
(add_v_proj): Linear(in_features=3072, out_features=3072, bias=True)
(norm_added_q): RMSNorm()
(norm_added_k): RMSNorm()
(to_out): Sequential(
(0): Linear(in_features=3072, out_features=3072, bias=True)
)
(to_add_out): Linear(in_features=3072, out_features=3072, bias=True)
)
(img_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
(img_mlp): QwenFeedForward(
(net): ModuleList(
(0): ApproximateGELU(
(proj): Linear(in_features=3072, out_features=12288, bias=True)
)
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=12288, out_features=3072, bias=True)
)
)
(txt_mod): Sequential(
(0): SiLU()
(1): Linear(in_features=3072, out_features=18432, bias=True)
)
(txt_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
(txt_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
(txt_mlp): QwenFeedForward(
(net): ModuleList(
(0): ApproximateGELU(
(proj): Linear(in_features=3072, out_features=12288, bias=True)
)
(1): Dropout(p=0.0, inplace=False)
(2): Linear(in_features=12288, out_features=3072, bias=True)
)
)
)
)
(norm_out): AdaLayerNorm(
(linear): Linear(in_features=3072, out_features=6144, bias=True)
(norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
)
(proj_out): Linear(in_features=3072, out_features=64, bias=True)
)
```
在显存管理中,我们只关心包含参数的 Layer。在这个模型结构中`QwenEmbedRope``TemporalTimesteps``SiLU` 等 Layer 都是不包含参数的,`LayerNorm` 也因为设置了 `elementwise_affine=False` 不包含参数。包含参数的 Layer 只有 `Linear``RMSNorm`
`diffsynth.core.vram` 中提供了两个用于替换的模块用于显存管理:
* `AutoWrappedLinear`: 用于替换 `Linear`
* `AutoWrappedModule`: 用于替换其他任意层
编写一个 `module_map`,将模型中的 `Linear``RMSNorm` 映射到对应的模块上:
```python
module_map={
torch.nn.Linear: AutoWrappedLinear,
RMSNorm: AutoWrappedModule,
}
```
此外,还需要提供 `vram_config``vram_limit`,这两个参数在[显存管理](../Pipeline_Usage/VRAM_management.md#更多使用方式)中已有介绍。
调用 `enable_vram_management` 即可启用显存管理,注意此时模型加载时的 `device``cpu`,与 `offload_device` 一致:
```python
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
import torch
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
inputs = {
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
"height": 1024,
"width": 1024,
}
model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cpu")
enable_vram_management(
model,
module_map={
torch.nn.Linear: AutoWrappedLinear,
RMSNorm: AutoWrappedModule,
},
vram_config = {
"offload_dtype": torch.bfloat16,
"offload_device": "cpu",
"onload_dtype": torch.bfloat16,
"onload_device": "cpu",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
},
vram_limit=0,
)
with torch.no_grad():
output = model(**inputs)
```
以上代码只需要 2G 显存就可以运行 20B 模型的 `forward`
## Disk Offload
[Disk Offload](../Pipeline_Usage/VRAM_management.md#disk-offload) 是特殊的显存管理方案需在模型加载过程中启用而非模型加载完毕后。通常在以上代码能够顺利运行的前提下Disk Offload 可以直接启用:
```python
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
import torch
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
inputs = {
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
"height": 1024,
"width": 1024,
}
model = load_model(
QwenImageDiT,
model_path,
module_map={
torch.nn.Linear: AutoWrappedLinear,
RMSNorm: AutoWrappedModule,
},
vram_config={
"offload_dtype": "disk",
"offload_device": "disk",
"onload_dtype": "disk",
"onload_device": "disk",
"preparing_dtype": torch.bfloat16,
"preparing_device": "cuda",
"computation_dtype": torch.bfloat16,
"computation_device": "cuda",
},
vram_limit=0,
)
with torch.no_grad():
output = model(**inputs)
```
Disk Offload 是极为特殊的显存管理方案,只支持 `.safetensors` 格式文件,不支持 `.bin``.pth``.ckpt` 等二进制文件,不支持带 Tensor reshape 的 [state dict converter](../Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换)。
如果出现非 Disk Offload 能正常运行但 Disk Offload 不能正常运行的情况,请在 GitHub 上给我们提 issue。
## 写入默认配置
为了让用户能够更方便地使用显存管理功能,我们将细粒度显存管理的配置写在 `diffsynth/configs/vram_management_module_maps.py` 中,上述模型的配置信息为:
```python
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
}
```