Update train_flux_lora.md

This commit is contained in:
Qianyi Zhao
2024-11-05 03:40:33 -06:00
committed by GitHub
parent 47a2f86f7b
commit 250ebf5c72

View File

@@ -22,7 +22,7 @@ models/FLUX/
└── model.safetensors.index.json
```
使用以下命令启动训练任务:
使用以下命令启动训练任务需要39G显存
```
CUDA_VISIBLE_DEVICES="0" python examples/train/flux/train_flux_lora.py \
@@ -33,26 +33,32 @@ CUDA_VISIBLE_DEVICES="0" python examples/train/flux/train_flux_lora.py \
--dataset_path data/dog \
--output_path ./models \
--max_epochs 1 \
--steps_per_epoch 500 \
--steps_per_epoch 100 \
--height 1024 \
--width 1024 \
--center_crop \
--precision "bf16" \
--learning_rate 1e-4 \
--lora_rank 4 \
--lora_alpha 4 \
--use_gradient_checkpointing
--lora_rank 16 \
--lora_alpha 16 \
--use_gradient_checkpointing \
--align_to_opensource_format
```
通过添加参数 `--quantize "float8_e4m3fn"`,你可以节省大约 10G 的显存。
**`--align_to_opensource_format` 表示此脚本将以开源格式导出 LoRA 权重。此格式可以在 DiffSynth-Studio 和其他代码库中加载。**
有关参数的更多信息,请使用 `python examples/train/flux/train_flux_lora.py -h` 查看详细信息。
训练完成后,使用 `model_manager.load_lora` 加载 LoRA 以进行推理。
训练完成后,使用 model_manager.load_lora 加载 LoRA 以进行推理。
```python
from diffsynth import ModelManager, FluxImagePipeline
import torch
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda",
file_path_list=[
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
@@ -60,11 +66,11 @@ model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
])
model_manager.load_lora("models/lightning_logs/version_0/checkpoints/epoch=0-step=500.ckpt", lora_alpha=1.0)
pipe = SDXLImagePipeline.from_model_manager(model_manager)
pipe = FluxImagePipeline.from_model_manager(model_manager)
torch.manual_seed(0)
image = pipe(
prompt=prompt,
prompt="a dog is jumping, flowers around the dog, the background is mountains and clouds",
num_inference_steps=30, embedded_guidance=3.5
)
image.save("image_with_lora.jpg")