diff --git a/README.md b/README.md index ecf5276..86fb792 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,7 @@ image.save("image.jpg") |Model ID|Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training| |-|-|-|-|-|-| |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| +|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)| @@ -362,10 +363,13 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44 ## Update History -- **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) with a focus on aesthetic photography is comprehensively supported, including low-GPU-memory layer-by-layer offload, LoRA training and full training. See [./examples/flux/](./examples/flux/). + +- **August 5, 2025** We open-sourced the distilled acceleration model of Qwen-Image, [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full), achieving approximately 5x speedup. - **August 4, 2025** 🔥 Qwen-Image is now open source. Welcome the new member to the image generation model family! +- **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) with a focus on aesthetic photography is comprehensively supported, including low-GPU-memory layer-by-layer offload, LoRA training and full training. See [./examples/flux/](./examples/flux/). + - **July 28, 2025** With the open-sourcing of Wan 2.2, we immediately provided comprehensive support, including low-GPU-memory layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, full training. See [./examples/wanvideo/](./examples/wanvideo/). - **July 11, 2025** We propose Nexus-Gen, a unified model that synergizes the language reasoning capabilities of LLMs with the image synthesis power of diffusion models. This framework enables seamless image understanding, generation, and editing tasks. @@ -375,13 +379,13 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44 - Training Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset) - Online Demo: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen) +
+More + - **June 15, 2025** ModelScope's official evaluation framework, [EvalScope](https://github.com/modelscope/evalscope), now supports text-to-image generation evaluation. Try it with the [Best Practices](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html) guide. - **March 25, 2025** Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality. -
-More - - **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details. - **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details. diff --git a/README_zh.md b/README_zh.md index 0e2385e..018f049 100644 --- a/README_zh.md +++ b/README_zh.md @@ -92,6 +92,7 @@ image.save("image.jpg") |模型 ID|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-| |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| +|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
@@ -378,10 +379,13 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44 ## 更新历史 -- **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源,这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持,包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。 + +- **2025年8月5日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full),实现了约 5 倍加速。 - **2025年8月4日** 🔥 Qwen-Image 开源,欢迎图像生成模型家族新成员! +- **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源,这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持,包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。 + - **2025年7月28日** Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。 - **2025年7月11日** 我们提出 Nexus-Gen,一个将大语言模型(LLM)的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。 @@ -391,13 +395,13 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44 - 训练数据集: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset) - 在线体验: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen) +
+更多 + - **2025年6月15日** ModelScope 官方评测框架 [EvalScope](https://github.com/modelscope/evalscope) 现已支持文生图生成评测。请参考[最佳实践](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html)指南进行尝试。 - **2025年3月25日** 我们的新开源项目 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 现已开源!专注于稳定的模型部署,面向工业界,提供更好的工程支持、更高的计算性能和更稳定的功能。 -
-更多 - - **2025年3月31日** 我们支持 InfiniteYou,一种用于 FLUX 的人脸特征保留方法。更多细节请参考 [./examples/InfiniteYou/](./examples/InfiniteYou/)。 - **2025年3月13日** 我们支持 HunyuanVideo-I2V,即腾讯开源的 HunyuanVideo 的图像到视频生成版本。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。 diff --git a/diffsynth/models/lora.py b/diffsynth/models/lora.py index 11b34e3..0278bb1 100644 --- a/diffsynth/models/lora.py +++ b/diffsynth/models/lora.py @@ -383,5 +383,20 @@ class WanLoRAConverter: return state_dict +class QwenImageLoRAConverter: + def __init__(self): + pass + + @staticmethod + def align_to_opensource_format(state_dict, **kwargs): + state_dict = {name.replace(".default.", "."): param for name, param in state_dict.items()} + return state_dict + + @staticmethod + def align_to_diffsynth_format(state_dict, **kwargs): + state_dict = {name.replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()} + return state_dict + + def get_lora_loaders(): return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()] diff --git a/diffsynth/models/wan_video_dit.py b/diffsynth/models/wan_video_dit.py index ea473b0..419f8cf 100644 --- a/diffsynth/models/wan_video_dit.py +++ b/diffsynth/models/wan_video_dit.py @@ -335,7 +335,7 @@ class WanModel(torch.nn.Module): else: self.control_adapter = None - def patchify(self, x: torch.Tensor,control_camera_latents_input: torch.Tensor = None): + def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None): x = self.patch_embedding(x) if self.control_adapter is not None and control_camera_latents_input is not None: y_camera = self.control_adapter(control_camera_latents_input) diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index c478e92..65e4e50 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -4,6 +4,7 @@ from PIL import Image import pandas as pd from tqdm import tqdm from accelerate import Accelerator +from accelerate.utils import DistributedDataParallelKwargs @@ -364,12 +365,15 @@ class ModelLogger: self.output_path = output_path self.remove_prefix_in_ckpt = remove_prefix_in_ckpt self.state_dict_converter = state_dict_converter - - - def on_step_end(self, loss): - pass - - + self.num_steps = 0 + + + def on_step_end(self, accelerator, model, save_steps=None): + self.num_steps += 1 + if save_steps is not None and self.num_steps % save_steps == 0: + self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors") + + def on_epoch_end(self, accelerator, model, epoch_id): accelerator.wait_for_everyone() if accelerator.is_main_process: @@ -381,6 +385,21 @@ class ModelLogger: accelerator.save(state_dict, path, safe_serialization=True) + def on_training_end(self, accelerator, model, save_steps=None): + if save_steps is not None and self.num_steps % save_steps != 0: + self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors") + + + def save_model(self, accelerator, model, file_name): + accelerator.wait_for_everyone() + if accelerator.is_main_process: + state_dict = accelerator.get_state_dict(model) + state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt) + state_dict = self.state_dict_converter(state_dict) + os.makedirs(self.output_path, exist_ok=True) + path = os.path.join(self.output_path, file_name) + accelerator.save(state_dict, path, safe_serialization=True) + def launch_training_task( dataset: torch.utils.data.Dataset, @@ -388,11 +407,17 @@ def launch_training_task( model_logger: ModelLogger, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, + num_workers: int = 8, + save_steps: int = None, num_epochs: int = 1, gradient_accumulation_steps: int = 1, + find_unused_parameters: bool = False, ): - dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0]) - accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) + dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers) + accelerator = Accelerator( + gradient_accumulation_steps=gradient_accumulation_steps, + kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=find_unused_parameters)], + ) model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) for epoch_id in range(num_epochs): @@ -402,10 +427,11 @@ def launch_training_task( loss = model(data) accelerator.backward(loss) optimizer.step() - model_logger.on_step_end(loss) + model_logger.on_step_end(accelerator, model, save_steps) scheduler.step() - model_logger.on_epoch_end(accelerator, model, epoch_id) - + if save_steps is None: + model_logger.on_epoch_end(accelerator, model, epoch_id) + model_logger.on_training_end(accelerator, model, save_steps) def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"): @@ -446,6 +472,9 @@ def wan_parser(): parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") + parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.") + parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.") + parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.") return parser @@ -474,6 +503,9 @@ def flux_parser(): parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") + parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.") + parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.") + parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.") return parser @@ -503,4 +535,7 @@ def qwen_image_parser(): parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") + parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.") + parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.") + parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.") return parser diff --git a/examples/flux/README.md b/examples/flux/README.md index 7361907..8137b70 100644 --- a/examples/flux/README.md +++ b/examples/flux/README.md @@ -249,6 +249,7 @@ The script includes the following parameters: * `--width`: Width of the image or video. Leave `height` and `width` empty to enable dynamic resolution. * `--data_file_keys`: Data file keys in the metadata. Separate with commas. * `--dataset_repeat`: Number of times the dataset repeats per epoch. + * `--dataset_num_workers`: Number of workers for data loading. * Model * `--model_paths`: Paths to load models. In JSON format. * `--model_id_with_origin_paths`: Model ID with original paths, e.g., black-forest-labs/FLUX.1-dev:flux1-dev.safetensors. Separate with commas. @@ -257,6 +258,8 @@ The script includes the following parameters: * `--num_epochs`: Number of epochs. * `--output_path`: Save path. * `--remove_prefix_in_ckpt`: Remove prefix in checkpoint. + * `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch. + * `--find_unused_parameters`: Whether to find unused parameters in DDP. * Trainable Modules * `--trainable_models`: Models that can be trained, e.g., dit, vae, text_encoder. * `--lora_base_model`: Which model to add LoRA to. diff --git a/examples/flux/README_zh.md b/examples/flux/README_zh.md index 6bb3100..6bbd6fe 100644 --- a/examples/flux/README_zh.md +++ b/examples/flux/README_zh.md @@ -249,6 +249,7 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。 * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 + * `--dataset_num_workers`: 每个 Dataloder 的进程数量。 * 模型 * `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 black-forest-labs/FLUX.1-dev:flux1-dev.safetensors。用逗号分隔。 @@ -257,6 +258,8 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra * `--num_epochs`: 轮数(Epoch)。 * `--output_path`: 保存路径。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 + * `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次 + * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数 * 可训练模块 * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 * `--lora_base_model`: LoRA 添加到哪个模型上。 diff --git a/examples/flux/model_training/train.py b/examples/flux/model_training/train.py index ca52ff4..5ee4dff 100644 --- a/examples/flux/model_training/train.py +++ b/examples/flux/model_training/train.py @@ -121,4 +121,7 @@ if __name__ == "__main__": dataset, model, model_logger, optimizer, scheduler, num_epochs=args.num_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, + save_steps=args.save_steps, + find_unused_parameters=args.find_unused_parameters, + num_workers=args.dataset_num_workers, ) diff --git a/examples/qwen_image/README.md b/examples/qwen_image/README.md index 533b49d..c9fd4ae 100644 --- a/examples/qwen_image/README.md +++ b/examples/qwen_image/README.md @@ -43,6 +43,7 @@ image.save("image.jpg") |Model ID|Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training| |-|-|-|-|-|-| |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image )|[code](./model_inference/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)| +|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)| ## Model Inference @@ -218,6 +219,7 @@ The script includes the following parameters: * `--width`: Width of image or video. Leave `height` and `width` empty to enable dynamic resolution. * `--data_file_keys`: Data file keys in metadata. Separate with commas. * `--dataset_repeat`: Number of times the dataset repeats per epoch. + * `--dataset_num_workers`: Number of workers for data loading. * Model * `--model_paths`: Model paths to load. In JSON format. * `--model_id_with_origin_paths`: Model ID with original paths, e.g., Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors. Separate with commas. @@ -227,6 +229,8 @@ The script includes the following parameters: * `--num_epochs`: Number of epochs. * `--output_path`: Save path. * `--remove_prefix_in_ckpt`: Remove prefix in checkpoint. + * `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch. + * `--find_unused_parameters`: Whether to find unused parameters in DDP. * Trainable Modules * `--trainable_models`: Models to train, e.g., dit, vae, text_encoder. * `--lora_base_model`: Which model to add LoRA to. diff --git a/examples/qwen_image/README_zh.md b/examples/qwen_image/README_zh.md index bf91142..0a311c1 100644 --- a/examples/qwen_image/README_zh.md +++ b/examples/qwen_image/README_zh.md @@ -43,6 +43,7 @@ image.save("image.jpg") |模型 ID|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |-|-|-|-|-|-| |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)| +|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)| ## 模型推理 @@ -218,6 +219,7 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod * `--width`: 图像或视频的宽度。将 `height` 和 `width` 留空以启用动态分辨率。 * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 + * `--dataset_num_workers`: 每个 Dataloder 的进程数量。 * 模型 * `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors。用逗号分隔。 @@ -227,6 +229,8 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod * `--num_epochs`: 轮数(Epoch)。 * `--output_path`: 保存路径。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 + * `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次 + * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数 * 可训练模块 * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 * `--lora_base_model`: LoRA 添加到哪个模型上。 diff --git a/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py b/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py new file mode 100644 index 0000000..c13a417 --- /dev/null +++ b/examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py @@ -0,0 +1,17 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Distill-Full", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1) +image.save("image.jpg") diff --git a/examples/qwen_image/model_inference_lor_vram/Qwen-Image-Distill-Full.py b/examples/qwen_image/model_inference_lor_vram/Qwen-Image-Distill-Full.py new file mode 100644 index 0000000..0839dd0 --- /dev/null +++ b/examples/qwen_image/model_inference_lor_vram/Qwen-Image-Distill-Full.py @@ -0,0 +1,18 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Distill-Full", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +pipe.enable_vram_management() +prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。" +image = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1) +image.save("image.jpg") diff --git a/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh b/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh new file mode 100644 index 0000000..6343754 --- /dev/null +++ b/examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh @@ -0,0 +1,12 @@ +accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "DiffSynth-Studio/Qwen-Image-Distill-Full:diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Distill-Full_full" \ + --trainable_models "dit" \ + --use_gradient_checkpointing diff --git a/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh b/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh new file mode 100644 index 0000000..983638d --- /dev/null +++ b/examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh @@ -0,0 +1,15 @@ +accelerate launch examples/qwen_image/model_training/train.py \ + --dataset_base_path data/example_image_dataset \ + --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --max_pixels 1048576 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "DiffSynth-Studio/Qwen-Image-Distill-Full:diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-4 \ + --num_epochs 5 \ + --remove_prefix_in_ckpt "pipe.dit." \ + --output_path "./models/train/Qwen-Image-Distill-Full_lora" \ + --lora_base_model "dit" \ + --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ + --lora_rank 32 \ + --align_to_opensource_format \ + --use_gradient_checkpointing diff --git a/examples/qwen_image/model_training/lora/Qwen-Image.sh b/examples/qwen_image/model_training/lora/Qwen-Image.sh index 0c94391..15084c2 100644 --- a/examples/qwen_image/model_training/lora/Qwen-Image.sh +++ b/examples/qwen_image/model_training/lora/Qwen-Image.sh @@ -12,4 +12,6 @@ accelerate launch examples/qwen_image/model_training/train.py \ --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ --lora_rank 32 \ --align_to_opensource_format \ - --use_gradient_checkpointing + --use_gradient_checkpointing \ + --dataset_num_workers 8 \ + --find_unused_parameters diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 48d2d1a..4e8cf48 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -1,6 +1,7 @@ import torch, os, json from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser +from diffsynth.models.lora import QwenImageLoRAConverter os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -29,7 +30,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule): self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs, tokenizer_config=ModelConfig(tokenizer_path)) else: self.pipe = QwenImagePipeline.from_pretrained(torch_dtype=torch.bfloat16, device="cpu", model_configs=model_configs) - + # Reset training scheduler (do it in each training step) self.pipe.scheduler.set_timesteps(1000, training=True) @@ -49,7 +50,7 @@ class QwenImageTrainingModule(DiffusionTrainingModule): self.use_gradient_checkpointing = use_gradient_checkpointing self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else [] - + def forward_preprocess(self, data): # CFG-sensitive parameters @@ -108,6 +109,7 @@ if __name__ == "__main__": model_logger = ModelLogger( args.output_path, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, + state_dict_converter=QwenImageLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x, ) optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) @@ -115,4 +117,7 @@ if __name__ == "__main__": dataset, model, model_logger, optimizer, scheduler, num_epochs=args.num_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, + save_steps=args.save_steps, + find_unused_parameters=args.find_unused_parameters, + num_workers=args.dataset_num_workers, ) diff --git a/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py b/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py new file mode 100644 index 0000000..07389c5 --- /dev/null +++ b/examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py @@ -0,0 +1,20 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +from diffsynth import load_state_dict +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Distill-Full", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("models/train/Qwen-Image-Distill-Full_full/epoch-1.safetensors") +pipe.dit.load_state_dict(state_dict) +prompt = "a dog" +image = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1) +image.save("image.jpg") diff --git a/examples/qwen_image/model_training/validate_full/Qwen-Image.py b/examples/qwen_image/model_training/validate_full/Qwen-Image.py index ba4d989..8723218 100644 --- a/examples/qwen_image/model_training/validate_full/Qwen-Image.py +++ b/examples/qwen_image/model_training/validate_full/Qwen-Image.py @@ -7,9 +7,9 @@ pipe = QwenImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ - ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), - ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), - ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), ], tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), ) diff --git a/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py new file mode 100644 index 0000000..7f644aa --- /dev/null +++ b/examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py @@ -0,0 +1,18 @@ +from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig +import torch + + +pipe = QwenImagePipeline.from_pretrained( + torch_dtype=torch.bfloat16, + device="cuda", + model_configs=[ + ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Distill-Full", origin_file_pattern="diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), +) +pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Distill-Full_lora/epoch-4.safetensors") +prompt = "a dog" +image = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1) +image.save("image.jpg") diff --git a/examples/qwen_image/model_training/validate_lora/Qwen-Image.py b/examples/qwen_image/model_training/validate_lora/Qwen-Image.py index 640e140..16be2b4 100644 --- a/examples/qwen_image/model_training/validate_lora/Qwen-Image.py +++ b/examples/qwen_image/model_training/validate_lora/Qwen-Image.py @@ -6,9 +6,9 @@ pipe = QwenImagePipeline.from_pretrained( torch_dtype=torch.bfloat16, device="cuda", model_configs=[ - ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), - ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), - ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"), + ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), ], tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), ) diff --git a/examples/wanvideo/README.md b/examples/wanvideo/README.md index 71ff60e..8af8dca 100644 --- a/examples/wanvideo/README.md +++ b/examples/wanvideo/README.md @@ -280,6 +280,7 @@ The script includes the following parameters: * `--num_frames`: Number of frames per video. Frames are sampled from the video prefix. * `--data_file_keys`: Data file keys in the metadata. Comma-separated. * `--dataset_repeat`: Number of times to repeat the dataset per epoch. + * `--dataset_num_workers`: Number of workers for data loading. * Models * `--model_paths`: Paths to load models. In JSON format. * `--model_id_with_origin_paths`: Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated. @@ -290,6 +291,8 @@ The script includes the following parameters: * `--num_epochs`: Number of epochs. * `--output_path`: Output save path. * `--remove_prefix_in_ckpt`: Remove prefix in ckpt. + * `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch. + * `--find_unused_parameters`: Whether to find unused parameters in DDP. * Trainable Modules * `--trainable_models`: Models to train, e.g., dit, vae, text_encoder. * `--lora_base_model`: Which model LoRA is added to. diff --git a/examples/wanvideo/README_zh.md b/examples/wanvideo/README_zh.md index 461a86f..06e81fa 100644 --- a/examples/wanvideo/README_zh.md +++ b/examples/wanvideo/README_zh.md @@ -282,6 +282,7 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai * `--num_frames`: 每个视频中的帧数。帧从视频前缀中采样。 * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 + * `--dataset_num_workers`: 每个 Dataloder 的进程数量。 * 模型 * `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID,例如 Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors。用逗号分隔。 @@ -292,6 +293,8 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai * `--num_epochs`: 轮数(Epoch)。 * `--output_path`: 保存路径。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 + * `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次 + * `--find_unused_parameters`: DDP 训练中是否存在未使用的参数 * 可训练模块 * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 * `--lora_base_model`: LoRA 添加到哪个模型上。 diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 98c737f..1b79004 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -127,4 +127,7 @@ if __name__ == "__main__": dataset, model, model_logger, optimizer, scheduler, num_epochs=args.num_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, + save_steps=args.save_steps, + find_unused_parameters=args.find_unused_parameters, + num_workers=args.dataset_num_workers, )