minor fix

This commit is contained in:
mi804
2025-08-14 13:59:04 +08:00
parent 49f9a11eb3
commit 3212c83398
2 changed files with 1 additions and 2 deletions

View File

@@ -82,7 +82,6 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
# Extra inputs
controlnet_input = {}
for extra_input in self.extra_inputs:
inputs_shared[extra_input] = data[extra_input]
if extra_input.startswith("blockwise_controlnet_"):
controlnet_input[extra_input.replace("blockwise_controlnet_", "")] = data[extra_input]
elif extra_input.startswith("controlnet_"):

View File

@@ -296,12 +296,12 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
* `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数
* `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径LoRA 将从此检查点加载。
* 可训练模块
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
* `--lora_base_model`: LoRA 添加到哪个模型上。
* `--lora_target_modules`: LoRA 添加到哪一层上。
* `--lora_rank`: LoRA 的秩Rank
* `--lora_checkpoint`: LoRA 检查点的路径。如果提供此路径LoRA 将从此检查点加载。
* 额外模型输入
* `--extra_inputs`: 额外的模型输入,以逗号分隔。
* 显存管理