Flux lora update (#237)

* update flux lora

---------

Co-authored-by: tc2000731 <tc2000731@163.com>
This commit is contained in:
Zhongjie Duan
2024-10-11 18:41:24 +08:00
committed by GitHub
parent 75ab786afc
commit 22e4ae99e8
11 changed files with 63 additions and 24 deletions

View File

@@ -9,7 +9,7 @@ class LightningModel(LightningModelForT2ILoRA):
self,
torch_dtype=torch.float16, pretrained_weights=[],
learning_rate=1e-4, use_gradient_checkpointing=True,
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out"
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian",
):
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing)
# Load models
@@ -19,7 +19,7 @@ class LightningModel(LightningModelForT2ILoRA):
self.pipe.scheduler.set_timesteps(1000)
self.freeze_parameters()
self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules)
self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights)
def parse_args():
@@ -51,6 +51,7 @@ if __name__ == '__main__':
use_gradient_checkpointing=args.use_gradient_checkpointing,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
init_lora_weights=args.init_lora_weights,
lora_target_modules=args.lora_target_modules
)
launch_training_task(model, args)