Merge branch 'main' into qwen-image-eligen

This commit is contained in:
mi804
2025-08-06 15:57:13 +08:00
23 changed files with 220 additions and 29 deletions

View File

@@ -90,6 +90,7 @@ image.save("image.jpg")
|Model ID|Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training| |Model ID|Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|-|-|-|-|-|-| |-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
</details> </details>
@@ -362,10 +363,13 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
## Update History ## Update History
- **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) with a focus on aesthetic photography is comprehensively supported, including low-GPU-memory layer-by-layer offload, LoRA training and full training. See [./examples/flux/](./examples/flux/).
- **August 5, 2025** We open-sourced the distilled acceleration model of Qwen-Image, [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full), achieving approximately 5x speedup.
- **August 4, 2025** 🔥 Qwen-Image is now open source. Welcome the new member to the image generation model family! - **August 4, 2025** 🔥 Qwen-Image is now open source. Welcome the new member to the image generation model family!
- **August 1, 2025** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) with a focus on aesthetic photography is comprehensively supported, including low-GPU-memory layer-by-layer offload, LoRA training and full training. See [./examples/flux/](./examples/flux/).
- **July 28, 2025** With the open-sourcing of Wan 2.2, we immediately provided comprehensive support, including low-GPU-memory layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, full training. See [./examples/wanvideo/](./examples/wanvideo/). - **July 28, 2025** With the open-sourcing of Wan 2.2, we immediately provided comprehensive support, including low-GPU-memory layer-by-layer offload, FP8 quantization, sequence parallelism, LoRA training, full training. See [./examples/wanvideo/](./examples/wanvideo/).
- **July 11, 2025** We propose Nexus-Gen, a unified model that synergizes the language reasoning capabilities of LLMs with the image synthesis power of diffusion models. This framework enables seamless image understanding, generation, and editing tasks. - **July 11, 2025** We propose Nexus-Gen, a unified model that synergizes the language reasoning capabilities of LLMs with the image synthesis power of diffusion models. This framework enables seamless image understanding, generation, and editing tasks.
@@ -375,13 +379,13 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
- Training Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset) - Training Dataset: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
- Online Demo: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen) - Online Demo: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
<details>
<summary>More</summary>
- **June 15, 2025** ModelScope's official evaluation framework, [EvalScope](https://github.com/modelscope/evalscope), now supports text-to-image generation evaluation. Try it with the [Best Practices](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html) guide. - **June 15, 2025** ModelScope's official evaluation framework, [EvalScope](https://github.com/modelscope/evalscope), now supports text-to-image generation evaluation. Try it with the [Best Practices](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html) guide.
- **March 25, 2025** Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality. - **March 25, 2025** Our new open-source project, [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine), is now open-sourced! Focused on stable model deployment. Geared towards industry. Offers better engineering support, higher computational performance, and more stable functionality.
<details>
<summary>More</summary>
- **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details. - **March 31, 2025** We support InfiniteYou, an identity preserving method for FLUX. Please refer to [./examples/InfiniteYou/](./examples/InfiniteYou/) for more details.
- **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details. - **March 13, 2025** We support HunyuanVideo-I2V, the image-to-video generation version of HunyuanVideo open-sourced by Tencent. Please refer to [./examples/HunyuanVideo/](./examples/HunyuanVideo/) for more details.

View File

@@ -92,6 +92,7 @@ image.save("image.jpg")
|模型 ID|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |模型 ID|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-| |-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)| |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./examples/qwen_image/model_inference/Qwen-Image.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./examples/qwen_image/model_inference/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/full/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./examples/qwen_image/model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./examples/qwen_image/model_training/validate_lora/Qwen-Image-Distill-Full.py)|
</details> </details>
@@ -378,10 +379,13 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
## 更新历史 ## 更新历史
- **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源,这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持,包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。
- **2025年8月5日** 我们开源了 Qwen-Image 的蒸馏加速模型 [DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full),实现了约 5 倍加速。
- **2025年8月4日** 🔥 Qwen-Image 开源,欢迎图像生成模型家族新成员! - **2025年8月4日** 🔥 Qwen-Image 开源,欢迎图像生成模型家族新成员!
- **2025年8月1日** [FLUX.1-Krea-dev](https://www.modelscope.cn/models/black-forest-labs/FLUX.1-Krea-dev) 开源,这是一个专注于美学摄影的文生图模型。我们第一时间提供了全方位支持,包括低显存逐层 offload、LoRA 训练、全量训练。详细信息请参考 [./examples/flux/](./examples/flux/)。
- **2025年7月28日** Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。 - **2025年7月28日** Wan 2.2 开源,我们第一时间提供了全方位支持,包括低显存逐层 offload、FP8 量化、序列并行、LoRA 训练、全量训练。详细信息请参考 [./examples/wanvideo/](./examples/wanvideo/)。
- **2025年7月11日** 我们提出 Nexus-Gen一个将大语言模型LLM的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。 - **2025年7月11日** 我们提出 Nexus-Gen一个将大语言模型LLM的语言推理能力与扩散模型的图像生成能力相结合的统一框架。该框架支持无缝的图像理解、生成和编辑任务。
@@ -391,13 +395,13 @@ https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/59fb2f7b-8de0-44
- 训练数据集: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset) - 训练数据集: [ModelScope Dataset](https://www.modelscope.cn/datasets/DiffSynth-Studio/Nexus-Gen-Training-Dataset)
- 在线体验: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen) - 在线体验: [ModelScope Nexus-Gen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/Nexus-Gen)
<details>
<summary>更多</summary>
- **2025年6月15日** ModelScope 官方评测框架 [EvalScope](https://github.com/modelscope/evalscope) 现已支持文生图生成评测。请参考[最佳实践](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html)指南进行尝试。 - **2025年6月15日** ModelScope 官方评测框架 [EvalScope](https://github.com/modelscope/evalscope) 现已支持文生图生成评测。请参考[最佳实践](https://evalscope.readthedocs.io/zh-cn/latest/best_practice/t2i_eval.html)指南进行尝试。
- **2025年3月25日** 我们的新开源项目 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 现已开源!专注于稳定的模型部署,面向工业界,提供更好的工程支持、更高的计算性能和更稳定的功能。 - **2025年3月25日** 我们的新开源项目 [DiffSynth-Engine](https://github.com/modelscope/DiffSynth-Engine) 现已开源!专注于稳定的模型部署,面向工业界,提供更好的工程支持、更高的计算性能和更稳定的功能。
<details>
<summary>更多</summary>
- **2025年3月31日** 我们支持 InfiniteYou一种用于 FLUX 的人脸特征保留方法。更多细节请参考 [./examples/InfiniteYou/](./examples/InfiniteYou/)。 - **2025年3月31日** 我们支持 InfiniteYou一种用于 FLUX 的人脸特征保留方法。更多细节请参考 [./examples/InfiniteYou/](./examples/InfiniteYou/)。
- **2025年3月13日** 我们支持 HunyuanVideo-I2V即腾讯开源的 HunyuanVideo 的图像到视频生成版本。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。 - **2025年3月13日** 我们支持 HunyuanVideo-I2V即腾讯开源的 HunyuanVideo 的图像到视频生成版本。更多细节请参考 [./examples/HunyuanVideo/](./examples/HunyuanVideo/)。

View File

@@ -383,5 +383,20 @@ class WanLoRAConverter:
return state_dict return state_dict
class QwenImageLoRAConverter:
def __init__(self):
pass
@staticmethod
def align_to_opensource_format(state_dict, **kwargs):
state_dict = {name.replace(".default.", "."): param for name, param in state_dict.items()}
return state_dict
@staticmethod
def align_to_diffsynth_format(state_dict, **kwargs):
state_dict = {name.replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight"): param for name, param in state_dict.items()}
return state_dict
def get_lora_loaders(): def get_lora_loaders():
return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()] return [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), FluxLoRAFromCivitai(), HunyuanVideoLoRAFromCivitai(), GeneralLoRAFromPeft()]

View File

@@ -335,7 +335,7 @@ class WanModel(torch.nn.Module):
else: else:
self.control_adapter = None self.control_adapter = None
def patchify(self, x: torch.Tensor,control_camera_latents_input: torch.Tensor = None): def patchify(self, x: torch.Tensor, control_camera_latents_input: Optional[torch.Tensor] = None):
x = self.patch_embedding(x) x = self.patch_embedding(x)
if self.control_adapter is not None and control_camera_latents_input is not None: if self.control_adapter is not None and control_camera_latents_input is not None:
y_camera = self.control_adapter(control_camera_latents_input) y_camera = self.control_adapter(control_camera_latents_input)

View File

@@ -4,6 +4,7 @@ from PIL import Image
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs
@@ -364,10 +365,13 @@ class ModelLogger:
self.output_path = output_path self.output_path = output_path
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
self.state_dict_converter = state_dict_converter self.state_dict_converter = state_dict_converter
self.num_steps = 0
def on_step_end(self, loss): def on_step_end(self, accelerator, model, save_steps=None):
pass self.num_steps += 1
if save_steps is not None and self.num_steps % save_steps == 0:
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
def on_epoch_end(self, accelerator, model, epoch_id): def on_epoch_end(self, accelerator, model, epoch_id):
@@ -381,6 +385,21 @@ class ModelLogger:
accelerator.save(state_dict, path, safe_serialization=True) accelerator.save(state_dict, path, safe_serialization=True)
def on_training_end(self, accelerator, model, save_steps=None):
if save_steps is not None and self.num_steps % save_steps != 0:
self.save_model(accelerator, model, f"step-{self.num_steps}.safetensors")
def save_model(self, accelerator, model, file_name):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
state_dict = accelerator.get_state_dict(model)
state_dict = accelerator.unwrap_model(model).export_trainable_state_dict(state_dict, remove_prefix=self.remove_prefix_in_ckpt)
state_dict = self.state_dict_converter(state_dict)
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, file_name)
accelerator.save(state_dict, path, safe_serialization=True)
def launch_training_task( def launch_training_task(
dataset: torch.utils.data.Dataset, dataset: torch.utils.data.Dataset,
@@ -388,11 +407,17 @@ def launch_training_task(
model_logger: ModelLogger, model_logger: ModelLogger,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
scheduler: torch.optim.lr_scheduler.LRScheduler, scheduler: torch.optim.lr_scheduler.LRScheduler,
num_workers: int = 8,
save_steps: int = None,
num_epochs: int = 1, num_epochs: int = 1,
gradient_accumulation_steps: int = 1, gradient_accumulation_steps: int = 1,
find_unused_parameters: bool = False,
): ):
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0]) dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, collate_fn=lambda x: x[0], num_workers=num_workers)
accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps) accelerator = Accelerator(
gradient_accumulation_steps=gradient_accumulation_steps,
kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=find_unused_parameters)],
)
model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler) model, optimizer, dataloader, scheduler = accelerator.prepare(model, optimizer, dataloader, scheduler)
for epoch_id in range(num_epochs): for epoch_id in range(num_epochs):
@@ -402,10 +427,11 @@ def launch_training_task(
loss = model(data) loss = model(data)
accelerator.backward(loss) accelerator.backward(loss)
optimizer.step() optimizer.step()
model_logger.on_step_end(loss) model_logger.on_step_end(accelerator, model, save_steps)
scheduler.step() scheduler.step()
model_logger.on_epoch_end(accelerator, model, epoch_id) if save_steps is None:
model_logger.on_epoch_end(accelerator, model, epoch_id)
model_logger.on_training_end(accelerator, model, save_steps)
def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"): def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_path="./models"):
@@ -446,6 +472,9 @@ def wan_parser():
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).")
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
return parser return parser
@@ -474,6 +503,9 @@ def flux_parser():
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
return parser return parser
@@ -503,4 +535,7 @@ def qwen_image_parser():
parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.")
parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.")
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.")
parser.add_argument("--find_unused_parameters", default=False, action="store_true", help="Whether to find unused parameters in DDP.")
parser.add_argument("--save_steps", type=int, default=None, help="Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.")
parser.add_argument("--dataset_num_workers", type=int, default=0, help="Number of workers for data loading.")
return parser return parser

