diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index c478e92..65e4e50 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -4,6 +4,7 @@ from PIL import Image import pandas as pd from tqdm import tqdm from accelerate import Accelerator +from accelerate.utils import DistributedDataParallelKwargs @@ -364,12 +365,15 @@ class ModelLogger: self.output_path = output_path self.remove_prefix_in_ckpt = remove_prefix_in_ckpt self.state_dict_converter = state_dict_converter - - - def on_step_end(self, loss): - pass - - + self.num_steps = 0 + + + def on_step_end(self, accelerator, model, save_steps=None): + self.num_steps += 1 + if save_steps is not None and self.num_steps % save_steps == 0: + self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors") + + def on_epoch_end(self, accelerator, model, epoch_id): accelerator.wait_for_everyone() if accelerator.is_main_process: @@ -381,6 +385,21 @@ class ModelLogger: accelerator.save(state_dict, path, safe_serialization=True) + def on_training_end(self, accelerator, model, save_steps=None): + if save_steps is not None and self.num_steps % save_steps != 0: + self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors") + + + def save_model(self, accelerator, model, file_name): + accelerator.wait_for_everyone() + if accelerator.is_main_process: + state_dict = accelerator.get_state_dict(model) + state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) + state_dict = self.state_dict_converter(state_dict) + os.makedirs(self.output_path, exist_ok=True) + path = os.path.join(self.output_path, file_name) + accelerator.save(state_dict, path, safe_serialization=True) + def launch_training_task( dataset: torch.utils.data.Dataset, @@ -388,11 +407,17 @@ def launch_training_task( model_logger: ModelLogger, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, + num_workers: int = 8, + save_steps: int = None, num_epochs: int = 1, gradient_accumulation_steps: int = 1, + find_unused_parameters: bool = False, ): - dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0]) - accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) + dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) + accelerator = Accelerator( + gradient_accumulation_steps=gradient_accumulation_steps, + kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=find_unused_parameters)], + ) model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) for epoch_id in range(num_epochs): @@ -402,10 +427,11 @@ def launch_training_task( loss = model(data) accelerator.backward(loss) optimizer.step() - model_logger.on_step_end(loss) + model_logger.on_step_end(accelerator, model, save_steps) scheduler.step() - model_logger.on_epoch_end(accelerator, model, epoch_id) - + if save_steps is None: + model_logger.on_epoch_end(accelerator, model, epoch_id) + model_logger.on_training_end(accelerator, model, save_steps) def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"): @@ -446,6 +472,9 @@ def wan_parser(): parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") + parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.") + parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.") + parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.") return parser @@ -474,6 +503,9 @@ def flux_parser(): parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") + parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.") + parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.") + parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.") return parser @@ -503,4 +535,7 @@ def qwen_image_parser(): parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") + parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.") + parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.") + parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.") return parser diff --git a/examples/flux/README.md b/examples/flux/README.md index 7361907..8137b70 100644 --- a/examples/flux/README.md +++ b/examples/flux/README.md @@ -249,6 +249,7 @@ The script includes the following parameters: * `--width`: Width of the image or video. Leave `height` and `width` empty to enable dynamic resolution. * `--data_file_keys`: Data file keys in the metadata. Separate with commas. * `--dataset_repeat`: Number of times the dataset repeats per epoch. + * `--dataset_num_workers`: Number of workers for data loading. * Model * `--model_paths`: Paths to load models. In JSON format. * `--model_id_with_origin_paths`: Model ID with original paths, e.g., black-forest-labs/FLUX.1-dev:flux1-dev.safetensors. Separate with commas. @@ -257,6 +258,8 @@ The script includes the following parameters: * `--num_epochs`: Number of epochs. * `--output_path`: Save path. * `--remove_prefix_in_ckpt`: Remove prefix in checkpoint. + * `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch. + * `--find_unused_parameters`: Whether to find unused parameters in DDP. * Trainable Modules * `--trainable_models`: Models that can be trained, e.g., dit, vae, text_encoder. * `--lora_base_model`: Which model to add LoRA to. diff --git a/examples/flux/README_zh.md b/examples/flux/README_zh.md index 6bb3100..6bbd6fe 100644 --- a/examples/flux/README_zh.md +++ b/examples/flux/README_zh.md @@ -249,6 +249,7 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。 * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 + * `--dataset_num_workers`: 每个 Dataloder 的进程数量。 * 模型 * `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 black-forest-labs/FLUX.1-dev:flux1-dev.safetensors。用逗号分隔。 @@ -257,6 +258,8 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra * `--num_epochs`: 轮数(Epoch)。 * `--output_path`: 保存路径。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 + * `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次 + * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数 * 可训练模块 * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 * `--lora_base_model`: LoRA 添加到哪个模型上。 diff --git a/examples/flux/model_training/train.py b/examples/flux/model_training/train.py index ca52ff4..5ee4dff 100644 --- a/examples/flux/model_training/train.py +++ b/examples/flux/model_training/train.py @@ -121,4 +121,7 @@ if __name__ == "__main__": dataset, model, model_logger, optimizer, scheduler, num_epochs=args.num_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, + save_steps=args.save_steps, + find_unused_parameters=args.find_unused_parameters, + num_workers=args.dataset_num_workers, ) diff --git a/examples/qwen_image/README.md b/examples/qwen_image/README.md index 14b5a16..c9fd4ae 100644 --- a/examples/qwen_image/README.md +++ b/examples/qwen_image/README.md @@ -219,6 +219,7 @@ The script includes the following parameters: * `--width`: Width of image or video. Leave `height` and `width` empty to enable dynamic resolution. * `--data_file_keys`: Data file keys in metadata. Separate with commas. * `--dataset_repeat`: Number of times the dataset repeats per epoch. + * `--dataset_num_workers`: Number of workers for data loading. * Model * `--model_paths`: Model paths to load. In JSON format. * `--model_id_with_origin_paths`: Model ID with original paths, e.g., Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors. Separate with commas. @@ -228,6 +229,8 @@ The script includes the following parameters: * `--num_epochs`: Number of epochs. * `--output_path`: Save path. * `--remove_prefix_in_ckpt`: Remove prefix in checkpoint. + * `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch. + * `--find_unused_parameters`: Whether to find unused parameters in DDP. * Trainable Modules * `--trainable_models`: Models to train, e.g., dit, vae, text_encoder. * `--lora_base_model`: Which model to add LoRA to. diff --git a/examples/qwen_image/README_zh.md b/examples/qwen_image/README_zh.md index 84f69d0..0a311c1 100644 --- a/examples/qwen_image/README_zh.md +++ b/examples/qwen_image/README_zh.md @@ -219,6 +219,7 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。 * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 + * `--dataset_num_workers`: 每个 Dataloder 的进程数量。 * 模型 * `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors。用逗号分隔。 @@ -228,6 +229,8 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod * `--num_epochs`: 轮数(Epoch)。 * `--output_path`: 保存路径。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 + * `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次 + * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数 * 可训练模块 * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 * `--lora_base_model`: LoRA 添加到哪个模型上。 diff --git a/examples/qwen_image/model_training/lora/Qwen-Image.sh b/examples/qwen_image/model_training/lora/Qwen-Image.sh index 0c94391..15084c2 100644 --- a/examples/qwen_image/model_training/lora/Qwen-Image.sh +++ b/examples/qwen_image/model_training/lora/Qwen-Image.sh @@ -12,4 +12,6 @@ accelerate launch examples/qwen_image/model_training/train.py \ --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ --lora_rank 32 \ --align_to_opensource_format \ - --use_gradient_checkpointing + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 17b234a..4e8cf48 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -30,7 +30,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule): self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=ModelConfig(tokenizer_path)) else: self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs) - + # Reset training scheduler (do it in each training step) self.pipe.scheduler.set_timesteps(1000, training=True) @@ -50,7 +50,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule): self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] - + def forward_preprocess(self, data): # CFG-sensitive parameters @@ -117,4 +117,7 @@ if __name__ == "__main__": dataset, model, model_logger, optimizer, scheduler, num_epochs=args.num_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, + save_steps=args.save_steps, + find_unused_parameters=args.find_unused_parameters, + num_workers=args.dataset_num_workers, ) diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 71ff60e..8af8dca 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -280,6 +280,7 @@ The script includes the following parameters: * `--num_frames`: Number of frames per video. Frames are sampled from the video prefix. * `--data_file_keys`: Data file keys in the metadata. Comma-separated. * `--dataset_repeat`: Number of times to repeat the dataset per epoch. + * `--dataset_num_workers`: Number of workers for data loading. * Models * `--model_paths`: Paths to load models. In JSON format. * `--model_id_with_origin_paths`: Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated. @@ -290,6 +291,8 @@ The script includes the following parameters: * `--num_epochs`: Number of epochs. * `--output_path`: Output save path. * `--remove_prefix_in_ckpt`: Remove prefix in ckpt. + * `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch. + * `--find_unused_parameters`: Whether to find unused parameters in DDP. * Trainable Modules * `--trainable_models`: Models to train, e.g., dit, vae, text_encoder. * `--lora_base_model`: Which model LoRA is added to. diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md index 461a86f..06e81fa 100644 --- a/examples/wanvideo/README_zh.md +++ b/examples/wanvideo/README_zh.md @@ -282,6 +282,7 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai * `--num_frames`: 每个视频中的帧数。帧从视频前缀中采样。 * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 + * `--dataset_num_workers`: 每个 Dataloder 的进程数量。 * 模型 * `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors。用逗号分隔。 @@ -292,6 +293,8 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai * `--num_epochs`: 轮数(Epoch)。 * `--output_path`: 保存路径。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 + * `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次 + * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数 * 可训练模块 * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 * `--lora_base_model`: LoRA 添加到哪个模型上。 diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 98c737f..1b79004 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -127,4 +127,7 @@ if __name__ == "__main__": dataset, model, model_logger, optimizer, scheduler, num_epochs=args.num_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, + save_steps=args.save_steps, + find_unused_parameters=args.find_unused_parameters, + num_workers=args.dataset_num_workers, )