lora merger

This commit is contained in:
Artiprocher
2025-04-21 15:48:25 +08:00
parent 04260801a2
commit 44da204dbd
7 changed files with 516 additions and 30 deletions

85
scripts/data_process.py Normal file
View File

@@ -0,0 +1,85 @@
import torch, os, dashscope
import pandas as pd
from tqdm import tqdm
from diffsynth import load_state_dict, hash_state_dict_keys
def search_for_model_file(path, allow_file_extensions=(".safetensors",)):
for file_name in os.listdir(path):
for file_extension in allow_file_extensions:
if file_name.endswith(file_extension):
return os.path.join(path, file_name)
def search_for_cover_images(path, allow_file_extensions=(".png", ".jpg", ".jpeg")):
image_files = []
for file_name in os.listdir(path):
for file_extension in allow_file_extensions:
if file_name.endswith(file_extension):
image_files.append(os.path.join(path, file_name))
break
return image_files
def search_for_lora_data(path):
model_file = search_for_model_file(path)
if "_cover_images_" not in os.listdir(path):
return None
image_files = search_for_cover_images(os.path.join(path, "_cover_images_"))
if model_file is None or len(image_files) == 0:
return None
state_dict = load_state_dict(model_file)
if hash_state_dict_keys(state_dict, with_shape=False) != "52544ae3076666228978b738fbb8b086":
return None
return model_file, image_files
def image_to_text(images=[], prompt="", system_prompt=None):
dashscope.api_key = "xxxxx" # TODO
messages = []
if system_prompt is not None:
messages.append({"role": "system", "content": system_prompt})
if not isinstance(images, list):
images = [images]
messages.append({"role": "user", "content": [{"text": prompt}] + [{"image": image} for image in images]})
response = dashscope.MultiModalConversation.call(model="qwen-vl-max-latest", messages=messages)
response = response["output"]["choices"][0]["message"]["content"][0]["text"]
return response
qwen_i2t_prompt = '''
You are a professional image captioner.
Generate a caption according to the image so that another image generation model can generate the image via the caption. Just return the string description, do not return anything else.
'''.strip()
def data_to_csv(model_file_list, image_file_list, text_list, save_path):
data_df = pd.DataFrame()
data_df["model_file"] = model_file_list
data_df["image_file"] = image_file_list
data_df["text"] = text_list
data_df.to_csv(save_path, index=False, encoding="utf-8-sig")
base_path = "/data/zhiwen/LoRA-Fusion/models/FLUXLoRA"
model_file_list = []
image_file_list = []
text_list = []
for lora_name in tqdm(os.listdir(base_path)):
lora_folder_path = os.path.join(base_path, lora_name)
if os.path.isdir(lora_folder_path):
data = search_for_lora_data(lora_folder_path)
if data is not None:
model_file, image_files = data
for image_file in image_files:
try:
text = image_to_text(image_file, prompt=qwen_i2t_prompt)
except:
continue
model_file_list.append(model_file)
image_file_list.append(image_file)
text_list.append(text)
data_to_csv(model_file_list, image_file_list, text_list, "data/loras.csv")

166
scripts/test.py Normal file
View File