View File

@@ -249,6 +249,7 @@ The script includes the following parameters:
* `--width`: Width of the image or video. Leave `height` and `width` empty to enable dynamic resolution. * `--width`: Width of the image or video. Leave `height` and `width` empty to enable dynamic resolution.
* `--data_file_keys`: Data file keys in the metadata. Separate with commas. * `--data_file_keys`: Data file keys in the metadata. Separate with commas.
* `--dataset_repeat`: Number of times the dataset repeats per epoch. * `--dataset_repeat`: Number of times the dataset repeats per epoch.
* `--dataset_num_workers`: Number of workers for data loading.
* Model * Model
* `--model_paths`: Paths to load models. In JSON format. * `--model_paths`: Paths to load models. In JSON format.
* `--model_id_with_origin_paths`: Model ID with original paths, e.g., black-forest-labs/FLUX.1-dev:flux1-dev.safetensors. Separate with commas. * `--model_id_with_origin_paths`: Model ID with original paths, e.g., black-forest-labs/FLUX.1-dev:flux1-dev.safetensors. Separate with commas.
@@ -257,6 +258,8 @@ The script includes the following parameters:
* `--num_epochs`: Number of epochs. * `--num_epochs`: Number of epochs.
* `--output_path`: Save path. * `--output_path`: Save path.
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint. * `--remove_prefix_in_ckpt`: Remove prefix in checkpoint.
* `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.
* `--find_unused_parameters`: Whether to find unused parameters in DDP.
* Trainable Modules * Trainable Modules
* `--trainable_models`: Models that can be trained, e.g., dit, vae, text_encoder. * `--trainable_models`: Models that can be trained, e.g., dit, vae, text_encoder.
* `--lora_base_model`: Which model to add LoRA to. * `--lora_base_model`: Which model to add LoRA to.

View File

@@ -249,6 +249,7 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra
* `--width`: 图像或视频的宽度。将 `height``width` 留空以启用动态分辨率。 * `--width`: 图像或视频的宽度。将 `height``width` 留空以启用动态分辨率。
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
* 模型 * 模型
* `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 black-forest-labs/FLUX.1-dev:flux1-dev.safetensors。用逗号分隔。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 black-forest-labs/FLUX.1-dev:flux1-dev.safetensors。用逗号分隔。
@@ -257,6 +258,8 @@ FLUX 系列模型训练通过统一的 [`./model_training/train.py`](./model_tra
* `--num_epochs`: 轮数Epoch * `--num_epochs`: 轮数Epoch
* `--output_path`: 保存路径。 * `--output_path`: 保存路径。
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
* `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数
* 可训练模块 * 可训练模块
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
* `--lora_base_model`: LoRA 添加到哪个模型上。 * `--lora_base_model`: LoRA 添加到哪个模型上。

View File

@@ -121,4 +121,7 @@ if __name__ == "__main__":
dataset, model, model_logger, optimizer, scheduler, dataset, model, model_logger, optimizer, scheduler,
num_epochs=args.num_epochs, num_epochs=args.num_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
save_steps=args.save_steps,
find_unused_parameters=args.find_unused_parameters,
num_workers=args.dataset_num_workers,
) )

View File

@@ -43,6 +43,7 @@ image.save("image.jpg")
|Model ID|Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training| |Model ID|Inference|Full Training|Validation after Full Training|LoRA Training|Validation after LoRA Training|
|-|-|-|-|-|-| |-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image )|[code](./model_inference/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)| |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image )|[code](./model_inference/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
## Model Inference ## Model Inference
@@ -218,6 +219,7 @@ The script includes the following parameters:
* `--width`: Width of image or video. Leave `height` and `width` empty to enable dynamic resolution. * `--width`: Width of image or video. Leave `height` and `width` empty to enable dynamic resolution.
* `--data_file_keys`: Data file keys in metadata. Separate with commas. * `--data_file_keys`: Data file keys in metadata. Separate with commas.
* `--dataset_repeat`: Number of times the dataset repeats per epoch. * `--dataset_repeat`: Number of times the dataset repeats per epoch.
* `--dataset_num_workers`: Number of workers for data loading.
* Model * Model
* `--model_paths`: Model paths to load. In JSON format. * `--model_paths`: Model paths to load. In JSON format.
* `--model_id_with_origin_paths`: Model ID with original paths, e.g., Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors. Separate with commas. * `--model_id_with_origin_paths`: Model ID with original paths, e.g., Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors. Separate with commas.
@@ -227,6 +229,8 @@ The script includes the following parameters:
* `--num_epochs`: Number of epochs. * `--num_epochs`: Number of epochs.
* `--output_path`: Save path. * `--output_path`: Save path.
* `--remove_prefix_in_ckpt`: Remove prefix in checkpoint. * `--remove_prefix_in_ckpt`: Remove prefix in checkpoint.
* `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.
* `--find_unused_parameters`: Whether to find unused parameters in DDP.
* Trainable Modules * Trainable Modules
* `--trainable_models`: Models to train, e.g., dit, vae, text_encoder. * `--trainable_models`: Models to train, e.g., dit, vae, text_encoder.
* `--lora_base_model`: Which model to add LoRA to. * `--lora_base_model`: Which model to add LoRA to.

View File

@@ -43,6 +43,7 @@ image.save("image.jpg")
|模型 ID|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证| |模型 ID|推理|全量训练|全量训练后验证|LoRA 训练|LoRA 训练后验证|
|-|-|-|-|-|-| |-|-|-|-|-|-|
|[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)| |[Qwen/Qwen-Image](https://www.modelscope.cn/models/Qwen/Qwen-Image)|[code](./model_inference/Qwen-Image.py)|[code](./model_training/full/Qwen-Image.sh)|[code](./model_training/validate_full/Qwen-Image.py)|[code](./model_training/lora/Qwen-Image.sh)|[code](./model_training/validate_lora/Qwen-Image.py)|
|[DiffSynth-Studio/Qwen-Image-Distill-Full](https://www.modelscope.cn/models/DiffSynth-Studio/Qwen-Image-Distill-Full)|[code](./model_inference/Qwen-Image-Distill-Full.py)|[code](./model_training/full/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_full/Qwen-Image-Distill-Full.py)|[code](./model_training/lora/Qwen-Image-Distill-Full.sh)|[code](./model_training/validate_lora/Qwen-Image-Distill-Full.py)|
## 模型推理 ## 模型推理
@@ -218,6 +219,7 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod
* `--width`: 图像或视频的宽度。将 `height``width` 留空以启用动态分辨率。 * `--width`: 图像或视频的宽度。将 `height``width` 留空以启用动态分辨率。
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
* 模型 * 模型
* `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors。用逗号分隔。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors。用逗号分隔。
@@ -227,6 +229,8 @@ Qwen-Image 系列模型训练通过统一的 [`./model_training/train.py`](./mod
* `--num_epochs`: 轮数Epoch * `--num_epochs`: 轮数Epoch
* `--output_path`: 保存路径。 * `--output_path`: 保存路径。
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
* `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数
* 可训练模块 * 可训练模块
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
* `--lora_base_model`: LoRA 添加到哪个模型上。 * `--lora_base_model`: LoRA 添加到哪个模型上。

View File

@@ -0,0 +1,17 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Distill-Full", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1)
image.save("image.jpg")

View File

@@ -0,0 +1,18 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Distill-Full", origin_file_pattern="diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
pipe.enable_vram_management()
prompt = "精致肖像,水下少女,蓝裙飘逸,发丝轻扬,光影透澈,气泡环绕,面容恬静,细节精致,梦幻唯美。"
image = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1)
image.save("image.jpg")

View File

@@ -0,0 +1,12 @@
accelerate launch --config_file examples/qwen_image/model_training/full/accelerate_config_zero2offload.yaml examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "DiffSynth-Studio/Qwen-Image-Distill-Full:diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-5 \
--num_epochs 2 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Distill-Full_full" \
--trainable_models "dit" \
--use_gradient_checkpointing

View File

@@ -0,0 +1,15 @@
accelerate launch examples/qwen_image/model_training/train.py \
--dataset_base_path data/example_image_dataset \
--dataset_metadata_path data/example_image_dataset/metadata.csv \
--max_pixels 1048576 \
--dataset_repeat 50 \
--model_id_with_origin_paths "DiffSynth-Studio/Qwen-Image-Distill-Full:diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \
--learning_rate 1e-4 \
--num_epochs 5 \
--remove_prefix_in_ckpt "pipe.dit." \
--output_path "./models/train/Qwen-Image-Distill-Full_lora" \
--lora_base_model "dit" \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \
--align_to_opensource_format \
--use_gradient_checkpointing

View File

@@ -12,4 +12,6 @@ accelerate launch examples/qwen_image/model_training/train.py \
--lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \ --lora_target_modules "to_q,to_k,to_v,add_q_proj,add_k_proj,add_v_proj,to_out.0,to_add_out,img_mlp.net.2,img_mod.1,txt_mlp.net.2,txt_mod.1" \
--lora_rank 32 \ --lora_rank 32 \
--align_to_opensource_format \ --align_to_opensource_format \
--use_gradient_checkpointing --use_gradient_checkpointing \
--dataset_num_workers 8 \
--find_unused_parameters

View File

@@ -1,6 +1,7 @@
import torch, os, json import torch, os, json
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser
from diffsynth.models.lora import QwenImageLoRAConverter
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -108,6 +109,7 @@ if __name__ == "__main__":
model_logger = ModelLogger( model_logger = ModelLogger(
args.output_path, args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt, remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
state_dict_converter=QwenImageLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else lambda x:x,
) )
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate) optimizer = torch.optim.AdamW(model.trainable_modules(), lr=args.learning_rate)
scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
@@ -115,4 +117,7 @@ if __name__ == "__main__":
dataset, model, model_logger, optimizer, scheduler, dataset, model, model_logger, optimizer, scheduler,
num_epochs=args.num_epochs, num_epochs=args.num_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
save_steps=args.save_steps,
find_unused_parameters=args.find_unused_parameters,
num_workers=args.dataset_num_workers,
) )

