From 6a6eca7baf0f360c2ff861dea1a132dc9f69cb64 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Wed, 5 Nov 2025 20:37:11 +0800 Subject: [PATCH] update doc and code --- diffsynth/core/loader/file.py | 21 +- .../Enabling_VRAM_management.md | 228 ++++++++++++++++++ docs/Pipeline_Usage/VRAM_management.md | 2 + 3 files changed, 247 insertions(+), 4 deletions(-) diff --git a/diffsynth/core/loader/file.py b/diffsynth/core/loader/file.py index 8817cd1..5c5e13a 100644 --- a/diffsynth/core/loader/file.py +++ b/diffsynth/core/loader/file.py @@ -26,6 +26,11 @@ def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"): def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"): state_dict = torch.load(file_path, map_location=device, weights_only=True) + if len(state_dict) == 1: + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + elif "module" in state_dict: + state_dict = state_dict["module"] if torch_dtype is not None: for i in state_dict: if isinstance(state_dict[i], torch.Tensor): @@ -75,11 +80,19 @@ def load_keys_dict_from_safetensors(file_path): return keys_dict -def load_keys_dict_from_bin(file_path): - state_dict = load_state_dict_from_bin(file_path) +def convert_state_dict_to_keys_dict(state_dict): keys_dict = {} for k, v in state_dict.items(): - keys_dict[k] = list(v.shape) + if isinstance(v, torch.Tensor): + keys_dict[k] = list(v.shape) + else: + keys_dict[k] = convert_state_dict_to_keys_dict(v) + return keys_dict + + +def load_keys_dict_from_bin(file_path): + state_dict = load_state_dict_from_bin(file_path) + keys_dict = convert_state_dict_to_keys_dict(state_dict) return keys_dict @@ -88,7 +101,7 @@ def convert_keys_dict_to_single_str(state_dict, with_shape=True): for key, value in state_dict.items(): if isinstance(key, str): if isinstance(value, dict): - keys.append(key + "|" + convert_state_dict_keys_to_single_str(value, with_shape=with_shape)) + keys.append(key + "|" + convert_keys_dict_to_single_str(value, with_shape=with_shape)) else: if with_shape: shape = "_".join(map(str, list(value))) diff --git a/docs/Developer_Guide/Enabling_VRAM_management.md b/docs/Developer_Guide/Enabling_VRAM_management.md index e69de29..4389961 100644 --- a/docs/Developer_Guide/Enabling_VRAM_management.md +++ b/docs/Developer_Guide/Enabling_VRAM_management.md @@ -0,0 +1,228 @@ +# 启用显存管理 + +本文档介绍如何为模型编写合理的细粒度显存管理方案,以及如何将 `DiffSynth-Studio` 中的显存管理功能用于外部的其他代码库,在阅读本文档前,请先阅读文档[显存管理](../Pipeline_Usage/VRAM_management.md)。 + +## 20B 模型需要多少显存? + +以 Qwen-Image 的 DiT 模型为例,这一模型的参数量达到了 20B,以下代码会加载这一模型并进行推理,需要约 40G 显存,这个模型在显存较小的消费级 GPU 上显然是无法运行的。 + +```python +from diffsynth.core import load_model +from diffsynth.models.qwen_image_dit import QwenImageDiT +from modelscope import snapshot_download +import torch + +snapshot_download( + model_id="Qwen/Qwen-Image", + local_dir="models/Qwen/Qwen-Image", + allow_file_pattern="transformer/*" +) +prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model" +model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)] +inputs = { + "latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"), + "timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"), + "prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"), + "prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"), + "height": 1024, + "width": 1024, +} + +model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cuda") +with torch.no_grad(): + output = model(**inputs) +``` + +## 细粒度显存管理方案 + +为了编写细粒度的显存管理方案,我们需用 `print(model)` 观察和分析模型结构: + +``` +QwenImageDiT( + (pos_embed): QwenEmbedRope() + (time_text_embed): TimestepEmbeddings( + (time_proj): TemporalTimesteps() + (timestep_embedder): DiffusersCompatibleTimestepProj( + (linear_1): Linear(in_features=256, out_features=3072, bias=True) + (act): SiLU() + (linear_2): Linear(in_features=3072, out_features=3072, bias=True) + ) + ) + (txt_norm): RMSNorm() + (img_in): Linear(in_features=64, out_features=3072, bias=True) + (txt_in): Linear(in_features=3584, out_features=3072, bias=True) + (transformer_blocks): ModuleList( + (0-59): 60 x QwenImageTransformerBlock( + (img_mod): Sequential( + (0): SiLU() + (1): Linear(in_features=3072, out_features=18432, bias=True) + ) + (img_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False) + (attn): QwenDoubleStreamAttention( + (to_q): Linear(in_features=3072, out_features=3072, bias=True) + (to_k): Linear(in_features=3072, out_features=3072, bias=True) + (to_v): Linear(in_features=3072, out_features=3072, bias=True) + (norm_q): RMSNorm() + (norm_k): RMSNorm() + (add_q_proj): Linear(in_features=3072, out_features=3072, bias=True) + (add_k_proj): Linear(in_features=3072, out_features=3072, bias=True) + (add_v_proj): Linear(in_features=3072, out_features=3072, bias=True) + (norm_added_q): RMSNorm() + (norm_added_k): RMSNorm() + (to_out): Sequential( + (0): Linear(in_features=3072, out_features=3072, bias=True) + ) + (to_add_out): Linear(in_features=3072, out_features=3072, bias=True) + ) + (img_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False) + (img_mlp): QwenFeedForward( + (net): ModuleList( + (0): ApproximateGELU( + (proj): Linear(in_features=3072, out_features=12288, bias=True) + ) + (1): Dropout(p=0.0, inplace=False) + (2): Linear(in_features=12288, out_features=3072, bias=True) + ) + ) + (txt_mod): Sequential( + (0): SiLU() + (1): Linear(in_features=3072, out_features=18432, bias=True) + ) + (txt_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False) + (txt_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False) + (txt_mlp): QwenFeedForward( + (net): ModuleList( + (0): ApproximateGELU( + (proj): Linear(in_features=3072, out_features=12288, bias=True) + ) + (1): Dropout(p=0.0, inplace=False) + (2): Linear(in_features=12288, out_features=3072, bias=True) + ) + ) + ) + ) + (norm_out): AdaLayerNorm( + (linear): Linear(in_features=3072, out_features=6144, bias=True) + (norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False) + ) + (proj_out): Linear(in_features=3072, out_features=64, bias=True) +) +``` + +在显存管理中,我们只关心包含参数的 Layer。在这个模型结构中,`QwenEmbedRope`、`TemporalTimesteps`、`SiLU` 等 Layer 都是不包含参数的,`LayerNorm` 也因为设置了 `elementwise_affine=False` 不包含参数。包含参数的 Layer 只有 `Linear` 和 `RMSNorm`。 + +`diffsynth.core.vram` 中提供了两个用于替换的模块用于显存管理: +* `AutoWrappedLinear`: 用于替换 `Linear` 层 +* `AutoWrappedModule`: 用于替换其他任意层 + +编写一个 `module_map`,将模型中的 `Linear` 和 `RMSNorm` 映射到对应的模块上: + +```python +module_map={ + torch.nn.Linear: AutoWrappedLinear, + RMSNorm: AutoWrappedModule, +} +``` + +此外,还需要提供 `vram_config` 与 `vram_limit`,这两个参数在[显存管理](../Pipeline_Usage/VRAM_management.md#更多使用方式)中已有介绍。 + +调用 `enable_vram_management` 即可启用显存管理,注意此时模型加载时的 `device` 为 `cpu`,与 `offload_device` 一致: + +```python +from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule +from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm +import torch + +prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model" +model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)] +inputs = { + "latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"), + "timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"), + "prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"), + "prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"), + "height": 1024, + "width": 1024, +} + +model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cpu") +enable_vram_management( + model, + module_map={ + torch.nn.Linear: AutoWrappedLinear, + RMSNorm: AutoWrappedModule, + }, + vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", + }, + vram_limit=0, +) +with torch.no_grad(): + output = model(**inputs) +``` + +以上代码只需要 2G 显存就可以运行 20B 模型的 `forward`。 + +## Disk Offload + +[Disk Offload](../Pipeline_Usage/VRAM_management.md#disk-offload) 是特殊的显存管理方案,需在模型加载过程中启用,而非模型加载完毕后。通常,在以上代码能够顺利运行的前提下,Disk Offload 可以直接启用: + +```python +from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule +from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm +import torch + +prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model" +model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)] +inputs = { + "latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"), + "timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"), + "prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"), + "prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"), + "height": 1024, + "width": 1024, +} + +model = load_model( + QwenImageDiT, + model_path, + module_map={ + torch.nn.Linear: AutoWrappedLinear, + RMSNorm: AutoWrappedModule, + }, + vram_config={ + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", + }, + vram_limit=0, +) +with torch.no_grad(): + output = model(**inputs) +``` + +Disk Offload 是极为特殊的显存管理方案,只支持 `.safetensors` 格式文件,不支持 `.bin`、`.pth`、`.ckpt` 等二进制文件,不支持带 Tensor reshape 的 [state dict converter](../Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换)。 + +如果出现非 Disk Offload 能正常运行但 Disk Offload 不能正常运行的情况,请在 GitHub 上给我们提 issue。 + +## 写入默认配置 + +为了让用户能够更方便地使用显存管理功能,我们将细粒度显存管理的配置写在 `diffsynth/configs/vram_management_module_maps.py` 中,上述模型的配置信息为: + +```python +"diffsynth.models.qwen_image_dit.QwenImageDiT": { + "diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule", + "torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear", +} +``` diff --git a/docs/Pipeline_Usage/VRAM_management.md b/docs/Pipeline_Usage/VRAM_management.md index 15dcc48..cc7ddf9 100644 --- a/docs/Pipeline_Usage/VRAM_management.md +++ b/docs/Pipeline_Usage/VRAM_management.md @@ -140,6 +140,8 @@ image.save("image.jpg") 在更为极端的情况下,当内存也不足以存储整个模型时,Disk Offload 功能可以让模型参数惰性加载,即,模型中的每个 Layer 仅在调用 forward 时才会从硬盘中读取相应的参数。启用这一功能时,我们建议使用高速的 SSD 硬盘。 +Disk Offload 是极为特殊的显存管理方案,只支持 `.safetensors` 格式文件,不支持 `.bin`、`.pth`、`.ckpt` 等二进制文件,不支持带 Tensor reshape 的 [state dict converter](../Developer_Guide/Integrating_Your_Model.md#step-2-模型文件格式转换)。 + ```python from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig import torch