Merge pull request #194 from modelscope/flux-lora

support flux training
This commit is contained in:
Zhongjie Duan
2024-09-06 19:15:42 +08:00
committed by GitHub
5 changed files with 218 additions and 46 deletions

View File

@@ -218,9 +218,7 @@ class FluxSingleTransformerBlock(torch.nn.Module):
self.dim = dim self.dim = dim
self.norm = AdaLayerNormSingle(dim) self.norm = AdaLayerNormSingle(dim)
# self.proj_in = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), torch.nn.GELU(approximate="tanh")) self.to_qkv_mlp = torch.nn.Linear(dim, dim * (3 + 4))
# self.attn = FluxSingleAttention(dim, dim, num_attention_heads, dim // num_attention_heads)
self.linear = torch.nn.Linear(dim, dim * (3 + 4))
self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6) self.norm_q_a = RMSNorm(self.head_dim, eps=1e-6)
self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6) self.norm_k_a = RMSNorm(self.head_dim, eps=1e-6)
@@ -253,7 +251,7 @@ class FluxSingleTransformerBlock(torch.nn.Module):
def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb): def forward(self, hidden_states_a, hidden_states_b, temb, image_rotary_emb):
residual = hidden_states_a residual = hidden_states_a
norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb) norm_hidden_states, gate = self.norm(hidden_states_a, emb=temb)
hidden_states_a = self.linear(norm_hidden_states) hidden_states_a = self.to_qkv_mlp(norm_hidden_states)
attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:] attn_output, mlp_hidden_states = hidden_states_a[:, :, :self.dim * 3], hidden_states_a[:, :, self.dim * 3:]
attn_output = self.process_attention(attn_output, image_rotary_emb) attn_output = self.process_attention(attn_output, image_rotary_emb)
@@ -295,8 +293,8 @@ class FluxDiT(torch.nn.Module):
self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)]) self.blocks = torch.nn.ModuleList([FluxJointTransformerBlock(3072, 24) for _ in range(19)])
self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)]) self.single_blocks = torch.nn.ModuleList([FluxSingleTransformerBlock(3072, 24) for _ in range(38)])
self.norm_out = AdaLayerNormContinuous(3072) self.final_norm_out = AdaLayerNormContinuous(3072)
self.proj_out = torch.nn.Linear(3072, 64) self.final_proj_out = torch.nn.Linear(3072, 64)
def patchify(self, hidden_states): def patchify(self, hidden_states):
@@ -350,6 +348,7 @@ class FluxDiT(torch.nn.Module):
hidden_states, hidden_states,
timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None, timestep, prompt_emb, pooled_prompt_emb, guidance, text_ids, image_ids=None,
tiled=False, tile_size=128, tile_stride=64, tiled=False, tile_size=128, tile_stride=64,
use_gradient_checkpointing=False,
**kwargs **kwargs
): ):
if tiled: if tiled:
@@ -373,16 +372,35 @@ class FluxDiT(torch.nn.Module):
hidden_states = self.patchify(hidden_states) hidden_states = self.patchify(hidden_states)
hidden_states = self.x_embedder(hidden_states) hidden_states = self.x_embedder(hidden_states)
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for block in self.blocks: for block in self.blocks:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
hidden_states = torch.cat([prompt_emb, hidden_states], dim=1) hidden_states = torch.cat([prompt_emb, hidden_states], dim=1)
for block in self.single_blocks: for block in self.single_blocks:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb) if self.training and use_gradient_checkpointing:
hidden_states, prompt_emb = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, prompt_emb, conditioning, image_rotary_emb,
use_reentrant=False,
)
else:
hidden_states, prompt_emb = block(hidden_states, prompt_emb, conditioning, image_rotary_emb)
hidden_states = hidden_states[:, prompt_emb.shape[1]:] hidden_states = hidden_states[:, prompt_emb.shape[1]:]
hidden_states = self.norm_out(hidden_states, conditioning) hidden_states = self.final_norm_out(hidden_states, conditioning)
hidden_states = self.proj_out(hidden_states) hidden_states = self.final_proj_out(hidden_states)
hidden_states = self.unpatchify(hidden_states, height, width) hidden_states = self.unpatchify(hidden_states, height, width)
return hidden_states return hidden_states
@@ -399,7 +417,7 @@ class FluxDiTStateDictConverter:
pass pass
def from_diffusers(self, state_dict): def from_diffusers(self, state_dict):
rename_dict = { global_rename_dict = {
"context_embedder": "context_embedder", "context_embedder": "context_embedder",
"x_embedder": "x_embedder", "x_embedder": "x_embedder",
"time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0", "time_text_embed.timestep_embedder.linear_1": "time_embedder.timestep_embedder.0",
@@ -408,9 +426,11 @@ class FluxDiTStateDictConverter:
"time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2", "time_text_embed.guidance_embedder.linear_2": "guidance_embedder.timestep_embedder.2",
"time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0", "time_text_embed.text_embedder.linear_1": "pooled_text_embedder.0",
"time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2", "time_text_embed.text_embedder.linear_2": "pooled_text_embedder.2",
"norm_out.linear": "norm_out.linear", "norm_out.linear": "final_norm_out.linear",
"proj_out": "final_proj_out",
}
rename_dict = {
"proj_out": "proj_out", "proj_out": "proj_out",
"norm1.linear": "norm1_a.linear", "norm1.linear": "norm1_a.linear",
"norm1_context.linear": "norm1_b.linear", "norm1_context.linear": "norm1_b.linear",
"attn.to_q": "attn.a_to_q", "attn.to_q": "attn.a_to_q",
@@ -442,13 +462,11 @@ class FluxDiTStateDictConverter:
} }
state_dict_ = {} state_dict_ = {}
for name, param in state_dict.items(): for name, param in state_dict.items():
if name in rename_dict: if name.endswith(".weight") or name.endswith(".bias"):
state_dict_[rename_dict[name]] = param
elif name.endswith(".weight") or name.endswith(".bias"):
suffix = ".weight" if name.endswith(".weight") else ".bias" suffix = ".weight" if name.endswith(".weight") else ".bias"
prefix = name[:-len(suffix)] prefix = name[:-len(suffix)]
if prefix in rename_dict: if prefix in global_rename_dict:
state_dict_[rename_dict[prefix] + suffix] = param state_dict_[global_rename_dict[prefix] + suffix] = param
elif prefix.startswith("transformer_blocks."): elif prefix.startswith("transformer_blocks."):
names = prefix.split(".") names = prefix.split(".")
names[0] = "blocks" names[0] = "blocks"
@@ -469,7 +487,7 @@ class FluxDiTStateDictConverter:
pass pass
for name in list(state_dict_.keys()): for name in list(state_dict_.keys()):
if ".proj_in_besides_attn." in name: if ".proj_in_besides_attn." in name:
name_ = name.replace(".proj_in_besides_attn.", ".linear.") name_ = name.replace(".proj_in_besides_attn.", ".to_qkv_mlp.")
param = torch.concat([ param = torch.concat([
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")], state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_q.")],
state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")], state_dict_[name.replace(".proj_in_besides_attn.", f".a_to_k.")],
@@ -508,16 +526,16 @@ class FluxDiTStateDictConverter:
"vector_in.in_layer.weight": "pooled_text_embedder.0.weight", "vector_in.in_layer.weight": "pooled_text_embedder.0.weight",
"vector_in.out_layer.bias": "pooled_text_embedder.2.bias", "vector_in.out_layer.bias": "pooled_text_embedder.2.bias",
"vector_in.out_layer.weight": "pooled_text_embedder.2.weight", "vector_in.out_layer.weight": "pooled_text_embedder.2.weight",
"final_layer.linear.bias": "proj_out.bias", "final_layer.linear.bias": "final_proj_out.bias",
"final_layer.linear.weight": "proj_out.weight", "final_layer.linear.weight": "final_proj_out.weight",
"guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias", "guidance_in.in_layer.bias": "guidance_embedder.timestep_embedder.0.bias",
"guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight", "guidance_in.in_layer.weight": "guidance_embedder.timestep_embedder.0.weight",
"guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias", "guidance_in.out_layer.bias": "guidance_embedder.timestep_embedder.2.bias",
"guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight", "guidance_in.out_layer.weight": "guidance_embedder.timestep_embedder.2.weight",
"img_in.bias": "x_embedder.bias", "img_in.bias": "x_embedder.bias",
"img_in.weight": "x_embedder.weight", "img_in.weight": "x_embedder.weight",
"final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight", "final_layer.adaLN_modulation.1.weight": "final_norm_out.linear.weight",
"final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias", "final_layer.adaLN_modulation.1.bias": "final_norm_out.linear.bias",
} }
suffix_rename_dict = { suffix_rename_dict = {
"img_attn.norm.key_norm.scale": "attn.norm_k_a.weight", "img_attn.norm.key_norm.scale": "attn.norm_k_a.weight",
@@ -545,8 +563,8 @@ class FluxDiTStateDictConverter:
"txt_mod.lin.bias": "norm1_b.linear.bias", "txt_mod.lin.bias": "norm1_b.linear.bias",
"txt_mod.lin.weight": "norm1_b.linear.weight", "txt_mod.lin.weight": "norm1_b.linear.weight",
"linear1.bias": "linear.bias", "linear1.bias": "to_qkv_mlp.bias",
"linear1.weight": "linear.weight", "linear1.weight": "to_qkv_mlp.weight",
"linear2.bias": "proj_out.bias", "linear2.bias": "proj_out.bias",
"linear2.weight": "proj_out.weight", "linear2.weight": "proj_out.weight",
"modulation.lin.bias": "norm.linear.bias", "modulation.lin.bias": "norm.linear.bias",

View File

@@ -185,10 +185,19 @@ class FluxLoRAFromCivitai(LoRAFromCivitai):
class GeneralLoRAFromPeft: class GeneralLoRAFromPeft:
def __init__(self): def __init__(self):
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT] self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT]
def convert_state_dict(self, state_dict, alpha=1.0, device="cuda", torch_dtype=torch.float16): def fetch_device_dtype_from_state_dict(self, state_dict):
device, torch_dtype = None, None
for name, param in state_dict.items():
device, torch_dtype = param.device, param.dtype
break
return device, torch_dtype
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict)
state_dict_ = {} state_dict_ = {}
for key in state_dict: for key in state_dict:
if ".lora_B." not in key: if ".lora_B." not in key:
@@ -202,25 +211,26 @@ class GeneralLoRAFromPeft:
else: else:
lora_weight = alpha * torch.mm(weight_up, weight_down) lora_weight = alpha * torch.mm(weight_up, weight_down)
keys = key.split(".") keys = key.split(".")
keys.pop(keys.index("lora_B") + 1) if len(keys) > keys.index("lora_B") + 2:
keys.pop(keys.index("lora_B") + 1)
keys.pop(keys.index("lora_B")) keys.pop(keys.index("lora_B"))
target_name = ".".join(keys) target_name = ".".join(keys)
if target_name not in target_state_dict:
return {}
state_dict_[target_name] = lora_weight.cpu() state_dict_[target_name] = lora_weight.cpu()
return state_dict_ return state_dict_
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""): def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
state_dict_model = model.state_dict() state_dict_model = model.state_dict()
for name, param in state_dict_model.items(): state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model)
torch_dtype = param.dtype
device = param.device
break
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, device=device, torch_dtype=torch_dtype)
if len(state_dict_lora) > 0: if len(state_dict_lora) > 0:
print(f" {len(state_dict_lora)} tensors are updated.") print(f" {len(state_dict_lora)} tensors are updated.")
for name in state_dict_lora: for name in state_dict_lora:
state_dict_model[name] += state_dict_lora[name].to( state_dict_model[name] += state_dict_lora[name].to(
dtype=state_dict_model[name].dtype, device=state_dict_model[name].device) dtype=state_dict_model[name].dtype,
device=state_dict_model[name].device
)
model.load_state_dict(state_dict_model) model.load_state_dict(state_dict_model)
@@ -230,13 +240,8 @@ class GeneralLoRAFromPeft:
continue continue
state_dict_model = model.state_dict() state_dict_model = model.state_dict()
try: try:
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0) state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model)
if len(state_dict_lora_) == 0: if len(state_dict_lora_) > 0:
continue
for name in state_dict_lora_:
if name not in state_dict_model:
break
else:
return "", "" return "", ""
except: except:
pass pass