View File

@@ -0,0 +1,20 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
from diffsynth import load_state_dict
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Distill-Full", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
state_dict = load_state_dict("models/train/Qwen-Image-Distill-Full_full/epoch-1.safetensors")
pipe.dit.load_state_dict(state_dict)
prompt = "a dog"
image = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1)
image.save("image.jpg")

View File

@@ -7,9 +7,9 @@ pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device="cuda", device="cuda",
model_configs=[ model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
], ],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
) )

View File

@@ -0,0 +1,18 @@
from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig
import torch
pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(model_id="DiffSynth-Studio/Qwen-Image-Distill-Full", origin_file_pattern="diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
)
pipe.load_lora(pipe.dit, "models/train/Qwen-Image-Distill-Full_lora/epoch-4.safetensors")
prompt = "a dog"
image = pipe(prompt, seed=0, num_inference_steps=15, cfg_scale=1)
image.save("image.jpg")

View File

@@ -6,9 +6,9 @@ pipe = QwenImagePipeline.from_pretrained(
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
device="cuda", device="cuda",
model_configs=[ model_configs=[
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="transformer/diffusion_pytorch_model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="text_encoder/model*.safetensors"),
ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors", offload_device="cpu", offload_dtype=torch.float8_e4m3fn), ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
], ],
tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"), tokenizer_config=ModelConfig(model_id="Qwen/Qwen-Image", origin_file_pattern="tokenizer/"),
) )

