From 6bae70eee0083de62c857e1bd1aa20e80e667a02 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 6 Aug 2025 10:52:59 +0800 Subject: [PATCH 1/4] support num_workers,save_steps,find_unused_parameters --- README.md | 4 +- README_zh.md | 4 +- diffsynth/trainers/utils.py | 44 +++++++++++++++---- examples/flux/README.md | 3 ++ examples/flux/README_zh.md | 3 ++ examples/flux/model_training/train.py | 3 ++ examples/qwen_image/README.md | 3 ++ examples/qwen_image/README_zh.md | 3 ++ .../model_training/lora/Qwen-Image.sh | 4 +- examples/qwen_image/model_training/train.py | 7 ++- examples/wanvideo/README.md | 3 ++ examples/wanvideo/README_zh.md | 3 ++ examples/wanvideo/model_training/train.py | 3 ++ 13 files changed, 71 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index ecf5276..67ca862 100644 --- a/README.md +++ b/README.md @@ -362,10 +362,10 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44 ## Update History -- **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) with a focus on aesthetic photography is comprehensively supported, including low-GPU-memory layer-by-layer offload, LoRA training and full training. See [./examples/flux/](./examples/flux/). - - **August 4, 2025** 🔥 Qwen-Image is now open source. Welcome the new member to the image generation model family! +- **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) with a focus on aesthetic photography is comprehensively supported, including low-GPU-memory layer-by-layer offload, LoRA training and full training. See [./examples/flux/](./examples/flux/). + - **July 28, 2025** With the open-sourcing of Wan 2.2, we immediately provided comprehensive support, including low-GPU-memory layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, full training. See [./examples/wanvideo/](./examples/wanvideo/). - **July 11, 2025** We propose Nexus-Gen, a unified model that synergizes the language reasoning capabilities of LLMs with the image synthesis power of diffusion models. This framework enables seamless image understanding, generation, and editing tasks. diff --git a/README_zh.md b/README_zh.md index 0e2385e..feb9759 100644 --- a/README_zh.md +++ b/README_zh.md @@ -378,10 +378,10 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44 ## 更新历史 -- **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源,这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持,包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。 - - **2025年8月4日** 🔥 Qwen-Image 开源,欢迎图像生成模型家族新成员! +- **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源,这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持,包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。 + - **2025年7月28日** Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。 - **2025年7月11日** 我们提出 Nexus-Gen,一个将大语言模型(LLM)的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。 diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index c478e92..e0c20b9 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -4,6 +4,7 @@ from PIL import Image import pandas as pd from tqdm import tqdm from accelerate import Accelerator +from accelerate.utils import DistributedDataParallelKwargs @@ -369,34 +370,42 @@ class ModelLogger: def on_step_end(self, loss): pass - - def on_epoch_end(self, accelerator, model, epoch_id): + + def on_model_save(self, accelerator, model, step_id=None, epoch_id=None): accelerator.wait_for_everyone() if accelerator.is_main_process: state_dict = accelerator.get_state_dict(model) state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) state_dict = self.state_dict_converter(state_dict) os.makedirs(self.output_path, exist_ok=True) - path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors") + if step_id is not None: + path = os.path.join(self.output_path, f"step-{step_id}.safetensors") + else: + path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors") accelerator.save(state_dict, path, safe_serialization=True) - def launch_training_task( dataset: torch.utils.data.Dataset, model: DiffusionTrainingModule, model_logger: ModelLogger, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, + num_workers: int = 8, + save_steps: int = None, num_epochs: int = 1, gradient_accumulation_steps: int = 1, + find_unused_parameters: bool = False, ): - dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0]) - accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) + dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) + accelerator = Accelerator( + gradient_accumulation_steps=gradient_accumulation_steps, + kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=find_unused_parameters)], + ) model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) for epoch_id in range(num_epochs): - for data in tqdm(dataloader): + for step_id, data in enumerate(tqdm(dataloader)): with accelerator.accumulate(model): optimizer.zero_grad() loss = model(data) @@ -404,8 +413,16 @@ def launch_training_task( optimizer.step() model_logger.on_step_end(loss) scheduler.step() - model_logger.on_epoch_end(accelerator, model, epoch_id) - + global_steps = epoch_id * len(dataloader) + step_id + 1 + # save every `save_steps` steps + if save_steps is not None and global_steps % save_steps == 0: + model_logger.on_model_save(accelerator, model, step_id=global_steps) + # save the model at the end of each epoch if save_steps is None + if save_steps is None: + model_logger.on_model_save(accelerator, model, epoch_id=epoch_id) + # save the final model if save_steps is not None + if save_steps is not None: + model_logger.on_model_save(accelerator, model, step_id=global_steps) def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"): @@ -446,6 +463,9 @@ def wan_parser(): parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") + parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.") + parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.") + parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.") return parser @@ -474,6 +494,9 @@ def flux_parser(): parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") + parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.") + parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.") + parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.") return parser @@ -503,4 +526,7 @@ def qwen_image_parser(): parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") + parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.") + parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.") + parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.") return parser diff --git a/examples/flux/README.md b/examples/flux/README.md index 7361907..8137b70 100644 --- a/examples/flux/README.md +++ b/examples/flux/README.md @@ -249,6 +249,7 @@ The script includes the following parameters: * `--width`: Width of the image or video. Leave `height` and `width` empty to enable dynamic resolution. * `--data_file_keys`: Data file keys in the metadata. Separate with commas. * `--dataset_repeat`: Number of times the dataset repeats per epoch. + * `--dataset_num_workers`: Number of workers for data loading. * Model * `--model_paths`: Paths to load models. In JSON format. * `--model_id_with_origin_paths`: Model ID with original paths, e.g., black-forest-labs/FLUX.1-dev:flux1-dev.safetensors. Separate with commas. @@ -257,6 +258,8 @@ The script includes the following parameters: * `--num_epochs`: Number of epochs. * `--output_path`: Save path. * `--remove_prefix_in_ckpt`: Remove prefix in checkpoint. + * `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch. + * `--find_unused_parameters`: Whether to find unused parameters in DDP. * Trainable Modules * `--trainable_models`: Models that can be trained, e.g., dit, vae, text_encoder. * `--lora_base_model`: Which model to add LoRA to. diff --git a/examples/flux/README_zh.md b/examples/flux/README_zh.md index 6bb3100..6bbd6fe 100644 --- a/examples/flux/README_zh.md +++ b/examples/flux/README_zh.md @@ -249,6 +249,7 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。 * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 + * `--dataset_num_workers`: 每个 Dataloder 的进程数量。 * 模型 * `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 black-forest-labs/FLUX.1-dev:flux1-dev.safetensors。用逗号分隔。 @@ -257,6 +258,8 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra * `--num_epochs`: 轮数(Epoch)。 * `--output_path`: 保存路径。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 + * `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次 + * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数 * 可训练模块 * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 * `--lora_base_model`: LoRA 添加到哪个模型上。 diff --git a/examples/flux/model_training/train.py b/examples/flux/model_training/train.py index ca52ff4..5ee4dff 100644 --- a/examples/flux/model_training/train.py +++ b/examples/flux/model_training/train.py @@ -121,4 +121,7 @@ if __name__ == "__main__": dataset, model, model_logger, optimizer, scheduler, num_epochs=args.num_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, + save_steps=args.save_steps, + find_unused_parameters=args.find_unused_parameters, + num_workers=args.dataset_num_workers, ) diff --git a/examples/qwen_image/README.md b/examples/qwen_image/README.md index 533b49d..3a4d0cd 100644 --- a/examples/qwen_image/README.md +++ b/examples/qwen_image/README.md @@ -218,6 +218,7 @@ The script includes the following parameters: * `--width`: Width of image or video. Leave `height` and `width` empty to enable dynamic resolution. * `--data_file_keys`: Data file keys in metadata. Separate with commas. * `--dataset_repeat`: Number of times the dataset repeats per epoch. + * `--dataset_num_workers`: Number of workers for data loading. * Model * `--model_paths`: Model paths to load. In JSON format. * `--model_id_with_origin_paths`: Model ID with original paths, e.g., Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors. Separate with commas. @@ -227,6 +228,8 @@ The script includes the following parameters: * `--num_epochs`: Number of epochs. * `--output_path`: Save path. * `--remove_prefix_in_ckpt`: Remove prefix in checkpoint. + * `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch. + * `--find_unused_parameters`: Whether to find unused parameters in DDP. * Trainable Modules * `--trainable_models`: Models to train, e.g., dit, vae, text_encoder. * `--lora_base_model`: Which model to add LoRA to. diff --git a/examples/qwen_image/README_zh.md b/examples/qwen_image/README_zh.md index bf91142..77d19c4 100644 --- a/examples/qwen_image/README_zh.md +++ b/examples/qwen_image/README_zh.md @@ -218,6 +218,7 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。 * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 + * `--dataset_num_workers`: 每个 Dataloder 的进程数量。 * 模型 * `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors。用逗号分隔。 @@ -227,6 +228,8 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod * `--num_epochs`: 轮数(Epoch)。 * `--output_path`: 保存路径。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 + * `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次 + * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数 * 可训练模块 * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 * `--lora_base_model`: LoRA 添加到哪个模型上。 diff --git a/examples/qwen_image/model_training/lora/Qwen-Image.sh b/examples/qwen_image/model_training/lora/Qwen-Image.sh index 0c94391..7359f8c 100644 --- a/examples/qwen_image/model_training/lora/Qwen-Image.sh +++ b/examples/qwen_image/model_training/lora/Qwen-Image.sh @@ -12,4 +12,6 @@ accelerate launch examples/qwen_image/model_training/train.py \ --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ --lora_rank 32 \ --align_to_opensource_format \ - --use_gradient_checkpointing + --use_gradient_checkpointing \ + --num_workers 8 \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 48d2d1a..0f82e26 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -29,7 +29,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule): self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=ModelConfig(tokenizer_path)) else: self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs) - + # Reset training scheduler (do it in each training step) self.pipe.scheduler.set_timesteps(1000, training=True) @@ -49,7 +49,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule): self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] - + def forward_preprocess(self, data): # CFG-sensitive parameters @@ -115,4 +115,7 @@ if __name__ == "__main__": dataset, model, model_logger, optimizer, scheduler, num_epochs=args.num_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, + save_steps=args.save_steps, + find_unused_parameters=args.find_unused_parameters, + num_workers=args.dataset_num_workers, ) diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 71ff60e..8af8dca 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -280,6 +280,7 @@ The script includes the following parameters: * `--num_frames`: Number of frames per video. Frames are sampled from the video prefix. * `--data_file_keys`: Data file keys in the metadata. Comma-separated. * `--dataset_repeat`: Number of times to repeat the dataset per epoch. + * `--dataset_num_workers`: Number of workers for data loading. * Models * `--model_paths`: Paths to load models. In JSON format. * `--model_id_with_origin_paths`: Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated. @@ -290,6 +291,8 @@ The script includes the following parameters: * `--num_epochs`: Number of epochs. * `--output_path`: Output save path. * `--remove_prefix_in_ckpt`: Remove prefix in ckpt. + * `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch. + * `--find_unused_parameters`: Whether to find unused parameters in DDP. * Trainable Modules * `--trainable_models`: Models to train, e.g., dit, vae, text_encoder. * `--lora_base_model`: Which model LoRA is added to. diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md index 461a86f..06e81fa 100644 --- a/examples/wanvideo/README_zh.md +++ b/examples/wanvideo/README_zh.md @@ -282,6 +282,7 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai * `--num_frames`: 每个视频中的帧数。帧从视频前缀中采样。 * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 + * `--dataset_num_workers`: 每个 Dataloder 的进程数量。 * 模型 * `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors。用逗号分隔。 @@ -292,6 +293,8 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai * `--num_epochs`: 轮数(Epoch)。 * `--output_path`: 保存路径。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 + * `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次 + * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数 * 可训练模块 * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 * `--lora_base_model`: LoRA 添加到哪个模型上。 diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 98c737f..1b79004 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -127,4 +127,7 @@ if __name__ == "__main__": dataset, model, model_logger, optimizer, scheduler, num_epochs=args.num_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, + save_steps=args.save_steps, + find_unused_parameters=args.find_unused_parameters, + num_workers=args.dataset_num_workers, ) From 4299c999b55a1c6d803a89104770f0a01c840006 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 6 Aug 2025 10:56:46 +0800 Subject: [PATCH 2/4] restore readme --- README.md | 4 ++-- README_zh.md | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 67ca862..ecf5276 100644 --- a/README.md +++ b/README.md @@ -362,10 +362,10 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44 ## Update History -- **August 4, 2025** 🔥 Qwen-Image is now open source. Welcome the new member to the image generation model family! - - **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) with a focus on aesthetic photography is comprehensively supported, including low-GPU-memory layer-by-layer offload, LoRA training and full training. See [./examples/flux/](./examples/flux/). +- **August 4, 2025** 🔥 Qwen-Image is now open source. Welcome the new member to the image generation model family! + - **July 28, 2025** With the open-sourcing of Wan 2.2, we immediately provided comprehensive support, including low-GPU-memory layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, full training. See [./examples/wanvideo/](./examples/wanvideo/). - **July 11, 2025** We propose Nexus-Gen, a unified model that synergizes the language reasoning capabilities of LLMs with the image synthesis power of diffusion models. This framework enables seamless image understanding, generation, and editing tasks. diff --git a/README_zh.md b/README_zh.md index feb9759..0e2385e 100644 --- a/README_zh.md +++ b/README_zh.md @@ -378,10 +378,10 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44 ## 更新历史 -- **2025年8月4日** 🔥 Qwen-Image 开源,欢迎图像生成模型家族新成员! - - **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源,这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持,包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。 +- **2025年8月4日** 🔥 Qwen-Image 开源,欢迎图像生成模型家族新成员! + - **2025年7月28日** Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。 - **2025年7月11日** 我们提出 Nexus-Gen,一个将大语言模型(LLM)的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。 From 3915bc3ee6c81c580a22cec8f8dfcac72462cc40 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 6 Aug 2025 10:58:53 +0800 Subject: [PATCH 3/4] minor fix --- diffsynth/trainers/utils.py | 1 + examples/qwen_image/model_training/lora/Qwen-Image.sh | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index e0c20b9..fff84d5 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -404,6 +404,7 @@ def launch_training_task( ) model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) + global_steps = 0 for epoch_id in range(num_epochs): for step_id, data in enumerate(tqdm(dataloader)): with accelerator.accumulate(model): diff --git a/examples/qwen_image/model_training/lora/Qwen-Image.sh b/examples/qwen_image/model_training/lora/Qwen-Image.sh index 7359f8c..15084c2 100644 --- a/examples/qwen_image/model_training/lora/Qwen-Image.sh +++ b/examples/qwen_image/model_training/lora/Qwen-Image.sh @@ -13,5 +13,5 @@ accelerate launch examples/qwen_image/model_training/train.py \ --lora_rank 32 \ --align_to_opensource_format \ --use_gradient_checkpointing \ - --num_workers 8 \ + --dataset_num_workers 8 \ --find_unused_parameters From ef09db69cd471b614a3e6e6cba247a80c9cb3c39 Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Wed, 6 Aug 2025 15:47:35 +0800 Subject: [PATCH 4/4] refactor model_logger --- diffsynth/trainers/utils.py | 52 +++++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index fff84d5..65e4e50 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -365,23 +365,39 @@ class ModelLogger: self.output_path = output_path self.remove_prefix_in_ckpt = remove_prefix_in_ckpt self.state_dict_converter = state_dict_converter - - - def on_step_end(self, loss): - pass - + self.num_steps = 0 - def on_model_save(self, accelerator, model, step_id=None, epoch_id=None): + + def on_step_end(self, accelerator, model, save_steps=None): + self.num_steps += 1 + if save_steps is not None and self.num_steps % save_steps == 0: + self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors") + + + def on_epoch_end(self, accelerator, model, epoch_id): accelerator.wait_for_everyone() if accelerator.is_main_process: state_dict = accelerator.get_state_dict(model) state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) state_dict = self.state_dict_converter(state_dict) os.makedirs(self.output_path, exist_ok=True) - if step_id is not None: - path = os.path.join(self.output_path, f"step-{step_id}.safetensors") - else: - path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors") + path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors") + accelerator.save(state_dict, path, safe_serialization=True) + + + def on_training_end(self, accelerator, model, save_steps=None): + if save_steps is not None and self.num_steps % save_steps != 0: + self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors") + + + def save_model(self, accelerator, model, file_name): + accelerator.wait_for_everyone() + if accelerator.is_main_process: + state_dict = accelerator.get_state_dict(model) + state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) + state_dict = self.state_dict_converter(state_dict) + os.makedirs(self.output_path, exist_ok=True) + path = os.path.join(self.output_path, file_name) accelerator.save(state_dict, path, safe_serialization=True) @@ -404,26 +420,18 @@ def launch_training_task( ) model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) - global_steps = 0 for epoch_id in range(num_epochs): - for step_id, data in enumerate(tqdm(dataloader)): + for data in tqdm(dataloader): with accelerator.accumulate(model): optimizer.zero_grad() loss = model(data) accelerator.backward(loss) optimizer.step() - model_logger.on_step_end(loss) + model_logger.on_step_end(accelerator, model, save_steps) scheduler.step() - global_steps = epoch_id * len(dataloader) + step_id + 1 - # save every `save_steps` steps - if save_steps is not None and global_steps % save_steps == 0: - model_logger.on_model_save(accelerator, model, step_id=global_steps) - # save the model at the end of each epoch if save_steps is None if save_steps is None: - model_logger.on_model_save(accelerator, model, epoch_id=epoch_id) - # save the final model if save_steps is not None - if save_steps is not None: - model_logger.on_model_save(accelerator, model, step_id=global_steps) + model_logger.on_epoch_end(accelerator, model, epoch_id) + model_logger.on_training_end(accelerator, model, save_steps) def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):