From 250ebf5c72ffccc5ba32b0de535b4b24dec2271c Mon Sep 17 00:00:00 2001 From: Qianyi Zhao <49068354+Qing112@users.noreply.github.com> Date: Tue, 5 Nov 2024 03:40:33 -0600 Subject: [PATCH] Update train_flux_lora.md --- docs/source/finetune/train_flux_lora.md | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/docs/source/finetune/train_flux_lora.md b/docs/source/finetune/train_flux_lora.md index 9410a66..1923e6e 100644 --- a/docs/source/finetune/train_flux_lora.md +++ b/docs/source/finetune/train_flux_lora.md @@ -22,7 +22,7 @@ models/FLUX/ └── model.safetensors.index.json ``` -使用以下命令启动训练任务: +使用以下命令启动训练任务(需要39G显存): ``` CUDA_VISIBLE_DEVICES="0" python examples/train/flux/train_flux_lora.py \ @@ -33,26 +33,32 @@ CUDA_VISIBLE_DEVICES="0" python examples/train/flux/train_flux_lora.py \ --dataset_path data/dog \ --output_path ./models \ --max_epochs 1 \ - --steps_per_epoch 500 \ + --steps_per_epoch 100 \ --height 1024 \ --width 1024 \ --center_crop \ --precision "bf16" \ --learning_rate 1e-4 \ - --lora_rank 4 \ - --lora_alpha 4 \ - --use_gradient_checkpointing + --lora_rank 16 \ + --lora_alpha 16 \ + --use_gradient_checkpointing \ + --align_to_opensource_format ``` +通过添加参数 `--quantize "float8_e4m3fn"`,你可以节省大约 10G 的显存。 + +**`--align_to_opensource_format` 表示此脚本将以开源格式导出 LoRA 权重。此格式可以在 DiffSynth-Studio 和其他代码库中加载。** + 有关参数的更多信息,请使用 `python examples/train/flux/train_flux_lora.py -h` 查看详细信息。 -训练完成后,使用 `model_manager.load_lora` 加载 LoRA 以进行推理。 +训练完成后,使用 model_manager.load_lora 来加载 LoRA 以进行推理。 + ```python from diffsynth import ModelManager, FluxImagePipeline import torch -model_manager = ModelManager(torch_dtype=torch.float16, device="cuda", +model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", file_path_list=[ "models/FLUX/FLUX.1-dev/text_encoder/model.safetensors", "models/FLUX/FLUX.1-dev/text_encoder_2", @@ -60,11 +66,11 @@ model_manager = ModelManager(torch_dtype=torch.float16, device="cuda", "models/FLUX/FLUX.1-dev/flux1-dev.safetensors" ]) model_manager.load_lora("models/lightning_logs/version_0/checkpoints/epoch=0-step=500.ckpt", lora_alpha=1.0) -pipe = SDXLImagePipeline.from_model_manager(model_manager) +pipe = FluxImagePipeline.from_model_manager(model_manager) torch.manual_seed(0) image = pipe( - prompt=prompt, + prompt="a dog is jumping, flowers around the dog, the background is mountains and clouds", num_inference_steps=30, embedded_guidance=3.5 ) image.save("image_with_lora.jpg")