View File

@@ -280,6 +280,7 @@ The script includes the following parameters:
* `--num_frames`: Number of frames per video. Frames are sampled from the video prefix. * `--num_frames`: Number of frames per video. Frames are sampled from the video prefix.
* `--data_file_keys`: Data file keys in the metadata. Comma-separated. * `--data_file_keys`: Data file keys in the metadata. Comma-separated.
* `--dataset_repeat`: Number of times to repeat the dataset per epoch. * `--dataset_repeat`: Number of times to repeat the dataset per epoch.
* `--dataset_num_workers`: Number of workers for data loading.
* Models * Models
* `--model_paths`: Paths to load models. In JSON format. * `--model_paths`: Paths to load models. In JSON format.
* `--model_id_with_origin_paths`: Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated. * `--model_id_with_origin_paths`: Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.
@@ -290,6 +291,8 @@ The script includes the following parameters:
* `--num_epochs`: Number of epochs. * `--num_epochs`: Number of epochs.
* `--output_path`: Output save path. * `--output_path`: Output save path.
* `--remove_prefix_in_ckpt`: Remove prefix in ckpt. * `--remove_prefix_in_ckpt`: Remove prefix in ckpt.
* `--save_steps`: Number of checkpoint saving invervals. If None, checkpoints will be saved every epoch.
* `--find_unused_parameters`: Whether to find unused parameters in DDP.
* Trainable Modules * Trainable Modules
* `--trainable_models`: Models to train, e.g., dit, vae, text_encoder. * `--trainable_models`: Models to train, e.g., dit, vae, text_encoder.
* `--lora_base_model`: Which model LoRA is added to. * `--lora_base_model`: Which model LoRA is added to.

