mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 15:06:17 +00:00
sd and sdxl training
This commit is contained in:
@@ -124,6 +124,8 @@ class DDIMScheduler:
|
||||
else:
|
||||
raise ValueError(f"Unsupported timestep_spacing: {self.timestep_spacing}")
|
||||
|
||||
# Clamp timesteps to valid range [0, num_train_timesteps - 1]
|
||||
timesteps = np.clip(timesteps, 0, self.num_train_timesteps - 1)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.int64)
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
@@ -222,6 +224,8 @@ class DDIMScheduler:
|
||||
timestep = timestep[0].item()
|
||||
|
||||
timestep = int(timestep)
|
||||
# Defensive clamp: ensure timestep is within valid range
|
||||
timestep = max(0, min(timestep, self.num_train_timesteps - 1))
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt()
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timestep]).sqrt()
|
||||
|
||||
@@ -238,6 +242,14 @@ class DDIMScheduler:
|
||||
|
||||
def training_target(self, sample, noise, timestep):
|
||||
"""Return the training target for the given prediction type."""
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
if timestep.dim() == 0:
|
||||
timestep = timestep.item()
|
||||
elif timestep.dim() == 1:
|
||||
timestep = timestep[0].item()
|
||||
timestep = int(timestep)
|
||||
timestep = max(0, min(timestep, self.num_train_timesteps - 1))
|
||||
if self.prediction_type == "epsilon":
|
||||
return noise
|
||||
elif self.prediction_type == "v_prediction":
|
||||
@@ -251,5 +263,7 @@ class DDIMScheduler:
|
||||
|
||||
def training_weight(self, timestep):
|
||||
"""Return training weight for the given timestep."""
|
||||
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
||||
timestep = max(0, min(int(timestep), self.num_train_timesteps - 1))
|
||||
timestep_tensor = torch.tensor(timestep, device=self.timesteps.device)
|
||||
timestep_id = torch.argmin((self.timesteps - timestep_tensor).abs())
|
||||
return self.linear_timesteps_weights[timestep_id]
|
||||
|
||||
@@ -189,13 +189,26 @@ class SDUnit_NoiseInitializer(PipelineUnit):
|
||||
class SDUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("noise",),
|
||||
output_params=("latents",),
|
||||
input_params=("input_image", "noise"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",),
|
||||
)
|
||||
|
||||
def process(self, pipe: StableDiffusionPipeline, noise):
|
||||
# For Text-to-Image, latents = noise (scaled by scheduler)
|
||||
def process(self, pipe: StableDiffusionPipeline, input_image, noise):
|
||||
if input_image is None:
|
||||
return {"latents": noise * pipe.scheduler.init_noise_sigma, "input_latents": None}
|
||||
if pipe.scheduler.training:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_tensor = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae.encode(input_tensor).sample()
|
||||
latents = noise * pipe.scheduler.init_noise_sigma
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
else:
|
||||
# Inference mode: VAE encode input image, add noise for initial latent
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_tensor = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae.encode(input_tensor).sample()
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents}
|
||||
|
||||
|
||||
|
||||
@@ -87,9 +87,7 @@ class StableDiffusionXLPipeline(BasePipeline):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str,
|
||||
prompt_2: str = None,
|
||||
negative_prompt: str = "",
|
||||
negative_prompt_2: str = None,
|
||||
cfg_scale: float = 5.0,
|
||||
height: int = 1024,
|
||||
width: int = 1024,
|
||||
@@ -103,8 +101,6 @@ class StableDiffusionXLPipeline(BasePipeline):
|
||||
target_size: tuple = None,
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
prompt_2 = prompt_2 or prompt
|
||||
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
||||
original_size = original_size or (height, width)
|
||||
target_size = target_size or (height, width)
|
||||
|
||||
@@ -116,11 +112,9 @@ class StableDiffusionXLPipeline(BasePipeline):
|
||||
# 2. Three-dict input preparation
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
"prompt_2": prompt_2,
|
||||
}
|
||||
inputs_nega = {
|
||||
"prompt": negative_prompt,
|
||||
"prompt_2": negative_prompt_2,
|
||||
}
|
||||
inputs_shared = {
|
||||
"cfg_scale": cfg_scale,
|
||||
@@ -221,8 +215,8 @@ class SDXLUnit_PromptEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
seperate_cfg=True,
|
||||
input_params_posi={"prompt": "prompt", "prompt_2": "prompt_2"},
|
||||
input_params_nega={"prompt": "prompt", "prompt_2": "prompt_2"},
|
||||
input_params_posi={"prompt": "prompt"},
|
||||
input_params_nega={"prompt": "prompt"},
|
||||
output_params=("prompt_embeds", "pooled_prompt_embeds"),
|
||||
onload_model_names=("text_encoder", "text_encoder_2")
|
||||
)
|
||||
@@ -231,10 +225,9 @@ class SDXLUnit_PromptEmbedder(PipelineUnit):
|
||||
self,
|
||||
pipe: StableDiffusionXLPipeline,
|
||||
prompt: str,
|
||||
prompt_2: str,
|
||||
device: torch.device,
|
||||
) -> tuple:
|
||||
"""Encode prompt using both text encoders.
|
||||
"""Encode prompt using both text encoders (same prompt for both).
|
||||
|
||||
Returns (prompt_embeds, pooled_prompt_embeds):
|
||||
- prompt_embeds: concat(encoder1_output, encoder2_output) -> (B, 77, 2048)
|
||||
@@ -254,7 +247,7 @@ class SDXLUnit_PromptEmbedder(PipelineUnit):
|
||||
|
||||
# Text Encoder 2 (CLIP-bigG, 1280-dim) — uses penultimate hidden states + pooled
|
||||
text_input_ids_2 = pipe.tokenizer_2(
|
||||
prompt_2,
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=pipe.tokenizer_2.model_max_length,
|
||||
truncation=True,
|
||||
@@ -270,9 +263,9 @@ class SDXLUnit_PromptEmbedder(PipelineUnit):
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
def process(self, pipe: StableDiffusionXLPipeline, prompt, prompt_2):
|
||||
def process(self, pipe: StableDiffusionXLPipeline, prompt):
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
prompt_embeds, pooled_prompt_embeds = self.encode_prompt(pipe, prompt, prompt_2, pipe.device)
|
||||
prompt_embeds, pooled_prompt_embeds = self.encode_prompt(pipe, prompt, pipe.device)
|
||||
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds}
|
||||
|
||||
|
||||
@@ -294,13 +287,26 @@ class SDXLUnit_NoiseInitializer(PipelineUnit):
|
||||
class SDXLUnit_InputImageEmbedder(PipelineUnit):
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
input_params=("noise",),
|
||||
output_params=("latents",),
|
||||
input_params=("input_image", "noise"),
|
||||
output_params=("latents", "input_latents"),
|
||||
onload_model_names=("vae",),
|
||||
)
|
||||
|
||||
def process(self, pipe: StableDiffusionXLPipeline, noise):
|
||||
# For Text-to-Image, latents = noise (scaled by scheduler)
|
||||
def process(self, pipe: StableDiffusionXLPipeline, input_image, noise):
|
||||
if input_image is None:
|
||||
return {"latents": noise * pipe.scheduler.init_noise_sigma, "input_latents": None}
|
||||
if pipe.scheduler.training:
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_tensor = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae.encode(input_tensor).sample()
|
||||
latents = noise * pipe.scheduler.init_noise_sigma
|
||||
return {"latents": latents, "input_latents": input_latents}
|
||||
else:
|
||||
# Inference mode: VAE encode input image, add noise for initial latent
|
||||
pipe.load_models_to_device(self.onload_model_names)
|
||||
input_tensor = pipe.preprocess_image(input_image)
|
||||
input_latents = pipe.vae.encode(input_tensor).sample()
|
||||
latents = pipe.scheduler.add_noise(input_latents, noise, pipe.scheduler.timesteps[0])
|
||||
return {"latents": latents}
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
# Dataset: data/diffsynth_example_dataset/stable_diffusion/StableDiffusion/
|
||||
# Debug test: num_epochs=1, dataset_repeat=1 for quick validation
|
||||
|
||||
# ===== 固定参数(无需修改) =====
|
||||
accelerate launch examples/stable_diffusion/model_training/train.py \
|
||||
--learning_rate 1e-4 --num_epochs 1 \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing --find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/stable_diffusion/StableDiffusion" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/stable_diffusion/StableDiffusion/metadata.csv" \
|
||||
--model_id_with_origin_paths "AI-ModelScope/stable-diffusion-v1-5:text_encoder/model.safetensors,AI-ModelScope/stable-diffusion-v1-5:unet/diffusion_pytorch_model.safetensors,AI-ModelScope/stable-diffusion-v1-5:vae/diffusion_pytorch_model.safetensors" \
|
||||
--lora_base_model "unet" \
|
||||
--remove_prefix_in_ckpt "pipe.unet." \
|
||||
--max_pixels 262144 \
|
||||
--height 512 --width 512 \
|
||||
--dataset_repeat 1 \
|
||||
--output_path "./models/train/StableDiffusion_lora_debug" \
|
||||
--lora_target_modules "to_q,to_k,to_v,to_out.0" \
|
||||
--data_file_keys "image"
|
||||
@@ -0,0 +1,19 @@
|
||||
# Dataset: data/diffsynth_example_dataset/stable_diffusion/StableDiffusion/
|
||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion/StableDiffusion/*" --local_dir ./data/diffsynth_example_dataset
|
||||
|
||||
# ===== 固定参数(无需修改) =====
|
||||
accelerate launch examples/stable_diffusion/model_training/train.py \
|
||||
--learning_rate 1e-4 --num_epochs 5 \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing --find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/stable_diffusion/StableDiffusion" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/stable_diffusion/StableDiffusion/metadata.csv" \
|
||||
--model_id_with_origin_paths "AI-ModelScope/stable-diffusion-v1-5:text_encoder/model.safetensors,AI-ModelScope/stable-diffusion-v1-5:unet/diffusion_pytorch_model.safetensors,AI-ModelScope/stable-diffusion-v1-5:vae/diffusion_pytorch_model.safetensors" \
|
||||
--lora_base_model "unet" \
|
||||
--remove_prefix_in_ckpt "pipe.unet." \
|
||||
--max_pixels 262144 \
|
||||
--height 512 --width 512 \
|
||||
--dataset_repeat 50 \
|
||||
--output_path "./models/train/StableDiffusion_lora" \
|
||||
--lora_target_modules "to_q,to_k,to_v,to_out.0" \
|
||||
--data_file_keys "image"
|
||||
156
examples/stable_diffusion/model_training/train.py
Normal file
156
examples/stable_diffusion/model_training/train.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import torch, os, argparse, accelerate
|
||||
from diffsynth.core import UnifiedDataset
|
||||
from diffsynth.core.data.operators import ToAbsolutePath, LoadImage, ImageCropAndResize, RouteByType, SequencialProcess
|
||||
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline, ModelConfig
|
||||
from diffsynth.diffusion import *
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
class StableDiffusionTrainingModule(DiffusionTrainingModule):
|
||||
def __init__(
|
||||
self,
|
||||
model_paths=None, model_id_with_origin_paths=None,
|
||||
tokenizer_path=None,
|
||||
trainable_models=None,
|
||||
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||
preset_lora_path=None, preset_lora_model=None,
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
extra_inputs=None,
|
||||
fp8_models=None,
|
||||
offload_models=None,
|
||||
device="cpu",
|
||||
task="sft",
|
||||
):
|
||||
super().__init__()
|
||||
# ===== 解析模型配置 =====
|
||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||
# ===== Tokenizer 配置 =====
|
||||
tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"))
|
||||
# ===== 构建 Pipeline =====
|
||||
self.pipe = StableDiffusionPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config)
|
||||
# ===== 拆分 Pipeline Units =====
|
||||
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||
|
||||
# ===== 切换到训练模式 =====
|
||||
self.switch_pipe_to_training_mode(
|
||||
self.pipe, trainable_models,
|
||||
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
||||
preset_lora_path, preset_lora_model,
|
||||
task=task,
|
||||
)
|
||||
|
||||
# ===== 其他配置 =====
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||
self.fp8_models = fp8_models
|
||||
self.task = task
|
||||
# ===== 任务模式路由 =====
|
||||
self.task_to_loss = {
|
||||
"sft:data_process": lambda pipe, *args: args,
|
||||
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||
}
|
||||
|
||||
def get_pipeline_inputs(self, data):
|
||||
# ===== 正向提示词 =====
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
# ===== 负向提示词:训练不需要 =====
|
||||
inputs_nega = {"negative_prompt": ""}
|
||||
# ===== 共享参数 =====
|
||||
inputs_shared = {
|
||||
# ===== 核心字段映射 =====
|
||||
"input_image": data["image"],
|
||||
"height": data["image"].size[1],
|
||||
"width": data["image"].size[0],
|
||||
# ===== 框架控制参数 =====
|
||||
"cfg_scale": 1,
|
||||
"rand_device": self.pipe.device,
|
||||
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||
}
|
||||
# ===== 额外字段注入 =====
|
||||
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
def forward(self, data, inputs=None):
|
||||
# ===== 标准实现,不要修改 =====
|
||||
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
||||
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||
for unit in self.pipe.units:
|
||||
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
|
||||
loss = self.task_to_loss[self.task](self.pipe, *inputs)
|
||||
return loss
|
||||
|
||||
|
||||
def stable_diffusion_parser():
|
||||
parser = argparse.ArgumentParser(description="Stable Diffusion training.")
|
||||
parser = add_general_config(parser)
|
||||
parser = add_image_size_config(parser)
|
||||
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
||||
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = stable_diffusion_parser()
|
||||
args = parser.parse_args()
|
||||
# ===== Accelerator 配置 =====
|
||||
accelerator = accelerate.Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
||||
)
|
||||
# ===== 数据集定义 =====
|
||||
dataset = UnifiedDataset(
|
||||
base_path=args.dataset_base_path,
|
||||
metadata_path=args.dataset_metadata_path,
|
||||
repeat=args.dataset_repeat,
|
||||
data_file_keys=args.data_file_keys.split(","),
|
||||
main_data_operator=UnifiedDataset.default_image_operator(
|
||||
base_path=args.dataset_base_path,
|
||||
max_pixels=args.max_pixels,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
height_division_factor=8,
|
||||
width_division_factor=8,
|
||||
),
|
||||
special_operator_map={
|
||||
"image": RouteByType(operator_map=[
|
||||
(str, ToAbsolutePath(args.dataset_base_path) >> LoadImage() >> ImageCropAndResize(args.height, args.width, args.max_pixels, 8, 8)),
|
||||
(list, SequencialProcess(ToAbsolutePath(args.dataset_base_path) >> LoadImage(convert_RGB=False, convert_RGBA=True) >> ImageCropAndResize(args.height, args.width, args.max_pixels, 8, 8))),
|
||||
]),
|
||||
},
|
||||
)
|
||||
# ===== TrainingModule =====
|
||||
model = StableDiffusionTrainingModule(
|
||||
model_paths=args.model_paths,
|
||||
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||
tokenizer_path=args.tokenizer_path,
|
||||
trainable_models=args.trainable_models,
|
||||
lora_base_model=args.lora_base_model,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_checkpoint=args.lora_checkpoint,
|
||||
preset_lora_path=args.preset_lora_path,
|
||||
preset_lora_model=args.preset_lora_model,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
extra_inputs=args.extra_inputs,
|
||||
fp8_models=args.fp8_models,
|
||||
offload_models=args.offload_models,
|
||||
task=args.task,
|
||||
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
|
||||
)
|
||||
# ===== ModelLogger =====
|
||||
model_logger = ModelLogger(
|
||||
args.output_path,
|
||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||
)
|
||||
# ===== 任务路由 =====
|
||||
launcher_map = {
|
||||
"sft:data_process": launch_data_process_task,
|
||||
"sft": launch_training_task,
|
||||
"sft:train": launch_training_task,
|
||||
}
|
||||
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
|
||||
@@ -0,0 +1,17 @@
|
||||
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
torch_dtype=torch.float32,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
prompt = "dog, white and brown dog, sitting on wall, under pink flowers"
|
||||
image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=50, cfg_scale=7.5)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,18 @@
|
||||
from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
torch_dtype=torch.float32,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"),
|
||||
)
|
||||
pipe.load_lora(pipe.unet, "./models/train/StableDiffusion_lora/epoch-4.safetensors")
|
||||
prompt = "dog, white and brown dog, sitting on wall, under pink flowers"
|
||||
image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=50, cfg_scale=7.5)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,21 @@
|
||||
# Dataset: data/diffsynth_example_dataset/stable_diffusion_xl/StableDiffusionXL/
|
||||
# Debug test: num_epochs=1, dataset_repeat=1 for quick validation
|
||||
|
||||
# ===== 固定参数(无需修改) =====
|
||||
accelerate launch examples/stable_diffusion_xl/model_training/train.py \
|
||||
--learning_rate 1e-4 --num_epochs 1 \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing --find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/stable_diffusion_xl/StableDiffusionXL" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/stable_diffusion_xl/StableDiffusionXL/metadata.csv" \
|
||||
--model_id_with_origin_paths "AI-ModelScope/stable-diffusion-xl-base-1.0:text_encoder/model.safetensors,AI-ModelScope/stable-diffusion-xl-base-1.0:text_encoder_2/model.safetensors,AI-ModelScope/stable-diffusion-xl-base-1.0:unet/diffusion_pytorch_model.safetensors,AI-ModelScope/stable-diffusion-xl-base-1.0:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "AI-ModelScope/stable-diffusion-xl-base-1.0:tokenizer/" \
|
||||
--tokenizer_2_path "AI-ModelScope/stable-diffusion-xl-base-1.0:tokenizer_2/" \
|
||||
--lora_base_model "unet" \
|
||||
--remove_prefix_in_ckpt "pipe.unet." \
|
||||
--max_pixels 1048576 \
|
||||
--height 1024 --width 1024 \
|
||||
--dataset_repeat 1 \
|
||||
--output_path "./models/train/StableDiffusionXL_lora_debug" \
|
||||
--lora_target_modules "to_q,to_k,to_v,to_out.0" \
|
||||
--data_file_keys "image"
|
||||
@@ -0,0 +1,21 @@
|
||||
# Dataset: data/diffsynth_example_dataset/stable_diffusion_xl/StableDiffusionXL/
|
||||
# Download: modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion_xl/StableDiffusionXL/*" --local_dir ./data/diffsynth_example_dataset
|
||||
|
||||
# ===== 固定参数(无需修改) =====
|
||||
accelerate launch examples/stable_diffusion_xl/model_training/train.py \
|
||||
--learning_rate 1e-4 --num_epochs 5 \
|
||||
--lora_rank 32 \
|
||||
--use_gradient_checkpointing --find_unused_parameters \
|
||||
--dataset_base_path "./data/diffsynth_example_dataset/stable_diffusion_xl/StableDiffusionXL" \
|
||||
--dataset_metadata_path "./data/diffsynth_example_dataset/stable_diffusion_xl/StableDiffusionXL/metadata.csv" \
|
||||
--model_id_with_origin_paths "AI-ModelScope/stable-diffusion-xl-base-1.0:text_encoder/model.safetensors,AI-ModelScope/stable-diffusion-xl-base-1.0:text_encoder_2/model.safetensors,AI-ModelScope/stable-diffusion-xl-base-1.0:unet/diffusion_pytorch_model.safetensors,AI-ModelScope/stable-diffusion-xl-base-1.0:vae/diffusion_pytorch_model.safetensors" \
|
||||
--tokenizer_path "AI-ModelScope/stable-diffusion-xl-base-1.0:tokenizer/" \
|
||||
--tokenizer_2_path "AI-ModelScope/stable-diffusion-xl-base-1.0:tokenizer_2/" \
|
||||
--lora_base_model "unet" \
|
||||
--remove_prefix_in_ckpt "pipe.unet." \
|
||||
--max_pixels 1048576 \
|
||||
--height 1024 --width 1024 \
|
||||
--dataset_repeat 50 \
|
||||
--output_path "./models/train/StableDiffusionXL_lora" \
|
||||
--lora_target_modules "to_q,to_k,to_v,to_out.0" \
|
||||
--data_file_keys "image"
|
||||
174
examples/stable_diffusion_xl/model_training/train.py
Normal file
174
examples/stable_diffusion_xl/model_training/train.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import torch, os, argparse, accelerate
|
||||
from diffsynth.core import UnifiedDataset
|
||||
from diffsynth.core.data.operators import ToAbsolutePath, LoadImage, ImageCropAndResize, RouteByType, SequencialProcess
|
||||
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline, ModelConfig
|
||||
from diffsynth.diffusion import *
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
class StableDiffusionXLTrainingModule(DiffusionTrainingModule):
|
||||
def __init__(
|
||||
self,
|
||||
model_paths=None, model_id_with_origin_paths=None,
|
||||
tokenizer_path=None, tokenizer_2_path=None,
|
||||
trainable_models=None,
|
||||
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
|
||||
preset_lora_path=None, preset_lora_model=None,
|
||||
use_gradient_checkpointing=True,
|
||||
use_gradient_checkpointing_offload=False,
|
||||
extra_inputs=None,
|
||||
fp8_models=None,
|
||||
offload_models=None,
|
||||
device="cpu",
|
||||
task="sft",
|
||||
):
|
||||
super().__init__()
|
||||
# ===== 解析模型配置 =====
|
||||
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
|
||||
# ===== Tokenizer 配置 =====
|
||||
tokenizer_config = self.parse_path_or_model_id(tokenizer_path, default_value=ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"))
|
||||
tokenizer_2_config = self.parse_path_or_model_id(tokenizer_2_path, default_value=ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"))
|
||||
# ===== 构建 Pipeline =====
|
||||
self.pipe = StableDiffusionXLPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, tokenizer_2_config=tokenizer_2_config)
|
||||
# ===== 拆分 Pipeline Units =====
|
||||
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
|
||||
|
||||
# ===== 切换到训练模式 =====
|
||||
self.switch_pipe_to_training_mode(
|
||||
self.pipe, trainable_models,
|
||||
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
|
||||
preset_lora_path, preset_lora_model,
|
||||
task=task,
|
||||
)
|
||||
|
||||
# ===== 其他配置 =====
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
|
||||
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
|
||||
self.fp8_models = fp8_models
|
||||
self.task = task
|
||||
# ===== 任务模式路由 =====
|
||||
self.task_to_loss = {
|
||||
"sft:data_process": lambda pipe, *args: args,
|
||||
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
|
||||
}
|
||||
|
||||
def get_pipeline_inputs(self, data):
|
||||
# ===== 正向提示词 =====
|
||||
inputs_posi = {"prompt": data["prompt"]}
|
||||
# ===== 负向提示词:训练不需要 =====
|
||||
inputs_nega = {"negative_prompt": ""}
|
||||
# ===== 共享参数 =====
|
||||
height = data["image"].size[1]
|
||||
width = data["image"].size[0]
|
||||
inputs_shared = {
|
||||
# ===== 核心字段映射 =====
|
||||
"input_image": data["image"],
|
||||
"height": height,
|
||||
"width": width,
|
||||
# ===== 框架控制参数 =====
|
||||
"cfg_scale": 1,
|
||||
"rand_device": self.pipe.device,
|
||||
"use_gradient_checkpointing": self.use_gradient_checkpointing,
|
||||
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
|
||||
}
|
||||
# ===== SDXL 特有:add_time_ids (micro-conditioning) =====
|
||||
# 在 __call__ 中计算,但训练不跑 __call__,所以在这里注入
|
||||
text_encoder_projection_dim = self.pipe.text_encoder_2.config.projection_dim
|
||||
add_time_ids = [height, width, 0, 0, height, width]
|
||||
expected_add_embed_dim = self.pipe.unet.add_embedding.linear_1.in_features
|
||||
addition_time_embed_dim = self.pipe.unet.add_time_proj.num_channels
|
||||
passed_add_embed_dim = addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
if expected_add_embed_dim != passed_add_embed_dim:
|
||||
raise ValueError(
|
||||
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, "
|
||||
f"but a vector of {passed_add_embed_dim} was created."
|
||||
)
|
||||
inputs_posi["add_time_ids"] = torch.tensor([add_time_ids], dtype=self.pipe.torch_dtype, device=self.pipe.device)
|
||||
# ===== 额外字段注入 =====
|
||||
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
def forward(self, data, inputs=None):
|
||||
# ===== 标准实现,不要修改 =====
|
||||
if inputs is None: inputs = self.get_pipeline_inputs(data)
|
||||
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
|
||||
for unit in self.pipe.units:
|
||||
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
|
||||
loss = self.task_to_loss[self.task](self.pipe, *inputs)
|
||||
return loss
|
||||
|
||||
|
||||
def stable_diffusion_xl_parser():
|
||||
parser = argparse.ArgumentParser(description="Stable Diffusion XL training.")
|
||||
parser = add_general_config(parser)
|
||||
parser = add_image_size_config(parser)
|
||||
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
|
||||
parser.add_argument("--tokenizer_2_path", type=str, default=None, help="Path to tokenizer 2.")
|
||||
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = stable_diffusion_xl_parser()
|
||||
args = parser.parse_args()
|
||||
# ===== Accelerator 配置 =====
|
||||
accelerator = accelerate.Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
|
||||
)
|
||||
# ===== 数据集定义 =====
|
||||
dataset = UnifiedDataset(
|
||||
base_path=args.dataset_base_path,
|
||||
metadata_path=args.dataset_metadata_path,
|
||||
repeat=args.dataset_repeat,
|
||||
data_file_keys=args.data_file_keys.split(","),
|
||||
main_data_operator=UnifiedDataset.default_image_operator(
|
||||
base_path=args.dataset_base_path,
|
||||
max_pixels=args.max_pixels,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
height_division_factor=8,
|
||||
width_division_factor=8,
|
||||
),
|
||||
special_operator_map={
|
||||
"image": RouteByType(operator_map=[
|
||||
(str, ToAbsolutePath(args.dataset_base_path) >> LoadImage() >> ImageCropAndResize(args.height, args.width, args.max_pixels, 8, 8)),
|
||||
(list, SequencialProcess(ToAbsolutePath(args.dataset_base_path) >> LoadImage(convert_RGB=False, convert_RGBA=True) >> ImageCropAndResize(args.height, args.width, args.max_pixels, 8, 8))),
|
||||
]),
|
||||
},
|
||||
)
|
||||
# ===== TrainingModule =====
|
||||
model = StableDiffusionXLTrainingModule(
|
||||
model_paths=args.model_paths,
|
||||
model_id_with_origin_paths=args.model_id_with_origin_paths,
|
||||
tokenizer_path=args.tokenizer_path,
|
||||
tokenizer_2_path=args.tokenizer_2_path,
|
||||
trainable_models=args.trainable_models,
|
||||
lora_base_model=args.lora_base_model,
|
||||
lora_target_modules=args.lora_target_modules,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_checkpoint=args.lora_checkpoint,
|
||||
preset_lora_path=args.preset_lora_path,
|
||||
preset_lora_model=args.preset_lora_model,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing,
|
||||
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
|
||||
extra_inputs=args.extra_inputs,
|
||||
fp8_models=args.fp8_models,
|
||||
offload_models=args.offload_models,
|
||||
task=args.task,
|
||||
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
|
||||
)
|
||||
# ===== ModelLogger =====
|
||||
model_logger = ModelLogger(
|
||||
args.output_path,
|
||||
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
|
||||
)
|
||||
# ===== 任务路由 =====
|
||||
launcher_map = {
|
||||
"sft:data_process": launch_data_process_task,
|
||||
"sft": launch_training_task,
|
||||
"sft:train": launch_training_task,
|
||||
}
|
||||
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)
|
||||
@@ -0,0 +1,19 @@
|
||||
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
torch_dtype=torch.float32,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors"),
|
||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||
tokenizer_2_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||
)
|
||||
prompt = "dog, white and brown dog, sitting on wall, under pink flowers"
|
||||
image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=50, cfg_scale=5.0)
|
||||
image.save("image.jpg")
|
||||
@@ -0,0 +1,20 @@
|
||||
from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline, ModelConfig
|
||||
import torch
|
||||
|
||||
|
||||
pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
torch_dtype=torch.float32,
|
||||
device="cuda",
|
||||
model_configs=[
|
||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors"),
|
||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors"),
|
||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"),
|
||||
ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"),
|
||||
tokenizer_2_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"),
|
||||
)
|
||||
pipe.load_lora(pipe.unet, "./models/train/StableDiffusionXL_lora/epoch-4.safetensors")
|
||||
prompt = "dog, white and brown dog, sitting on wall, under pink flowers"
|
||||
image = pipe(prompt=prompt, seed=42, rand_device="cuda", num_inference_steps=50, cfg_scale=5.0)
|
||||
image.save("image.jpg")
|
||||
Reference in New Issue
Block a user