mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
reduce VRAM requirements in Kolors LoRA
This commit is contained in:
@@ -174,6 +174,9 @@ preset_models_on_modelscope = {
|
||||
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
||||
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
||||
],
|
||||
"SDXL-vae-fp16-fix": [
|
||||
("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "sdxl-vae-fp16-fix")
|
||||
],
|
||||
}
|
||||
Preset_model_id: TypeAlias = Literal[
|
||||
"HunyuanDiT",
|
||||
@@ -201,6 +204,7 @@ Preset_model_id: TypeAlias = Literal[
|
||||
"StableDiffusion3",
|
||||
"StableDiffusion3_without_T5",
|
||||
"Kolors",
|
||||
"SDXL-vae-fp16-fix",
|
||||
]
|
||||
Preset_model_website: TypeAlias = Literal[
|
||||
"HuggingFace",
|
||||
|
||||
@@ -4,23 +4,27 @@ Kolors is a Chinese diffusion model, which is based on ChatGLM and Stable Diffus
|
||||
|
||||
## Download models
|
||||
|
||||
The following files will be used for constructing Kolors. You can download them from [huggingface](https://huggingface.co/Kwai-Kolors/Kolors) or [modelscope](https://modelscope.cn/models/Kwai-Kolors/Kolors).
|
||||
The following files will be used for constructing Kolors. You can download Kolors from [huggingface](https://huggingface.co/Kwai-Kolors/Kolors) or [modelscope](https://modelscope.cn/models/Kwai-Kolors/Kolors). Due to precision overflow issues, we need to download an additional VAE model (from [huggingface](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) or [modelscope](https://modelscope.cn/models/AI-ModelScope/sdxl-vae-fp16-fix)).
|
||||
|
||||
```
|
||||
models/kolors/Kolors
|
||||
├── text_encoder
|
||||
│ ├── config.json
|
||||
│ ├── pytorch_model-00001-of-00007.bin
|
||||
│ ├── pytorch_model-00002-of-00007.bin
|
||||
│ ├── pytorch_model-00003-of-00007.bin
|
||||
│ ├── pytorch_model-00004-of-00007.bin
|
||||
│ ├── pytorch_model-00005-of-00007.bin
|
||||
│ ├── pytorch_model-00006-of-00007.bin
|
||||
│ ├── pytorch_model-00007-of-00007.bin
|
||||
│ └── pytorch_model.bin.index.json
|
||||
├── unet
|
||||
│ └── diffusion_pytorch_model.safetensors
|
||||
└── vae
|
||||
models
|
||||
├── kolors
|
||||
│ └── Kolors
|
||||
│ ├── text_encoder
|
||||
│ │ ├── config.json
|
||||
│ │ ├── pytorch_model-00001-of-00007.bin
|
||||
│ │ ├── pytorch_model-00002-of-00007.bin
|
||||
│ │ ├── pytorch_model-00003-of-00007.bin
|
||||
│ │ ├── pytorch_model-00004-of-00007.bin
|
||||
│ │ ├── pytorch_model-00005-of-00007.bin
|
||||
│ │ ├── pytorch_model-00006-of-00007.bin
|
||||
│ │ ├── pytorch_model-00007-of-00007.bin
|
||||
│ │ └── pytorch_model.bin.index.json
|
||||
│ ├── unet
|
||||
│ │ └── diffusion_pytorch_model.safetensors
|
||||
│ └── vae
|
||||
│ └── diffusion_pytorch_model.safetensors
|
||||
└── sdxl-vae-fp16-fix
|
||||
└── diffusion_pytorch_model.safetensors
|
||||
```
|
||||
|
||||
@@ -29,7 +33,7 @@ You can use the following code to download these files:
|
||||
```python
|
||||
from diffsynth import download_models
|
||||
|
||||
download_models(["Kolors"])
|
||||
download_models(["Kolors", "SDXL-vae-fp16-fix"])
|
||||
```
|
||||
|
||||
## Train
|
||||
@@ -70,24 +74,30 @@ file_name,text
|
||||
|
||||
We provide a training script `train_kolors_lora.py`. Before you run this training script, please copy it to the root directory of this project.
|
||||
|
||||
The following settings are recommended. **We found the UNet model suffers from precision overflow issues, thus the training script doesn't support float16. 40GB VRAM is required. We are working on overcoming this pitfall.**
|
||||
The following settings are recommended. 22GB VRAM is required.
|
||||
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES="0" python examples/train/kolors/train_kolors_lora.py \
|
||||
--pretrained_path models/kolors/Kolors \
|
||||
--pretrained_unet_path models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors \
|
||||
--pretrained_text_encoder_path models/kolors/Kolors/text_encoder \
|
||||
--pretrained_fp16_vae_path models/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors \
|
||||
--dataset_path data/dog \
|
||||
--output_path ./models \
|
||||
--max_epochs 10 \
|
||||
--center_crop \
|
||||
--use_gradient_checkpointing \
|
||||
--precision 32
|
||||
--precision "16-mixed"
|
||||
```
|
||||
|
||||
Optional arguments:
|
||||
```
|
||||
-h, --help show this help message and exit
|
||||
--pretrained_path PRETRAINED_PATH
|
||||
Path to pretrained model. For example, `models/kolors/Kolors`.
|
||||
--pretrained_unet_path PRETRAINED_UNET_PATH
|
||||
Path to pretrained model (UNet). For example, `models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors`.
|
||||
--pretrained_text_encoder_path PRETRAINED_TEXT_ENCODER_PATH
|
||||
Path to pretrained model (Text Encoder). For example, `models/kolors/Kolors/text_encoder`.
|
||||
--pretrained_fp16_vae_path PRETRAINED_FP16_VAE_PATH
|
||||
Path to pretrained model (VAE). For example, `models/kolors/Kolors/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors`.
|
||||
--dataset_path DATASET_PATH
|
||||
The path of the Dataset.
|
||||
--output_path OUTPUT_PATH
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from diffsynth import ModelManager, KolorsImagePipeline
|
||||
from diffsynth import KolorsImagePipeline, load_state_dict, ChatGLMModel, SDXLUNet, SDXLVAEEncoder
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
@@ -40,23 +40,40 @@ class TextImageDataset(torch.utils.data.Dataset):
|
||||
|
||||
|
||||
|
||||
def load_model_from_diffsynth(ModelClass, model_kwargs, state_dict_path, torch_dtype, device):
|
||||
model = ModelClass(**model_kwargs).to(dtype=torch_dtype, device=device)
|
||||
state_dict = load_state_dict(state_dict_path, torch_dtype=torch_dtype)
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
return model
|
||||
|
||||
|
||||
def load_model_from_transformers(ModelClass, model_kwargs, state_dict_path, torch_dtype, device):
|
||||
model = ModelClass.from_pretrained(state_dict_path, torch_dtype=torch_dtype)
|
||||
model = model.to(dtype=torch_dtype, device=device)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
class LightningModel(pl.LightningModule):
|
||||
def __init__(self, torch_dtype=torch.float16, learning_rate=1e-4, pretrained_weights=[], lora_rank=4, lora_alpha=4, use_gradient_checkpointing=True):
|
||||
def __init__(
|
||||
self,
|
||||
pretrained_unet_path, pretrained_text_encoder_path, pretrained_fp16_vae_path,
|
||||
torch_dtype=torch.float16, learning_rate=1e-4, lora_rank=4, lora_alpha=4, use_gradient_checkpointing=True
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
|
||||
model_manager.load_models(pretrained_weights)
|
||||
self.pipe = KolorsImagePipeline.from_model_manager(model_manager)
|
||||
self.pipe = KolorsImagePipeline(device=self.device, torch_dtype=torch_dtype)
|
||||
self.pipe.text_encoder = load_model_from_transformers(ChatGLMModel, {}, pretrained_text_encoder_path, torch_dtype, self.device)
|
||||
self.pipe.unet = load_model_from_diffsynth(SDXLUNet, {"is_kolors": True}, pretrained_unet_path, torch_dtype, self.device)
|
||||
self.pipe.vae_encoder = load_model_from_diffsynth(SDXLVAEEncoder, {}, pretrained_fp16_vae_path, torch_dtype, self.device)
|
||||
|
||||
# Freeze parameters
|
||||
self.pipe.text_encoder.requires_grad_(False)
|
||||
self.pipe.unet.requires_grad_(False)
|
||||
self.pipe.vae_decoder.requires_grad_(False)
|
||||
self.pipe.vae_encoder.requires_grad_(False)
|
||||
self.pipe.text_encoder.eval()
|
||||
self.pipe.unet.train()
|
||||
self.pipe.vae_decoder.eval()
|
||||
self.pipe.vae_encoder.eval()
|
||||
|
||||
# Add LoRA to UNet
|
||||
@@ -88,7 +105,7 @@ class LightningModel(pl.LightningModule):
|
||||
self.pipe.text_encoder, text, clip_skip=2, device=self.device, positive=True,
|
||||
)
|
||||
height, width = image.shape[-2:]
|
||||
latents = self.pipe.vae_encoder(image.to(dtype=torch.float32, device=self.device)).to(self.pipe.torch_dtype)
|
||||
latents = self.pipe.vae_encoder(image.to(self.device))
|
||||
noise = torch.randn_like(latents)
|
||||
timestep = torch.randint(0, 1100, (1,), device=self.device)[0]
|
||||
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
|
||||
@@ -126,11 +143,25 @@ class LightningModel(pl.LightningModule):
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_path",
|
||||
"--pretrained_unet_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model. For example, `models/kolors/Kolors`.",
|
||||
help="Path to pretrained model (UNet). For example, `models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_text_encoder_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model (Text Encoder). For example, `models/kolors/Kolors/text_encoder`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_fp16_vae_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model (VAE). For example, `models/kolors/Kolors/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_path",
|
||||
@@ -267,11 +298,9 @@ if __name__ == '__main__':
|
||||
|
||||
# model
|
||||
model = LightningModel(
|
||||
pretrained_weights=[
|
||||
os.path.join(args.pretrained_path, "text_encoder"),
|
||||
os.path.join(args.pretrained_path, "unet/diffusion_pytorch_model.safetensors"),
|
||||
os.path.join(args.pretrained_path, "vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
args.pretrained_unet_path,
|
||||
args.pretrained_text_encoder_path,
|
||||
args.pretrained_fp16_vae_path,
|
||||
torch_dtype=torch.float32 if args.precision == "32" else torch.float16,
|
||||
learning_rate=args.learning_rate,
|
||||
lora_rank=args.lora_rank,
|
||||
|
||||
Reference in New Issue
Block a user