From b1b2d50c0d5411d3afaded5a050e366f5b96b902 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 12 Jul 2024 17:30:19 +0800 Subject: [PATCH 1/2] reduce VRAM requirements in Kolors LoRA --- diffsynth/models/__init__.py | 4 ++ examples/train/kolors/README.md | 52 +++++++++++-------- examples/train/kolors/train_kolors_lora.py | 59 ++++++++++++++++------ 3 files changed, 79 insertions(+), 36 deletions(-) diff --git a/diffsynth/models/__init__.py b/diffsynth/models/__init__.py index 057ed89..8f2190d 100644 --- a/diffsynth/models/__init__.py +++ b/diffsynth/models/__init__.py @@ -174,6 +174,9 @@ preset_models_on_modelscope = { ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"), ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"), ], + "SDXL-vae-fp16-fix": [ + ("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "sdxl-vae-fp16-fix") + ], } Preset_model_id: TypeAlias = Literal[ "HunyuanDiT", @@ -201,6 +204,7 @@ Preset_model_id: TypeAlias = Literal[ "StableDiffusion3", "StableDiffusion3_without_T5", "Kolors", + "SDXL-vae-fp16-fix", ] Preset_model_website: TypeAlias = Literal[ "HuggingFace", diff --git a/examples/train/kolors/README.md b/examples/train/kolors/README.md index be50561..ba505b4 100644 --- a/examples/train/kolors/README.md +++ b/examples/train/kolors/README.md @@ -4,23 +4,27 @@ Kolors is a Chinese diffusion model, which is based on ChatGLM and Stable Diffus ## Download models -The following files will be used for constructing Kolors. You can download them from [huggingface](https://huggingface.co/Kwai-Kolors/Kolors) or [modelscope](https://modelscope.cn/models/Kwai-Kolors/Kolors). +The following files will be used for constructing Kolors. You can download Kolors from [huggingface](https://huggingface.co/Kwai-Kolors/Kolors) or [modelscope](https://modelscope.cn/models/Kwai-Kolors/Kolors). Due to precision overflow issues, we need to download an additional VAE model (from [huggingface](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) or [modelscope](https://modelscope.cn/models/AI-ModelScope/sdxl-vae-fp16-fix)). ``` -models/kolors/Kolors -├── text_encoder -│ ├── config.json -│ ├── pytorch_model-00001-of-00007.bin -│ ├── pytorch_model-00002-of-00007.bin -│ ├── pytorch_model-00003-of-00007.bin -│ ├── pytorch_model-00004-of-00007.bin -│ ├── pytorch_model-00005-of-00007.bin -│ ├── pytorch_model-00006-of-00007.bin -│ ├── pytorch_model-00007-of-00007.bin -│ └── pytorch_model.bin.index.json -├── unet -│ └── diffusion_pytorch_model.safetensors -└── vae +models +├── kolors +│ └── Kolors +│ ├── text_encoder +│ │ ├── config.json +│ │ ├── pytorch_model-00001-of-00007.bin +│ │ ├── pytorch_model-00002-of-00007.bin +│ │ ├── pytorch_model-00003-of-00007.bin +│ │ ├── pytorch_model-00004-of-00007.bin +│ │ ├── pytorch_model-00005-of-00007.bin +│ │ ├── pytorch_model-00006-of-00007.bin +│ │ ├── pytorch_model-00007-of-00007.bin +│ │ └── pytorch_model.bin.index.json +│ ├── unet +│ │ └── diffusion_pytorch_model.safetensors +│ └── vae +│ └── diffusion_pytorch_model.safetensors +└── sdxl-vae-fp16-fix └── diffusion_pytorch_model.safetensors ``` @@ -29,7 +33,7 @@ You can use the following code to download these files: ```python from diffsynth import download_models -download_models(["Kolors"]) +download_models(["Kolors", "SDXL-vae-fp16-fix"]) ``` ## Train @@ -70,24 +74,30 @@ file_name,text We provide a training script `train_kolors_lora.py`. Before you run this training script, please copy it to the root directory of this project. -The following settings are recommended. **We found the UNet model suffers from precision overflow issues, thus the training script doesn't support float16. 40GB VRAM is required. We are working on overcoming this pitfall.** +The following settings are recommended. 22GB VRAM is required. ``` CUDA_VISIBLE_DEVICES="0" python examples/train/kolors/train_kolors_lora.py \ - --pretrained_path models/kolors/Kolors \ + --pretrained_unet_path models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors \ + --pretrained_text_encoder_path models/kolors/Kolors/text_encoder \ + --pretrained_fp16_vae_path models/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors \ --dataset_path data/dog \ --output_path ./models \ --max_epochs 10 \ --center_crop \ --use_gradient_checkpointing \ - --precision 32 + --precision "16-mixed" ``` Optional arguments: ``` -h, --help show this help message and exit - --pretrained_path PRETRAINED_PATH - Path to pretrained model. For example, `models/kolors/Kolors`. + --pretrained_unet_path PRETRAINED_UNET_PATH + Path to pretrained model (UNet). For example, `models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors`. + --pretrained_text_encoder_path PRETRAINED_TEXT_ENCODER_PATH + Path to pretrained model (Text Encoder). For example, `models/kolors/Kolors/text_encoder`. + --pretrained_fp16_vae_path PRETRAINED_FP16_VAE_PATH + Path to pretrained model (VAE). For example, `models/kolors/Kolors/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors`. --dataset_path DATASET_PATH The path of the Dataset. --output_path OUTPUT_PATH diff --git a/examples/train/kolors/train_kolors_lora.py b/examples/train/kolors/train_kolors_lora.py index 0a0af9d..7f5c4ca 100644 --- a/examples/train/kolors/train_kolors_lora.py +++ b/examples/train/kolors/train_kolors_lora.py @@ -1,4 +1,4 @@ -from diffsynth import ModelManager, KolorsImagePipeline +from diffsynth import KolorsImagePipeline, load_state_dict, ChatGLMModel, SDXLUNet, SDXLVAEEncoder from peft import LoraConfig, inject_adapter_in_model from torchvision import transforms from PIL import Image @@ -40,23 +40,40 @@ class TextImageDataset(torch.utils.data.Dataset): +def load_model_from_diffsynth(ModelClass, model_kwargs, state_dict_path, torch_dtype, device): + model = ModelClass(**model_kwargs).to(dtype=torch_dtype, device=device) + state_dict = load_state_dict(state_dict_path, torch_dtype=torch_dtype) + model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict)) + return model + + +def load_model_from_transformers(ModelClass, model_kwargs, state_dict_path, torch_dtype, device): + model = ModelClass.from_pretrained(state_dict_path, torch_dtype=torch_dtype) + model = model.to(dtype=torch_dtype, device=device) + return model + + + class LightningModel(pl.LightningModule): - def __init__(self, torch_dtype=torch.float16, learning_rate=1e-4, pretrained_weights=[], lora_rank=4, lora_alpha=4, use_gradient_checkpointing=True): + def __init__( + self, + pretrained_unet_path, pretrained_text_encoder_path, pretrained_fp16_vae_path, + torch_dtype=torch.float16, learning_rate=1e-4, lora_rank=4, lora_alpha=4, use_gradient_checkpointing=True + ): super().__init__() # Load models - model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device) - model_manager.load_models(pretrained_weights) - self.pipe = KolorsImagePipeline.from_model_manager(model_manager) + self.pipe = KolorsImagePipeline(device=self.device, torch_dtype=torch_dtype) + self.pipe.text_encoder = load_model_from_transformers(ChatGLMModel, {}, pretrained_text_encoder_path, torch_dtype, self.device) + self.pipe.unet = load_model_from_diffsynth(SDXLUNet, {"is_kolors": True}, pretrained_unet_path, torch_dtype, self.device) + self.pipe.vae_encoder = load_model_from_diffsynth(SDXLVAEEncoder, {}, pretrained_fp16_vae_path, torch_dtype, self.device) # Freeze parameters self.pipe.text_encoder.requires_grad_(False) self.pipe.unet.requires_grad_(False) - self.pipe.vae_decoder.requires_grad_(False) self.pipe.vae_encoder.requires_grad_(False) self.pipe.text_encoder.eval() self.pipe.unet.train() - self.pipe.vae_decoder.eval() self.pipe.vae_encoder.eval() # Add LoRA to UNet @@ -88,7 +105,7 @@ class LightningModel(pl.LightningModule): self.pipe.text_encoder, text, clip_skip=2, device=self.device, positive=True, ) height, width = image.shape[-2:] - latents = self.pipe.vae_encoder(image.to(dtype=torch.float32, device=self.device)).to(self.pipe.torch_dtype) + latents = self.pipe.vae_encoder(image.to(self.device)) noise = torch.randn_like(latents) timestep = torch.randint(0, 1100, (1,), device=self.device)[0] add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device) @@ -126,11 +143,25 @@ class LightningModel(pl.LightningModule): def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( - "--pretrained_path", + "--pretrained_unet_path", type=str, default=None, required=True, - help="Path to pretrained model. For example, `models/kolors/Kolors`.", + help="Path to pretrained model (UNet). For example, `models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors`.", + ) + parser.add_argument( + "--pretrained_text_encoder_path", + type=str, + default=None, + required=True, + help="Path to pretrained model (Text Encoder). For example, `models/kolors/Kolors/text_encoder`.", + ) + parser.add_argument( + "--pretrained_fp16_vae_path", + type=str, + default=None, + required=True, + help="Path to pretrained model (VAE). For example, `models/kolors/Kolors/sdxl-vae-fp16-fix/diffusion_pytorch_model.safetensors`.", ) parser.add_argument( "--dataset_path", @@ -267,11 +298,9 @@ if __name__ == '__main__': # model model = LightningModel( - pretrained_weights=[ - os.path.join(args.pretrained_path, "text_encoder"), - os.path.join(args.pretrained_path, "unet/diffusion_pytorch_model.safetensors"), - os.path.join(args.pretrained_path, "vae/diffusion_pytorch_model.safetensors"), - ], + args.pretrained_unet_path, + args.pretrained_text_encoder_path, + args.pretrained_fp16_vae_path, torch_dtype=torch.float32 if args.precision == "32" else torch.float16, learning_rate=args.learning_rate, lora_rank=args.lora_rank, From 3f8eea46879ab7213ae89b0ab806d50d8c3b2e30 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Fri, 12 Jul 2024 17:39:26 +0800 Subject: [PATCH 2/2] update downloader --- diffsynth/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diffsynth/models/__init__.py b/diffsynth/models/__init__.py index 8f2190d..e634518 100644 --- a/diffsynth/models/__init__.py +++ b/diffsynth/models/__init__.py @@ -175,7 +175,7 @@ preset_models_on_modelscope = { ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"), ], "SDXL-vae-fp16-fix": [ - ("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "sdxl-vae-fp16-fix") + ("AI-ModelScope/sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors", "models/sdxl-vae-fp16-fix") ], } Preset_model_id: TypeAlias = Literal[