support num_workers,save_steps,find_unused_parameters

This commit is contained in:
mi804
2025-08-06 10:52:59 +08:00
parent 8d2f6ad32e
commit 6bae70eee0
13 changed files with 71 additions and 16 deletions

View File

@@ -362,10 +362,10 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
## Update History ## Update History
- **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) with a focus on aesthetic photography is comprehensively supported, including low-GPU-memory layer-by-layer offload, LoRA training and full training. See [./examples/flux/](./examples/flux/).
- **August 4, 2025** 🔥 Qwen-Image is now open source. Welcome the new member to the image generation model family! - **August 4, 2025** 🔥 Qwen-Image is now open source. Welcome the new member to the image generation model family!
- **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) with a focus on aesthetic photography is comprehensively supported, including low-GPU-memory layer-by-layer offload, LoRA training and full training. See [./examples/flux/](./examples/flux/).
- **July 28, 2025** With the open-sourcing of Wan 2.2, we immediately provided comprehensive support, including low-GPU-memory layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, full training. See [./examples/wanvideo/](./examples/wanvideo/). - **July 28, 2025** With the open-sourcing of Wan 2.2, we immediately provided comprehensive support, including low-GPU-memory layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, full training. See [./examples/wanvideo/](./examples/wanvideo/).
- **July 11, 2025** We propose Nexus-Gen, a unified model that synergizes the language reasoning capabilities of LLMs with the image synthesis power of diffusion models. This framework enables seamless image understanding, generation, and editing tasks. - **July 11, 2025** We propose Nexus-Gen, a unified model that synergizes the language reasoning capabilities of LLMs with the image synthesis power of diffusion models. This framework enables seamless image understanding, generation, and editing tasks.

View File

@@ -378,10 +378,10 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
## 更新历史 ## 更新历史
- **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源,这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持,包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。
- **2025年8月4日** 🔥 Qwen-Image 开源,欢迎图像生成模型家族新成员! - **2025年8月4日** 🔥 Qwen-Image 开源,欢迎图像生成模型家族新成员!
- **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源,这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持,包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。
- **2025年7月28日** Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。 - **2025年7月28日** Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。
- **2025年7月11日** 我们提出 Nexus-Gen一个将大语言模型LLM的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。 - **2025年7月11日** 我们提出 Nexus-Gen一个将大语言模型LLM的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。

View File

