Flux lora update (#237)

* update flux lora

---------

Co-authored-by: tc2000731 <tc2000731@163.com>
This commit is contained in:
Zhongjie Duan
2024-10-11 18:41:24 +08:00
committed by GitHub
parent 75ab786afc
commit 22e4ae99e8
11 changed files with 63 additions and 24 deletions

View File

@@ -32,12 +32,15 @@ class LightningModelForT2ILoRA(pl.LightningModule):
self.pipe.denoising_model().train()
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out"):
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian"):
# Add LoRA to UNet
if init_lora_weights == "kaiming":
init_lora_weights = True
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
init_lora_weights="gaussian",
init_lora_weights=init_lora_weights,
target_modules=lora_target_modules.split(","),
)
model = inject_adapter_in_model(lora_config, model)
@@ -67,7 +70,8 @@ class LightningModelForT2ILoRA(pl.LightningModule):
noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
use_gradient_checkpointing=self.use_gradient_checkpointing
)
loss = torch.nn.functional.mse_loss(noise_pred, training_target)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
# Record log
self.log("train_loss", loss, prog_bar=True)
@@ -179,6 +183,13 @@ def add_general_parsers(parser):
default=4.0,
help="The weight of the LoRA update matrices.",
)
parser.add_argument(
"--init_lora_weights",
type=str,
default="kaiming",
choices=["gaussian", "kaiming"],
help="The initializing method of LoRA weight.",
)
parser.add_argument(
"--use_gradient_checkpointing",
default=False,