Files
DiffSynth-Studio/docs/zh/Developer_Guide/Enabling_VRAM_management.md
Artiprocher d5a0aab2b2 update doc
2025-12-03 16:17:03 +08:00

9.1 KiB
Raw Blame History

细粒度显存管理方案

本文档介绍如何为模型编写合理的细粒度显存管理方案,以及如何将 DiffSynth-Studio 中的显存管理功能用于外部的其他代码库,在阅读本文档前,请先阅读文档显存管理

20B 模型需要多少显存?

以 Qwen-Image 的 DiT 模型为例,这一模型的参数量达到了 20B以下代码会加载这一模型并进行推理需要约 40G 显存,这个模型在显存较小的消费级 GPU 上显然是无法运行的。

from diffsynth.core import load_model
from diffsynth.models.qwen_image_dit import QwenImageDiT
from modelscope import snapshot_download
import torch

snapshot_download(
    model_id="Qwen/Qwen-Image",
    local_dir="models/Qwen/Qwen-Image",
    allow_file_pattern="transformer/*"
)
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
inputs = {
    "latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
    "timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
    "prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
    "prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
    "height": 1024,
    "width": 1024,
}

model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cuda")
with torch.no_grad():
    output = model(**inputs)

编写细粒度显存管理方案

为了编写细粒度的显存管理方案,我们需用 print(model) 观察和分析模型结构:

QwenImageDiT(
  (pos_embed): QwenEmbedRope()
  (time_text_embed): TimestepEmbeddings(
    (time_proj): TemporalTimesteps()
    (timestep_embedder): DiffusersCompatibleTimestepProj(
      (linear_1): Linear(in_features=256, out_features=3072, bias=True)
      (act): SiLU()
      (linear_2): Linear(in_features=3072, out_features=3072, bias=True)
    )
  )
  (txt_norm): RMSNorm()
  (img_in): Linear(in_features=64, out_features=3072, bias=True)
  (txt_in): Linear(in_features=3584, out_features=3072, bias=True)
  (transformer_blocks): ModuleList(
    (0-59): 60 x QwenImageTransformerBlock(
      (img_mod): Sequential(
        (0): SiLU()
        (1): Linear(in_features=3072, out_features=18432, bias=True)
      )
      (img_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
      (attn): QwenDoubleStreamAttention(
        (to_q): Linear(in_features=3072, out_features=3072, bias=True)
        (to_k): Linear(in_features=3072, out_features=3072, bias=True)
        (to_v): Linear(in_features=3072, out_features=3072, bias=True)
        (norm_q): RMSNorm()
        (norm_k): RMSNorm()
        (add_q_proj): Linear(in_features=3072, out_features=3072, bias=True)
        (add_k_proj): Linear(in_features=3072, out_features=3072, bias=True)
        (add_v_proj): Linear(in_features=3072, out_features=3072, bias=True)
        (norm_added_q): RMSNorm()
        (norm_added_k): RMSNorm()
        (to_out): Sequential(
          (0): Linear(in_features=3072, out_features=3072, bias=True)
        )
        (to_add_out): Linear(in_features=3072, out_features=3072, bias=True)
      )
      (img_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
      (img_mlp): QwenFeedForward(
        (net): ModuleList(
          (0): ApproximateGELU(
            (proj): Linear(in_features=3072, out_features=12288, bias=True)
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=12288, out_features=3072, bias=True)
        )
      )
      (txt_mod): Sequential(
        (0): SiLU()
        (1): Linear(in_features=3072, out_features=18432, bias=True)
      )
      (txt_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
      (txt_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
      (txt_mlp): QwenFeedForward(
        (net): ModuleList(
          (0): ApproximateGELU(
            (proj): Linear(in_features=3072, out_features=12288, bias=True)
          )
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=12288, out_features=3072, bias=True)
        )
      )
    )
  )
  (norm_out): AdaLayerNorm(
    (linear): Linear(in_features=3072, out_features=6144, bias=True)
    (norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
  )
  (proj_out): Linear(in_features=3072, out_features=64, bias=True)
)

在显存管理中,我们只关心包含参数的 Layer。在这个模型结构中QwenEmbedRopeTemporalTimestepsSiLU 等 Layer 都是不包含参数的,LayerNorm 也因为设置了 elementwise_affine=False 不包含参数。包含参数的 Layer 只有 LinearRMSNorm

diffsynth.core.vram 中提供了两个用于替换的模块用于显存管理:

  • AutoWrappedLinear: 用于替换 Linear
  • AutoWrappedModule: 用于替换其他任意层

编写一个 module_map,将模型中的 LinearRMSNorm 映射到对应的模块上:

module_map={
    torch.nn.Linear: AutoWrappedLinear,
    RMSNorm: AutoWrappedModule,
}

此外,还需要提供 vram_configvram_limit,这两个参数在显存管理中已有介绍。

调用 enable_vram_management 即可启用显存管理,注意此时模型加载时的 devicecpu,与 offload_device 一致:

from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
import torch

prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
inputs = {
    "latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
    "timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
    "prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
    "prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
    "height": 1024,
    "width": 1024,
}

model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cpu")
enable_vram_management(
    model,
    module_map={
        torch.nn.Linear: AutoWrappedLinear,
        RMSNorm: AutoWrappedModule,
    },
    vram_config = {
        "offload_dtype": torch.bfloat16,
        "offload_device": "cpu",
        "onload_dtype": torch.bfloat16,
        "onload_device": "cpu",
        "preparing_dtype": torch.bfloat16,
        "preparing_device": "cuda",
        "computation_dtype": torch.bfloat16,
        "computation_device": "cuda",
    },
    vram_limit=0,
)
with torch.no_grad():
    output = model(**inputs)

以上代码只需要 2G 显存就可以运行 20B 模型的 forward

Disk Offload

Disk Offload 是特殊的显存管理方案需在模型加载过程中启用而非模型加载完毕后。通常在以上代码能够顺利运行的前提下Disk Offload 可以直接启用:

from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
import torch

prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
inputs = {
    "latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
    "timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
    "prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
    "prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
    "height": 1024,
    "width": 1024,
}

model = load_model(
    QwenImageDiT,
    model_path,
    module_map={
        torch.nn.Linear: AutoWrappedLinear,
        RMSNorm: AutoWrappedModule,
    },
    vram_config={
        "offload_dtype": "disk",
        "offload_device": "disk",
        "onload_dtype": "disk",
        "onload_device": "disk",
        "preparing_dtype": torch.bfloat16,
        "preparing_device": "cuda",
        "computation_dtype": torch.bfloat16,
        "computation_device": "cuda",
    },
    vram_limit=0,
)
with torch.no_grad():
    output = model(**inputs)

Disk Offload 是极为特殊的显存管理方案,只支持 .safetensors 格式文件,不支持 .bin.pth.ckpt 等二进制文件,不支持带 Tensor reshape 的 state dict converter

如果出现非 Disk Offload 能正常运行但 Disk Offload 不能正常运行的情况,请在 GitHub 上给我们提 issue。

写入默认配置

为了让用户能够更方便地使用显存管理功能,我们将细粒度显存管理的配置写在 diffsynth/configs/vram_management_module_maps.py 中,上述模型的配置信息为:

"diffsynth.models.qwen_image_dit.QwenImageDiT": {
    "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
    "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
}