@@ -4,6 +4,7 @@ from PIL import Image
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
@@ -369,34 +370,42 @@ class ModelLogger:
def on_step_end(self, loss): def on_step_end(self, loss):
pass pass
def on_epoch_end(self, accelerator, model, epoch_id): def on_model_save(self, accelerator, model, step_id=None, epoch_id=None):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model) state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict) state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True) os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors") if step_id is not None:
path = os.path.join(self.output_path, f"step-{step_id}.safetensors")
else:
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
accelerator.save(state_dict, path, safe_serialization=True) accelerator.save(state_dict, path, safe_serialization=True)
def launch_training_task( def launch_training_task(
dataset: torch.utils.data.Dataset, dataset: torch.utils.data.Dataset,
model: DiffusionTrainingModule, model: DiffusionTrainingModule,
model_logger: ModelLogger, model_logger: ModelLogger,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler, scheduler: torch.optim.lr_scheduler.LRScheduler,
num_workers: int = 8,
save_steps: int = None,
num_epochs: int = 1, num_epochs: int = 1,
gradient_accumulation_steps: int = 1, gradient_accumulation_steps: int = 1,
find_unused_parameters: bool = False,
): ):
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0]) dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) accelerator = Accelerator(
gradient_accumulation_steps=gradient_accumulation_steps,
kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=find_unused_parameters)],
)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
for epoch_id in range(num_epochs): for epoch_id in range(num_epochs):
for data in tqdm(dataloader): for step_id, data in enumerate(tqdm(dataloader)):
with accelerator.accumulate(model): with accelerator.accumulate(model):
optimizer.zero_grad() optimizer.zero_grad()
loss = model(data) loss = model(data)
@@ -404,8 +413,16 @@ def launch_training_task(
optimizer.step() optimizer.step()
model_logger.on_step_end(loss) model_logger.on_step_end(loss)
scheduler.step() scheduler.step()
model_logger.on_epoch_end(accelerator, model, epoch_id) global_steps = epoch_id * len(dataloader) + step_id + 1
# save every `save_steps` steps
if save_steps is not None and global_steps % save_steps == 0:
model_logger.on_model_save(accelerator, model, step_id=global_steps)
# save the model at the end of each epoch if save_steps is None
if save_steps is None:
model_logger.on_model_save(accelerator, model, epoch_id=epoch_id)
# save the final model if save_steps is not None
if save_steps is not None:
model_logger.on_model_save(accelerator, model, step_id=global_steps)
def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"): def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):
@@ -446,6 +463,9 @@ def wan_parser():
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
return parser return parser
@@ -474,6 +494,9 @@ def flux_parser():
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
return parser return parser
@@ -503,4 +526,7 @@ def qwen_image_parser():
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
return parser return parser

View File

@@ -249,6 +249,7 @@ The script includes the following parameters:
* `--width`: Width of the image or video. Leave `height` and `width` empty to enable dynamic resolution. * `--width`: Width of the image or video. Leave `height` and `width` empty to enable dynamic resolution.
* `--data_file_keys`: Data file keys in the metadata. Separate with commas. * `--data_file_keys`: Data file keys in the metadata. Separate with commas.
* `--dataset_repeat`: Number of times the dataset repeats per epoch. * `--dataset_repeat`: Number of times the dataset repeats per epoch.
* `--dataset_num_workers`: Number of workers for data loading.
* Model * Model
* `--model_paths`: Paths to load models. In JSON format. * `--model_paths`: Paths to load models. In JSON format.
* `--model_id_with_origin_paths`: Model ID with original paths, e.g., black-forest-labs/FLUX.1-dev:flux1-dev.safetensors. Separate with commas. * `--model_id_with_origin_paths`: Model ID with original paths, e.g., black-forest-labs/FLUX.1-dev:flux1-dev.safetensors. Separate with commas.
@@ -257,6 +258,8 @@ The script includes the following parameters:
* `--num_epochs`: Number of epochs. * `--num_epochs`: Number of epochs.
* `--output_path`: Save path. * `--output_path`: Save path.
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint. * `--remove_prefix_in_ckpt`: Remove prefix in checkpoint.
* `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.
* `--find_unused_parameters`: Whether to find unused parameters in DDP.
* Trainable Modules * Trainable Modules
* `--trainable_models`: Models that can be trained, e.g., dit, vae, text_encoder. * `--trainable_models`: Models that can be trained, e.g., dit, vae, text_encoder.
* `--lora_base_model`: Which model to add LoRA to. * `--lora_base_model`: Which model to add LoRA to.

View File

@@ -249,6 +249,7 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra
* `--width`: 图像或视频的宽度。将 `height``width` 留空以启用动态分辨率。 * `--width`: 图像或视频的宽度。将 `height``width` 留空以启用动态分辨率。
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
* 模型 * 模型
* `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 black-forest-labs/FLUX.1-dev:flux1-dev.safetensors。用逗号分隔。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 black-forest-labs/FLUX.1-dev:flux1-dev.safetensors。用逗号分隔。
@@ -257,6 +258,8 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra
* `--num_epochs`: 轮数Epoch * `--num_epochs`: 轮数Epoch
* `--output_path`: 保存路径。 * `--output_path`: 保存路径。
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
* `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数
* 可训练模块 * 可训练模块
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
* `--lora_base_model`: LoRA 添加到哪个模型上。 * `--lora_base_model`: LoRA 添加到哪个模型上。

View File

@@ -121,4 +121,7 @@ if __name__ == "__main__":
dataset, model, model_logger, optimizer, scheduler, dataset, model, model_logger, optimizer, scheduler,
num_epochs=args.num_epochs, num_epochs=args.num_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
save_steps=args.save_steps,
find_unused_parameters=args.find_unused_parameters,
num_workers=args.dataset_num_workers,
) )

View File

@@ -218,6 +218,7 @@ The script includes the following parameters:
* `--width`: Width of image or video. Leave `height` and `width` empty to enable dynamic resolution. * `--width`: Width of image or video. Leave `height` and `width` empty to enable dynamic resolution.
* `--data_file_keys`: Data file keys in metadata. Separate with commas. * `--data_file_keys`: Data file keys in metadata. Separate with commas.
* `--dataset_repeat`: Number of times the dataset repeats per epoch. * `--dataset_repeat`: Number of times the dataset repeats per epoch.
* `--dataset_num_workers`: Number of workers for data loading.
* Model * Model
* `--model_paths`: Model paths to load. In JSON format. * `--model_paths`: Model paths to load. In JSON format.
* `--model_id_with_origin_paths`: Model ID with original paths, e.g., Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors. Separate with commas. * `--model_id_with_origin_paths`: Model ID with original paths, e.g., Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors. Separate with commas.
@@ -227,6 +228,8 @@ The script includes the following parameters:
* `--num_epochs`: Number of epochs. * `--num_epochs`: Number of epochs.
* `--output_path`: Save path. * `--output_path`: Save path.
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint. * `--remove_prefix_in_ckpt`: Remove prefix in checkpoint.
* `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.
* `--find_unused_parameters`: Whether to find unused parameters in DDP.
* Trainable Modules * Trainable Modules
* `--trainable_models`: Models to train, e.g., dit, vae, text_encoder. * `--trainable_models`: Models to train, e.g., dit, vae, text_encoder.
* `--lora_base_model`: Which model to add LoRA to. * `--lora_base_model`: Which model to add LoRA to.

View File

@@ -218,6 +218,7 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod
* `--width`: 图像或视频的宽度。将 `height``width` 留空以启用动态分辨率。 * `--width`: 图像或视频的宽度。将 `height``width` 留空以启用动态分辨率。
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
* 模型 * 模型
* `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors。用逗号分隔。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors。用逗号分隔。
@@ -227,6 +228,8 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod
* `--num_epochs`: 轮数Epoch * `--num_epochs`: 轮数Epoch
* `--output_path`: 保存路径。 * `--output_path`: 保存路径。
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
* `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数
* 可训练模块 * 可训练模块
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
* `--lora_base_model`: LoRA 添加到哪个模型上。 * `--lora_base_model`: LoRA 添加到哪个模型上。

View File

@@ -12,4 +12,6 @@ accelerate launch examples/qwen_image/model_training/train.py \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \ --lora_rank 32 \
--align_to_opensource_format \ --align_to_opensource_format \
--use_gradient_checkpointing --use_gradient_checkpointing \
--num_workers 8 \
--find_unused_parameters

View File

@@ -29,7 +29,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=ModelConfig(tokenizer_path)) self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=ModelConfig(tokenizer_path))
else: else:
self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs) self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs)
# Reset training scheduler (do it in each training step) # Reset training scheduler (do it in each training step)
self.pipe.scheduler.set_timesteps(1000, training=True) self.pipe.scheduler.set_timesteps(1000, training=True)
@@ -49,7 +49,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule):
self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
def forward_preprocess(self, data): def forward_preprocess(self, data):
# CFG-sensitive parameters # CFG-sensitive parameters
@@ -115,4 +115,7 @@ if __name__ == "__main__":
dataset, model, model_logger, optimizer, scheduler, dataset, model, model_logger, optimizer, scheduler,
num_epochs=args.num_epochs, num_epochs=args.num_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
save_steps=args.save_steps,
find_unused_parameters=args.find_unused_parameters,
num_workers=args.dataset_num_workers,
) )

