Merge pull request #749 from mi804/training_args

support num_workers,save_steps,find_unused_parameters
This commit is contained in:
Zhongjie Duan
2025-08-06 15:54:04 +08:00
committed by GitHub
11 changed files with 78 additions and 14 deletions

View File

@@ -4,6 +4,7 @@ from PIL import Image
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
@@ -364,10 +365,13 @@ class ModelLogger:
self.output_path = output_path self.output_path = output_path
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
self.state_dict_converter = state_dict_converter self.state_dict_converter = state_dict_converter
self.num_steps = 0
def on_step_end(self, loss): def on_step_end(self, accelerator, model, save_steps=None):
pass self.num_steps += 1
if save_steps is not None and self.num_steps % save_steps == 0:
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
def on_epoch_end(self, accelerator, model, epoch_id): def on_epoch_end(self, accelerator, model, epoch_id):
@@ -381,6 +385,21 @@ class ModelLogger:
accelerator.save(state_dict, path, safe_serialization=True) accelerator.save(state_dict, path, safe_serialization=True)
def on_training_end(self, accelerator, model, save_steps=None):
if save_steps is not None and self.num_steps % save_steps != 0:
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
def save_model(self, accelerator, model, file_name):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, file_name)
accelerator.save(state_dict, path, safe_serialization=True)
def launch_training_task( def launch_training_task(
dataset: torch.utils.data.Dataset, dataset: torch.utils.data.Dataset,
@@ -388,11 +407,17 @@ def launch_training_task(
model_logger: ModelLogger, model_logger: ModelLogger,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler, scheduler: torch.optim.lr_scheduler.LRScheduler,
num_workers: int = 8,
save_steps: int = None,
num_epochs: int = 1, num_epochs: int = 1,
gradient_accumulation_steps: int = 1, gradient_accumulation_steps: int = 1,
find_unused_parameters: bool = False,
): ):
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0]) dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) accelerator = Accelerator(
gradient_accumulation_steps=gradient_accumulation_steps,
kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=find_unused_parameters)],
)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
for epoch_id in range(num_epochs): for epoch_id in range(num_epochs):
@@ -402,10 +427,11 @@ def launch_training_task(
loss = model(data) loss = model(data)
accelerator.backward(loss) accelerator.backward(loss)
optimizer.step() optimizer.step()
model_logger.on_step_end(loss) model_logger.on_step_end(accelerator, model, save_steps)
scheduler.step() scheduler.step()
if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id) model_logger.on_epoch_end(accelerator, model, epoch_id)
model_logger.on_training_end(accelerator, model, save_steps)
def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"): def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):
@@ -446,6 +472,9 @@ def wan_parser():
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
return parser return parser
@@ -474,6 +503,9 @@ def flux_parser():
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
return parser return parser
@@ -503,4 +535,7 @@ def qwen_image_parser():
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
return parser return parser

View File

