From 2d7d5137ead7f1e0b46d802efce9aebc0590121a Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Fri, 24 Apr 2026 15:11:34 +0800 Subject: [PATCH] add full training --- .../full/stable-diffusion-v1-5.sh | 15 ++++++++++ .../validate_full/stable-diffusion-v1-5.py | 27 ++++++++++++++++++ .../full/stable-diffusion-xl-base-1.0.sh | 15 ++++++++++ .../stable-diffusion-xl-base-1.0.py | 28 +++++++++++++++++++ 4 files changed, 85 insertions(+) create mode 100644 examples/stable_diffusion/model_training/full/stable-diffusion-v1-5.sh create mode 100644 examples/stable_diffusion/model_training/validate_full/stable-diffusion-v1-5.py create mode 100644 examples/stable_diffusion_xl/model_training/full/stable-diffusion-xl-base-1.0.sh create mode 100644 examples/stable_diffusion_xl/model_training/validate_full/stable-diffusion-xl-base-1.0.py diff --git a/examples/stable_diffusion/model_training/full/stable-diffusion-v1-5.sh b/examples/stable_diffusion/model_training/full/stable-diffusion-v1-5.sh new file mode 100644 index 0000000..0396ded --- /dev/null +++ b/examples/stable_diffusion/model_training/full/stable-diffusion-v1-5.sh @@ -0,0 +1,15 @@ +modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion/stable-diffusion-v1-5/*" --local_dir ./data/diffsynth_example_dataset + +accelerate launch examples/stable_diffusion/model_training/train.py \ + --dataset_base_path data/diffsynth_example_dataset/stable_diffusion/stable-diffusion-v1-5 \ + --dataset_metadata_path data/diffsynth_example_dataset/stable_diffusion/stable-diffusion-v1-5/metadata.csv \ + --height 512 \ + --width 512 \ + --dataset_repeat 50 \ + --model_id_with_origin_paths "AI-ModelScope/stable-diffusion-v1-5:text_encoder/model.safetensors,AI-ModelScope/stable-diffusion-v1-5:unet/diffusion_pytorch_model.safetensors,AI-ModelScope/stable-diffusion-v1-5:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --trainable_models "unet" \ + --remove_prefix_in_ckpt "pipe.unet." \ + --output_path "./models/train/stable-diffusion-v1-5_full" \ + --use_gradient_checkpointing diff --git a/examples/stable_diffusion/model_training/validate_full/stable-diffusion-v1-5.py b/examples/stable_diffusion/model_training/validate_full/stable-diffusion-v1-5.py new file mode 100644 index 0000000..1daca48 --- /dev/null +++ b/examples/stable_diffusion/model_training/validate_full/stable-diffusion-v1-5.py @@ -0,0 +1,27 @@ +from diffsynth.pipelines.stable_diffusion import StableDiffusionPipeline, ModelConfig +from diffsynth.core import load_state_dict +import torch + +pipe = StableDiffusionPipeline.from_pretrained( + torch_dtype=torch.float32, + model_configs=[ + ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"), + ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="AI-ModelScope/stable-diffusion-v1-5", origin_file_pattern="tokenizer/"), +) +state_dict = load_state_dict("./models/train/stable-diffusion-v1-5_full/epoch-1.safetensors", torch_dtype=torch.float32) +pipe.unet.load_state_dict(state_dict) + +image = pipe( + prompt="a dog", + negative_prompt="blurry, low quality, deformed", + cfg_scale=7.5, + height=512, + width=512, + seed=42, + rand_device="cuda", + num_inference_steps=50, +) +image.save("image_stable-diffusion-v1-5_full.jpg") diff --git a/examples/stable_diffusion_xl/model_training/full/stable-diffusion-xl-base-1.0.sh b/examples/stable_diffusion_xl/model_training/full/stable-diffusion-xl-base-1.0.sh new file mode 100644 index 0000000..1cdb82d --- /dev/null +++ b/examples/stable_diffusion_xl/model_training/full/stable-diffusion-xl-base-1.0.sh @@ -0,0 +1,15 @@ +modelscope download --dataset DiffSynth-Studio/diffsynth_example_dataset --include "stable_diffusion_xl/stable-diffusion-xl-base-1.0/*" --local_dir ./data/diffsynth_example_dataset + +accelerate launch examples/stable_diffusion_xl/model_training/train.py \ + --dataset_base_path data/diffsynth_example_dataset/stable_diffusion_xl/stable-diffusion-xl-base-1.0 \ + --dataset_metadata_path data/diffsynth_example_dataset/stable_diffusion_xl/stable-diffusion-xl-base-1.0/metadata.csv \ + --height 1024 \ + --width 1024 \ + --dataset_repeat 10 \ + --model_id_with_origin_paths "stabilityai/stable-diffusion-xl-base-1.0:text_encoder/model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:text_encoder_2/model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:unet/diffusion_pytorch_model.safetensors,stabilityai/stable-diffusion-xl-base-1.0:vae/diffusion_pytorch_model.safetensors" \ + --learning_rate 1e-5 \ + --num_epochs 2 \ + --trainable_models "unet" \ + --remove_prefix_in_ckpt "pipe.unet." \ + --output_path "./models/train/stable-diffusion-xl-base-1.0_full" \ + --use_gradient_checkpointing diff --git a/examples/stable_diffusion_xl/model_training/validate_full/stable-diffusion-xl-base-1.0.py b/examples/stable_diffusion_xl/model_training/validate_full/stable-diffusion-xl-base-1.0.py new file mode 100644 index 0000000..06062ce --- /dev/null +++ b/examples/stable_diffusion_xl/model_training/validate_full/stable-diffusion-xl-base-1.0.py @@ -0,0 +1,28 @@ +from diffsynth.pipelines.stable_diffusion_xl import StableDiffusionXLPipeline, ModelConfig +from diffsynth.core import load_state_dict +import torch + +pipe = StableDiffusionXLPipeline.from_pretrained( + torch_dtype=torch.float32, + model_configs=[ + ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder/model.safetensors"), + ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="text_encoder_2/model.safetensors"), + ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="unet/diffusion_pytorch_model.safetensors"), + ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="vae/diffusion_pytorch_model.safetensors"), + ], + tokenizer_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer/"), + tokenizer_2_config=ModelConfig(model_id="stabilityai/stable-diffusion-xl-base-1.0", origin_file_pattern="tokenizer_2/"), +) +state_dict = load_state_dict("./models/train/stable-diffusion-xl-base-1.0_full/epoch-1.safetensors", torch_dtype=torch.float32) +pipe.unet.load_state_dict(state_dict) + +image = pipe( + prompt="a dog", + negative_prompt="", + cfg_scale=7.0, + height=1024, + width=1024, + seed=42, + num_inference_steps=50, +) +image.save("image_stable-diffusion-xl-base-1.0_full.jpg")