mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 15:48:20 +00:00
Flux lora update (#237)
* update flux lora --------- Co-authored-by: tc2000731 <tc2000731@163.com>
This commit is contained in:
@@ -32,12 +32,15 @@ class LightningModelForT2ILoRA(pl.LightningModule):
|
||||
self.pipe.denoising_model().train()
|
||||
|
||||
|
||||
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out"):
|
||||
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian"):
|
||||
# Add LoRA to UNet
|
||||
if init_lora_weights == "kaiming":
|
||||
init_lora_weights = True
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
init_lora_weights="gaussian",
|
||||
init_lora_weights=init_lora_weights,
|
||||
target_modules=lora_target_modules.split(","),
|
||||
)
|
||||
model = inject_adapter_in_model(lora_config, model)
|
||||
@@ -67,7 +70,8 @@ class LightningModelForT2ILoRA(pl.LightningModule):
|
||||
noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||
use_gradient_checkpointing=self.use_gradient_checkpointing
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred, training_target)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||
|
||||
# Record log
|
||||
self.log("train_loss", loss, prog_bar=True)
|
||||
@@ -179,6 +183,13 @@ def add_general_parsers(parser):
|
||||
default=4.0,
|
||||
help="The weight of the LoRA update matrices.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--init_lora_weights",
|
||||
type=str,
|
||||
default="kaiming",
|
||||
choices=["gaussian", "kaiming"],
|
||||
help="The initializing method of LoRA weight.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_gradient_checkpointing",
|
||||
default=False,
|
||||
|
||||
Reference in New Issue
Block a user