support num_workers,save_steps,find_unused_parameters

This commit is contained in:
mi804
2025-08-06 10:52:59 +08:00
parent 8d2f6ad32e
commit 6bae70eee0
13 changed files with 71 additions and 16 deletions

View File

@@ -280,6 +280,7 @@ The script includes the following parameters:
* `--num_frames`: Number of frames per video. Frames are sampled from the video prefix.
* `--data_file_keys`: Data file keys in the metadata. Comma-separated.
* `--dataset_repeat`: Number of times to repeat the dataset per epoch.
* `--dataset_num_workers`: Number of workers for data loading.
* Models
* `--model_paths`: Paths to load models. In JSON format.
* `--model_id_with_origin_paths`: Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.
@@ -290,6 +291,8 @@ The script includes the following parameters:
* `--num_epochs`: Number of epochs.
* `--output_path`: Output save path.
* `--remove_prefix_in_ckpt`: Remove prefix in ckpt.
* `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.
* `--find_unused_parameters`: Whether to find unused parameters in DDP.
* Trainable Modules
* `--trainable_models`: Models to train, e.g., dit, vae, text_encoder.
* `--lora_base_model`: Which model LoRA is added to.

View File

@@ -282,6 +282,7 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
* `--num_frames`: 每个视频中的帧数。帧从视频前缀中采样。
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
* 模型
* `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors。用逗号分隔。
@@ -292,6 +293,8 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
* `--num_epochs`: 轮数Epoch
* `--output_path`: 保存路径。
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
* `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数
* 可训练模块
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
* `--lora_base_model`: LoRA 添加到哪个模型上。

View File

@@ -127,4 +127,7 @@ if __name__ == "__main__":
dataset, model, model_logger, optimizer, scheduler,
num_epochs=args.num_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps,
save_steps=args.save_steps,
find_unused_parameters=args.find_unused_parameters,
num_workers=args.dataset_num_workers,
)