Files
DiffSynth-Studio/docs/source_en/finetune/train_sd_lora.md
yrk111222 883d26abb4 Add files via upload
第一版翻译完成,保留了getStart目录,有一些名词还是需要重新检查
2024-10-18 18:02:52 +08:00

2.3 KiB

Training Stable Diffusion LoRA

The training script only requires one file. We support mainstream checkpoints on CivitAI. By default, we use the basic Stable Diffusion v1.5. You can download it from HuggingFace or ModelScope. You can use the following code to download this file:

from diffsynth import download_models

download_models(["StableDiffusion_v15"])
models/stable_diffusion
├── Put Stable Diffusion checkpoints here.txt
└── v1-5-pruned-emaonly.safetensors

To initiate the training process, please use the following command:

CUDA_VISIBLE_DEVICES="0" python examples/train/stable_diffusion/train_sd_lora.py \
  --pretrained_path models/stable_diffusion/v1-5-pruned-emaonly.safetensors \
  --dataset_path data/dog \
  --output_path ./models \
  --max_epochs 1 \
  --steps_per_epoch 500 \
  --height 512 \
  --width 512 \
  --center_crop \
  --precision "16-mixed" \
  --learning_rate 1e-4 \
  --lora_rank 4 \
  --lora_alpha 4 \
  --use_gradient_checkpointing

For more information about the parameters, please use python examples/train/stable_diffusion/train_sd_lora.py -h to view detailed information.

After training is complete, use model_manager.load_lora to load LoRA for inference.

from diffsynth import ModelManager, SDImagePipeline
import torch

model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
                             file_path_list=["models/stable_diffusion/v1-5-pruned-emaonly.safetensors"])
model_manager.load_lora("models/lightning_logs/version_0/checkpoints/epoch=0-step=500.ckpt", lora_alpha=1.0)
pipe = SDImagePipeline.from_model_manager(model_manager)

torch.manual_seed(0)
image = pipe(
    prompt="a dog is jumping, flowers around the dog, the background is mountains and clouds", 
    negative_prompt="bad quality, poor quality, doll, disfigured, jpg, toy, bad anatomy, missing limbs, missing fingers, 3d, cgi, extra tails",
    cfg_scale=7.5,
    num_inference_steps=100, width=512, height=512,
)
image.save("image_with_lora.jpg")