mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 14:58:12 +00:00
DiffSynth-Studio 2.0 major update
This commit is contained in:
250
docs/en/Developer_Guide/Building_a_Pipeline.md
Normal file
250
docs/en/Developer_Guide/Building_a_Pipeline.md
Normal file
@@ -0,0 +1,250 @@
|
||||
# Building a Pipeline
|
||||
|
||||
After [integrating the required models for the Pipeline](/docs/en/Developer_Guide/Integrating_Your_Model.md), you also need to build a `Pipeline` for model inference. This document provides a standardized process for building a `Pipeline`. Developers can also refer to existing `Pipeline` implementations for construction.
|
||||
|
||||
The `Pipeline` implementation is located in `diffsynth/pipelines`. Each `Pipeline` contains the following essential key components:
|
||||
|
||||
* `__init__`
|
||||
* `from_pretrained`
|
||||
* `__call__`
|
||||
* `units`
|
||||
* `model_fn`
|
||||
|
||||
## `__init__`
|
||||
|
||||
In `__init__`, the `Pipeline` is initialized. Here is a simple implementation:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from PIL import Image
|
||||
from typing import Union
|
||||
from tqdm import tqdm
|
||||
from ..diffusion import FlowMatchScheduler
|
||||
from ..core import ModelConfig
|
||||
from ..diffusion.base_pipeline import BasePipeline, PipelineUnit
|
||||
from ..models.new_models import XXX_Model, YYY_Model, ZZZ_Model
|
||||
|
||||
class NewDiffSynthPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.bfloat16):
|
||||
super().__init__(device=device, torch_dtype=torch_dtype)
|
||||
self.scheduler = FlowMatchScheduler()
|
||||
self.text_encoder: XXX_Model = None
|
||||
self.dit: YYY_Model = None
|
||||
self.vae: ZZZ_Model = None
|
||||
self.in_iteration_models = ("dit",)
|
||||
self.units = [
|
||||
NewDiffSynthPipelineUnit_xxx(),
|
||||
...
|
||||
]
|
||||
self.model_fn = model_fn_new
|
||||
```
|
||||
|
||||
This includes the following parts:
|
||||
|
||||
* `scheduler`: Scheduler, used to control the coefficients in the iterative formula during inference, controlling the noise content at each step.
|
||||
* `text_encoder`, `dit`, `vae`: Models. Since [Latent Diffusion](https://arxiv.org/abs/2112.10752) was proposed, this three-stage model architecture has become the mainstream Diffusion model architecture. However, this is not immutable, and any number of models can be added to the `Pipeline`.
|
||||
* `in_iteration_models`: Iteration models. This tuple marks which models will be called during iteration.
|
||||
* `units`: Pre-processing units for model iteration. See [`units`](#units) for details.
|
||||
* `model_fn`: The `forward` function of the denoising model during iteration. See [`model_fn`](#model_fn) for details.
|
||||
|
||||
> Q: Model loading does not occur in `__init__`, why initialize each model as `None` here?
|
||||
>
|
||||
> A: By annotating the type of each model here, the code editor can provide code completion prompts based on each model, facilitating subsequent development.
|
||||
|
||||
## `from_pretrained`
|
||||
|
||||
`from_pretrained` is responsible for loading the required models to make the `Pipeline` callable. Here is a simple implementation:
|
||||
|
||||
```python
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
model_configs: list[ModelConfig] = [],
|
||||
vram_limit: float = None,
|
||||
):
|
||||
# Initialize pipeline
|
||||
pipe = NewDiffSynthPipeline(device=device, torch_dtype=torch_dtype)
|
||||
model_pool = pipe.download_and_load_models(model_configs, vram_limit)
|
||||
|
||||
# Fetch models
|
||||
pipe.text_encoder = model_pool.fetch_model("xxx_text_encoder")
|
||||
pipe.dit = model_pool.fetch_model("yyy_dit")
|
||||
pipe.vae = model_pool.fetch_model("zzz_vae")
|
||||
# If necessary, load tokenizers here.
|
||||
|
||||
# VRAM Management
|
||||
pipe.vram_management_enabled = pipe.check_vram_management_state()
|
||||
return pipe
|
||||
```
|
||||
|
||||
Developers need to implement the logic for fetching models. The corresponding model names are the `"model_name"` in the [model Config filled in during model integration](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-3-writing-model-config).
|
||||
|
||||
Some models also need to load `tokenizer`. Extra `tokenizer_config` parameters can be added to `from_pretrained` as needed, and this part can be implemented after fetching the models.
|
||||
|
||||
## `__call__`
|
||||
|
||||
`__call__` implements the entire generation process of the Pipeline. Below is a common generation process template. Developers can modify it based on their needs.
|
||||
|
||||
```python
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
cfg_scale: float = 4.0,
|
||||
input_image: Image.Image = None,
|
||||
denoising_strength: float = 1.0,
|
||||
height: int = 1328,
|
||||
width: int = 1328,
|
||||
seed: int = None,
|
||||
rand_device: str = "cpu",
|
||||
num_inference_steps: int = 30,
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(
|
||||
num_inference_steps,
|
||||
denoising_strength=denoising_strength
|
||||
)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
}
|
||||
inputs_nega = {
|
||||
"negative_prompt": negative_prompt,
|
||||
}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
"input_image": input_image,
|
||||
"denoising_strength": denoising_strength,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"seed": seed,
|
||||
"rand_device": rand_device,
|
||||
"num_inference_steps": num_inference_steps,
|
||||
}
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
# Inference
|
||||
noise_pred_posi = self.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep, progress_id=progress_id)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = self.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep, progress_id=progress_id)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
# Scheduler
|
||||
inputs_shared["latents"] = self.step(self.scheduler, progress_id=progress_id, noise_pred=noise_pred, **inputs_shared)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae'])
|
||||
image = self.vae.decode(inputs_shared["latents"], device=self.device)
|
||||
image = self.vae_output_to_image(image)
|
||||
self.load_models_to_device([])
|
||||
|
||||
return image
|
||||
```
|
||||
|
||||
## `units`
|
||||
|
||||
`units` contains all the preprocessing processes, such as: width/height checking, prompt encoding, initial noise generation, etc. In the entire model preprocessing process, data is abstracted into three mutually exclusive parts, stored in corresponding dictionaries:
|
||||
|
||||
* `inputs_shared`: Shared inputs, parameters unrelated to [Classifier-Free Guidance](https://arxiv.org/abs/2207.12598) (CFG for short).
|
||||
* `inputs_posi`: Positive side inputs for Classifier-Free Guidance, containing content related to positive prompts.
|
||||
* `inputs_nega`: Negative side inputs for Classifier-Free Guidance, containing content related to negative prompts.
|
||||
|
||||
Pipeline Unit implementations include three types: direct mode, CFG separation mode, and takeover mode.
|
||||
|
||||
If some calculations are unrelated to CFG, direct mode can be used, for example, Qwen-Image's random noise initialization:
|
||||
|
||||
```python
|
||||
class QwenImageUnit_NoiseInitializer(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("height", "width", "seed", "rand_device"),
|
||||
output_params=("noise",),
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, height, width, seed, rand_device):
|
||||
noise = pipe.generate_noise((1, 16, height//8, width//8), seed=seed, rand_device=rand_device, rand_torch_dtype=pipe.torch_dtype)
|
||||
return {"noise": noise}
|
||||
```
|
||||
|
||||
If some calculations are related to CFG and need to separately process positive and negative prompts, but the input parameters on both sides are the same, CFG separation mode can be used, for example, Qwen-image's prompt encoding:
|
||||
|
||||
```python
|
||||
class QwenImageUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "negative_prompt"},
|
||||
input_params=("edit_image",),
|
||||
output_params=("prompt_emb", "prompt_emb_mask"),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, prompt, edit_image=None) -> dict:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
# Do something
|
||||
return {"prompt_emb": prompt_embeds, "prompt_emb_mask": encoder_attention_mask}
|
||||
```
|
||||
|
||||
If some calculations need global information, takeover mode is required, for example, Qwen-Image's entity partition control:
|
||||
|
||||
```python
|
||||
class QwenImageUnit_EntityControl(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
take_over=True,
|
||||
input_params=("eligen_entity_prompts", "width", "height", "eligen_enable_on_negative", "cfg_scale"),
|
||||
output_params=("entity_prompt_emb", "entity_masks", "entity_prompt_emb_mask"),
|
||||
onload_model_names=("text_encoder",)
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
# Do something
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
```
|
||||
|
||||
The following are the parameter configurations required for Pipeline Unit:
|
||||
|
||||
* `seperate_cfg`: Whether to enable CFG separation mode
|
||||
* `take_over`: Whether to enable takeover mode
|
||||
* `input_params`: Shared input parameters
|
||||
* `output_params`: Output parameters
|
||||
* `input_params_posi`: Positive side input parameters
|
||||
* `input_params_nega`: Negative side input parameters
|
||||
* `onload_model_names`: Names of model components to be called
|
||||
|
||||
When designing `unit`, please try to follow these principles:
|
||||
|
||||
* Default fallback: For optional function `unit` input parameters, the default is `None` rather than `False` or other values. Please provide fallback processing for this default value.
|
||||
* Parameter triggering: Some Adapter models may not be loaded, such as ControlNet. The corresponding `unit` should control triggering based on whether the parameter input is `None` rather than whether the model is loaded. For example, when the user inputs `controlnet_image` but does not load the ControlNet model, the code should give an error rather than ignore these input parameters and continue execution.
|
||||
* Simplicity first: Use direct mode as much as possible, only use takeover mode when the function cannot be implemented.
|
||||
* VRAM efficiency: When calling models in `unit`, please use `pipe.load_models_to_device(self.onload_model_names)` to activate the corresponding models. Do not call other models outside `onload_model_names`. After `unit` calculation is completed, do not manually release VRAM with `pipe.load_models_to_device([])`.
|
||||
|
||||
> Q: Some parameters are not called during the inference process, such as `output_params`. Is it still necessary to configure them?
|
||||
>
|
||||
> A: These parameters will not affect the inference process, but they will affect some experimental features. Therefore, we recommend configuring them properly. For example, "split training" - we can complete the preprocessing offline during training, but some model calculations that require gradient backpropagation cannot be split. These parameters are used to build computational graphs to infer which calculations can be split.
|
||||
|
||||
## `model_fn`
|
||||
|
||||
`model_fn` is the unified `forward` interface during iteration. For models where the open-source ecosystem is not yet formed, you can directly use the denoising model's `forward`, for example:
|
||||
|
||||
```python
|
||||
def model_fn_new(dit=None, latents=None, timestep=None, prompt_emb=None, **kwargs):
|
||||
return dit(latents, prompt_emb, timestep)
|
||||
```
|
||||
|
||||
For models with rich open-source ecosystems, `model_fn` usually contains complex and chaotic cross-model inference. Taking `diffsynth/pipelines/qwen_image.py` as an example, the additional calculations implemented in this function include: entity partition control, three types of ControlNet, Gradient Checkpointing, etc. Developers need to be extra careful when implementing this part to avoid conflicts between module functions.
|
||||
455
docs/en/Developer_Guide/Enabling_VRAM_management.md
Normal file
455
docs/en/Developer_Guide/Enabling_VRAM_management.md
Normal file
@@ -0,0 +1,455 @@
|
||||
# Fine-Grained VRAM Management Scheme
|
||||
|
||||
This document introduces how to write reasonable fine-grained VRAM management schemes for models, and how to use the VRAM management functions in `DiffSynth-Studio` for other external code libraries. Before reading this document, please read the document [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md).
|
||||
|
||||
## How Much VRAM Does a 20B Model Need?
|
||||
|
||||
Taking Qwen-Image's DiT model as an example, this model has reached 20B parameters. The following code will load this model and perform inference, requiring about 40G VRAM. This model obviously cannot run on consumer-grade GPUs with smaller VRAM.
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_model
|
||||
from diffsynth.models.qwen_image_dit import QwenImageDiT
|
||||
from modelscope import snapshot_download
|
||||
import torch
|
||||
|
||||
snapshot_download(
|
||||
model_id="Qwen/Qwen-Image",
|
||||
local_dir="models/Qwen/Qwen-Image",
|
||||
allow_file_pattern="transformer/*"
|
||||
)
|
||||
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
|
||||
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
|
||||
inputs = {
|
||||
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
|
||||
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
}
|
||||
|
||||
model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cuda")
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
```
|
||||
|
||||
## Writing Fine-Grained VRAM Management Scheme
|
||||
|
||||
To write a fine-grained VRAM management scheme, we need to use `print(model)` to observe and analyze the model structure:
|
||||
|
||||
```
|
||||
QwenImageDiT(
|
||||
(pos_embed): QwenEmbedRope()
|
||||
(time_text_embed): TimestepEmbeddings(
|
||||
(time_proj): TemporalTimesteps()
|
||||
(timestep_embedder): DiffusersCompatibleTimestepProj(
|
||||
(linear_1): Linear(in_features=256, out_features=3072, bias=True)
|
||||
(act): SiLU()
|
||||
(linear_2): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
)
|
||||
)
|
||||
(txt_norm): RMSNorm()
|
||||
(img_in): Linear(in_features=64, out_features=3072, bias=True)
|
||||
(txt_in): Linear(in_features=3584, out_features=3072, bias=True)
|
||||
(transformer_blocks): ModuleList(
|
||||
(0-59): 60 x QwenImageTransformerBlock(
|
||||
(img_mod): Sequential(
|
||||
(0): SiLU()
|
||||
(1): Linear(in_features=3072, out_features=18432, bias=True)
|
||||
)
|
||||
(img_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
(attn): QwenDoubleStreamAttention(
|
||||
(to_q): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(to_k): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(to_v): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(norm_q): RMSNorm()
|
||||
(norm_k): RMSNorm()
|
||||
(add_q_proj): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(add_k_proj): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(add_v_proj): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(norm_added_q): RMSNorm()
|
||||
(norm_added_k): RMSNorm()
|
||||
(to_out): Sequential(
|
||||
(0): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
)
|
||||
(to_add_out): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
)
|
||||
(img_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
(img_mlp): QwenFeedForward(
|
||||
(net): ModuleList(
|
||||
(0): ApproximateGELU(
|
||||
(proj): Linear(in_features=3072, out_features=12288, bias=True)
|
||||
)
|
||||
(1): Dropout(p=0.0, inplace=False)
|
||||
(2): Linear(in_features=12288, out_features=3072, bias=True)
|
||||
)
|
||||
)
|
||||
(txt_mod): Sequential(
|
||||
(0): SiLU()
|
||||
(1): Linear(in_features=3072, out_features=18432, bias=True)
|
||||
)
|
||||
(txt_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
(txt_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
(txt_mlp): QwenFeedForward(
|
||||
(net): ModuleList(
|
||||
(0): ApproximateGELU(
|
||||
(proj): Linear(in_features=3072, out_features=12288, bias=True)
|
||||
)
|
||||
(1): Dropout(p=0.0, inplace=False)
|
||||
(2): Linear(in_features=12288, out_features=3072, bias=True)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
(norm_out): AdaLayerNorm(
|
||||
(linear): Linear(in_features=3072, out_features=6144, bias=True)
|
||||
(norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
)
|
||||
(proj_out): Linear(in_features=3072, out_features=64, bias=True)
|
||||
)
|
||||
```
|
||||
|
||||
In VRAM management, we only care about layers containing parameters. In this model structure, `QwenEmbedRope`, `TemporalTimesteps`, `SiLU` and other Layers do not contain parameters. `LayerNorm` also does not contain parameters because `elementwise_affine=False` is set. Layers containing parameters are only `Linear` and `RMSNorm`.
|
||||
|
||||
`diffsynth.core.vram` provides two replacement modules for VRAM management:
|
||||
* `AutoWrappedLinear`: Used to replace `Linear` layers
|
||||
* `AutoWrappedModule`: Used to replace any other layer
|
||||
|
||||
Write a `module_map` to map `Linear` and `RMSNorm` in the model to the corresponding modules:
|
||||
|
||||
```python
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
}
|
||||
```
|
||||
|
||||
In addition, `vram_config` and `vram_limit` are also required, which have been introduced in [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md#more-usage-methods).
|
||||
|
||||
Call `enable_vram_management` to enable VRAM management. Note that the `device` when loading the model is `cpu`, consistent with `offload_device`:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
|
||||
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
|
||||
import torch
|
||||
|
||||
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
|
||||
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
|
||||
inputs = {
|
||||
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
|
||||
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
}
|
||||
|
||||
model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cpu")
|
||||
enable_vram_management(
|
||||
model,
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
},
|
||||
vram_limit=0,
|
||||
)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
```
|
||||
|
||||
The above code only requires 2G VRAM to run the `forward` of a 20B model.
|
||||
|
||||
## Disk Offload
|
||||
|
||||
[Disk Offload](/docs/en/Pipeline_Usage/VRAM_management.md#disk-offload) is a special VRAM management scheme that needs to be enabled during the model loading process, not after the model is loaded. Usually, when the above code can run smoothly, Disk Offload can be directly enabled:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
|
||||
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
|
||||
import torch
|
||||
|
||||
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
|
||||
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
|
||||
inputs = {
|
||||
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
|
||||
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
}
|
||||
|
||||
model = load_model(
|
||||
QwenImageDiT,
|
||||
model_path,
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
vram_config={
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": "disk",
|
||||
"onload_device": "disk",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
},
|
||||
vram_limit=0,
|
||||
)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
```
|
||||
|
||||
Disk Offload is an extremely special VRAM management scheme. It only supports `.safetensors` format files, not binary files such as `.bin`, `.pth`, `.ckpt`, and does not support [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape.
|
||||
|
||||
If there are situations where Disk Offload cannot run normally but non-Disk Offload can run normally, please submit an issue to us on GitHub.
|
||||
|
||||
## Writing Default Configuration
|
||||
|
||||
To make it easier for users to use the VRAM management function, we write the fine-grained VRAM management configuration in `diffsynth/configs/vram_management_module_maps.py`. The configuration information for the above model is:
|
||||
|
||||
```python
|
||||
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
|
||||
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
}
|
||||
```# Fine-Grained VRAM Management Scheme
|
||||
|
||||
This document introduces how to write reasonable fine-grained VRAM management schemes for models, and how to use the VRAM management functions in `DiffSynth-Studio` for other external code libraries. Before reading this document, please read the document [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md).
|
||||
|
||||
## How Much VRAM Does a 20B Model Need?
|
||||
|
||||
Taking Qwen-Image's DiT model as an example, this model has reached 20B parameters. The following code will load this model and perform inference, requiring about 40G VRAM. This model obviously cannot run on consumer-grade GPUs with smaller VRAM.
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_model
|
||||
from diffsynth.models.qwen_image_dit import QwenImageDiT
|
||||
from modelscope import snapshot_download
|
||||
import torch
|
||||
|
||||
snapshot_download(
|
||||
model_id="Qwen/Qwen-Image",
|
||||
local_dir="models/Qwen/Qwen-Image",
|
||||
allow_file_pattern="transformer/*"
|
||||
)
|
||||
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
|
||||
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
|
||||
inputs = {
|
||||
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
|
||||
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
}
|
||||
|
||||
model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cuda")
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
```
|
||||
|
||||
## Writing Fine-Grained VRAM Management Scheme
|
||||
|
||||
To write a fine-grained VRAM management scheme, we need to use `print(model)` to observe and analyze the model structure:
|
||||
|
||||
```
|
||||
QwenImageDiT(
|
||||
(pos_embed): QwenEmbedRope()
|
||||
(time_text_embed): TimestepEmbeddings(
|
||||
(time_proj): TemporalTimesteps()
|
||||
(timestep_embedder): DiffusersCompatibleTimestepProj(
|
||||
(linear_1): Linear(in_features=256, out_features=3072, bias=True)
|
||||
(act): SiLU()
|
||||
(linear_2): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
)
|
||||
)
|
||||
(txt_norm): RMSNorm()
|
||||
(img_in): Linear(in_features=64, out_features=3072, bias=True)
|
||||
(txt_in): Linear(in_features=3584, out_features=3072, bias=True)
|
||||
(transformer_blocks): ModuleList(
|
||||
(0-59): 60 x QwenImageTransformerBlock(
|
||||
(img_mod): Sequential(
|
||||
(0): SiLU()
|
||||
(1): Linear(in_features=3072, out_features=18432, bias=True)
|
||||
)
|
||||
(img_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
(attn): QwenDoubleStreamAttention(
|
||||
(to_q): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(to_k): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(to_v): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(norm_q): RMSNorm()
|
||||
(norm_k): RMSNorm()
|
||||
(add_q_proj): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(add_k_proj): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(add_v_proj): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
(norm_added_q): RMSNorm()
|
||||
(norm_added_k): RMSNorm()
|
||||
(to_out): Sequential(
|
||||
(0): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
)
|
||||
(to_add_out): Linear(in_features=3072, out_features=3072, bias=True)
|
||||
)
|
||||
(img_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
(img_mlp): QwenFeedForward(
|
||||
(net): ModuleList(
|
||||
(0): ApproximateGELU(
|
||||
(proj): Linear(in_features=3072, out_features=12288, bias=True)
|
||||
)
|
||||
(1): Dropout(p=0.0, inplace=False)
|
||||
(2): Linear(in_features=12288, out_features=3072, bias=True)
|
||||
)
|
||||
)
|
||||
(txt_mod): Sequential(
|
||||
(0): SiLU()
|
||||
(1): Linear(in_features=3072, out_features=18432, bias=True)
|
||||
)
|
||||
(txt_norm1): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
(txt_norm2): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
(txt_mlp): QwenFeedForward(
|
||||
(net): ModuleList(
|
||||
(0): ApproximateGELU(
|
||||
(proj): Linear(in_features=3072, out_features=12288, bias=True)
|
||||
)
|
||||
(1): Dropout(p=0.0, inplace=False)
|
||||
(2): Linear(in_features=12288, out_features=3072, bias=True)
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
(norm_out): AdaLayerNorm(
|
||||
(linear): Linear(in_features=3072, out_features=6144, bias=True)
|
||||
(norm): LayerNorm((3072,), eps=1e-06, elementwise_affine=False)
|
||||
)
|
||||
(proj_out): Linear(in_features=3072, out_features=64, bias=True)
|
||||
)
|
||||
```
|
||||
|
||||
In VRAM management, we only care about layers containing parameters. In this model structure, `QwenEmbedRope`, `TemporalTimesteps`, `SiLU` and other Layers do not contain parameters. `LayerNorm` also does not contain parameters because `elementwise_affine=False` is set. Layers containing parameters are only `Linear` and `RMSNorm`.
|
||||
|
||||
`diffsynth.core.vram` provides two replacement modules for VRAM management:
|
||||
* `AutoWrappedLinear`: Used to replace `Linear` layers
|
||||
* `AutoWrappedModule`: Used to replace any other layer
|
||||
|
||||
Write a `module_map` to map `Linear` and `RMSNorm` in the model to the corresponding modules:
|
||||
|
||||
```python
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
}
|
||||
```
|
||||
|
||||
In addition, `vram_config` and `vram_limit` are also required, which have been introduced in [VRAM Management](/docs/en/Pipeline_Usage/VRAM_management.md#more-usage-methods).
|
||||
|
||||
Call `enable_vram_management` to enable VRAM management. Note that the `device` when loading the model is `cpu`, consistent with `offload_device`:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
|
||||
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
|
||||
import torch
|
||||
|
||||
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
|
||||
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
|
||||
inputs = {
|
||||
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
|
||||
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
}
|
||||
|
||||
model = load_model(QwenImageDiT, model_path, torch_dtype=torch.bfloat16, device="cpu")
|
||||
enable_vram_management(
|
||||
model,
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
vram_config = {
|
||||
"offload_dtype": torch.bfloat16,
|
||||
"offload_device": "cpu",
|
||||
"onload_dtype": torch.bfloat16,
|
||||
"onload_device": "cpu",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
},
|
||||
vram_limit=0,
|
||||
)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
```
|
||||
|
||||
The above code only requires 2G VRAM to run the `forward` of a 20B model.
|
||||
|
||||
## Disk Offload
|
||||
|
||||
[Disk Offload](/docs/en/Pipeline_Usage/VRAM_management.md#disk-offload) is a special VRAM management scheme that needs to be enabled during the model loading process, not after the model is loaded. Usually, when the above code can run smoothly, Disk Offload can be directly enabled:
|
||||
|
||||
```python
|
||||
from diffsynth.core import load_model, enable_vram_management, AutoWrappedLinear, AutoWrappedModule
|
||||
from diffsynth.models.qwen_image_dit import QwenImageDiT, RMSNorm
|
||||
import torch
|
||||
|
||||
prefix = "models/Qwen/Qwen-Image/transformer/diffusion_pytorch_model"
|
||||
model_path = [prefix + f"-0000{i}-of-00009.safetensors" for i in range(1, 10)]
|
||||
inputs = {
|
||||
"latents": torch.randn((1, 16, 128, 128), dtype=torch.bfloat16, device="cuda"),
|
||||
"timestep": torch.zeros((1,), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb": torch.randn((1, 5, 3584), dtype=torch.bfloat16, device="cuda"),
|
||||
"prompt_emb_mask": torch.ones((1, 5), dtype=torch.int64, device="cuda"),
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
}
|
||||
|
||||
model = load_model(
|
||||
QwenImageDiT,
|
||||
model_path,
|
||||
module_map={
|
||||
torch.nn.Linear: AutoWrappedLinear,
|
||||
RMSNorm: AutoWrappedModule,
|
||||
},
|
||||
vram_config={
|
||||
"offload_dtype": "disk",
|
||||
"offload_device": "disk",
|
||||
"onload_dtype": "disk",
|
||||
"onload_device": "disk",
|
||||
"preparing_dtype": torch.bfloat16,
|
||||
"preparing_device": "cuda",
|
||||
"computation_dtype": torch.bfloat16,
|
||||
"computation_device": "cuda",
|
||||
},
|
||||
vram_limit=0,
|
||||
)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
```
|
||||
|
||||
Disk Offload is an extremely special VRAM management scheme. It only supports `.safetensors` format files, not binary files such as `.bin`, `.pth`, `.ckpt`, and does not support [state dict converter](/docs/en/Developer_Guide/Integrating_Your_Model.md#step-2-model-file-format-conversion) with Tensor reshape.
|
||||
|
||||
If there are situations where Disk Offload cannot run normally but non-Disk Offload can run normally, please submit an issue to us on GitHub.
|
||||
|
||||
## Writing Default Configuration
|
||||
|
||||
To make it easier for users to use the VRAM management function, we write the fine-grained VRAM management configuration in `diffsynth/configs/vram_management_module_maps.py`. The configuration information for the above model is:
|
||||
|
||||
```python
|
||||
"diffsynth.models.qwen_image_dit.QwenImageDiT": {
|
||||
"diffsynth.models.qwen_image_dit.RMSNorm": "diffsynth.core.vram.layers.AutoWrappedModule",
|
||||
"torch.nn.Linear": "diffsynth.core.vram.layers.AutoWrappedLinear",
|
||||
}
|
||||
```
|
||||
186
docs/en/Developer_Guide/Integrating_Your_Model.md
Normal file
186
docs/en/Developer_Guide/Integrating_Your_Model.md
Normal file
@@ -0,0 +1,186 @@
|
||||
# Integrating Model Architecture
|
||||
|
||||
This document introduces how to integrate models into the `DiffSynth-Studio` framework for use by modules such as `Pipeline`.
|
||||
|
||||
## Step 1: Integrate Model Architecture Code
|
||||
|
||||
All model architecture implementations in `DiffSynth-Studio` are unified in `diffsynth/models`. Each `.py` code file implements a model architecture, and all models are loaded through `ModelPool` in `diffsynth/models/model_loader.py`. When integrating new model architectures, please create a new `.py` file under this path.
|
||||
|
||||
```shell
|
||||
diffsynth/models/
|
||||
├── general_modules.py
|
||||
├── model_loader.py
|
||||
├── qwen_image_controlnet.py
|
||||
├── qwen_image_dit.py
|
||||
├── qwen_image_text_encoder.py
|
||||
├── qwen_image_vae.py
|
||||
└── ...
|
||||
```
|
||||
|
||||
In most cases, we recommend integrating models in native `PyTorch` code form, with the model architecture class directly inheriting from `torch.nn.Module`, for example:
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class NewDiffSynthModel(torch.nn.Module):
|
||||
def __init__(self, dim=1024):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(dim, dim)
|
||||
self.activation = torch.nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
x = self.activation(x)
|
||||
return x
|
||||
```
|
||||
|
||||
If the model architecture implementation contains additional dependencies, we strongly recommend removing them, otherwise this will cause heavy package dependency issues. In our existing models, Qwen-Image's Blockwise ControlNet is integrated in this way. The code is lightweight, please refer to `diffsynth/models/qwen_image_controlnet.py`.
|
||||
|
||||
If the model has been integrated by Huggingface Library ([`transformers`](https://huggingface.co/docs/transformers/main/index), [`diffusers`](https://huggingface.co/docs/diffusers/main/index), etc.), we can integrate the model in a simpler way:
|
||||
|
||||
<details>
|
||||
<summary>Integrating Huggingface Library Style Model Architecture Code</summary>
|
||||
|
||||
The loading method for these models in Huggingface Library is:
|
||||
|
||||
```python
|
||||
from transformers import XXX_Model
|
||||
|
||||
model = XXX_Model.from_pretrained("path_to_your_model")
|
||||
```
|
||||
|
||||
`DiffSynth-Studio` does not support loading models through `from_pretrained` because this conflicts with VRAM management and other functions. Please rewrite the model architecture in the following format:
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
class DiffSynth_XXX_Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
from transformers import XXX_Config, XXX_Model
|
||||
config = XXX_Config(**{
|
||||
"architectures": ["XXX_Model"],
|
||||
"other_configs": "Please copy and paste the other configs here.",
|
||||
})
|
||||
self.model = XXX_Model(config)
|
||||
|
||||
def forward(self, x):
|
||||
outputs = self.model(x)
|
||||
return outputs
|
||||
```
|
||||
|
||||
Where `XXX_Config` is the Config class corresponding to the model. For example, the Config class for `Qwen2_5_VLModel` is `Qwen2_5_VLConfig`, which can be found by consulting its source code. The content inside Config can usually be found in the `config.json` file in the model library. `DiffSynth-Studio` will not read the `config.json` file, so the content needs to be copied and pasted into the code.
|
||||
|
||||
In rare cases, version updates of `transformers` and `diffusers` may cause some models to be unable to import. Therefore, if possible, we still recommend using the model integration method in Step 1.1.
|
||||
|
||||
In our existing models, Qwen-Image's Text Encoder is integrated in this way. The code is lightweight, please refer to `diffsynth/models/qwen_image_text_encoder.py`.
|
||||
|
||||
</details>
|
||||
|
||||
## Step 2: Model File Format Conversion
|
||||
|
||||
Due to the variety of model file formats provided by developers in the open-source community, we sometimes need to convert model file formats to form correctly formatted [state dict](https://docs.pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html). This is common in the following situations:
|
||||
|
||||
* Model files built by different code libraries, for example [Wan-AI/Wan2.1-T2V-1.3B](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B) and [Wan-AI/Wan2.1-T2V-1.3B-Diffusers](https://www.modelscope.cn/models/Wan-AI/Wan2.1-T2V-1.3B-Diffusers).
|
||||
* Models modified during integration, for example, the Text Encoder of [Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image) adds a `model.` prefix in `diffsynth/models/qwen_image_text_encoder.py`.
|
||||
* Model files containing multiple models, for example, the VACE Adapter and base DiT model of [Wan-AI/Wan2.1-VACE-14B](https://www.modelscope.cn/models/Wan-AI/Wan2.1-VACE-14B) are mixed and stored in the same set of model files.
|
||||
|
||||
In our development philosophy, we hope to respect the wishes of model authors as much as possible. If we repackage the model files, for example [Comfy-Org/Qwen-Image_ComfyUI](https://www.modelscope.cn/models/Comfy-Org/Qwen-Image_ComfyUI), although we can call the model more conveniently, traffic (model page views and downloads, etc.) will be directed elsewhere, and the original author of the model will also lose the power to delete the model. Therefore, we have added the `diffsynth/utils/state_dict_converters` module to the framework for file format conversion during model loading.
|
||||
|
||||
This part of logic is very simple. Taking Qwen-Image's Text Encoder as an example, only 10 lines of code are needed:
|
||||
|
||||
```python
|
||||
def QwenImageTextEncoderStateDictConverter(state_dict):
|
||||
state_dict_ = {}
|
||||
for k in state_dict:
|
||||
v = state_dict[k]
|
||||
if k.startswith("visual."):
|
||||
k = "model." + k
|
||||
elif k.startswith("model."):
|
||||
k = k.replace("model.", "model.language_model.")
|
||||
state_dict_[k] = v
|
||||
return state_dict_
|
||||
```
|
||||
|
||||
## Step 3: Writing Model Config
|
||||
|
||||
Model Config is located in `diffsynth/configs/model_configs.py`, used to identify model types and load them. The following fields need to be filled in:
|
||||
|
||||
* `model_hash`: Model file hash value, which can be obtained through the `hash_model_file` function. This hash value is only related to the keys and tensor shapes in the model file's state dict, and is unrelated to other information in the file.
|
||||
* `model_name`: Model name, used for `Pipeline` to identify the required model. If different structured models play the same role in `Pipeline`, the same `model_name` can be used. When integrating new models, just ensure that `model_name` is different from other existing functional models. The corresponding model is fetched through `model_name` in the `Pipeline`'s `from_pretrained`.
|
||||
* `model_class`: Model architecture import path, pointing to the model architecture class implemented in Step 1, for example `diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder`.
|
||||
* `state_dict_converter`: Optional parameter. If model file format conversion is needed, the import path of the model conversion logic needs to be filled in, for example `diffsynth.utils.state_dict_converters.qwen_image_text_encoder.QwenImageTextEncoderStateDictConverter`.
|
||||
* `extra_kwargs`: Optional parameter. If additional parameters need to be passed when initializing the model, these parameters need to be filled in. For example, models [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Canny) and [DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Blockwise-ControlNet-Inpaint) both adopt the `QwenImageBlockWiseControlNet` structure in `diffsynth/models/qwen_image_controlnet.py`, but the latter also needs additional configuration `additional_in_dim=4`. Therefore, this configuration information needs to be filled in the `extra_kwargs` field.
|
||||
|
||||
We provide a piece of code to quickly understand how models are loaded through this configuration information:
|
||||
|
||||
```python
|
||||
from diffsynth.core import hash_model_file, load_state_dict, skip_model_initialization
|
||||
from diffsynth.models.qwen_image_text_encoder import QwenImageTextEncoder
|
||||
from diffsynth.utils.state_dict_converters.qwen_image_text_encoder import QwenImageTextEncoderStateDictConverter
|
||||
import torch
|
||||
|
||||
model_hash = "8004730443f55db63092006dd9f7110e"
|
||||
model_name = "qwen_image_text_encoder"
|
||||
model_class = QwenImageTextEncoder
|
||||
state_dict_converter = QwenImageTextEncoderStateDictConverter
|
||||
extra_kwargs = {}
|
||||
|
||||
model_path = [
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors",
|
||||
]
|
||||
if hash_model_file(model_path) == model_hash:
|
||||
with skip_model_initialization():
|
||||
model = model_class(**extra_kwargs)
|
||||
state_dict = load_state_dict(model_path, torch_dtype=torch.bfloat16, device="cuda")
|
||||
state_dict = state_dict_converter(state_dict)
|
||||
model.load_state_dict(state_dict, assign=True)
|
||||
print("Done!")
|
||||
```
|
||||
|
||||
> Q: The logic of the above code looks very simple, why is this part of code in `DiffSynth-Studio` extremely complex?
|
||||
>
|
||||
> A: Because we provide aggressive VRAM management functions that are coupled with the model loading logic, this leads to the complexity of the framework structure. We have tried our best to simplify the interface exposed to developers.
|
||||
|
||||
The `model_hash` in `diffsynth/configs/model_configs.py` is not uniquely existing. Multiple models may exist in the same model file. For this situation, please use multiple model Configs to load each model separately, and write the corresponding `state_dict_converter` to separate the parameters required by each model.
|
||||
|
||||
## Step 4: Verifying Whether the Model Can Be Recognized and Loaded
|
||||
|
||||
After model integration, the following code can be used to verify whether the model can be correctly recognized and loaded. The following code will attempt to load the model into memory:
|
||||
|
||||
```python
|
||||
from diffsynth.models.model_loader import ModelPool
|
||||
|
||||
model_pool = ModelPool()
|
||||
model_pool.auto_load_model(
|
||||
[
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors",
|
||||
],
|
||||
)
|
||||
```
|
||||
|
||||
If the model can be recognized and loaded, you will see the following output:
|
||||
|
||||
```
|
||||
Loading models from: [
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00001-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00002-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00003-of-00004.safetensors",
|
||||
"models/Qwen/Qwen-Image/text_encoder/model-00004-of-00004.safetensors"
|
||||
]
|
||||
Loaded model: {
|
||||
"model_name": "qwen_image_text_encoder",
|
||||
"model_class": "diffsynth.models.qwen_image_text_encoder.QwenImageTextEncoder",
|
||||
"extra_kwargs": null
|
||||
}
|
||||
```
|
||||
|
||||
## Step 5: Writing Model VRAM Management Scheme
|
||||
|
||||
`DiffSynth-Studio` supports complex VRAM management. See [Enabling VRAM Management](/docs/en/Developer_Guide/Enabling_VRAM_management.md) for details.
|
||||
66
docs/en/Developer_Guide/Training_Diffusion_Models.md
Normal file
66
docs/en/Developer_Guide/Training_Diffusion_Models.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# Integrating Model Training
|
||||
|
||||
After [integrating models](/docs/en/Developer_Guide/Integrating_Your_Model.md) and [implementing Pipeline](/docs/en/Developer_Guide/Building_a_Pipeline.md), the next step is to integrate model training functionality.
|
||||
|
||||
## Training-Inference Consistent Pipeline Modification
|
||||
|
||||
To ensure strict consistency between training and inference processes, we will use most of the inference code during training, but still need to make minor modifications.
|
||||
|
||||
First, add extra logic during inference to switch the image-to-image/video-to-video logic based on the `scheduler` state. Taking Qwen-Image as an example:
|
||||
|
||||
```python
|
||||
class QwenImageUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("input_image", "noise", "tiled", "tile_size", "tile_stride"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",)
|
||||
)
|
||||
|
||||
def process(self, pipe: QwenImagePipeline, input_image, noise, tiled, tile_size, tile_stride):
|
||||
if input_image is None:
|
||||
return {"latents": noise, "input_latents": None}
|
||||
pipe.load_models_to_device(['vae'])
|
||||
image = pipe.preprocess_image(input_image).to(device=pipe.device, dtype=pipe.torch_dtype)
|
||||
input_latents = pipe.vae.encode(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
if pipe.scheduler.training:
|
||||
return {"latents": noise, "input_latents": input_latents}
|
||||
else:
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, timestep=pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
```
|
||||
|
||||
Then, enable Gradient Checkpointing in `model_fn`, which will significantly reduce the VRAM required for training at the cost of computational speed. This is not mandatory, but we strongly recommend doing so.
|
||||
|
||||
Taking Qwen-Image as an example, before modification:
|
||||
|
||||
```python
|
||||
text, image = block(
|
||||
image=image,
|
||||
text=text,
|
||||
temb=conditioning,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
```
|
||||
|
||||
After modification:
|
||||
|
||||
```python
|
||||
from ..core import gradient_checkpoint_forward
|
||||
|
||||
text, image = gradient_checkpoint_forward(
|
||||
block,
|
||||
use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload,
|
||||
image=image,
|
||||
text=text,
|
||||
temb=conditioning,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
```
|
||||
|
||||
## Writing Training Scripts
|
||||
|
||||
`DiffSynth-Studio` does not strictly encapsulate the training framework, but exposes the script content to developers. This approach makes it more convenient to modify training scripts to implement additional functions. Developers can refer to existing training scripts, such as `examples/qwen_image/model_training/train.py`, for modification to adapt to new model training.
|
||||
Reference in New Issue
Block a user