Flux lora update (#237)

* update flux lora

---------

Co-authored-by: tc2000731 <tc2000731@163.com>
This commit is contained in:
Zhongjie Duan
2024-10-11 18:41:24 +08:00
committed by GitHub
parent 75ab786afc
commit 22e4ae99e8
11 changed files with 63 additions and 24 deletions

View File

@@ -10,7 +10,7 @@ class LightningModel(LightningModelForT2ILoRA):
self,
torch_dtype=torch.float16, pretrained_weights=[],
learning_rate=1e-4, use_gradient_checkpointing=True,
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out",
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming",
state_dict_converter=None, quantize = None
):
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter)
@@ -27,10 +27,10 @@ class LightningModel(LightningModelForT2ILoRA):
if quantize is not None:
self.pipe.dit.quantize()
self.pipe.scheduler.set_timesteps(1000)
self.pipe.scheduler.set_timesteps(1000, training=True)
self.freeze_parameters()
self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules)
self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights)
def parse_args():
@@ -97,6 +97,7 @@ if __name__ == '__main__':
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_target_modules=args.lora_target_modules,
init_lora_weights=args.init_lora_weights,
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else None,
quantize={"float8_e4m3fn": torch.float8_e4m3fn}.get(args.quantize, None),
)