@@ -249,6 +249,7 @@ The script includes the following parameters:
* `--width`: Width of the image or video. Leave `height` and `width` empty to enable dynamic resolution. * `--width`: Width of the image or video. Leave `height` and `width` empty to enable dynamic resolution.
* `--data_file_keys`: Data file keys in the metadata. Separate with commas. * `--data_file_keys`: Data file keys in the metadata. Separate with commas.
* `--dataset_repeat`: Number of times the dataset repeats per epoch. * `--dataset_repeat`: Number of times the dataset repeats per epoch.
* `--dataset_num_workers`: Number of workers for data loading.
* Model * Model
* `--model_paths`: Paths to load models. In JSON format. * `--model_paths`: Paths to load models. In JSON format.
* `--model_id_with_origin_paths`: Model ID with original paths, e.g., black-forest-labs/FLUX.1-dev:flux1-dev.safetensors. Separate with commas. * `--model_id_with_origin_paths`: Model ID with original paths, e.g., black-forest-labs/FLUX.1-dev:flux1-dev.safetensors. Separate with commas.
@@ -257,6 +258,8 @@ The script includes the following parameters:
* `--num_epochs`: Number of epochs. * `--num_epochs`: Number of epochs.
* `--output_path`: Save path. * `--output_path`: Save path.
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint. * `--remove_prefix_in_ckpt`: Remove prefix in checkpoint.
* `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.
* `--find_unused_parameters`: Whether to find unused parameters in DDP.
* Trainable Modules * Trainable Modules
* `--trainable_models`: Models that can be trained, e.g., dit, vae, text_encoder. * `--trainable_models`: Models that can be trained, e.g., dit, vae, text_encoder.
* `--lora_base_model`: Which model to add LoRA to. * `--lora_base_model`: Which model to add LoRA to.

View File

@@ -249,6 +249,7 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra
* `--width`: 图像或视频的宽度。将 `height``width` 留空以启用动态分辨率。 * `--width`: 图像或视频的宽度。将 `height``width` 留空以启用动态分辨率。
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
* 模型 * 模型
* `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 black-forest-labs/FLUX.1-dev:flux1-dev.safetensors。用逗号分隔。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 black-forest-labs/FLUX.1-dev:flux1-dev.safetensors。用逗号分隔。
@@ -257,6 +258,8 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra
* `--num_epochs`: 轮数Epoch * `--num_epochs`: 轮数Epoch
* `--output_path`: 保存路径。 * `--output_path`: 保存路径。
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
* `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数
* 可训练模块 * 可训练模块
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
* `--lora_base_model`: LoRA 添加到哪个模型上。 * `--lora_base_model`: LoRA 添加到哪个模型上。

View File

@@ -121,4 +121,7 @@ if __name__ == "__main__":
dataset, model, model_logger, optimizer, scheduler, dataset, model, model_logger, optimizer, scheduler,
num_epochs=args.num_epochs, num_epochs=args.num_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
save_steps=args.save_steps,
find_unused_parameters=args.find_unused_parameters,
num_workers=args.dataset_num_workers,
) )

View File

@@ -219,6 +219,7 @@ The script includes the following parameters:
* `--width`: Width of image or video. Leave `height` and `width` empty to enable dynamic resolution. * `--width`: Width of image or video. Leave `height` and `width` empty to enable dynamic resolution.
* `--data_file_keys`: Data file keys in metadata. Separate with commas. * `--data_file_keys`: Data file keys in metadata. Separate with commas.
* `--dataset_repeat`: Number of times the dataset repeats per epoch. * `--dataset_repeat`: Number of times the dataset repeats per epoch.
* `--dataset_num_workers`: Number of workers for data loading.
* Model * Model
* `--model_paths`: Model paths to load. In JSON format. * `--model_paths`: Model paths to load. In JSON format.
* `--model_id_with_origin_paths`: Model ID with original paths, e.g., Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors. Separate with commas. * `--model_id_with_origin_paths`: Model ID with original paths, e.g., Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors. Separate with commas.
@@ -228,6 +229,8 @@ The script includes the following parameters:
* `--num_epochs`: Number of epochs. * `--num_epochs`: Number of epochs.
* `--output_path`: Save path. * `--output_path`: Save path.
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint. * `--remove_prefix_in_ckpt`: Remove prefix in checkpoint.
* `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.
* `--find_unused_parameters`: Whether to find unused parameters in DDP.
* Trainable Modules * Trainable Modules
* `--trainable_models`: Models to train, e.g., dit, vae, text_encoder. * `--trainable_models`: Models to train, e.g., dit, vae, text_encoder.
* `--lora_base_model`: Which model to add LoRA to. * `--lora_base_model`: Which model to add LoRA to.

View File

