reduce VRAM requirements in Kolors LoRA

This commit is contained in:
Artiprocher
2024-07-12 17:30:19 +08:00
parent 9c6607f78d
commit b1b2d50c0d
3 changed files with 79 additions and 36 deletions

View File

@@ -4,23 +4,27 @@ Kolors is a Chinese diffusion model, which is based on ChatGLM and Stable Diffus
## Download models
The following files will be used for constructing Kolors. You can download them from [huggingface](https://huggingface.co/Kwai-Kolors/Kolors) or [modelscope](https://modelscope.cn/models/Kwai-Kolors/Kolors).
The following files will be used for constructing Kolors. You can download Kolors from [huggingface](https://huggingface.co/Kwai-Kolors/Kolors) or [modelscope](https://modelscope.cn/models/Kwai-Kolors/Kolors). Due to precision overflow issues, we need to download an additional VAE model (from [huggingface](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) or [modelscope](https://modelscope.cn/models/AI-ModelScope/sdxl-vae-fp16-fix)).
```
models/kolors/Kolors
├── text_encoder
── config.json
├── pytorch_model-00001-of-00007.bin
├── pytorch_model-00002-of-00007.bin
│ ├── pytorch_model-00003-of-00007.bin
│ ├── pytorch_model-00004-of-00007.bin
│ ├── pytorch_model-00005-of-00007.bin
│ ├── pytorch_model-00006-of-00007.bin
│ ├── pytorch_model-00007-of-00007.bin
── pytorch_model.bin.index.json
├── unet
└── diffusion_pytorch_model.safetensors
└── vae
models
├── kolors
── Kolors
├── text_encoder
│ ├── config.json
│ ├── pytorch_model-00001-of-00007.bin
│ ├── pytorch_model-00002-of-00007.bin
│ ├── pytorch_model-00003-of-00007.bin
│ ├── pytorch_model-00004-of-00007.bin
│ ├── pytorch_model-00005-of-00007.bin
│ ├── pytorch_model-00006-of-00007.bin
│ │ ├── pytorch_model-00007-of-00007.bin
│ └── pytorch_model.bin.index.json
│ ├── unet
│ │ └── diffusion_pytorch_model.safetensors
│ └── vae
│ └── diffusion_pytorch_model.safetensors
└── sdxl-vae-fp16-fix
└── diffusion_pytorch_model.safetensors
```
@@ -29,7 +33,7 @@ You can use the following code to download these files:
```python
from diffsynth import download_models
download_models(["Kolors"])
download_models(["Kolors", "SDXL-vae-fp16-fix"])
```
## Train
@@ -70,24 +74,30 @@ file_name,text
We provide a training script `train_kolors_lora.py`. Before you run this training script, please copy it to the root directory of this project.
The following settings are recommended. **We found the UNet model suffers from precision overflow issues, thus the training script doesn't support float16. 40GB VRAM is required. We are working on overcoming this pitfall.**
The following settings are recommended. 22GB VRAM is required.
```
CUDA_VISIBLE_DEVICES="0" python examples/train/kolors/train_kolors_lora.py \
--pretrained_path models/kolors/Kolors \
--pretrained_unet_path models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors \
--pretrained_text_encoder_path models/kolors/Kolors/text_encoder \
--pretrained_fp16_vae_path models/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors \
--dataset_path data/dog \
--output_path ./models \
--max_epochs 10 \
--center_crop \
--use_gradient_checkpointing \
--precision 32
--precision "16-mixed"
```
Optional arguments:
```
-h, --help show this help message and exit
--pretrained_path PRETRAINED_PATH
Path to pretrained model. For example, `models/kolors/Kolors`.
--pretrained_unet_path PRETRAINED_UNET_PATH
Path to pretrained model (UNet). For example, `models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors`.
--pretrained_text_encoder_path PRETRAINED_TEXT_ENCODER_PATH
Path to pretrained model (Text Encoder). For example, `models/kolors/Kolors/text_encoder`.
--pretrained_fp16_vae_path PRETRAINED_FP16_VAE_PATH
Path to pretrained model (VAE). For example, `models/kolors/Kolors/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors`.
--dataset_path DATASET_PATH
The path of the Dataset.
--output_path OUTPUT_PATH

View File

@@ -1,4 +1,4 @@
from diffsynth import ModelManager, KolorsImagePipeline
from diffsynth import KolorsImagePipeline, load_state_dict, ChatGLMModel, SDXLUNet, SDXLVAEEncoder
from peft import LoraConfig, inject_adapter_in_model
from torchvision import transforms
from PIL import Image
@@ -40,23 +40,40 @@ class TextImageDataset(torch.utils.data.Dataset):
def load_model_from_diffsynth(ModelClass, model_kwargs, state_dict_path, torch_dtype, device):
model = ModelClass(**model_kwargs).to(dtype=torch_dtype, device=device)
state_dict = load_state_dict(state_dict_path, torch_dtype=torch_dtype)
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
return model
def load_model_from_transformers(ModelClass, model_kwargs, state_dict_path, torch_dtype, device):
model = ModelClass.from_pretrained(state_dict_path, torch_dtype=torch_dtype)
model = model.to(dtype=torch_dtype, device=device)
return model
class LightningModel(pl.LightningModule):
def __init__(self, torch_dtype=torch.float16, learning_rate=1e-4, pretrained_weights=[], lora_rank=4, lora_alpha=4, use_gradient_checkpointing=True):
def __init__(
self,
pretrained_unet_path, pretrained_text_encoder_path, pretrained_fp16_vae_path,
torch_dtype=torch.float16, learning_rate=1e-4, lora_rank=4, lora_alpha=4, use_gradient_checkpointing=True
):
super().__init__()
# Load models
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
model_manager.load_models(pretrained_weights)
self.pipe = KolorsImagePipeline.from_model_manager(model_manager)
self.pipe = KolorsImagePipeline(device=self.device, torch_dtype=torch_dtype)
self.pipe.text_encoder = load_model_from_transformers(ChatGLMModel, {}, pretrained_text_encoder_path, torch_dtype, self.device)
self.pipe.unet = load_model_from_diffsynth(SDXLUNet, {"is_kolors": True}, pretrained_unet_path, torch_dtype, self.device)
self.pipe.vae_encoder = load_model_from_diffsynth(SDXLVAEEncoder, {}, pretrained_fp16_vae_path, torch_dtype, self.device)
# Freeze parameters
self.pipe.text_encoder.requires_grad_(False)
self.pipe.unet.requires_grad_(False)
self.pipe.vae_decoder.requires_grad_(False)
self.pipe.vae_encoder.requires_grad_(False)
self.pipe.text_encoder.eval()
self.pipe.unet.train()
self.pipe.vae_decoder.eval()
self.pipe.vae_encoder.eval()
# Add LoRA to UNet
@@ -88,7 +105,7 @@ class LightningModel(pl.LightningModule):
self.pipe.text_encoder, text, clip_skip=2, device=self.device, positive=True,
)
height, width = image.shape[-2:]
latents = self.pipe.vae_encoder(image.to(dtype=torch.float32, device=self.device)).to(self.pipe.torch_dtype)
latents = self.pipe.vae_encoder(image.to(self.device))
noise = torch.randn_like(latents)
timestep = torch.randint(0, 1100, (1,), device=self.device)[0]
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
@@ -126,11 +143,25 @@ class LightningModel(pl.LightningModule):
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_path",
"--pretrained_unet_path",
type=str,
default=None,
required=True,
help="Path to pretrained model. For example, `models/kolors/Kolors`.",
help="Path to pretrained model (UNet). For example, `models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors`.",
)
parser.add_argument(
"--pretrained_text_encoder_path",
type=str,
default=None,
required=True,
help="Path to pretrained model (Text Encoder). For example, `models/kolors/Kolors/text_encoder`.",
)
parser.add_argument(
"--pretrained_fp16_vae_path",
type=str,
default=None,
required=True,
help="Path to pretrained model (VAE). For example, `models/kolors/Kolors/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors`.",
)
parser.add_argument(
"--dataset_path",
@@ -267,11 +298,9 @@ if __name__ == '__main__':
# model
model = LightningModel(
pretrained_weights=[
os.path.join(args.pretrained_path, "text_encoder"),
os.path.join(args.pretrained_path, "unet/diffusion_pytorch_model.safetensors"),
os.path.join(args.pretrained_path, "vae/diffusion_pytorch_model.safetensors"),
],
args.pretrained_unet_path,
args.pretrained_text_encoder_path,
args.pretrained_fp16_vae_path,
torch_dtype=torch.float32 if args.precision == "32" else torch.float16,
learning_rate=args.learning_rate,
lora_rank=args.lora_rank,