mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge pull request #1354 from mi804/low_vram_training_ds
low vram training with deepspeed zero3
This commit is contained in:
@@ -1,12 +1,32 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import deepspeed
|
||||||
|
_HAS_DEEPSPEED = True
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
_HAS_DEEPSPEED = False
|
||||||
|
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs, **kwargs):
|
def custom_forward(*inputs, **kwargs):
|
||||||
return module(*inputs, **kwargs)
|
return module(*inputs, **kwargs)
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
|
|
||||||
|
def create_custom_forward_use_reentrant(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
|
||||||
|
def judge_args_requires_grad(*args):
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, torch.Tensor) and arg.requires_grad:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def gradient_checkpoint_forward(
|
def gradient_checkpoint_forward(
|
||||||
model,
|
model,
|
||||||
use_gradient_checkpointing,
|
use_gradient_checkpointing,
|
||||||
@@ -14,6 +34,17 @@ def gradient_checkpoint_forward(
|
|||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
if use_gradient_checkpointing and _HAS_DEEPSPEED and deepspeed.checkpointing.is_configured():
|
||||||
|
all_args = args + tuple(kwargs.values())
|
||||||
|
if not judge_args_requires_grad(*all_args):
|
||||||
|
# get the first grad_enabled tensor from un_checkpointed forward
|
||||||
|
model_output = model(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
model_output = deepspeed.checkpointing.checkpoint(
|
||||||
|
create_custom_forward_use_reentrant(model),
|
||||||
|
*all_args,
|
||||||
|
)
|
||||||
|
return model_output
|
||||||
if use_gradient_checkpointing_offload:
|
if use_gradient_checkpointing_offload:
|
||||||
with torch.autograd.graph.save_on_cpu():
|
with torch.autograd.graph.save_on_cpu():
|
||||||
model_output = torch.utils.checkpoint.checkpoint(
|
model_output = torch.utils.checkpoint.checkpoint(
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ def launch_training_task(
|
|||||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
|
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
|
||||||
model.to(device=accelerator.device)
|
model.to(device=accelerator.device)
|
||||||
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
|
||||||
|
initialize_deepspeed_gradient_checkpointing(accelerator)
|
||||||
for epoch_id in range(num_epochs):
|
for epoch_id in range(num_epochs):
|
||||||
for data in tqdm(dataloader):
|
for data in tqdm(dataloader):
|
||||||
with accelerator.accumulate(model):
|
with accelerator.accumulate(model):
|
||||||
@@ -70,3 +70,19 @@ def launch_data_process_task(
|
|||||||
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
|
save_path = os.path.join(model_logger.output_path, str(accelerator.process_index), f"{data_id}.pth")
|
||||||
data = model(data)
|
data = model(data)
|
||||||
torch.save(data, save_path)
|
torch.save(data, save_path)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_deepspeed_gradient_checkpointing(accelerator: Accelerator):
|
||||||
|
if getattr(accelerator.state, "deepspeed_plugin", None) is not None:
|
||||||
|
ds_config = accelerator.state.deepspeed_plugin.deepspeed_config
|
||||||
|
if "activation_checkpointing" in ds_config:
|
||||||
|
import deepspeed
|
||||||
|
act_config = ds_config["activation_checkpointing"]
|
||||||
|
deepspeed.checkpointing.configure(
|
||||||
|
mpu_=None,
|
||||||
|
partition_activations=act_config.get("partition_activations", False),
|
||||||
|
checkpoint_in_cpu=act_config.get("cpu_checkpointing", False),
|
||||||
|
contiguous_checkpointing=act_config.get("contiguous_memory_optimization", False)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("Do not find activation_checkpointing config in deepspeed config, skip initializing deepspeed gradient checkpointing.")
|
||||||
|
|||||||
@@ -123,7 +123,6 @@ Similar to [model loading during inference](../Pipeline_Usage/Model_Inference.md
|
|||||||
|
|
||||||
<details>
|
<details>
|
||||||
|
|
||||||
<details>
|
|
||||||
|
|
||||||
<summary>Load models from local file paths</summary>
|
<summary>Load models from local file paths</summary>
|
||||||
|
|
||||||
@@ -244,4 +243,119 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelera
|
|||||||
* The training framework does not support batch size > 1. The reasons are complex. See [Q&A: Why doesn't the training framework support batch size > 1?](../QA.md#why-doesnt-the-training-framework-support-batch-size--1)
|
* The training framework does not support batch size > 1. The reasons are complex. See [Q&A: Why doesn't the training framework support batch size > 1?](../QA.md#why-doesnt-the-training-framework-support-batch-size--1)
|
||||||
* Some models contain redundant parameters. For example, the text encoding part of the last layer of Qwen-Image's DiT part. When training these models, `--find_unused_parameters` needs to be set to avoid errors in multi-GPU training. For compatibility with community models, we do not intend to remove these redundant parameters.
|
* Some models contain redundant parameters. For example, the text encoding part of the last layer of Qwen-Image's DiT part. When training these models, `--find_unused_parameters` needs to be set to avoid errors in multi-GPU training. For compatibility with community models, we do not intend to remove these redundant parameters.
|
||||||
* The loss function value of Diffusion models has little relationship with actual effects. Therefore, we do not record loss function values during training. We recommend setting `--num_epochs` to a sufficiently large value, testing while training, and manually closing the training program after the effect converges.
|
* The loss function value of Diffusion models has little relationship with actual effects. Therefore, we do not record loss function values during training. We recommend setting `--num_epochs` to a sufficiently large value, testing while training, and manually closing the training program after the effect converges.
|
||||||
* `--use_gradient_checkpointing` is usually enabled unless GPU VRAM is sufficient; `--use_gradient_checkpointing_offload` is enabled as needed. See [`diffsynth.core.gradient`](../API_Reference/core/gradient.md) for details.
|
* `--use_gradient_checkpointing` is usually enabled unless GPU VRAM is sufficient; `--use_gradient_checkpointing_offload` is enabled as needed. See [`diffsynth.core.gradient`](../API_Reference/core/gradient.md) for details.
|
||||||
|
|
||||||
|
## Low VRAM Training
|
||||||
|
|
||||||
|
If you want to complete LoRA model training on GPU with low vram, you can combine [Two-Stage Split Training](../Training/Split_Training.md) with `deepspeed_zero3_offload` training. First, split the preprocessing steps into the first stage and store the computed results onto the hard disk. Second, read these results from the disk and train the denoising model. By using `deepspeed_zero3_offload`, the training parameters and optimizer states are offloaded to the CPU or disk. We provide examples for some models, primarily by specifying the `deepspeed` configuration via `--config_file`.
|
||||||
|
|
||||||
|
Please note that the `deepspeed_zero3_offload` mode is incompatible with PyTorch's native gradient checkpointing mechanism. To address this, we have adapted the `checkpointing` interface of `deepspeed`. Users need to fill the `activation_checkpointing` field in the `deepspeed` configuration to enable gradient checkpointing.
|
||||||
|
|
||||||
|
Below is the script for low VRAM model training for the Qwen-Image model:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
accelerate launch examples/qwen_image/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_image_dataset \
|
||||||
|
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 1 \
|
||||||
|
--model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Qwen-Image_lora-splited-cache" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--task "sft:data_process" \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--dataset_num_workers 8 \
|
||||||
|
--find_unused_parameters
|
||||||
|
|
||||||
|
accelerate launch --config_file examples/qwen_image/model_training/special/low_vram_training/deepspeed_zero3_cpuoffload.yaml examples/qwen_image/model_training/train.py \
|
||||||
|
--dataset_base_path "./models/train/Qwen-Image_lora-splited-cache" \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 50 \
|
||||||
|
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Qwen-Image_lora" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--task "sft:train" \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--dataset_num_workers 8 \
|
||||||
|
--find_unused_parameters \
|
||||||
|
--initialize_model_on_cpu
|
||||||
|
```
|
||||||
|
|
||||||
|
The configurations for `accelerate` and `deepspeed` are as follows:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: true
|
||||||
|
deepspeed_config:
|
||||||
|
deepspeed_config_file: examples/qwen_image/model_training/special/low_vram_training/ds_z3_cpuoffload.json
|
||||||
|
zero3_init_flag: true
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
enable_cpu_affinity: false
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 1
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
|
```
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"loss_scale": 0,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"initial_scale_power": 16,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": "auto"
|
||||||
|
},
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 3,
|
||||||
|
"offload_optimizer": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"offload_param": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"overlap_comm": false,
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"sub_group_size": 1e9,
|
||||||
|
"reduce_bucket_size": 5e7,
|
||||||
|
"stage3_prefetch_bucket_size": 5e7,
|
||||||
|
"stage3_param_persistence_threshold": 1e5,
|
||||||
|
"stage3_max_live_parameters": 1e8,
|
||||||
|
"stage3_max_reuse_distance": 1e8,
|
||||||
|
"stage3_gather_16bit_weights_on_model_save": true
|
||||||
|
},
|
||||||
|
"activation_checkpointing": {
|
||||||
|
"partition_activations": false,
|
||||||
|
"cpu_checkpointing": false,
|
||||||
|
"contiguous_memory_optimization": false
|
||||||
|
},
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"gradient_clipping": "auto",
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
||||||
|
```
|
||||||
@@ -243,3 +243,116 @@ accelerate launch --config_file examples/qwen_image/model_training/full/accelera
|
|||||||
* 少数模型包含冗余参数,例如 Qwen-Image 的 DiT 部分最后一层的文本编码部分,在训练这些模型时,需设置 `--find_unused_parameters` 避免在多 GPU 训练中报错。出于对开源社区模型兼容性的考虑,我们不打算删除这些冗余参数。
|
* 少数模型包含冗余参数,例如 Qwen-Image 的 DiT 部分最后一层的文本编码部分,在训练这些模型时,需设置 `--find_unused_parameters` 避免在多 GPU 训练中报错。出于对开源社区模型兼容性的考虑,我们不打算删除这些冗余参数。
|
||||||
* Diffusion 模型的损失函数值与实际效果的关系不大,因此我们在训练过程中不会记录损失函数值。我们建议把 `--num_epochs` 设置为足够大的数值,边训边测,直至效果收敛后手动关闭训练程序。
|
* Diffusion 模型的损失函数值与实际效果的关系不大,因此我们在训练过程中不会记录损失函数值。我们建议把 `--num_epochs` 设置为足够大的数值,边训边测,直至效果收敛后手动关闭训练程序。
|
||||||
* `--use_gradient_checkpointing` 通常是开启的,除非 GPU 显存足够;`--use_gradient_checkpointing_offload` 则按需开启,详见 [`diffsynth.core.gradient`](../API_Reference/core/gradient.md)。
|
* `--use_gradient_checkpointing` 通常是开启的,除非 GPU 显存足够;`--use_gradient_checkpointing_offload` 则按需开启,详见 [`diffsynth.core.gradient`](../API_Reference/core/gradient.md)。
|
||||||
|
|
||||||
|
## 低显存训练
|
||||||
|
如果想在低显存显卡上完成 LoRA 模型训练,可以同时采用 [两阶段拆分训练](../Training/Split_Training.md) 和 `deepspeed_zero3_offload` 训练。 首先,将前处理过程拆分到第一阶段,将计算结果存储到硬盘中。其次,在第二阶段从硬盘中读取这些结果并进行去噪模型的训练,训练通过采用 `deepspeed_zero3_offload`,将训练参数和优化器状态 offload 到 cpu 或者 disk 上。我们为部分模型提供了样例,主要是通过 `--config_file` 指定 `deepspeed` 配置。
|
||||||
|
|
||||||
|
需要注意的是,`deepspeed_zero3_offload` 模式与 `pytorch` 原生的梯度检查点机制不兼容,我们为此对 `deepspeed` 的`checkpointing` 接口做了适配。用户需要在 `deepspeed` 配置中填写 `activation_checkpointing` 字段以启用梯度检查点。
|
||||||
|
|
||||||
|
以下为 Qwen-Image 模型的低显存模型训练脚本:
|
||||||
|
```shell
|
||||||
|
accelerate launch examples/qwen_image/model_training/train.py \
|
||||||
|
--dataset_base_path data/example_image_dataset \
|
||||||
|
--dataset_metadata_path data/example_image_dataset/metadata.csv \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 1 \
|
||||||
|
--model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Qwen-Image_lora-splited-cache" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--task "sft:data_process" \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--dataset_num_workers 8 \
|
||||||
|
--find_unused_parameters
|
||||||
|
|
||||||
|
accelerate launch --config_file examples/qwen_image/model_training/special/low_vram_training/deepspeed_zero3_cpuoffload.yaml examples/qwen_image/model_training/train.py \
|
||||||
|
--dataset_base_path "./models/train/Qwen-Image_lora-splited-cache" \
|
||||||
|
--max_pixels 1048576 \
|
||||||
|
--dataset_repeat 50 \
|
||||||
|
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \
|
||||||
|
--learning_rate 1e-4 \
|
||||||
|
--num_epochs 5 \
|
||||||
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
|
--output_path "./models/train/Qwen-Image_lora" \
|
||||||
|
--lora_base_model "dit" \
|
||||||
|
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||||
|
--lora_rank 32 \
|
||||||
|
--task "sft:train" \
|
||||||
|
--use_gradient_checkpointing \
|
||||||
|
--dataset_num_workers 8 \
|
||||||
|
--find_unused_parameters \
|
||||||
|
--initialize_model_on_cpu
|
||||||
|
```
|
||||||
|
|
||||||
|
其中,`accelerate` 和 `deepspeed` 的配置文件如下:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: true
|
||||||
|
deepspeed_config:
|
||||||
|
deepspeed_config_file: examples/qwen_image/model_training/special/low_vram_training/ds_z3_cpuoffload.json
|
||||||
|
zero3_init_flag: true
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
enable_cpu_affinity: false
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 1
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
|
```
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"loss_scale": 0,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"initial_scale_power": 16,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": "auto"
|
||||||
|
},
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 3,
|
||||||
|
"offload_optimizer": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"offload_param": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"overlap_comm": false,
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"sub_group_size": 1e9,
|
||||||
|
"reduce_bucket_size": 5e7,
|
||||||
|
"stage3_prefetch_bucket_size": 5e7,
|
||||||
|
"stage3_param_persistence_threshold": 1e5,
|
||||||
|
"stage3_max_live_parameters": 1e8,
|
||||||
|
"stage3_max_reuse_distance": 1e8,
|
||||||
|
"stage3_gather_16bit_weights_on_model_save": true
|
||||||
|
},
|
||||||
|
"activation_checkpointing": {
|
||||||
|
"partition_activations": false,
|
||||||
|
"cpu_checkpointing": false,
|
||||||
|
"contiguous_memory_optimization": false
|
||||||
|
},
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"gradient_clipping": "auto",
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
||||||
|
```
|
||||||
@@ -4,35 +4,32 @@ accelerate launch examples/qwen_image/model_training/train.py \
|
|||||||
--max_pixels 1048576 \
|
--max_pixels 1048576 \
|
||||||
--dataset_repeat 1 \
|
--dataset_repeat 1 \
|
||||||
--model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
--model_id_with_origin_paths "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
||||||
--fp8_models "Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
|
|
||||||
--learning_rate 1e-4 \
|
--learning_rate 1e-4 \
|
||||||
--num_epochs 5 \
|
--num_epochs 5 \
|
||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
--output_path "./models/train/Qwen-Image-LoRA-lowvram-cache" \
|
--output_path "./models/train/Qwen-Image_lora-splited-cache" \
|
||||||
--lora_base_model "dit" \
|
--lora_base_model "dit" \
|
||||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||||
--lora_rank 32 \
|
--lora_rank 32 \
|
||||||
|
--task "sft:data_process" \
|
||||||
--use_gradient_checkpointing \
|
--use_gradient_checkpointing \
|
||||||
--use_gradient_checkpointing_offload \
|
|
||||||
--dataset_num_workers 8 \
|
--dataset_num_workers 8 \
|
||||||
--find_unused_parameters \
|
--find_unused_parameters
|
||||||
--task "sft:data_process"
|
|
||||||
|
|
||||||
accelerate launch examples/qwen_image/model_training/train.py \
|
accelerate launch --config_file examples/qwen_image/model_training/special/low_vram_training/deepspeed_zero3_cpuoffload.yaml examples/qwen_image/model_training/train.py \
|
||||||
--dataset_base_path "./models/train/Qwen-Image-LoRA-lowvram-cache" \
|
--dataset_base_path "./models/train/Qwen-Image_lora-splited-cache" \
|
||||||
--max_pixels 1048576 \
|
--max_pixels 1048576 \
|
||||||
--dataset_repeat 50 \
|
--dataset_repeat 50 \
|
||||||
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \
|
--model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \
|
||||||
--fp8_models "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors" \
|
|
||||||
--learning_rate 1e-4 \
|
--learning_rate 1e-4 \
|
||||||
--num_epochs 5 \
|
--num_epochs 5 \
|
||||||
--remove_prefix_in_ckpt "pipe.dit." \
|
--remove_prefix_in_ckpt "pipe.dit." \
|
||||||
--output_path "./models/train/Qwen-Image-LoRA-lowvram" \
|
--output_path "./models/train/Qwen-Image_lora" \
|
||||||
--lora_base_model "dit" \
|
--lora_base_model "dit" \
|
||||||
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
|
||||||
--lora_rank 32 \
|
--lora_rank 32 \
|
||||||
|
--task "sft:train" \
|
||||||
--use_gradient_checkpointing \
|
--use_gradient_checkpointing \
|
||||||
--use_gradient_checkpointing_offload \
|
|
||||||
--dataset_num_workers 8 \
|
--dataset_num_workers 8 \
|
||||||
--find_unused_parameters \
|
--find_unused_parameters \
|
||||||
--task "sft:train"
|
--initialize_model_on_cpu
|
||||||
|
|||||||
@@ -0,0 +1,18 @@
|
|||||||
|
compute_environment: LOCAL_MACHINE
|
||||||
|
debug: true
|
||||||
|
deepspeed_config:
|
||||||
|
deepspeed_config_file: examples/qwen_image/model_training/special/low_vram_training/ds_z3_cpuoffload.json
|
||||||
|
zero3_init_flag: true
|
||||||
|
distributed_type: DEEPSPEED
|
||||||
|
downcast_bf16: 'no'
|
||||||
|
enable_cpu_affinity: false
|
||||||
|
machine_rank: 0
|
||||||
|
main_training_function: main
|
||||||
|
num_machines: 1
|
||||||
|
num_processes: 1
|
||||||
|
rdzv_backend: static
|
||||||
|
same_network: true
|
||||||
|
tpu_env: []
|
||||||
|
tpu_use_cluster: false
|
||||||
|
tpu_use_sudo: false
|
||||||
|
use_cpu: false
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
{
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"loss_scale": 0,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"initial_scale_power": 16,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": "auto"
|
||||||
|
},
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 3,
|
||||||
|
"offload_optimizer": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"offload_param": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"overlap_comm": false,
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"sub_group_size": 1e9,
|
||||||
|
"reduce_bucket_size": 5e7,
|
||||||
|
"stage3_prefetch_bucket_size": 5e7,
|
||||||
|
"stage3_param_persistence_threshold": 1e5,
|
||||||
|
"stage3_max_live_parameters": 1e8,
|
||||||
|
"stage3_max_reuse_distance": 1e8,
|
||||||
|
"stage3_gather_16bit_weights_on_model_save": true
|
||||||
|
},
|
||||||
|
"activation_checkpointing": {
|
||||||
|
"partition_activations": false,
|
||||||
|
"cpu_checkpointing": false,
|
||||||
|
"contiguous_memory_optimization": false
|
||||||
|
},
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"gradient_clipping": "auto",
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user