Flux lora update (#237)

* update flux lora

---------

Co-authored-by: tc2000731 <tc2000731@163.com>
This commit is contained in:
Zhongjie Duan
2024-10-11 18:41:24 +08:00
committed by GitHub
parent 75ab786afc
commit 22e4ae99e8
11 changed files with 63 additions and 24 deletions

View File

@@ -123,7 +123,7 @@ models/FLUX/
└── model.safetensors.index.json
```
Launch the training task using the following command:
Launch the training task using the following command (39G VRAM required):
```
CUDA_VISIBLE_DEVICES="0" python examples/train/flux/train_flux_lora.py \
@@ -134,18 +134,20 @@ CUDA_VISIBLE_DEVICES="0" python examples/train/flux/train_flux_lora.py \
--dataset_path data/dog \
--output_path ./models \
--max_epochs 1 \
--steps_per_epoch 500 \
--steps_per_epoch 100 \
--height 1024 \
--width 1024 \
--center_crop \
--precision "bf16" \
--learning_rate 1e-4 \
--lora_rank 4 \
--lora_alpha 4 \
--lora_rank 16 \
--lora_alpha 16 \
--use_gradient_checkpointing \
--align_to_opensource_format
```
By adding parameter `--quantize "float8_e4m3fn"`, you can save approximate 10G VRAM.
**`--align_to_opensource_format` means that this script will export the LoRA weights in the opensource format. This format can be loaded in both DiffSynth-Studio and other codebases.**
For more information about the parameters, please use `python examples/train/flux/train_flux_lora.py -h` to see the details.

View File

@@ -10,7 +10,7 @@ class LightningModel(LightningModelForT2ILoRA):
self,
torch_dtype=torch.float16, pretrained_weights=[],
learning_rate=1e-4, use_gradient_checkpointing=True,
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out",
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="kaiming",
state_dict_converter=None, quantize = None
):
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing, state_dict_converter=state_dict_converter)
@@ -27,10 +27,10 @@ class LightningModel(LightningModelForT2ILoRA):
if quantize is not None:
self.pipe.dit.quantize()
self.pipe.scheduler.set_timesteps(1000)
self.pipe.scheduler.set_timesteps(1000, training=True)
self.freeze_parameters()
self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules)
self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights)
def parse_args():
@@ -97,6 +97,7 @@ if __name__ == '__main__':
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_target_modules=args.lora_target_modules,
init_lora_weights=args.init_lora_weights,
state_dict_converter=FluxLoRAConverter.align_to_opensource_format if args.align_to_opensource_format else None,
quantize={"float8_e4m3fn": torch.float8_e4m3fn}.get(args.quantize, None),
)

View File

@@ -9,7 +9,7 @@ class LightningModel(LightningModelForT2ILoRA):
self,
torch_dtype=torch.float16, pretrained_weights=[],
learning_rate=1e-4, use_gradient_checkpointing=True,
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out"
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian",
):
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing)
# Load models
@@ -19,7 +19,7 @@ class LightningModel(LightningModelForT2ILoRA):
self.pipe.scheduler.set_timesteps(1000)
self.freeze_parameters()
self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules)
self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights)
def parse_args():
@@ -56,6 +56,7 @@ if __name__ == '__main__':
use_gradient_checkpointing=args.use_gradient_checkpointing,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
init_lora_weights=args.init_lora_weights,
lora_target_modules=args.lora_target_modules
)
launch_training_task(model, args)

View File

@@ -9,7 +9,7 @@ class LightningModel(LightningModelForT2ILoRA):
self,
torch_dtype=torch.float16, pretrained_weights=[],
learning_rate=1e-4, use_gradient_checkpointing=True,
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out"
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian",
):
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing)
# Load models
@@ -22,7 +22,7 @@ class LightningModel(LightningModelForT2ILoRA):
self.pipe.vae_encoder.to(torch_dtype)
self.freeze_parameters()
self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules)
self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights)
def parse_args():
@@ -72,6 +72,7 @@ if __name__ == '__main__':
use_gradient_checkpointing=args.use_gradient_checkpointing,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
init_lora_weights=args.init_lora_weights,
lora_target_modules=args.lora_target_modules
)
launch_training_task(model, args)

View File

@@ -9,7 +9,7 @@ class LightningModel(LightningModelForT2ILoRA):
self,
torch_dtype=torch.float16, pretrained_weights=[],
learning_rate=1e-4, use_gradient_checkpointing=True,
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out"
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian",
):
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing)
# Load models
@@ -19,7 +19,7 @@ class LightningModel(LightningModelForT2ILoRA):
self.pipe.scheduler.set_timesteps(1000)
self.freeze_parameters()
self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules)
self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights)
def parse_args():
@@ -51,6 +51,7 @@ if __name__ == '__main__':
use_gradient_checkpointing=args.use_gradient_checkpointing,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
init_lora_weights=args.init_lora_weights,
lora_target_modules=args.lora_target_modules
)
launch_training_task(model, args)

View File

@@ -9,7 +9,7 @@ class LightningModel(LightningModelForT2ILoRA):
self,
torch_dtype=torch.float16, pretrained_weights=[],
learning_rate=1e-4, use_gradient_checkpointing=True,
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out"
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian",
):
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing)
# Load models
@@ -19,7 +19,7 @@ class LightningModel(LightningModelForT2ILoRA):
self.pipe.scheduler.set_timesteps(1000)
self.freeze_parameters()
self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules)
self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights)
def parse_args():
@@ -51,6 +51,7 @@ if __name__ == '__main__':
use_gradient_checkpointing=args.use_gradient_checkpointing,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
init_lora_weights=args.init_lora_weights,
lora_target_modules=args.lora_target_modules
)
launch_training_task(model, args)

View File

@@ -9,7 +9,7 @@ class LightningModel(LightningModelForT2ILoRA):
self,
torch_dtype=torch.float16, pretrained_weights=[],
learning_rate=1e-4, use_gradient_checkpointing=True,
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out"
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out", init_lora_weights="gaussian",
):
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing)
# Load models
@@ -19,7 +19,7 @@ class LightningModel(LightningModelForT2ILoRA):
self.pipe.scheduler.set_timesteps(1000)
self.freeze_parameters()
self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules)
self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules, init_lora_weights=init_lora_weights)
def parse_args():
@@ -51,6 +51,7 @@ if __name__ == '__main__':
use_gradient_checkpointing=args.use_gradient_checkpointing,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
init_lora_weights=args.init_lora_weights,
lora_target_modules=args.lora_target_modules
)
launch_training_task(model, args)