View File

@@ -282,6 +282,7 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
* `--num_frames`: 每个视频中的帧数。帧从视频前缀中采样。 * `--num_frames`: 每个视频中的帧数。帧从视频前缀中采样。
* `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。 * `--data_file_keys`: 元数据中的数据文件键。用逗号分隔。
* `--dataset_repeat`: 每个 epoch 中数据集重复的次数。 * `--dataset_repeat`: 每个 epoch 中数据集重复的次数。
* `--dataset_num_workers`: 每个 Dataloder 的进程数量。
* 模型 * 模型
* `--model_paths`: 要加载的模型路径。JSON 格式。 * `--model_paths`: 要加载的模型路径。JSON 格式。
* `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors。用逗号分隔。 * `--model_id_with_origin_paths`: 带原始路径的模型 ID例如 Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors。用逗号分隔。
@@ -292,6 +293,8 @@ Wan 系列模型训练通过统一的 [`./model_training/train.py`](./model_trai
* `--num_epochs`: 轮数Epoch * `--num_epochs`: 轮数Epoch
* `--output_path`: 保存路径。 * `--output_path`: 保存路径。
* `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。 * `--remove_prefix_in_ckpt`: 在 ckpt 中移除前缀。
* `--save_steps`: 保存模型的间隔 step 数量,如果设置为 None ,则每个 epoch 保存一次
* `--find_unused_parameters`: DDP 训练中是否存在未使用的参数
* 可训练模块 * 可训练模块
* `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。 * `--trainable_models`: 可训练的模型,例如 dit、vae、text_encoder。
* `--lora_base_model`: LoRA 添加到哪个模型上。 * `--lora_base_model`: LoRA 添加到哪个模型上。

View File

@@ -127,4 +127,7 @@ if __name__ == "__main__":
dataset, model, model_logger, optimizer, scheduler, dataset, model, model_logger, optimizer, scheduler,
num_epochs=args.num_epochs, num_epochs=args.num_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps, gradient_accumulation_steps=args.gradient_accumulation_steps,
save_steps=args.save_steps,
find_unused_parameters=args.find_unused_parameters,
num_workers=args.dataset_num_workers,
) )