diff --git a/diffsynth/core/vram/layers.py b/diffsynth/core/vram/layers.py index f9d2bf2..cdfff0b 100644 --- a/diffsynth/core/vram/layers.py +++ b/diffsynth/core/vram/layers.py @@ -158,7 +158,7 @@ class AutoWrappedModule(AutoTorchModule): if self.state < 1: if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk": self.load_from_disk(self.onload_dtype, self.onload_device) - else: + elif self.onload_device != "disk": self.to(dtype=self.onload_dtype, device=self.onload_device) self.state = 1 @@ -167,7 +167,7 @@ class AutoWrappedModule(AutoTorchModule): if self.state != 2: if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk": self.load_from_disk(self.preparing_dtype, self.preparing_device) - else: + elif self.preparing_device != "disk": self.to(dtype=self.preparing_dtype, device=self.preparing_device) self.state = 2 @@ -308,7 +308,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): if self.state < 1: if self.disk_offload and self.onload_device != "disk" and self.offload_device == "disk": self.load_from_disk(self.onload_dtype, self.onload_device) - else: + elif self.onload_device != "disk": self.to(dtype=self.onload_dtype, device=self.onload_device) self.state = 1 @@ -317,7 +317,7 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): if self.state != 2: if self.disk_offload and self.preparing_device != "disk" and self.onload_device == "disk": self.load_from_disk(self.preparing_dtype, self.preparing_device) - else: + elif self.preparing_device != "disk": self.to(dtype=self.preparing_dtype, device=self.preparing_device) self.state = 2 diff --git a/diffsynth/models/model_loader.py b/diffsynth/models/model_loader.py index 76d69c3..0c5a019 100644 --- a/diffsynth/models/model_loader.py +++ b/diffsynth/models/model_loader.py @@ -1,7 +1,7 @@ from ..core.loader import load_model, hash_model_file from ..core.vram import AutoWrappedModule from ..configs import MODEL_CONFIGS, VRAM_MANAGEMENT_MODULE_MAPS -import importlib, json +import importlib, json, torch class ModelPool: @@ -46,8 +46,23 @@ class ModelPool: ) return model - def auto_load_model(self, path, vram_config, vram_limit=None): + def default_vram_config(self): + vram_config = { + "offload_dtype": None, + "offload_device": None, + "onload_dtype": torch.bfloat16, + "onload_device": "cpu", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cpu", + "computation_dtype": torch.bfloat16, + "computation_device": "cpu", + } + return vram_config + + def auto_load_model(self, path, vram_config=None, vram_limit=None): print(f"Loading models from: {json.dumps(path, indent=4)}") + if vram_config is None: + vram_config = self.default_vram_config() model_hash = hash_model_file(path) loaded = False for config in MODEL_CONFIGS: diff --git a/docs/API_Reference/core/attention.md b/docs/API_Reference/core/attention.md index a51f98d..1aed20a 100644 --- a/docs/API_Reference/core/attention.md +++ b/docs/API_Reference/core/attention.md @@ -5,12 +5,14 @@ ## 注意力机制 注意力机制是在论文[《Attention Is All You Need》](https://arxiv.org/abs/1706.03762)中提出的模型结构,在原论文中,注意力机制按照如下公式实现: + $$ \text{Attention}(Q, K, V) = \text{Softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) V. $$ + 在 `PyTorch` 中,可以用如下代码实现: ```python import torch @@ -66,6 +68,10 @@ print((output_1 - output_2).abs().mean()) 请注意,加速的同时会引入误差,但在大多数情况下误差是可以忽略不计的。 +## 开发者导引 + +在为 `DiffSynth-Studio` 接入新模型时,开发者可自行决定是否调用 `diffsynth.core.attention` 中的 `attention_forward`,但我们期望模型能够尽可能优先调用这一模块,以便让新的注意力机制实现能够在这些模型上直接生效。 + ## 最佳实践 **在大多数情况下,我们建议直接使用 `PyTorch` 原生的实现,无需安装任何额外的包。** 虽然其他注意力机制实现可以加速,但加速效果是较为有限的,在少数情况下会出现兼容性和精度不足的问题。 diff --git a/docs/API_Reference/core/gradient.md b/docs/API_Reference/core/gradient.md index e69de29..59dd29a 100644 --- a/docs/API_Reference/core/gradient.md +++ b/docs/API_Reference/core/gradient.md @@ -0,0 +1,69 @@ +# `diffsynth.core.gradient`: 梯度检查点 + +`diffsynth.core.gradient` 中提供了封装好的梯度检查点及其 Offload 版本,用于模型训练。 + +## 梯度检查点 + +梯度检查点是用于减少训练时显存占用的技术。我们提供一个例子来帮助你理解这一技术,以下是一个简单的模型结构 + +```python +import torch + +class ToyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.activation = torch.nn.Sigmoid() + + def forward(self, x): + return self.activation(x) + +model = ToyModel() +x = torch.randn((2, 3)) +y = model(x) +``` + +在这个模型结构中,输入的参数 $x$ 经过 Sigmoid 激活函数得到输出值 $y=\frac{1}{1+e^{-x}}$。 + +在训练过程中,假定我们的损失函数值为 $\mathcal L$,在梯度反响传播时,我们得到 $\frac{\partial \mathcal L}{\partial y}$,此时我们需计算 $\frac{\partial \mathcal L}{\partial x}$,不难发现 $\frac{\partial y}{\partial x}=y(1-y)$,进而有 $\frac{\partial \mathcal L}{\partial x}=\frac{\partial \mathcal L}{\partial y}\frac{\partial y}{\partial x}=\frac{\partial \mathcal L}{\partial y}y(1-y)$。如果在模型前向传播时保存 $y$ 的数值,并在梯度反向传播时直接计算 $y(1-y)$,这将避免复杂的 exp 计算,加快计算速度,但这会导致我们需要额外的显存来存储中间变量 $y$。 + +不启用梯度检查点时,训练框架会默认存储所有辅助梯度计算的中间变量,从而达到最佳的计算速度。开启梯度检查点时,中间变量则不会存储,但输入参数 $x$ 仍会存储,减少显存占用,在梯度反向传播时需重新计算这些变量,减慢计算速度。 + +## 启用梯度检查点及其 Offload + +`diffsynth.core.gradient` 中的 `gradient_checkpoint_forward` 实现了梯度检查点及其 Offload,可参考以下代码调用: + +```python +import torch +from diffsynth.core.gradient import gradient_checkpoint_forward + +class ToyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.activation = torch.nn.Sigmoid() + + def forward(self, x): + return self.activation(x) + +model = ToyModel() +x = torch.randn((2, 3)) +y = gradient_checkpoint_forward( + model, + use_gradient_checkpointing=True, + use_gradient_checkpointing_offload=False, + x=x, +) +``` + +* 当 `use_gradient_checkpointing=False` 且 `use_gradient_checkpointing_offload=False` 时,计算过程与原始计算完全相同,不影响模型的推理和训练,你可以直接将其集成到代码中。 +* 当 `use_gradient_checkpointing=True` 且 `use_gradient_checkpointing_offload=False` 时,启用梯度检查点。 +* 当 `use_gradient_checkpointing_offload=True` 时,启用梯度检查点,所有梯度检查点的输入参数存储在内存中,进一步降低显存占用和减慢计算速度。 + +## 最佳实践 + +> Q: 应当在何处启用梯度检查点? +> +> A: 对整个模型启用梯度检查点时,计算效率和显存占用并不是最优的,我们需要设置细粒度的梯度检查点,但同时不希望为框架增加过多繁杂的代码。因此我们建议在 `Pipeline` 的 `model_fn` 中实现,例如 `diffsynth/pipelines/qwen_image.py` 中的 `model_fn_qwen_image`,在 Block 层级启用梯度检查点,不需要修改模型结构的任何代码。 + +> Q: 什么情况下需要启用梯度检查点? +> +> A: 随着模型参数量越来越大,梯度检查点已成为必要的训练技术,梯度检查点通常是需要启用的。梯度检查点的 Offload 则仅需在激活值占用显存过大的模型(例如视频生成模型)中启用。 diff --git a/docs/Developer_Guide/Enabling_VRAM_management.md b/docs/Developer_Guide/Enabling_VRAM_management.md new file mode 100644 index 0000000..e69de29 diff --git a/docs/Developer_Guide/Integrating_Your_Model.md b/docs/Developer_Guide/Integrating_Your_Model.md index 05be548..f5d5a3c 100644 --- a/docs/Developer_Guide/Integrating_Your_Model.md +++ b/docs/Developer_Guide/Integrating_Your_Model.md @@ -147,6 +147,40 @@ if hash_model_file(model_path) == model_hash: `diffsynth/configs/model_configs.py` 中的 `model_hash` 不是唯一存在的,同一模型文件中可能存在多个模型。对于这种情况,请使用多个模型 Config 分别加载每个模型,编写相应的 `state_dict_converter` 分离每个模型所需的参数。 -## Step 4: 编写模型显存管理方案 +## Step 4: 检验模型是否能被识别和加载 -`DiffSynth-Studio` 支持复杂的显存管理,详见[启用显存管理](./Enabling_VRAM_management.py)。 +模型接入之后,可通过以下代码验证模型是否能够被正确识别和加载,以下代码会试图将模型加载到内存中: + +```python +from diffsynth.models.model_loader import ModelPool + +model_pool = ModelPool() +model_pool.auto_load_model( + [ + "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors", + ], +) +``` + +如果模型能够被识别和加载,则会看到以下输出内容: + +``` +Loading models from: [ + "models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors", + "models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors" +] +Loaded model: { + "model_name": "qwen_image_text_encoder", + "model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder", + "extra_kwargs": null +} +``` + +## Step 5: 编写模型显存管理方案 + +`DiffSynth-Studio` 支持复杂的显存管理,详见[启用显存管理](./Enabling_VRAM_management.md)。 diff --git a/docs/Pipeline_Usage/VRAM_management.md b/docs/Pipeline_Usage/VRAM_management.md new file mode 100644 index 0000000..15dcc48 --- /dev/null +++ b/docs/Pipeline_Usage/VRAM_management.md @@ -0,0 +1,204 @@ +# 显存管理 + +显存管理是 `DiffSynth-Studio` 的特色功能,能够让低显存的 GPU 能够运行参数量巨大的模型推理。本文档以 Qwen-Image 为例,介绍显存管理方案的使用。 + +## 基础推理 + +以下代码中没有启用任何显存管理,显存占用 56G,作为参考。 + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +## CPU Offload + +由于模型 `Pipeline` 包括多个组件,这些组件并非同时调用的,因此我们可以在某些组件不需要参与计算时将其移至内存,减少显存占用,以下代码可以实现这一逻辑,显存占用 40G。 + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.bfloat16, + "offload_device": "cpu", + "onload_dtype": torch.bfloat16, + "onload_device": "cuda", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +## FP8 量化 + +在 CPU Offload 的基础上,我们进一步启用 FP8 量化来减少显存需求,以下代码可以令模型参数以 FP8 精度存储在显存中,并在推理时临时转为 BF16 精度计算,显存占用 21G。但这种量化方案有微小的图像质量下降问题。 + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cuda", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +> Q: 为什么要在推理时临时转为 BF16 精度,而不是以 FP8 精度计算? +> +> A: FP8 的原生计算仅在 Hopper 架构的 GPU(例如 H20)支持,且计算误差很大,我们目前暂不开放 FP8 精度计算。目前的 FP8 量化仅能减少显存占用,不会提高计算速度。 + +## 动态显存管理 + +在 CPU Offload 中,我们对模型组件进行控制,事实上,我们支持做到 Layer 级别的 Offload,将一个模型拆分为多个 Layer,令一部分常驻显存,令一部分存储在内存中按需移至显存计算。这一功能需要模型开发者针对每个模型提供详细的显存管理方案,相关配置在 `diffsynth/configs/vram_management_module_maps.py` 中。 + +通过在 `Pipeline` 中增加 `vram_limit` 参数,框架可以自动感知设备的剩余显存并决定如何拆分模型到显存和内存中。`vram_limit` 越小,占用显存越少,速度越慢。 +* `vram_limit=None` 时,即默认状态,框架认为显存无限,动态显存管理是不启用的 +* `vram_limit=10` 时,框架会在显存占用超过 10G 之后限制模型,将超出的部分移至内存中存储。 +* `vram_limit=0` 时,框架会尽全力减少显存占用,所有模型参数都存储在内存中,仅在必要时移至显存计算 + +在显存不足以运行模型推理的情况下,框架会试图超出 `vram_limit` 的限制从而让模型推理运行下去,因此显存管理框架并不能总是保证占用的显存小于 `vram_limit`,我们建议将其设置为略小于实际可用显存的数值,例如 GPU 显存为 16G 时,设置为 `vram_limit=15.5`。`PyTorch` 中可用 `torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3)` 获取 GPU 的显存。 + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": torch.float8_e4m3fn, + "offload_device": "cpu", + "onload_dtype": torch.float8_e4m3fn, + "onload_device": "cpu", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024 ** 3) - 0.5, +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +## Disk Offload + +在更为极端的情况下,当内存也不足以存储整个模型时,Disk Offload 功能可以让模型参数惰性加载,即,模型中的每个 Layer 仅在调用 forward 时才会从硬盘中读取相应的参数。启用这一功能时,我们建议使用高速的 SSD 硬盘。 + +```python +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.float8_e4m3fn, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", **vram_config), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", **vram_config), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), + vram_limit=10, +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=40) +image.save("image.jpg") +``` + +## 更多使用方式 + +`vram_config` 中的信息可自行填写,例如不开 FP8 量化的 Disk Offload: + +```python +vram_config = { + "offload_dtype": "disk", + "offload_device": "disk", + "onload_dtype": "disk", + "onload_device": "disk", + "preparing_dtype": torch.bfloat16, + "preparing_device": "cuda", + "computation_dtype": torch.bfloat16, + "computation_device": "cuda", +} +``` + +具体地,显存管理模块会将模型的 Layer 分为以下四种状态: + +* Offload:短期内不调用这个模型,这个状态由 `Pipeline` 控制切换 +* Onload:接下来随时要调用这个模型,这个状态由 `Pipeline` 控制切换 +* Preparing:Onload 和 Computation 的中间状态,在显存允许的前提下的暂存状态,这个状态由显存管理机制控制切换,当且仅当【vram_limit 设置为无限制】或【vram_limit 已设置且有空余显存】时会进入这一状态 +* Computation:模型正在计算过程中,这个状态由显存管理机制控制切换,仅在 `forward` 中临时进入 + +如果你是模型开发者,希望自行控制某个模型的显存管理粒度,请参考[../Developer_Guide/Enabling_VRAM_management.md](../Developer_Guide/Enabling_VRAM_management.md)。 + +## 最佳实践 + +* 显存足够 -> 使用[基础推理](#基础推理) +* 显存不足 + * 内存足够 -> 使用[动态显存管理](#动态显存管理) + * 内存不足 -> 使用[Disk Offload](#disk-offload)