@@ -0,0 +1,166 @@
import torch, shutil, os
from diffsynth import ModelManager, FluxImagePipeline, download_models, load_state_dict
from diffsynth.models.lora import LoRAFromCivitai, FluxLoRAConverter
import pandas as pd
import torch
import pandas as pd
from PIL import Image
import lightning as pl
from diffsynth import ModelManager, FluxImagePipeline, download_models, load_state_dict
from diffsynth.models.lora import LoRAFromCivitai, FluxLoRAConverter
from diffsynth.data.video import crop_and_resize
from diffsynth.pipelines.flux_image import lets_dance_flux
from torchvision.transforms import v2
baseline = "trained"
class LoraMerger(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.weight_base = torch.nn.Parameter(torch.randn((dim,)))
self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))
self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))
self.weight_out = torch.nn.Parameter(torch.ones((dim,)))
self.bias = torch.nn.Parameter(torch.randn((dim,)))
self.activation = torch.nn.Sigmoid()
self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)
self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)
def forward(self, base_output, lora_outputs):
global baseline
if baseline == "nolora":
output = base_output
elif baseline == "lora1":
output = base_output + lora_outputs[0]
elif baseline == "lora2":
output = base_output + lora_outputs[1]
elif baseline == "alllora":
output = base_output + lora_outputs.sum(dim=0)
else:
norm_base_output = self.norm_base(base_output)
norm_lora_outputs = self.norm_lora(lora_outputs)
gate = self.activation(
norm_base_output * self.weight_base \
+ norm_lora_outputs * self.weight_lora \
+ norm_base_output * norm_lora_outputs * self.weight_cross + self.bias
)
output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
return output
class LoraPatcher(torch.nn.Module):
def __init__(self, lora_patterns=None):
super().__init__()
if lora_patterns is None:
lora_patterns = self.default_lora_patterns()
model_dict = {}
for lora_pattern in lora_patterns:
name, dim = lora_pattern["name"], lora_pattern["dim"]
model_dict[name.replace(".", "___")] = LoraMerger(dim)
self.model_dict = torch.nn.ModuleDict(model_dict)
def default_lora_patterns(self):
lora_patterns = []
lora_dict = {
"attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432,
"attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432,
}
for i in range(19):
for suffix in lora_dict:
lora_patterns.append({
"name": f"blocks.{i}.{suffix}",
"dim": lora_dict[suffix]
})
lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216}
for i in range(38):
for suffix in lora_dict:
lora_patterns.append({
"name": f"single_blocks.{i}.{suffix}",
"dim": lora_dict[suffix]
})
return lora_patterns
def forward(self, base_output, lora_outputs, name):
return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)
class LoraDataset(torch.utils.data.Dataset):
def __init__(self, metadata_path, steps_per_epoch=1000):
data_df = pd.read_csv(metadata_path)
self.model_file = data_df["model_file"].tolist()
self.image_file = data_df["image_file"].tolist()
self.text = data_df["text"].tolist()
self.max_resolution = 1920 * 1080
self.steps_per_epoch = steps_per_epoch
def read_image(self, image_file):
image = Image.open(image_file)
width, height = image.size
if width * height > self.max_resolution:
scale = (width * height / self.max_resolution) ** 0.5
image = image.resize((int(width / scale), int(height / scale)))
width, height = image.size
if width % 16 != 0 or height % 16 != 0:
image = crop_and_resize(image, height // 16 * 16, width // 16 * 16)
image = v2.functional.to_image(image)
image = v2.functional.to_dtype(image, dtype=torch.float32, scale=True)
image = v2.functional.normalize(image, [0.5], [0.5])
return image
def __getitem__(self, index):
data_id = torch.randint(0, len(self.model_file), (1,))[0]
data_id = (data_id + index) % len(self.model_file) # For fixed seed.
data_id_extra = torch.randint(0, len(self.model_file), (1,))[0]
return {
"model_file": self.model_file[data_id],
"model_file_extra": self.model_file[data_id_extra],
"image": self.read_image(self.image_file[data_id]),
"text": self.text[data_id]
}
def __len__(self):
return self.steps_per_epoch
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda")
model_manager.load_models([
"models/FLUX/FLUX.1-dev/text_encoder/model.safetensors",
"models/FLUX/FLUX.1-dev/text_encoder_2",
"models/FLUX/FLUX.1-dev/ae.safetensors",
"models/FLUX/FLUX.1-dev/flux1-dev.safetensors"
])
pipe = FluxImagePipeline.from_model_manager(model_manager)
pipe.enable_auto_lora()
lora_alpahs = [1, 1]
lora_patcher = LoraPatcher().to(dtype=torch.bfloat16, device="cuda")
lora_patcher.load_state_dict(load_state_dict("models/lightning_logs/version_13/checkpoints/epoch=2-step=1500.ckpt"))
dataset = LoraDataset("data/loras_picked.csv")
for seed in range(100):
data = dataset[0]
lora_state_dicts = [
FluxLoRAConverter.align_to_diffsynth_format(load_state_dict(data["model_file"], torch_dtype=torch.bfloat16, device="cuda")),
FluxLoRAConverter.align_to_diffsynth_format(load_state_dict(data["model_file_extra"], torch_dtype=torch.bfloat16, device="cuda")),
]
lora_alpahs = [1, 1]
for pattern in ["nolora", "lora1", "lora2", "alllora", "loramerger"]:
baseline = pattern
image = pipe(
prompt=data["text"],
lora_state_dicts=lora_state_dicts,
lora_alpahs=lora_alpahs,
lora_patcher=lora_patcher,
seed=seed,
)
image.save(f"data/lora_outputs/image_{seed}_{pattern}.jpg")

207
scripts/train.py Normal file
View File

@@ -0,0 +1,207 @@
import torch
import pandas as pd
from PIL import Image
import lightning as pl
from diffsynth import ModelManager, FluxImagePipeline, download_models, load_state_dict
from diffsynth.models.lora import LoRAFromCivitai, FluxLoRAConverter
from diffsynth.data.video import crop_and_resize
from diffsynth.pipelines.flux_image import lets_dance_flux
from torchvision.transforms import v2
class LoraMerger(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.weight_base = torch.nn.Parameter(torch.randn((dim,)))
self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))
self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))
self.weight_out = torch.nn.Parameter(torch.ones((dim,)))
self.bias = torch.nn.Parameter(torch.randn((dim,)))
self.activation = torch.nn.Sigmoid()
self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)
self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)
def forward(self, base_output, lora_outputs):
norm_base_output = self.norm_base(base_output)
norm_lora_outputs = self.norm_lora(lora_outputs)
gate = self.activation(
norm_base_output * self.weight_base \
+ norm_lora_outputs * self.weight_lora \
+ norm_base_output * norm_lora_outputs * self.weight_cross + self.bias
)
output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
return output
class LoraPatcher(torch.nn.Module):
def __init__(self, lora_patterns=None):
super().__init__()
if lora_patterns is None:
lora_patterns = self.default_lora_patterns()
model_dict = {}
for lora_pattern in lora_patterns:
name, dim = lora_pattern["name"], lora_pattern["dim"]
model_dict[name.replace(".", "___")] = LoraMerger(dim)
self.model_dict = torch.nn.ModuleDict(model_dict)
def default_lora_patterns(self):
lora_patterns = []
lora_dict = {
"attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432,
"attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432,
}
for i in range(19):
for suffix in lora_dict:
lora_patterns.append({
"name": f"blocks.{i}.{suffix}",
"dim": lora_dict[suffix]
})
lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216}
for i in range(38):
for suffix in lora_dict:
lora_patterns.append({
"name": f"single_blocks.{i}.{suffix}",
"dim": lora_dict[suffix]
})
return lora_patterns
def forward(self, base_output, lora_outputs, name):
return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)
class LoraDataset(torch.utils.data.Dataset):
def __init__(self, metadata_path, steps_per_epoch=1000):
data_df = pd.read_csv(metadata_path)
self.model_file = data_df["model_file"].tolist()
self.image_file = data_df["image_file"].tolist()
self.text = data_df["text"].tolist()
self.max_resolution = 1920 * 1080
self.steps_per_epoch = steps_per_epoch
def read_image(self, image_file):
image = Image.open(image_file)
width, height = image.size
if width * height > self.max_resolution:
scale = (width * height / self.max_resolution) ** 0.5
image = image.resize((int(width / scale), int(height / scale)))
width, height = image.size
if width % 16 != 0 or height % 16 != 0:
image = crop_and_resize(image, height // 16 * 16, width // 16 * 16)
image = v2.functional.to_image(image)
image = v2.functional.to_dtype(image, dtype=torch.float32, scale=True)
image = v2.functional.normalize(image, [0.5], [0.5])
return image
def __getitem__(self, index):
data_id = torch.randint(0, len(self.model_file), (1,))[0]
data_id = (data_id + index) % len(self.model_file) # For fixed seed.
data_id_extra = torch.randint(0, len(self.model_file), (1,))[0]
return {
"model_file": self.model_file[data_id],
"model_file_extra": self.model_file[data_id_extra],
"image": self.read_image(self.image_file[data_id]),
"text": self.text[data_id]
}
def __len__(self):
return self.steps_per_epoch
class LightningModel(pl.LightningModule):
def __init__(
self,
learning_rate=1e-4,
use_gradient_checkpointing=True,
state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format,
):
super().__init__()
model_manager = ModelManager(torch_dtype=torch.bfloat16, device=self.device, model_id_list=["FLUX.1-dev"])
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
self.lora_patcher = LoraPatcher()
self.pipe.enable_auto_lora()
self.pipe.scheduler.set_timesteps(1000, training=True)
self.freeze_parameters()
# Set parameters
self.learning_rate = learning_rate
self.use_gradient_checkpointing = use_gradient_checkpointing
self.state_dict_converter = state_dict_converter
def freeze_parameters(self):
# Freeze parameters
self.pipe.requires_grad_(False)
self.pipe.eval()
self.pipe.denoising_model().train()
def training_step(self, batch, batch_idx):
# Data
text, image = batch["text"], batch["image"]
lora_state_dicts = [
self.state_dict_converter(load_state_dict(batch["model_file"][0], torch_dtype=torch.bfloat16, device=self.device)),
self.state_dict_converter(load_state_dict(batch["model_file_extra"][0], torch_dtype=torch.bfloat16, device=self.device)),
]
lora_alpahs = [1, 1]
# Prepare input parameters
self.pipe.device = self.device
prompt_emb = self.pipe.encode_prompt(text, positive=True)
if "latents" in batch:
latents = batch["latents"].to(dtype=self.pipe.torch_dtype, device=self.device)
else:
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
# Compute loss
noise_pred = lets_dance_flux(
self.pipe.dit,
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
lora_state_dicts=lora_state_dicts, lora_alpahs=lora_alpahs, lora_patcher=self.lora_patcher,
use_gradient_checkpointing=self.use_gradient_checkpointing
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
# Record log
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
trainable_modules = filter(lambda p: p.requires_grad, self.lora_patcher.parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
return optimizer
def on_save_checkpoint(self, checkpoint):
checkpoint.clear()
checkpoint.update(self.lora_patcher.state_dict())
if __name__ == '__main__':
model = LightningModel(learning_rate=1e-4)
dataset = LoraDataset("data/loras.csv", steps_per_epoch=500)
train_loader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1, num_workers=1)
trainer = pl.Trainer(
max_epochs=100000,
accelerator="gpu",
devices="auto",
precision="bf16",
strategy="auto",
default_root_dir="./models",
accumulate_grad_batches=1,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)],
)
trainer.fit(model=model, train_dataloaders=train_loader)