update sd training scripts

This commit is contained in:
Artiprocher
2026-04-24 14:30:09 +08:00
parent 5cdab9ed01
commit 3799bdc23a
23 changed files with 323 additions and 612 deletions

View File

@@ -1,6 +1,5 @@
import torch, os, argparse, accelerate
from diffsynth.core import UnifiedDataset
from diffsynth.core.data.operators import ToAbsolutePath, LoadImage, ImageCropAndResize, RouteByType, SequencialProcess
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline, ModelConfig
from diffsynth.diffusion import *
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -23,59 +22,55 @@ class StableDiffusionTrainingModule(DiffusionTrainingModule):
task="sft",
):
super().__init__()
# ===== 解析模型配置 =====
# Load models
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
# ===== Tokenizer 配置 =====
tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"))
# ===== 构建 Pipeline =====
self.pipe = StableDiffusionPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
# ===== 拆分 Pipeline Units =====
tokenizer_config = self.parse_path_or_model_id(tokenizer_path, ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"))
self.pipe = StableDiffusionPipeline.from_pretrained(torch_dtype=torch.float32, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
# ===== 切换到训练模式 =====
# Training mode
self.switch_pipe_to_training_mode(
self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
preset_lora_path, preset_lora_model,
task=task,
)
# ===== 其他配置 =====
# Other configs
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.fp8_models = fp8_models
self.task = task
# ===== 任务模式路由 =====
self.task_to_loss = {
"sft:data_process": lambda pipe, *args: args,
"direct_distill:data_process": lambda pipe, *args: args,
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
"direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
"direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
}
def get_pipeline_inputs(self, data):
# ===== 正向提示词 =====
inputs_posi = {"prompt": data["prompt"]}
# ===== 负向提示词:训练不需要 =====
inputs_nega = {"negative_prompt": ""}
# ===== 共享参数 =====
inputs_shared = {
# ===== 核心字段映射 =====
# Assume you are using this pipeline for inference,
# please fill in the input parameters.
"input_image": data["image"],
"height": data["image"].size[1],
"width": data["image"].size[0],
# ===== 框架控制参数 =====
# Please do not modify the following parameters
# unless you clearly know what this will cause.
"cfg_scale": 1,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
}
# ===== 额外字段注入 =====
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
return inputs_shared, inputs_posi, inputs_nega
def forward(self, data, inputs=None):
# ===== 标准实现,不要修改 =====
if inputs is None: inputs = self.get_pipeline_inputs(data)
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
for unit in self.pipe.units:
@@ -84,24 +79,21 @@ class StableDiffusionTrainingModule(DiffusionTrainingModule):
return loss
def stable_diffusion_parser():
parser = argparse.ArgumentParser(description="Stable Diffusion training.")
def parser():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser = add_general_config(parser)
parser = add_image_size_config(parser)
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
return parser
if __name__ == "__main__":
parser = stable_diffusion_parser()
parser = parser()
args = parser.parse_args()
# ===== Accelerator 配置 =====
accelerator = accelerate.Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
)
# ===== 数据集定义 =====
dataset = UnifiedDataset(
base_path=args.dataset_base_path,
metadata_path=args.dataset_metadata_path,
@@ -112,17 +104,10 @@ if __name__ == "__main__":
max_pixels=args.max_pixels,
height=args.height,
width=args.width,
height_division_factor=8,
width_division_factor=8,
),
special_operator_map={
"image": RouteByType(operator_map=[
(str, ToAbsolutePath(args.dataset_base_path) >> LoadImage() >> ImageCropAndResize(args.height, args.width, args.max_pixels, 8, 8)),
(list, SequencialProcess(ToAbsolutePath(args.dataset_base_path) >> LoadImage(convert_RGB=False, convert_RGBA=True) >> ImageCropAndResize(args.height, args.width, args.max_pixels, 8, 8))),
]),
},
height_division_factor=32,
width_division_factor=32,
)
)
# ===== TrainingModule =====
model = StableDiffusionTrainingModule(
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
@@ -140,17 +125,18 @@ if __name__ == "__main__":
fp8_models=args.fp8_models,
offload_models=args.offload_models,
task=args.task,
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
device=accelerator.device,
)
# ===== ModelLogger =====
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
)
# ===== 任务路由 =====
launcher_map = {
"sft:data_process": launch_data_process_task,
"direct_distill:data_process": launch_data_process_task,
"sft": launch_training_task,
"sft:train": launch_training_task,
"direct_distill": launch_training_task,
"direct_distill:train": launch_training_task,
}
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)