mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
diffsynth 2.0 prototype
This commit is contained in:
243
docs/Developer_Guide/Building_a_Pipeline.md
Normal file
243
docs/Developer_Guide/Building_a_Pipeline.md
Normal file
@@ -0,0 +1,243 @@
|
||||
# Pipeline 构建
|
||||
|
||||
在[将 Pipeline 所需的模型接入](./Integrating_Your_Model.md)之后,还需构建 `Pipeline` 用于模型推理,本文档提供 `Pipeline` 构建的标准化流程,开发者也可参考现有的 `Pipeline` 进行构建。
|
||||
|
||||
`Pipeline` 的实现位于 `diffsynth/pipelines`,每个 `Pipeline` 包含以下必要的关键组件:
|
||||
|
||||
* `__init__`
|
||||
* `from_pretrained`
|
||||
* `__call__`
|
||||
* `units`
|
||||
* `model_fn`
|
||||
|
||||
## `__init__`
|
||||
|
||||
在 `__init__` 中,`Pipeline` 进行初始化,以下是一个简易的实现:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from PIL import Image
|
||||
from typing import Union
|
||||
from tqdm import tqdm
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
from ..models.new_models import XXX_Model, YYY_Model, ZZZ_Model
|
||||
|
||||
class NewDiffSynthPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = FlowMatchScheduler()
|
||||
self.text_encoder: XXX_Model = None
|
||||
self.dit: YYY_Model = None
|
||||
self.vae: ZZZ_Model = None
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
NewDiffSynthPipelineUnit_xxx(),
|
||||
...
|
||||
]
|
||||
self.model_fn = model_fn_new
|
||||
```
|
||||
|
||||
其中包括以下几部分
|
||||
|
||||
* `scheduler`: 调度器,用于控制推理的迭代公式中的系数,控制每一步的噪声含量。
|
||||
* `text_encoder`、`dit`、`vae`: 模型,自 [Latent Diffusion](https://arxiv.org/abs/2112.10752) 被提出以来,这种三段式模型架构已成为主流的 Diffusion 模型架构,但这并不是一成不变的,`Pipeline` 中可添加任意多个模型。
|
||||
* `in_iteration_models`: 迭代中模型,这个元组标注了在迭代中会调用哪些模型。
|
||||
* `units`: 模型迭代的前处理单元,详见[`units`](#units)。
|
||||
* `model_fn`: 迭代中去噪模型的 `forward` 函数,详见[`model_fn`](#model_fn)。
|
||||
|
||||
> Q: 模型加载并不发生在 `__init__`,为什么这里仍要将每个模型初始化为 `None`?
|
||||
>
|
||||
> A: 在这里标注每个模型的类型后,代码编辑器就可以根据每个模型提供代码补全提示,便于后续的开发。
|
||||
|
||||
## `from_pretrained`
|
||||
|
||||
`from_pretrained` 负责加载所需的模型,让 `Pipeline` 变成可调用的状态。以下是一个简易的实现:
|
||||
|
||||
```python
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
vram_limit: float = None,
|
||||
):
|
||||
# Initialize pipeline
|
||||
pipe = NewDiffSynthPipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("xxx_text_encoder")
|
||||
pipe.dit = model_pool.fetch_model("yyy_dit")
|
||||
pipe.vae = model_pool.fetch_model("zzz_vae")
|
||||
# If necessary, load tokenizers here.
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
```
|
||||
|
||||
开发者需要实现其中获取模型的逻辑,对应的模型名称即为[模型接入时填写的模型 Config](Integrating_Your_Model.md#step-3-编写模型-config) 中的 `"model_name"`。
|
||||
|
||||
部分模型还需要加载 `tokenizer`,可根据需要在 `from_pretrained` 上添加额外的 `tokenizer_config` 参数并在获取模型后实现这部分。
|
||||
|
||||
## `__call__`
|
||||
|
||||
`__call__` 实现了整个 Pipeline 的生成过程,以下是常见的生成过程模板,开发者可根据需要在此基础上修改。
|
||||
|
||||
```python
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 4.0,
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
height: int = 1328,
|
||||
width: int = 1328,
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
num_inference_steps: int = 30,
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(
|
||||
num_inference_steps,
|
||||
denoising_strength=denoising_strength
|
||||
)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
}
|
||||
inputs_nega = {
|
||||
"negative_prompt": negative_prompt,
|
||||
}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"input_image": input_image,
|
||||
"denoising_strength": denoising_strength,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"seed": seed,
|
||||
"rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep, progress_id=progress_id)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep, progress_id=progress_id)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# Scheduler
|
||||
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
image = self.vae.decode(inputs_shared["latents"], device=self.device)
|
||||
image = self.vae_output_to_image(image)
|
||||
self.load_models_to_device([])
|
||||
|
||||
return image
|
||||
```
|
||||
|
||||
## `units`
|
||||
|
||||
`units` 包含了所有的前处理过程,例如:宽高检查、提示词编码、初始噪声生成等。在整个模型前处理过程中,数据被抽象为了互斥的三部分,分别存储在对应的字典中:
|
||||
|
||||
* `inputs_shard`: 共享输入,与 [Classifier-Free Guidance](https://arxiv.org/abs/2207.12598)(简称 CFG)无关的参数。
|
||||
* `inputs_posi`: Classifier-Free Guidance 的 Positive 侧输入,包含与正向提示词相关的内容。
|
||||
* `inputs_nega`: Classifier-Free Guidance 的 Negative 侧输入,包含与负向提示词相关的内容。
|
||||
|
||||
Pipeline Unit 的实现包括三种:直接模式、CFG 分离模式、接管模式。
|
||||
|
||||
如果某些计算与 CFG 无关,可采用直接模式,例如 Qwen-Image 的随机噪声初始化:
|
||||
|
||||
```python
|
||||
class QwenImageUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "seed", "rand_device"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device):
|
||||
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
return {"noise": noise}
|
||||
```
|
||||
|
||||
如果某些计算与 CFG 有关,需分别处理正向和负向提示词,但两侧的输入参数是相同的,可采用 CFG 分离模式,例如 Qwen-image 的提示词编码:
|
||||
|
||||
```python
|
||||
class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
input_params=("edit_image",),
|
||||
output_params=("prompt_emb", "prompt_emb_mask"),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
# Do something
|
||||
return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask}
|
||||
```
|
||||
|
||||
如果某些计算需要全局的信息,则需要接管模式,例如 Qwen-Image 的实体分区控制:
|
||||
|
||||
```python
|
||||
class QwenImageUnit_EntityControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
input_params=("eligen_entity_prompts", "width", "height", "eligen_enable_on_negative", "cfg_scale"),
|
||||
output_params=("entity_prompt_emb", "entity_masks", "entity_prompt_emb_mask"),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
# Do something
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
```
|
||||
|
||||
以下是 Pipeline Unit 所需的参数配置:
|
||||
|
||||
* `seperate_cfg`: 是否启用 CFG 分离模式
|
||||
* `take_over`: 是否启用接管模式
|
||||
* `input_params`: 共享输入参数
|
||||
* `output_params`: 输出参数
|
||||
* `input_params_posi`: Positive 侧输入参数
|
||||
* `input_params_nega`: Negative 侧输入参数
|
||||
* `onload_model_names`: 需调用的模型组件名
|
||||
|
||||
> Q: 部分参数并未在推理过程中调用,例如 `output_params`,是否仍有必要配置?
|
||||
>
|
||||
> A: 这些参数不会影响推理过程,但会影响一些实验性功能,因此我们建议将其配置好。例如“拆分训练”,我们可以将训练中的前处理离线完成,但部分需要梯度回传的模型计算无法拆分,这些参数用于构建计算图从而推断哪些计算是可以拆分的。
|
||||
|
||||
## `model_fn`
|
||||
|
||||
`model_fn` 是迭代中的统一 `forward` 接口,对于开源模型生态尚未形成的模型,直接沿用去噪模型的 `forward` 即可,例如:
|
||||
|
||||
```python
|
||||
def model_fn_new(dit=None, latents=None, timestep=None, prompt_emb=None, **kwargs):
|
||||
return dit(latents, prompt_emb, timestep)
|
||||
```
|
||||
|
||||
对于开源生态丰富的模型,`model_fn` 通常包含复杂且混乱的跨模型推理,以 `diffsynth/pipelines/qwen_image.py` 为例,这个函数中实现的额外计算包括:实体分区控制、三种 ControlNet、Gradient Checkpointing 等,开发者在实现这一部分时要格外小心,避免模块功能之间的冲突。
|
||||
152
docs/Developer_Guide/Integrating_Your_Model.md
Normal file
152
docs/Developer_Guide/Integrating_Your_Model.md
Normal file
@@ -0,0 +1,152 @@
|
||||
# 模型接入
|
||||
|
||||
本文档介绍如何将模型接入到 `DiffSynth-Studio` 框架中,供 `Pipeline` 等模块调用。
|
||||
|
||||
## Step 1: 集成模型结构代码
|
||||
|
||||
`DiffSynth-Studio` 中的所有模型结构实现统一在 `diffsynth/models` 中,每个 `.py` 代码文件分别实现一个模型结构,所有模型通过 `diffsynth/models/model_loader.py` 中的 `ModelPool` 来加载。在接入新的模型结构时,请在这个路径下建立新的 `.py` 文件。
|
||||
|
||||
```shell
|
||||
diffsynth/models/
|
||||
├── general_modules.py
|
||||
├── model_loader.py
|
||||
├── qwen_image_controlnet.py
|
||||
├── qwen_image_dit.py
|
||||
├── qwen_image_text_encoder.py
|
||||
├── qwen_image_vae.py
|
||||
└── ...
|
||||
```
|
||||
|
||||
在大多数情况下,我们建议用 `PyTorch` 原生代码的形式集成模型,让模型结构类直接继承 `torch.nn.Module`,例如:
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class NewDiffSynthModel(torch.nn.Module):
|
||||
def __init__(self, dim=1024):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(dim, dim)
|
||||
self.activation = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
x = self.activation(x)
|
||||
return x
|
||||
```
|
||||
|
||||
如果模型结构的实现中包含额外的依赖,我们强烈建议将其删除,否则这会导致沉重的包依赖问题。在我们现有的模型中,Qwen-Image 的 Blockwise ControlNet 是以这种方式集成的,其代码很轻量,请参考 `diffsynth/models/qwen_image_controlnet.py`。
|
||||
|
||||
如果模型已被 Huggingface Library ([`transformers`](https://huggingface.co/docs/transformers/main/index)、[`diffusers`](https://huggingface.co/docs/diffusers/main/index) 等)集成,我们能够以更简单的方式集成模型:
|
||||
|
||||
<details>
|
||||
<summary>集成 Huggingface Library 风格模型结构代码</summary>
|
||||
|
||||
这类模型在 Huggingface Library 中的加载方式为:
|
||||
|
||||
```python
|
||||
from transformers import XXX_Model
|
||||
|
||||
model = XXX_Model.from_pretrained("path_to_your_model")
|
||||
```
|
||||
|
||||
`DiffSynth-Studio` 不支持通过 `from_pretrained` 加载模型,因为这与显存管理等功能是冲突的,请将模型结构改写成以下格式:
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class DiffSynth_XXX_Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
from transformers import XXX_Config, XXX_Model
|
||||
config = XXX_Config(**{
|
||||
"architectures": ["XXX_Model"],
|
||||
"other_configs": "Please copy and paste the other configs here.",
|
||||
})
|
||||
self.model = XXX_Model(config)
|
||||
|
||||
def forward(self, x):
|
||||
outputs = self.model(x)
|
||||
return outputs
|
||||
```
|
||||
|
||||
其中 `XXX_Config` 为模型对应的 Config 类,例如 `Qwen2_5_VLModel` 的 Config 类为 `Qwen2_5_VLConfig`,可通过查阅其源代码找到。Config 内部的内容通常可以在模型库中的 `config.json` 中找到,`DiffSynth-Studio` 不会读取 `config.json` 文件,因此需要将其中的内容复制粘贴到代码中。
|
||||
|
||||
在少数情况下,`transformers` 和 `diffusers` 的版本更新会导致部分的模型无法导入,因此如果可能的话,我们仍建议使用 Step 1.1 中的模型集成方式。
|
||||
|
||||
在我们现有的模型中,Qwen-Image 的 Text Encoder 是以这种方式集成的,其代码很轻量,请参考 `diffsynth/models/qwen_image_text_encoder.py`。
|
||||
|
||||
</details>
|
||||
|
||||
## Step 2: 模型文件格式转换
|
||||
|
||||
由于开源社区中开发者提供的模型文件格式多种多样,因此我们有时需对模型文件格式进行转换,从而形成格式正确的 [state dict](https://docs.pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html),常见于以下几种情况:
|
||||
|
||||
* 模型文件由不同代码库构建,例如 [Wan-AI/Wan2.1-T2V-1.3B](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) 和 [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B-Diffusers)。
|
||||
* 模型在接入中做了修改,例如 [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) 的 Text Encoder 在 `diffsynth/models/qwen_image_text_encoder.py` 中增加了 `model.` 前缀。
|
||||
* 模型文件包含多个模型,例如 [Wan-AI/Wan2.1-VACE-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) 的 VACE Adapter 和基础 DiT 模型混合存储在同一组模型文件中。
|
||||
|
||||
在我们的开发理念中,我们希望尽可能尊重模型原作者的意愿。如果对模型文件进行重新封装,例如 [Comfy-Org/Qwen-Image_ComfyUI](https://www.modelscope.cn/models/Comfy-Org/Qwen-Image_ComfyUI),虽然我们可以更方便地调用模型,但流量(模型页面浏览量和下载量等)会被引向他处,模型的原作者也会失去删除模型的权力。因此,我们在框架中增加了 `diffsynth/utils/state_dict_converters` 这一模块,用于在模型加载过程中进行文件格式转换。
|
||||
|
||||
这部分逻辑是非常简单的,以 Qwen-Image 的 Text Encoder 为例,只需要 10 行代码即可:
|
||||
|
||||
```python
|
||||
def QwenImageTextEncoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for k in state_dict:
|
||||
v = state_dict[k]
|
||||
if k.startswith("visual."):
|
||||
k = "model." + k
|
||||
elif k.startswith("model."):
|
||||
k = k.replace("model.", "model.language_model.")
|
||||
state_dict_[k] = v
|
||||
return state_dict_
|
||||
```
|
||||
|
||||
## Step 3: 编写模型 Config
|
||||
|
||||
模型 Config 位于 `diffsynth/configs/model_configs.py`,用于识别模型类型并进行加载。需填入以下字段:
|
||||
|
||||
* `model_hash`:模型文件哈希值,可通过 `hash_model_file` 函数获取,此哈希值仅与模型文件中 state dict 的 keys 和张量 shape 有关,与文件中的其他信息无关。
|
||||
* `model_name`: 模型名称,用于给 `Pipeline` 识别所需模型。如果不同结构的模型在 `Pipeline` 中发挥的作用相同,则可以使用相同的 `model_name`。在接入新模型时,只需保证 `model_name` 与现有的其他功能模型不同即可。在 `Pipeline` 的 `from_pretrained` 中通过 `model_name` 获取对应的模型。
|
||||
* `model_class`: 模型结构导入路径,指向在 Step 1 中实现的模型结构类,例如 `diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder`。
|
||||
* `state_dict_converter`: 可选参数,如需进行模型文件格式转换,则需填入模型转换逻辑的导入路径,例如 `diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter`。
|
||||
* `extra_kwargs`: 可选参数,如果模型初始化时需传入额外参数,则需要填入这些参数,例如模型 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) 与 [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) 都采用了 `diffsynth/models/qwen_image_controlnet.py` 中的 `QwenImageBlockWiseControlNet` 结构,但后者还需额外的配置 `additional_in_dim=4`,因此这部分配置信息需填入 `extra_kwargs` 字段。
|
||||
|
||||
我们提供了一份代码,以便快速理解模型是如何通过这些配置信息加载的:
|
||||
|
||||
```python
|
||||
from diffsynth.core import hash_model_file, load_state_dict, skip_model_initialization
|
||||
from diffsynth.models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from diffsynth.utils.state_dict_converters.qwen_image_text_encoder import QwenImageTextEncoderStateDictConverter
|
||||
import torch
|
||||
|
||||
model_hash = "8004730443f55db63092006dd9f7110e"
|
||||
model_name = "qwen_image_text_encoder"
|
||||
model_class = QwenImageTextEncoder
|
||||
state_dict_converter = QwenImageTextEncoderStateDictConverter
|
||||
extra_kwargs = {}
|
||||
|
||||
model_path = [
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors",
|
||||
]
|
||||
if hash_model_file(model_path) == model_hash:
|
||||
with skip_model_initialization():
|
||||
model = model_class(**extra_kwargs)
|
||||
state_dict = load_state_dict(model_path, torch_dtype=torch.bfloat16, device="cuda")
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
print("Done!")
|
||||
```
|
||||
|
||||
> Q: 上述代码的逻辑看起来很简单,为什么 `DiffSynth-Studio` 中的这部分代码极为复杂?
|
||||
>
|
||||
> A: 因为我们提供了激进的显存管理功能,与模型加载逻辑耦合,这导致框架结构的复杂性,我们已尽可能简化暴露给开发者的接口。
|
||||
|
||||
`diffsynth/configs/model_configs.py` 中的 `model_hash` 不是唯一存在的,同一模型文件中可能存在多个模型。对于这种情况,请使用多个模型 Config 分别加载每个模型,编写相应的 `state_dict_converter` 分离每个模型所需的参数。
|
||||
|
||||
## Step 4: 编写模型显存管理方案
|
||||
|
||||
`DiffSynth-Studio` 支持复杂的显存管理,详见[启用显存管理](./Enabling_VRAM_management.py)。
|
||||
66
docs/Developer_Guide/Training_Diffusion_Models.md
Normal file
66
docs/Developer_Guide/Training_Diffusion_Models.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# 模型训练
|
||||
|
||||
在[接入模型](./Integrating_Your_Model.md)并[实现 Pipeline](./Building_a_Pipeline.md)后,接下来接入模型训练功能。
|
||||
|
||||
## 训推一致的 Pipeline 改造
|
||||
|
||||
为了保证训练和推理过程严格的一致性,我们会在训练过程中沿用大部分推理代码,但仍需作出少量改造。
|
||||
|
||||
首先,在推理过程中添加额外的逻辑,让图生图/视频生视频逻辑根据 `scheduler` 状态进行切换。以 Qwen-Image 为例:
|
||||
|
||||
```python
|
||||
class QwenImageUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, input_image, noise, tiled, tile_size, tile_stride):
|
||||
if input_image is None:
|
||||
return {"latents": noise, "input_latents": None}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
if pipe.scheduler.training:
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
```
|
||||
|
||||
然后,在 `model_fn` 中启用 Gradient Checkpointing,这将以计算速度为代价,大幅度减少训练所需的显存。这并不是必需的,但我们强烈建议这么做。
|
||||
|
||||
以 Qwen-Image 为例,修改前:
|
||||
|
||||
```python
|
||||
text, image = block(
|
||||
image=image,
|
||||
text=text,
|
||||
temb=conditioning,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
```
|
||||
|
||||
修改后:
|
||||
|
||||
```python
|
||||
from ..core import gradient_checkpoint_forward
|
||||
|
||||
text, image = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
image=image,
|
||||
text=text,
|
||||
temb=conditioning,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
```
|
||||
|
||||
## 编写训练脚本
|
||||
|
||||
`DiffSynth-Studio` 没有对训练框架做严格的封装,而是将脚本内容暴露给开发者,这种方式可以更方便地对训练脚本进行修改,实现额外的功能。开发者可参考现有的训练脚本,例如 `examples/qwen_image/model_training/train.py` 进行修改,从而适配新的模型训练。
|
||||
Reference in New Issue
Block a user