reduce VRAM requirements in Kolors LoRA

This commit is contained in:
Artiprocher
2024-07-12 17:30:19 +08:00
parent 9c6607f78d
commit b1b2d50c0d
3 changed files with 79 additions and 36 deletions

View File

@@ -4,23 +4,27 @@ Kolors is a Chinese diffusion model, which is based on ChatGLM and Stable Diffus
## Download models
The following files will be used for constructing Kolors. You can download them from [huggingface](https://huggingface.co/Kwai-Kolors/Kolors) or [modelscope](https://modelscope.cn/models/Kwai-Kolors/Kolors).
The following files will be used for constructing Kolors. You can download Kolors from [huggingface](https://huggingface.co/Kwai-Kolors/Kolors) or [modelscope](https://modelscope.cn/models/Kwai-Kolors/Kolors). Due to precision overflow issues, we need to download an additional VAE model (from [huggingface](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) or [modelscope](https://modelscope.cn/models/AI-ModelScope/sdxl-vae-fp16-fix)).
```
models/kolors/Kolors
├── text_encoder
── config.json
├── pytorch_model-00001-of-00007.bin
├── pytorch_model-00002-of-00007.bin
│ ├── pytorch_model-00003-of-00007.bin
│ ├── pytorch_model-00004-of-00007.bin
│ ├── pytorch_model-00005-of-00007.bin
│ ├── pytorch_model-00006-of-00007.bin
│ ├── pytorch_model-00007-of-00007.bin
── pytorch_model.bin.index.json
├── unet
└── diffusion_pytorch_model.safetensors
└── vae
models
├── kolors
── Kolors
├── text_encoder
│ ├── config.json
│ ├── pytorch_model-00001-of-00007.bin
│ ├── pytorch_model-00002-of-00007.bin
│ ├── pytorch_model-00003-of-00007.bin
│ ├── pytorch_model-00004-of-00007.bin
│ ├── pytorch_model-00005-of-00007.bin
│ ├── pytorch_model-00006-of-00007.bin
│ │ ├── pytorch_model-00007-of-00007.bin
│ └── pytorch_model.bin.index.json
│ ├── unet
│ │ └── diffusion_pytorch_model.safetensors
│ └── vae
│ └── diffusion_pytorch_model.safetensors
└── sdxl-vae-fp16-fix
└── diffusion_pytorch_model.safetensors
```
@@ -29,7 +33,7 @@ You can use the following code to download these files:
```python
from diffsynth import download_models
download_models(["Kolors"])
download_models(["Kolors", "SDXL-vae-fp16-fix"])
```
## Train
@@ -70,24 +74,30 @@ file_name,text
We provide a training script `train_kolors_lora.py`. Before you run this training script, please copy it to the root directory of this project.
The following settings are recommended. **We found the UNet model suffers from precision overflow issues, thus the training script doesn't support float16. 40GB VRAM is required. We are working on overcoming this pitfall.**
The following settings are recommended. 22GB VRAM is required.
```
CUDA_VISIBLE_DEVICES="0" python examples/train/kolors/train_kolors_lora.py \
--pretrained_path models/kolors/Kolors \
--pretrained_unet_path models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors \
--pretrained_text_encoder_path models/kolors/Kolors/text_encoder \
--pretrained_fp16_vae_path models/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors \
--dataset_path data/dog \
--output_path ./models \
--max_epochs 10 \
--center_crop \
--use_gradient_checkpointing \
--precision 32
--precision "16-mixed"
```
Optional arguments:
```
-h, --help show this help message and exit
--pretrained_path PRETRAINED_PATH
Path to pretrained model. For example, `models/kolors/Kolors`.
--pretrained_unet_path PRETRAINED_UNET_PATH
Path to pretrained model (UNet). For example, `models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors`.
--pretrained_text_encoder_path PRETRAINED_TEXT_ENCODER_PATH
Path to pretrained model (Text Encoder). For example, `models/kolors/Kolors/text_encoder`.
--pretrained_fp16_vae_path PRETRAINED_FP16_VAE_PATH
Path to pretrained model (VAE). For example, `models/kolors/Kolors/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors`.
--dataset_path DATASET_PATH
The path of the Dataset.
--output_path OUTPUT_PATH