support num_workers,save_steps,find_unused_parameters

This commit is contained in:
mi804
2025-08-06 10:52:59 +08:00
parent 8d2f6ad32e
commit 6bae70eee0
13 changed files with 71 additions and 16 deletions

View File

@@ -218,6 +218,7 @@ The script includes the following parameters:
* `--width`: Width of image or video. Leave `height` and `width` empty to enable dynamic resolution.
* `--data_file_keys`: Data file keys in metadata. Separate with commas.
* `--dataset_repeat`: Number of times the dataset repeats per epoch.
* `--dataset_num_workers`: Number of workers for data loading.
* Model
* `--model_paths`: Model paths to load. In JSON format.
* `--model_id_with_origin_paths`: Model ID with original paths, e.g., Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors. Separate with commas.
@@ -227,6 +228,8 @@ The script includes the following parameters:
* `--num_epochs`: Number of epochs.
* `--output_path`: Save path.
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint.
* `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.
* `--find_unused_parameters`: Whether to find unused parameters in DDP.
* Trainable Modules
* `--trainable_models`: Models to train, e.g., dit, vae, text_encoder.
* `--lora_base_model`: Which model to add LoRA to.

View File

@@ -218,6 +218,7 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod
* `--width`: 图像或视频的宽度。将 `height``width` 留空以启用动态分辨率。
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
* 模型
* `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors。用逗号分隔。
@@ -227,6 +228,8 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod
* `--num_epochs`: 轮数Epoch
* `--output_path`: 保存路径。
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
* `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数
* 可训练模块
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
* `--lora_base_model`: LoRA 添加到哪个模型上。

View File

@@ -12,4 +12,6 @@ accelerate launch examples/qwen_image/model_training/train.py \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--align_to_opensource_format \
--use_gradient_checkpointing
--use_gradient_checkpointing \
--num_workers 8 \
--find_unused_parameters

View File

@@ -29,7 +29,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=ModelConfig(tokenizer_path))
else:
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
# Reset training scheduler (do it in each training step)
self.pipe.scheduler.set_timesteps(1000, training=True)
@@ -49,7 +49,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
def forward_preprocess(self, data):
# CFG-sensitive parameters
@@ -115,4 +115,7 @@ if __name__ == "__main__":
dataset, model, model_logger, optimizer, scheduler,
num_epochs=args.num_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps,
save_steps=args.save_steps,
find_unused_parameters=args.find_unused_parameters,
num_workers=args.dataset_num_workers,
)