View File

@@ -280,6 +280,7 @@ The script includes the following parameters:
* `--num_frames`: Number of frames per video. Frames are sampled from the video prefix. * `--num_frames`: Number of frames per video. Frames are sampled from the video prefix.
* `--data_file_keys`: Data file keys in the metadata. Comma-separated. * `--data_file_keys`: Data file keys in the metadata. Comma-separated.
* `--dataset_repeat`: Number of times to repeat the dataset per epoch. * `--dataset_repeat`: Number of times to repeat the dataset per epoch.
* `--dataset_num_workers`: Number of workers for data loading.
* Models * Models
* `--model_paths`: Paths to load models. In JSON format. * `--model_paths`: Paths to load models. In JSON format.
* `--model_id_with_origin_paths`: Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated. * `--model_id_with_origin_paths`: Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.
@@ -290,6 +291,8 @@ The script includes the following parameters:
* `--num_epochs`: Number of epochs. * `--num_epochs`: Number of epochs.
* `--output_path`: Output save path. * `--output_path`: Output save path.
* `--remove_prefix_in_ckpt`: Remove prefix in ckpt. * `--remove_prefix_in_ckpt`: Remove prefix in ckpt.
* `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.
* `--find_unused_parameters`: Whether to find unused parameters in DDP.
* Trainable Modules * Trainable Modules
* `--trainable_models`: Models to train, e.g., dit, vae, text_encoder. * `--trainable_models`: Models to train, e.g., dit, vae, text_encoder.
* `--lora_base_model`: Which model LoRA is added to. * `--lora_base_model`: Which model LoRA is added to.

View File

@@ -282,6 +282,7 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
* `--num_frames`: 每个视频中的帧数。帧从视频前缀中采样。 * `--num_frames`: 每个视频中的帧数。帧从视频前缀中采样。
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
* 模型 * 模型
* `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors。用逗号分隔。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors。用逗号分隔。
@@ -292,6 +293,8 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
* `--num_epochs`: 轮数Epoch * `--num_epochs`: 轮数Epoch
* `--output_path`: 保存路径。 * `--output_path`: 保存路径。
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
* `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数
* 可训练模块 * 可训练模块
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
* `--lora_base_model`: LoRA 添加到哪个模型上。 * `--lora_base_model`: LoRA 添加到哪个模型上。

View File

@@ -127,4 +127,7 @@ if __name__ == "__main__":
dataset, model, model_logger, optimizer, scheduler, dataset, model, model_logger, optimizer, scheduler,
num_epochs=args.num_epochs, num_epochs=args.num_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
save_steps=args.save_steps,
find_unused_parameters=args.find_unused_parameters,
num_workers=args.dataset_num_workers,
) )