mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support num_workers,save_steps,find_unused_parameters
This commit is contained in:
@@ -280,6 +280,7 @@ The script includes the following parameters:
|
||||
* `--num_frames`: Number of frames per video. Frames are sampled from the video prefix.
|
||||
* `--data_file_keys`: Data file keys in the metadata. Comma-separated.
|
||||
* `--dataset_repeat`: Number of times to repeat the dataset per epoch.
|
||||
* `--dataset_num_workers`: Number of workers for data loading.
|
||||
* Models
|
||||
* `--model_paths`: Paths to load models. In JSON format.
|
||||
* `--model_id_with_origin_paths`: Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.
|
||||
@@ -290,6 +291,8 @@ The script includes the following parameters:
|
||||
* `--num_epochs`: Number of epochs.
|
||||
* `--output_path`: Output save path.
|
||||
* `--remove_prefix_in_ckpt`: Remove prefix in ckpt.
|
||||
* `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.
|
||||
* `--find_unused_parameters`: Whether to find unused parameters in DDP.
|
||||
* Trainable Modules
|
||||
* `--trainable_models`: Models to train, e.g., dit, vae, text_encoder.
|
||||
* `--lora_base_model`: Which model LoRA is added to.
|
||||
|
||||
@@ -282,6 +282,7 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
|
||||
* `--num_frames`: 每个视频中的帧数。帧从视频前缀中采样。
|
||||
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
|
||||
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
|
||||
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
|
||||
* 模型
|
||||
* `--model_paths`: 要加载的模型路径。JSON 格式。
|
||||
* `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors。用逗号分隔。
|
||||
@@ -292,6 +293,8 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
|
||||
* `--num_epochs`: 轮数(Epoch)。
|
||||
* `--output_path`: 保存路径。
|
||||
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
|
||||
* `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次
|
||||
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数
|
||||
* 可训练模块
|
||||
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
|
||||
* `--lora_base_model`: LoRA 添加到哪个模型上。
|
||||
|
||||
@@ -127,4 +127,7 @@ if __name__ == "__main__":
|
||||
dataset, model, model_logger, optimizer, scheduler,
|
||||
num_epochs=args.num_epochs,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
save_steps=args.save_steps,
|
||||
find_unused_parameters=args.find_unused_parameters,
|
||||
num_workers=args.dataset_num_workers,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user