View File

@@ -152,7 +152,7 @@ def add_general_parsers(parser):
"--precision", "--precision",
type=str, type=str,
default="16-mixed", default="16-mixed",
choices=["32", "16", "16-mixed"], choices=["32", "16", "16-mixed", "bf16"],
help="Training precision", help="Training precision",
) )
parser.add_argument( parser.add_argument(

View File

@@ -8,10 +8,10 @@ We have implemented a training framework for text-to-image Diffusion models, ena
Image Examples of fine-tuned LoRA. The prompt is "一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉" (for Chinese models) or "a dog is jumping, flowers around the dog, the background is mountains and clouds" (for English models). Image Examples of fine-tuned LoRA. The prompt is "一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉" (for Chinese models) or "a dog is jumping, flowers around the dog, the background is mountains and clouds" (for English models).
||Kolors|Stable Diffusion 3|Hunyuan-DiT| ||FLUX.1-dev|Kolors|Stable Diffusion 3|Hunyuan-DiT|
|-|-|-|-| |-|-|-|-|-|
|Without LoRA|![image_without_lora](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/9d79ed7a-e8cf-4d98-800a-f182809db318)|![image_without_lora](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/ddb834a5-6366-412b-93dc-6d957230d66e)|![image_without_lora](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/1aa21de5-a992-4b66-b14f-caa44e08876e)| |Without LoRA|![image_without_lora](https://github.com/user-attachments/assets/df62cef6-d54f-4e3d-a602-5dd290079d49)|![image_without_lora](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/9d79ed7a-e8cf-4d98-800a-f182809db318)|![image_without_lora](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/ddb834a5-6366-412b-93dc-6d957230d66e)|![image_without_lora](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/1aa21de5-a992-4b66-b14f-caa44e08876e)|
|With LoRA|![image_with_lora](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/02f62323-6ee5-4788-97a1-549732dbe4f0)|![image_with_lora](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/8e7b2888-d874-4da4-a75b-11b6b214b9bf)|![image_with_lora](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/83a0a41a-691f-4610-8e7b-d8e17c50a282)| |With LoRA|![image_with_lora](https://github.com/user-attachments/assets/4fd39890-0291-4d19-8a88-d70d0ae18533)|![image_with_lora](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/02f62323-6ee5-4788-97a1-549732dbe4f0)|![image_with_lora](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/8e7b2888-d874-4da4-a75b-11b6b214b9bf)|![image_with_lora](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/83a0a41a-691f-4610-8e7b-d8e17c50a282)|
## Install additional packages ## Install additional packages
@@ -99,6 +99,78 @@ General options:
Access key on ModelScope (https://www.modelscope.cn/). Required if you want to upload the model to ModelScope. Access key on ModelScope (https://www.modelscope.cn/). Required if you want to upload the model to ModelScope.
``` ```
### FLUX
The following files will be used for constructing FLUX. You can download them from [huggingface](https://huggingface.co/black-forest-labs/FLUX.1-dev) or [modelscope](https://www.modelscope.cn/models/ai-modelscope/flux.1-dev). You can use the following code to download these files:
```python
from diffsynth import download_models
download_models(["FLUX.1-dev"])
```
```
models/FLUX/
└── FLUX.1-dev
├── ae.safetensors
├── flux1-dev.safetensors
├── text_encoder
│ └── model.safetensors
└── text_encoder_2
├── config.json
├── model-00001-of-00002.safetensors
├── model-00002-of-00002.safetensors
└── model.safetensors.index.json
```
Launch the training task using the following command:
```
CUDA_VISIBLE_DEVICES="0" python examples/train/flux/train_flux_lora.py \
--pretrained_text_encoder_path models/FLUX/FLUX.1-dev/text_encoder/model.safetensors \
--pretrained_text_encoder_2_path models/FLUX/FLUX.1-dev/text_encoder_2 \
--pretrained_dit_path models/FLUX/FLUX.1-dev/flux1-dev.safetensors \
--pretrained_vae_path models/FLUX/FLUX.1-dev/ae.safetensors \
--dataset_path data/dog \
--output_path ./models \
--max_epochs 1 \
--steps_per_epoch 500 \
--height 1024 \
--width 1024 \
--center_crop \
--precision "bf16" \
--learning_rate 1e-4 \
--lora_rank 4 \
--lora_alpha 4 \
--use_gradient_checkpointing
```
For more information about the parameters, please use `python examples/train/flux/train_flux_lora.py -h` to see the details.
After training, use `model_manager.load_lora` to load the LoRA for inference.
```python
from diffsynth import ModelManager, FluxImagePipeline
import torch
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
file_path_list=[
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
])
model_manager.load_lora("models/lightning_logs/version_0/checkpoints/epoch=0-step=500.ckpt", lora_alpha=1.0)
pipe = SDXLImagePipeline.from_model_manager(model_manager)
torch.manual_seed(0)
image = pipe(
prompt=prompt,
num_inference_steps=30, embedded_guidance=3.5
)
image.save("image_with_lora.jpg")
```
### Kolors ### Kolors
The following files will be used for constructing Kolors. You can download Kolors from [huggingface](https://huggingface.co/Kwai-Kolors/Kolors) or [modelscope](https://modelscope.cn/models/Kwai-Kolors/Kolors). Due to precision overflow issues, we need to download an additional VAE model (from [huggingface](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) or [modelscope](https://modelscope.cn/models/AI-ModelScope/sdxl-vae-fp16-fix)). You can use the following code to download these files: The following files will be used for constructing Kolors. You can download Kolors from [huggingface](https://huggingface.co/Kwai-Kolors/Kolors) or [modelscope](https://modelscope.cn/models/Kwai-Kolors/Kolors). Due to precision overflow issues, we need to download an additional VAE model (from [huggingface](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) or [modelscope](https://modelscope.cn/models/AI-ModelScope/sdxl-vae-fp16-fix)). You can use the following code to download these files:

View File

@@ -0,0 +1,77 @@
from diffsynth import ModelManager, FluxImagePipeline
from diffsynth.trainers.text_to_image import LightningModelForT2ILoRA, add_general_parsers, launch_training_task
import torch, os, argparse
os.environ["TOKENIZERS_PARALLELISM"] = "True"
class LightningModel(LightningModelForT2ILoRA):
def __init__(
self,
torch_dtype=torch.float16, pretrained_weights=[],
learning_rate=1e-4, use_gradient_checkpointing=True,
lora_rank=4, lora_alpha=4, lora_target_modules="to_q,to_k,to_v,to_out"
):
super().__init__(learning_rate=learning_rate, use_gradient_checkpointing=use_gradient_checkpointing)
# Load models
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
model_manager.load_models(pretrained_weights)
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
self.pipe.scheduler.set_timesteps(1000)
self.freeze_parameters()
self.add_lora_to_model(self.pipe.denoising_model(), lora_rank=lora_rank, lora_alpha=lora_alpha, lora_target_modules=lora_target_modules)
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_text_encoder_path",
type=str,
default=None,
required=True,
help="Path to pretrained text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder/model.safetensors`.",
)
parser.add_argument(
"--pretrained_text_encoder_2_path",
type=str,
default=None,
required=True,
help="Path to pretrained t5 text encoder model. For example, `models/FLUX/FLUX.1-dev/text_encoder_2`.",
)
parser.add_argument(
"--pretrained_dit_path",
type=str,
default=None,
required=True,
help="Path to pretrained dit model. For example, `models/FLUX/FLUX.1-dev/flux1-dev.safetensors`.",
)
parser.add_argument(
"--pretrained_vae_path",
type=str,
default=None,
required=True,
help="Path to pretrained vae model. For example, `models/FLUX/FLUX.1-dev/ae.safetensors`.",
)
parser.add_argument(
"--lora_target_modules",
type=str,
default="a_to_qkv,b_to_qkv,ff_a.0,ff_a.2,ff_b.0,ff_b.2,a_to_out,b_to_out,proj_out,norm.linear,norm1_a.linear,norm1_b.linear,to_qkv_mlp",
help="Layers with LoRA modules.",
)
parser = add_general_parsers(parser)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
model = LightningModel(
torch_dtype={"32": torch.float32, "bf16": torch.bfloat16}.get(args.precision, torch.float16),
pretrained_weights=[args.pretrained_text_encoder_path, args.pretrained_text_encoder_2_path, args.pretrained_dit_path, args.pretrained_vae_path],
learning_rate=args.learning_rate,
use_gradient_checkpointing=args.use_gradient_checkpointing,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_target_modules=args.lora_target_modules
)
launch_training_task(model, args)