@@ -219,6 +219,7 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod
* `--width`: 图像或视频的宽度。将 `height``width` 留空以启用动态分辨率。 * `--width`: 图像或视频的宽度。将 `height``width` 留空以启用动态分辨率。
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
* 模型 * 模型
* `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors。用逗号分隔。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors。用逗号分隔。
@@ -228,6 +229,8 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod
* `--num_epochs`: 轮数Epoch * `--num_epochs`: 轮数Epoch
* `--output_path`: 保存路径。 * `--output_path`: 保存路径。
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
* `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数
* 可训练模块 * 可训练模块
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
* `--lora_base_model`: LoRA 添加到哪个模型上。 * `--lora_base_model`: LoRA 添加到哪个模型上。

View File

@@ -12,4 +12,6 @@ accelerate launch examples/qwen_image/model_training/train.py \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \ --lora_rank 32 \
--align_to_opensource_format \ --align_to_opensource_format \
--use_gradient_checkpointing --use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters

View File

@@ -117,4 +117,7 @@ if __name__ == "__main__":
dataset, model, model_logger, optimizer, scheduler, dataset, model, model_logger, optimizer, scheduler,
num_epochs=args.num_epochs, num_epochs=args.num_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
save_steps=args.save_steps,
find_unused_parameters=args.find_unused_parameters,
num_workers=args.dataset_num_workers,
) )

View File

@@ -280,6 +280,7 @@ The script includes the following parameters:
* `--num_frames`: Number of frames per video. Frames are sampled from the video prefix. * `--num_frames`: Number of frames per video. Frames are sampled from the video prefix.
* `--data_file_keys`: Data file keys in the metadata. Comma-separated. * `--data_file_keys`: Data file keys in the metadata. Comma-separated.
* `--dataset_repeat`: Number of times to repeat the dataset per epoch. * `--dataset_repeat`: Number of times to repeat the dataset per epoch.
* `--dataset_num_workers`: Number of workers for data loading.
* Models * Models
* `--model_paths`: Paths to load models. In JSON format. * `--model_paths`: Paths to load models. In JSON format.
* `--model_id_with_origin_paths`: Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated. * `--model_id_with_origin_paths`: Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.
@@ -290,6 +291,8 @@ The script includes the following parameters:
* `--num_epochs`: Number of epochs. * `--num_epochs`: Number of epochs.
* `--output_path`: Output save path. * `--output_path`: Output save path.
* `--remove_prefix_in_ckpt`: Remove prefix in ckpt. * `--remove_prefix_in_ckpt`: Remove prefix in ckpt.
* `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.
* `--find_unused_parameters`: Whether to find unused parameters in DDP.
* Trainable Modules * Trainable Modules
* `--trainable_models`: Models to train, e.g., dit, vae, text_encoder. * `--trainable_models`: Models to train, e.g., dit, vae, text_encoder.
* `--lora_base_model`: Which model LoRA is added to. * `--lora_base_model`: Which model LoRA is added to.

View File

@@ -282,6 +282,7 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
* `--num_frames`: 每个视频中的帧数。帧从视频前缀中采样。 * `--num_frames`: 每个视频中的帧数。帧从视频前缀中采样。
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
* 模型 * 模型
* `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors。用逗号分隔。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors。用逗号分隔。
@@ -292,6 +293,8 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
* `--num_epochs`: 轮数Epoch * `--num_epochs`: 轮数Epoch
* `--output_path`: 保存路径。 * `--output_path`: 保存路径。
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
* `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数
* 可训练模块 * 可训练模块
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
* `--lora_base_model`: LoRA 添加到哪个模型上。 * `--lora_base_model`: LoRA 添加到哪个模型上。

View File

@@ -127,4 +127,7 @@ if __name__ == "__main__":
dataset, model, model_logger, optimizer, scheduler, dataset, model, model_logger, optimizer, scheduler,
num_epochs=args.num_epochs, num_epochs=args.num_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
save_steps=args.save_steps,
find_unused_parameters=args.find_unused_parameters,
num_workers=args.dataset_num_workers,
) )