From 8c2671ce400e235fd69df82cc26cc813e74d3d39 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 16 Dec 2024 11:08:14 +0800 Subject: [PATCH] support resume training --- diffsynth/trainers/text_to_image.py | 18 +++++++++++++++++- examples/train/flux/train_flux_lora.py | 12 ++++++++++-- .../hunyuan_dit/train_hunyuan_dit_lora.py | 12 ++++++++++-- examples/train/kolors/train_kolors_lora.py | 12 ++++++++++-- .../train/stable_diffusion/train_sd_lora.py | 12 ++++++++++-- .../train/stable_diffusion_3/train_sd3_lora.py | 12 ++++++++++-- .../stable_diffusion_xl/train_sdxl_lora.py | 12 ++++++++++-- 7 files changed, 77 insertions(+), 13 deletions(-) diff --git a/diffsynth/trainers/text_to_image.py b/diffsynth/trainers/text_to_image.py index 3177474..5e49c98 100644 --- a/diffsynth/trainers/text_to_image.py +++ b/diffsynth/trainers/text_to_image.py @@ -3,6 +3,7 @@ from peft import LoraConfig, inject_adapter_in_model import torch, os from ..data.simple_text_image import TextImageDataset from modelscope.hub.api import HubApi +from ..models.utils import load_state_dict @@ -33,7 +34,7 @@ class LightningModelForT2ILoRA(pl.LightningModule): self.pipe.denoising_model().train() - def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian"): + def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None): # Add LoRA to UNet self.lora_alpha = lora_alpha if init_lora_weights == "kaiming": @@ -51,6 +52,15 @@ class LightningModelForT2ILoRA(pl.LightningModule): if param.requires_grad: param.data = param.to(torch.float32) + # Lora pretrained lora weights + if pretrained_lora_path is not None: + state_dict = load_state_dict(pretrained_lora_path) + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + all_keys = [i for i, _ in model.named_parameters()] + num_updated_keys = len(all_keys) - len(missing_keys) + num_unexpected_keys = len(unexpected_keys) + print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.") + def training_step(self, batch, batch_idx): # Data @@ -229,6 +239,12 @@ def add_general_parsers(parser): default=None, help="Access key on ModelScope (https://www.modelscope.cn/). Required if you want to upload the model to ModelScope.", ) + parser.add_argument( + "--pretrained_lora_path", + type=str, + default=None, + help="Pretrained LoRA path. Required if the training is resumed.", + ) return parser diff --git a/examples/train/flux/train_flux_lora.py b/examples/train/flux/train_flux_lora.py index 4efeed3..681d6ba 100644 --- a/examples/train/flux/train_flux_lora.py +++ b/examples/train/flux/train_flux_lora.py @@ -10,7 +10,7 @@ class LightningModel(LightningModelForT2ILoRA): self, torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None, learning_rate=1e-4, use_gradient_checkpointing=True, - lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", + lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None, quantize = None ): super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter) @@ -34,7 +34,14 @@ class LightningModel(LightningModelForT2ILoRA): self.pipe.scheduler.set_timesteps(1000, training=True) self.freeze_parameters() - self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights) + self.add_lora_to_model( + self.pipe.denoising_model(), + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_target_modules=lora_target_modules, + init_lora_weights=init_lora_weights, + pretrained_lora_path=pretrained_lora_path + ) def parse_args(): @@ -109,6 +116,7 @@ if __name__ == '__main__': lora_alpha=args.lora_alpha, lora_target_modules=args.lora_target_modules, init_lora_weights=args.init_lora_weights, + pretrained_lora_path=args.pretrained_lora_path, state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else None, quantize={"float8_e4m3fn": torch.float8_e4m3fn}.get(args.quantize, None), ) diff --git a/examples/train/hunyuan_dit/train_hunyuan_dit_lora.py b/examples/train/hunyuan_dit/train_hunyuan_dit_lora.py index 6ceba42..7764ab5 100644 --- a/examples/train/hunyuan_dit/train_hunyuan_dit_lora.py +++ b/examples/train/hunyuan_dit/train_hunyuan_dit_lora.py @@ -9,7 +9,7 @@ class LightningModel(LightningModelForT2ILoRA): self, torch_dtype=torch.float16, pretrained_weights=[], learning_rate=1e-4, use_gradient_checkpointing=True, - lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", + lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None, ): super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing) # Load models @@ -19,7 +19,14 @@ class LightningModel(LightningModelForT2ILoRA): self.pipe.scheduler.set_timesteps(1000) self.freeze_parameters() - self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights) + self.add_lora_to_model( + self.pipe.denoising_model(), + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_target_modules=lora_target_modules, + init_lora_weights=init_lora_weights, + pretrained_lora_path=pretrained_lora_path, + ) def parse_args(): @@ -57,6 +64,7 @@ if __name__ == '__main__': lora_rank=args.lora_rank, lora_alpha=args.lora_alpha, init_lora_weights=args.init_lora_weights, + pretrained_lora_path=args.pretrained_lora_path, lora_target_modules=args.lora_target_modules ) launch_training_task(model, args) diff --git a/examples/train/kolors/train_kolors_lora.py b/examples/train/kolors/train_kolors_lora.py index 120e41d..48a9892 100644 --- a/examples/train/kolors/train_kolors_lora.py +++ b/examples/train/kolors/train_kolors_lora.py @@ -9,7 +9,7 @@ class LightningModel(LightningModelForT2ILoRA): self, torch_dtype=torch.float16, pretrained_weights=[], learning_rate=1e-4, use_gradient_checkpointing=True, - lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", + lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None, ): super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing) # Load models @@ -22,7 +22,14 @@ class LightningModel(LightningModelForT2ILoRA): self.pipe.vae_encoder.to(torch_dtype) self.freeze_parameters() - self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights) + self.add_lora_to_model( + self.pipe.denoising_model(), + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_target_modules=lora_target_modules, + init_lora_weights=init_lora_weights, + pretrained_lora_path=pretrained_lora_path, + ) def parse_args(): @@ -73,6 +80,7 @@ if __name__ == '__main__': lora_rank=args.lora_rank, lora_alpha=args.lora_alpha, init_lora_weights=args.init_lora_weights, + pretrained_lora_path=args.pretrained_lora_path, lora_target_modules=args.lora_target_modules ) launch_training_task(model, args) diff --git a/examples/train/stable_diffusion/train_sd_lora.py b/examples/train/stable_diffusion/train_sd_lora.py index 8dcaf7a..dc24520 100644 --- a/examples/train/stable_diffusion/train_sd_lora.py +++ b/examples/train/stable_diffusion/train_sd_lora.py @@ -9,7 +9,7 @@ class LightningModel(LightningModelForT2ILoRA): self, torch_dtype=torch.float16, pretrained_weights=[], learning_rate=1e-4, use_gradient_checkpointing=True, - lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", + lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None, ): super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing) # Load models @@ -19,7 +19,14 @@ class LightningModel(LightningModelForT2ILoRA): self.pipe.scheduler.set_timesteps(1000) self.freeze_parameters() - self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights) + self.add_lora_to_model( + self.pipe.denoising_model(), + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_target_modules=lora_target_modules, + init_lora_weights=init_lora_weights, + pretrained_lora_path=pretrained_lora_path, + ) def parse_args(): @@ -52,6 +59,7 @@ if __name__ == '__main__': lora_rank=args.lora_rank, lora_alpha=args.lora_alpha, init_lora_weights=args.init_lora_weights, + pretrained_lora_path=args.pretrained_lora_path, lora_target_modules=args.lora_target_modules ) launch_training_task(model, args) diff --git a/examples/train/stable_diffusion_3/train_sd3_lora.py b/examples/train/stable_diffusion_3/train_sd3_lora.py index a677bcb..c9abf2b 100644 --- a/examples/train/stable_diffusion_3/train_sd3_lora.py +++ b/examples/train/stable_diffusion_3/train_sd3_lora.py @@ -9,7 +9,7 @@ class LightningModel(LightningModelForT2ILoRA): self, torch_dtype=torch.float16, pretrained_weights=[], preset_lora_path=None, learning_rate=1e-4, use_gradient_checkpointing=True, - lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", + lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None, ): super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing) # Load models @@ -24,7 +24,14 @@ class LightningModel(LightningModelForT2ILoRA): model_manager.load_lora(path) self.freeze_parameters() - self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights) + self.add_lora_to_model( + self.pipe.denoising_model(), + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_target_modules=lora_target_modules, + init_lora_weights=init_lora_weights, + pretrained_lora_path=pretrained_lora_path, + ) def parse_args(): @@ -70,6 +77,7 @@ if __name__ == '__main__': lora_rank=args.lora_rank, lora_alpha=args.lora_alpha, init_lora_weights=args.init_lora_weights, + pretrained_lora_path=args.pretrained_lora_path, lora_target_modules=args.lora_target_modules ) launch_training_task(model, args) diff --git a/examples/train/stable_diffusion_xl/train_sdxl_lora.py b/examples/train/stable_diffusion_xl/train_sdxl_lora.py index 69ca71d..de0241d 100644 --- a/examples/train/stable_diffusion_xl/train_sdxl_lora.py +++ b/examples/train/stable_diffusion_xl/train_sdxl_lora.py @@ -9,7 +9,7 @@ class LightningModel(LightningModelForT2ILoRA): self, torch_dtype=torch.float16, pretrained_weights=[], learning_rate=1e-4, use_gradient_checkpointing=True, - lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", + lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian", pretrained_lora_path=None, ): super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing) # Load models @@ -19,7 +19,14 @@ class LightningModel(LightningModelForT2ILoRA): self.pipe.scheduler.set_timesteps(1000) self.freeze_parameters() - self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights) + self.add_lora_to_model( + self.pipe.denoising_model(), + lora_rank=lora_rank, + lora_alpha=lora_alpha, + lora_target_modules=lora_target_modules, + init_lora_weights=init_lora_weights, + pretrained_lora_path=pretrained_lora_path, + ) def parse_args(): @@ -52,6 +59,7 @@ if __name__ == '__main__': lora_rank=args.lora_rank, lora_alpha=args.lora_alpha, init_lora_weights=args.init_lora_weights, + pretrained_lora_path=args.pretrained_lora_path, lora_target_modules=args.lora_target_modules ) launch_training_task(model, args)