mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
lora retrieval
This commit is contained in:
0
lora/__init__.py
Normal file
0
lora/__init__.py
Normal file
54
lora/dataset.py
Normal file
54
lora/dataset.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import torch, os
|
||||
import pandas as pd
|
||||
from PIL import Image
|
||||
from torchvision.transforms import v2
|
||||
from diffsynth.data.video import crop_and_resize
|
||||
|
||||
|
||||
class LoraDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, base_path, metadata_path, steps_per_epoch=1000, loras_per_item=1):
|
||||
self.base_path = base_path
|
||||
data_df = pd.read_csv(metadata_path)
|
||||
self.model_file = data_df["model_file"].tolist()
|
||||
self.image_file = data_df["image_file"].tolist()
|
||||
self.text = data_df["text"].tolist()
|
||||
self.max_resolution = 1920 * 1080
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
self.loras_per_item = loras_per_item
|
||||
|
||||
|
||||
def read_image(self, image_file):
|
||||
image = Image.open(image_file).convert("RGB")
|
||||
width, height = image.size
|
||||
if width * height > self.max_resolution:
|
||||
scale = (width * height / self.max_resolution) ** 0.5
|
||||
image = image.resize((int(width / scale), int(height / scale)))
|
||||
width, height = image.size
|
||||
if width % 16 != 0 or height % 16 != 0:
|
||||
image = crop_and_resize(image, height // 16 * 16, width // 16 * 16)
|
||||
image = v2.functional.to_image(image)
|
||||
image = v2.functional.to_dtype(image, dtype=torch.float32, scale=True)
|
||||
image = v2.functional.normalize(image, [0.5], [0.5])
|
||||
return image
|
||||
|
||||
|
||||
def get_data(self, data_id):
|
||||
data = {
|
||||
"model_file": os.path.join(self.base_path, self.model_file[data_id]),
|
||||
"image": self.read_image(os.path.join(self.base_path, self.image_file[data_id])),
|
||||
"text": self.text[data_id]
|
||||
}
|
||||
return data
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data = []
|
||||
while len(data) < self.loras_per_item:
|
||||
data_id = torch.randint(0, len(self.model_file), (1,))[0]
|
||||
data_id = (data_id + index) % len(self.model_file) # For fixed seed.
|
||||
data.append(self.get_data(data_id))
|
||||
return data
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.steps_per_epoch
|
||||
61
lora/merger.py
Normal file
61
lora/merger.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import torch
|
||||
|
||||
|
||||
class LoraMerger(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.weight_base = torch.nn.Parameter(torch.randn((dim,)))
|
||||
self.weight_lora = torch.nn.Parameter(torch.randn((dim,)))
|
||||
self.weight_cross = torch.nn.Parameter(torch.randn((dim,)))
|
||||
self.weight_out = torch.nn.Parameter(torch.ones((dim,)))
|
||||
self.bias = torch.nn.Parameter(torch.randn((dim,)))
|
||||
self.activation = torch.nn.Sigmoid()
|
||||
self.norm_base = torch.nn.LayerNorm(dim, eps=1e-5)
|
||||
self.norm_lora = torch.nn.LayerNorm(dim, eps=1e-5)
|
||||
|
||||
def forward(self, base_output, lora_outputs):
|
||||
norm_base_output = self.norm_base(base_output)
|
||||
norm_lora_outputs = self.norm_lora(lora_outputs)
|
||||
gate = self.activation(
|
||||
norm_base_output * self.weight_base \
|
||||
+ norm_lora_outputs * self.weight_lora \
|
||||
+ norm_base_output * norm_lora_outputs * self.weight_cross + self.bias
|
||||
)
|
||||
output = base_output + (self.weight_out * gate * lora_outputs).sum(dim=0)
|
||||
return output
|
||||
|
||||
|
||||
class LoraPatcher(torch.nn.Module):
|
||||
def __init__(self, lora_patterns=None):
|
||||
super().__init__()
|
||||
if lora_patterns is None:
|
||||
lora_patterns = self.default_lora_patterns()
|
||||
model_dict = {}
|
||||
for lora_pattern in lora_patterns:
|
||||
name, dim = lora_pattern["name"], lora_pattern["dim"]
|
||||
model_dict[name.replace(".", "___")] = LoraMerger(dim)
|
||||
self.model_dict = torch.nn.ModuleDict(model_dict)
|
||||
|
||||
def default_lora_patterns(self):
|
||||
lora_patterns = []
|
||||
lora_dict = {
|
||||
"attn.a_to_qkv": 9216, "attn.a_to_out": 3072, "ff_a.0": 12288, "ff_a.2": 3072, "norm1_a.linear": 18432,
|
||||
"attn.b_to_qkv": 9216, "attn.b_to_out": 3072, "ff_b.0": 12288, "ff_b.2": 3072, "norm1_b.linear": 18432,
|
||||
}
|
||||
for i in range(19):
|
||||
for suffix in lora_dict:
|
||||
lora_patterns.append({
|
||||
"name": f"blocks.{i}.{suffix}",
|
||||
"dim": lora_dict[suffix]
|
||||
})
|
||||
lora_dict = {"to_qkv_mlp": 21504, "proj_out": 3072, "norm.linear": 9216}
|
||||
for i in range(38):
|
||||
for suffix in lora_dict:
|
||||
lora_patterns.append({
|
||||
"name": f"single_blocks.{i}.{suffix}",
|
||||
"dim": lora_dict[suffix]
|
||||
})
|
||||
return lora_patterns
|
||||
|
||||
def forward(self, base_output, lora_outputs, name):
|
||||
return self.model_dict[name.replace(".", "___")](base_output, lora_outputs)
|
||||
149
lora/retriever.py
Normal file
149
lora/retriever.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import torch
|
||||
from diffsynth import SDTextEncoder
|
||||
from diffsynth.models.sd3_text_encoder import SD3TextEncoder1StateDictConverter
|
||||
from diffsynth.models.sd_text_encoder import CLIPEncoderLayer
|
||||
|
||||
|
||||
class LoRALayerBlock(torch.nn.Module):
|
||||
def __init__(self, L, dim_in):
|
||||
super().__init__()
|
||||
self.x = torch.nn.Parameter(torch.randn(1, L, dim_in))
|
||||
|
||||
def forward(self, lora_A, lora_B):
|
||||
out = self.x @ lora_A.T @ lora_B.T
|
||||
return out
|
||||
|
||||
|
||||
class LoRAEmbedder(torch.nn.Module):
|
||||
def __init__(self, lora_patterns=None, L=1, out_dim=2048):
|
||||
super().__init__()
|
||||
if lora_patterns is None:
|
||||
lora_patterns = self.default_lora_patterns()
|
||||
|
||||
model_dict = {}
|
||||
for lora_pattern in lora_patterns:
|
||||
name, dim = lora_pattern["name"], lora_pattern["dim"][0]
|
||||
model_dict[name.replace(".", "___")] = LoRALayerBlock(L, dim)
|
||||
self.model_dict = torch.nn.ModuleDict(model_dict)
|
||||
|
||||
proj_dict = {}
|
||||
for lora_pattern in lora_patterns:
|
||||
layer_type, dim = lora_pattern["type"], lora_pattern["dim"][1]
|
||||
if layer_type not in proj_dict:
|
||||
proj_dict[layer_type.replace(".", "___")] = torch.nn.Linear(dim, out_dim)
|
||||
self.proj_dict = torch.nn.ModuleDict(proj_dict)
|
||||
|
||||
self.lora_patterns = lora_patterns
|
||||
|
||||
|
||||
def default_lora_patterns(self):
|
||||
lora_patterns = []
|
||||
lora_dict = {
|
||||
"attn.a_to_qkv": (3072, 9216), "attn.a_to_out": (3072, 3072), "ff_a.0": (3072, 12288), "ff_a.2": (12288, 3072), "norm1_a.linear": (3072, 18432),
|
||||
"attn.b_to_qkv": (3072, 9216), "attn.b_to_out": (3072, 3072), "ff_b.0": (3072, 12288), "ff_b.2": (12288, 3072), "norm1_b.linear": (3072, 18432),
|
||||
}
|
||||
for i in range(19):
|
||||
for suffix in lora_dict:
|
||||
lora_patterns.append({
|
||||
"name": f"blocks.{i}.{suffix}",
|
||||
"dim": lora_dict[suffix],
|
||||
"type": suffix,
|
||||
})
|
||||
lora_dict = {"to_qkv_mlp": (3072, 21504), "proj_out": (15360, 3072), "norm.linear": (3072, 9216)}
|
||||
for i in range(38):
|
||||
for suffix in lora_dict:
|
||||
lora_patterns.append({
|
||||
"name": f"single_blocks.{i}.{suffix}",
|
||||
"dim": lora_dict[suffix],
|
||||
"type": suffix,
|
||||
})
|
||||
return lora_patterns
|
||||
|
||||
def forward(self, lora):
|
||||
lora_emb = []
|
||||
for lora_pattern in self.lora_patterns:
|
||||
name, layer_type = lora_pattern["name"], lora_pattern["type"]
|
||||
lora_A = lora[name + ".lora_A.default.weight"]
|
||||
lora_B = lora[name + ".lora_B.default.weight"]
|
||||
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
|
||||
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
|
||||
lora_emb.append(lora_out)
|
||||
lora_emb = torch.concat(lora_emb, dim=1)
|
||||
return lora_emb
|
||||
|
||||
|
||||
class TextEncoder(torch.nn.Module):
|
||||
def __init__(self, embed_dim=768, vocab_size=49408, max_position_embeddings=77, num_encoder_layers=12, encoder_intermediate_size=3072):
|
||||
super().__init__()
|
||||
|
||||
# token_embedding
|
||||
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
|
||||
|
||||
# position_embeds (This is a fixed tensor)
|
||||
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||
|
||||
# encoders
|
||||
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
||||
|
||||
# attn_mask
|
||||
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||
|
||||
# final_layer_norm
|
||||
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||
|
||||
def attention_mask(self, length):
|
||||
mask = torch.empty(length, length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1)
|
||||
return mask
|
||||
|
||||
def forward(self, input_ids, clip_skip=1):
|
||||
embeds = self.token_embedding(input_ids) + self.position_embeds
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
if encoder_id + clip_skip == len(self.encoders):
|
||||
break
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
pooled_embeds = embeds[torch.arange(embeds.shape[0]), input_ids.to(dtype=torch.int).argmax(dim=-1)]
|
||||
return pooled_embeds
|
||||
|
||||
@staticmethod
|
||||
def state_dict_converter():
|
||||
return SD3TextEncoder1StateDictConverter()
|
||||
|
||||
|
||||
class LoRAEncoder(torch.nn.Module):
|
||||
def __init__(self, embed_dim=768, max_position_embeddings=304, num_encoder_layers=2, encoder_intermediate_size=3072, L=1):
|
||||
super().__init__()
|
||||
max_position_embeddings *= L
|
||||
|
||||
# Embedder
|
||||
self.embedder = LoRAEmbedder(L=L, out_dim=embed_dim)
|
||||
|
||||
# position_embeds (This is a fixed tensor)
|
||||
self.position_embeds = torch.nn.Parameter(torch.zeros(1, max_position_embeddings, embed_dim))
|
||||
|
||||
# encoders
|
||||
self.encoders = torch.nn.ModuleList([CLIPEncoderLayer(embed_dim, encoder_intermediate_size) for _ in range(num_encoder_layers)])
|
||||
|
||||
# attn_mask
|
||||
self.attn_mask = self.attention_mask(max_position_embeddings)
|
||||
|
||||
# final_layer_norm
|
||||
self.final_layer_norm = torch.nn.LayerNorm(embed_dim)
|
||||
|
||||
def attention_mask(self, length):
|
||||
mask = torch.empty(length, length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1)
|
||||
return mask
|
||||
|
||||
def forward(self, lora):
|
||||
embeds = self.embedder(lora) + self.position_embeds
|
||||
attn_mask = self.attn_mask.to(device=embeds.device, dtype=embeds.dtype)
|
||||
for encoder_id, encoder in enumerate(self.encoders):
|
||||
embeds = encoder(embeds, attn_mask=attn_mask)
|
||||
embeds = self.final_layer_norm(embeds)
|
||||
embeds = embeds.mean(dim=1)
|
||||
return embeds
|
||||
46
lora/test_merger.py
Normal file
46
lora/test_merger.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from diffsynth import FluxImagePipeline, ModelManager, load_state_dict
|
||||
from diffsynth.models.lora import FluxLoRAConverter
|
||||
from diffsynth.pipelines.flux_image import lets_dance_flux
|
||||
from lora.dataset import LoraDataset
|
||||
from lora.merger import LoraPatcher
|
||||
from lora.utils import load_lora
|
||||
import torch, os
|
||||
from accelerate import Accelerator, DistributedDataParallelKwargs
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
pipe.enable_auto_lora()
|
||||
|
||||
lora_patcher = LoraPatcher().to(dtype=torch.bfloat16, device="cuda")
|
||||
lora_patcher.load_state_dict(load_state_dict("models/lora_merger/epoch-3.safetensors"))
|
||||
|
||||
dataset = LoraDataset("data/lora/models", "data/lora/lora_dataset_1000.csv", steps_per_epoch=800, loras_per_item=4)
|
||||
|
||||
for seed in range(100):
|
||||
batch = dataset[0]
|
||||
num_lora = torch.randint(1, len(batch), (1,))[0]
|
||||
lora_state_dicts = [
|
||||
FluxLoRAConverter.align_to_diffsynth_format(load_lora(batch[i]["model_file"], device="cuda")) for i in range(num_lora)
|
||||
]
|
||||
image = pipe(
|
||||
prompt=batch[0]["text"],
|
||||
seed=seed,
|
||||
)
|
||||
image.save(f"data/lora/lora_outputs/image_{seed}_nolora.jpg")
|
||||
for i in range(num_lora):
|
||||
image = pipe(
|
||||
prompt=batch[0]["text"],
|
||||
lora_state_dicts=[lora_state_dicts[i]],
|
||||
lora_patcher=lora_patcher,
|
||||
seed=seed,
|
||||
)
|
||||
image.save(f"data/lora/lora_outputs/image_{seed}_{i}.jpg")
|
||||
image = pipe(
|
||||
prompt=batch[0]["text"],
|
||||
lora_state_dicts=lora_state_dicts,
|
||||
lora_patcher=lora_patcher,
|
||||
seed=seed,
|
||||
)
|
||||
image.save(f"data/lora/lora_outputs/image_{seed}_merger.jpg")
|
||||
148
lora/test_retriever.py
Normal file
148
lora/test_retriever.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from diffsynth import FluxImagePipeline, ModelManager, load_state_dict
|
||||
from diffsynth.models.lora import FluxLoRAConverter
|
||||
from diffsynth.pipelines.flux_image import lets_dance_flux
|
||||
from lora.dataset import LoraDataset
|
||||
from lora.retriever import TextEncoder, LoRAEncoder
|
||||
from lora.merger import LoraPatcher
|
||||
from lora.utils import load_lora
|
||||
import torch, os
|
||||
from accelerate import Accelerator, DistributedDataParallelKwargs
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer, CLIPModel
|
||||
import pandas as pd
|
||||
|
||||
|
||||
|
||||
class LoRARetrieverTrainingModel(torch.nn.Module):
|
||||
def __init__(self, pretrained_path):
|
||||
super().__init__()
|
||||
|
||||
self.text_encoder = TextEncoder().to(torch.bfloat16)
|
||||
state_dict = load_state_dict("models/FLUX/FLUX.1-dev/text_encoder/model.safetensors")
|
||||
self.text_encoder.load_state_dict(TextEncoder.state_dict_converter().from_civitai(state_dict))
|
||||
self.text_encoder.requires_grad_(False)
|
||||
self.text_encoder.eval()
|
||||
|
||||
self.lora_encoder = LoRAEncoder().to(torch.bfloat16)
|
||||
state_dict = load_state_dict(pretrained_path)
|
||||
self.lora_encoder.load_state_dict(state_dict)
|
||||
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained("diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1")
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.torch_dtype = dtype
|
||||
super().to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
def forward(self, batch):
|
||||
text = [data["text"] for data in batch]
|
||||
input_ids = self.tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
truncation=True
|
||||
).input_ids.to(self.device)
|
||||
text_emb = self.text_encoder(input_ids)
|
||||
text_emb = text_emb / text_emb.norm()
|
||||
|
||||
lora_emb = []
|
||||
for data in batch:
|
||||
lora = FluxLoRAConverter.align_to_diffsynth_format(load_lora(data["model_file"], device=self.device))
|
||||
lora_emb.append(self.lora_encoder(lora))
|
||||
lora_emb = torch.concat(lora_emb)
|
||||
lora_emb = lora_emb / lora_emb.norm()
|
||||
|
||||
similarity = text_emb @ lora_emb.T
|
||||
print(similarity)
|
||||
loss = -torch.log(torch.softmax(similarity, dim=0).diag()) - torch.log(torch.softmax(similarity, dim=1).diag())
|
||||
loss = 10 * loss.mean()
|
||||
return loss
|
||||
|
||||
|
||||
def trainable_modules(self):
|
||||
return self.lora_encoder.parameters()
|
||||
|
||||
@torch.no_grad()
|
||||
def process_lora_list(self, lora_list):
|
||||
lora_emb = []
|
||||
for lora in tqdm(lora_list):
|
||||
lora = FluxLoRAConverter.align_to_diffsynth_format(load_lora(lora, device="cuda"))
|
||||
lora_emb.append(self.lora_encoder(lora))
|
||||
lora_emb = torch.concat(lora_emb)
|
||||
lora_emb = lora_emb / lora_emb.norm()
|
||||
self.lora_emb = lora_emb
|
||||
self.lora_list = lora_list
|
||||
|
||||
@torch.no_grad()
|
||||
def retrieve(self, text, k=1):
|
||||
input_ids = self.tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
truncation=True
|
||||
).input_ids.to(self.device)
|
||||
text_emb = self.text_encoder(input_ids)
|
||||
text_emb = text_emb / text_emb.norm()
|
||||
|
||||
similarity = text_emb @ self.lora_emb.T
|
||||
topk = torch.topk(similarity, k, dim=1).indices[0]
|
||||
|
||||
lora_list = []
|
||||
model_url_list = []
|
||||
for lora_id in topk:
|
||||
print(self.lora_list[lora_id])
|
||||
lora = FluxLoRAConverter.align_to_diffsynth_format(load_lora(self.lora_list[lora_id], device="cuda"))
|
||||
lora_list.append(lora)
|
||||
model_id = self.lora_list[lora_id].split("/")[3:5]
|
||||
model_url_list.append(f"https://www.modelscope.cn/models/{model_id[0]}/{model_id[1]}")
|
||||
return lora_list, model_url_list
|
||||
|
||||
|
||||
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"])
|
||||
pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
pipe.enable_auto_lora()
|
||||
|
||||
lora_patcher = LoraPatcher().to(dtype=torch.bfloat16, device="cuda")
|
||||
lora_patcher.load_state_dict(load_state_dict("models/lora_merger/epoch-9.safetensors"))
|
||||
|
||||
retriever = LoRARetrieverTrainingModel("models/lora_retriever/epoch-3.safetensors").to(dtype=torch.bfloat16, device="cuda")
|
||||
retriever.process_lora_list(list(set("data/lora/models/" + i for i in pd.read_csv("data/lora/lora_dataset_1000.csv")["model_file"])))
|
||||
|
||||
dataset = LoraDataset("data/lora/models", "data/lora/lora_dataset_1000.csv", steps_per_epoch=800, loras_per_item=1)
|
||||
|
||||
text_list = []
|
||||
model_url_list = []
|
||||
for seed in range(100):
|
||||
text = dataset[0][0]["text"]
|
||||
print(text)
|
||||
loras, urls = retriever.retrieve(text, k=3)
|
||||
print(urls)
|
||||
image = pipe(
|
||||
prompt=text,
|
||||
seed=seed,
|
||||
)
|
||||
image.save(f"data/lora/lora_outputs/image_{seed}_top0.jpg")
|
||||
for i in range(2, 3):
|
||||
image = pipe(
|
||||
prompt=text,
|
||||
lora_state_dicts=loras[:i+1],
|
||||
lora_patcher=lora_patcher,
|
||||
seed=seed,
|
||||
)
|
||||
image.save(f"data/lora/lora_outputs/image_{seed}_top{i+1}.jpg")
|
||||
|
||||
text_list.append(text)
|
||||
model_url_list.append(urls)
|
||||
df = pd.DataFrame()
|
||||
df["text"] = text_list
|
||||
df["models"] = [",".join(i) for i in model_url_list]
|
||||
df.to_csv("data/lora/lora_outputs/metadata.csv", index=False, encoding="utf-8-sig")
|
||||
119
lora/train_merger.py
Normal file
119
lora/train_merger.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from diffsynth import FluxImagePipeline, ModelManager
|
||||
from diffsynth.models.lora import FluxLoRAConverter
|
||||
from diffsynth.pipelines.flux_image import lets_dance_flux
|
||||
from lora.dataset import LoraDataset
|
||||
from lora.merger import LoraPatcher
|
||||
from lora.utils import load_lora
|
||||
import torch, os
|
||||
from accelerate import Accelerator, DistributedDataParallelKwargs
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
class LoRAMergerTrainingModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu", model_id_list=["FLUX.1-dev"])
|
||||
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
self.lora_patcher = LoraPatcher()
|
||||
self.pipe.enable_auto_lora()
|
||||
self.freeze_parameters()
|
||||
self.switch_to_training_mode()
|
||||
self.use_gradient_checkpointing = True
|
||||
self.state_dict_converter = FluxLoRAConverter.align_to_diffsynth_format
|
||||
self.device = "cuda"
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.torch_dtype = dtype
|
||||
super().to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
def switch_to_training_mode(self):
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
|
||||
def freeze_parameters(self):
|
||||
self.pipe.requires_grad_(False)
|
||||
self.pipe.eval()
|
||||
self.pipe.denoising_model().train()
|
||||
self.lora_patcher.requires_grad_(True)
|
||||
|
||||
|
||||
def forward(self, batch):
|
||||
# Data
|
||||
text, image = batch[0]["text"], batch[0]["image"].unsqueeze(0)
|
||||
num_lora = torch.randint(1, len(batch), (1,))[0]
|
||||
lora_state_dicts = [
|
||||
self.state_dict_converter(load_lora(batch[i]["model_file"], device=self.device)) for i in range(num_lora)
|
||||
]
|
||||
lora_alphas = None
|
||||
|
||||
# Prepare input parameters
|
||||
self.pipe.device = self.device
|
||||
prompt_emb = self.pipe.encode_prompt(text, positive=True)
|
||||
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
|
||||
noise = torch.randn_like(latents)
|
||||
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
||||
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
|
||||
extra_input = self.pipe.prepare_extra_input(latents)
|
||||
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
||||
|
||||
# Compute loss
|
||||
noise_pred = lets_dance_flux(
|
||||
self.pipe.dit,
|
||||
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||
lora_state_dicts=lora_state_dicts, lora_alphas=lora_alphas, lora_patcher=self.lora_patcher,
|
||||
use_gradient_checkpointing=self.use_gradient_checkpointing
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||
return loss
|
||||
|
||||
|
||||
def trainable_modules(self):
|
||||
return self.lora_patcher.parameters()
|
||||
|
||||
|
||||
class ModelLogger:
|
||||
def __init__(self, output_path, remove_prefix_in_ckpt=None):
|
||||
self.output_path = output_path
|
||||
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
||||
|
||||
|
||||
def on_step_end(self, loss):
|
||||
pass
|
||||
|
||||
|
||||
def on_epoch_end(self, accelerator, model, epoch_id):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.unwrap_model(model).lora_patcher.state_dict()
|
||||
os.makedirs(self.output_path, exist_ok=True)
|
||||
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
||||
accelerator.save(state_dict, path, safe_serialization=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = LoRAMergerTrainingModel()
|
||||
dataset = LoraDataset("data/lora/models/", "data/lora/lora_dataset_1000.csv", steps_per_epoch=800, loras_per_item=4)
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1, num_workers=1, collate_fn=lambda x: x[0])
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=1e-4)
|
||||
model_logger = ModelLogger("models/lora_merger")
|
||||
accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
for epoch_id in range(1000000):
|
||||
for data in tqdm(dataloader):
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
loss = model(data)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||
105
lora/train_retriever.py
Normal file
105
lora/train_retriever.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from diffsynth import FluxImagePipeline, ModelManager, load_state_dict
|
||||
from diffsynth.models.lora import FluxLoRAConverter
|
||||
from diffsynth.pipelines.flux_image import lets_dance_flux
|
||||
from lora.dataset import LoraDataset
|
||||
from lora.retriever import TextEncoder, LoRAEncoder
|
||||
from lora.utils import load_lora
|
||||
import torch, os
|
||||
from accelerate import Accelerator, DistributedDataParallelKwargs
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTokenizer, CLIPModel
|
||||
|
||||
|
||||
|
||||
class LoRARetrieverTrainingModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.text_encoder = TextEncoder().to(torch.bfloat16)
|
||||
state_dict = load_state_dict("models/FLUX/FLUX.1-dev/text_encoder/model.safetensors")
|
||||
self.text_encoder.load_state_dict(TextEncoder.state_dict_converter().from_civitai(state_dict))
|
||||
self.text_encoder.requires_grad_(False)
|
||||
self.text_encoder.eval()
|
||||
|
||||
self.lora_encoder = LoRAEncoder().to(torch.bfloat16)
|
||||
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained("diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1")
|
||||
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
||||
if device is not None:
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.torch_dtype = dtype
|
||||
super().to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
|
||||
def forward(self, batch):
|
||||
text = [data["text"] for data in batch]
|
||||
input_ids = self.tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=77,
|
||||
truncation=True
|
||||
).input_ids.to(self.device)
|
||||
text_emb = self.text_encoder(input_ids)
|
||||
text_emb = text_emb / text_emb.norm()
|
||||
|
||||
lora_emb = []
|
||||
for data in batch:
|
||||
lora = FluxLoRAConverter.align_to_diffsynth_format(load_lora(data["model_file"], device=self.device))
|
||||
lora_emb.append(self.lora_encoder(lora))
|
||||
lora_emb = torch.concat(lora_emb)
|
||||
lora_emb = lora_emb / lora_emb.norm()
|
||||
|
||||
similarity = text_emb @ lora_emb.T
|
||||
print(similarity)
|
||||
loss = -torch.log(torch.softmax(similarity, dim=0).diag()) - torch.log(torch.softmax(similarity, dim=1).diag())
|
||||
loss = 10 * loss.mean()
|
||||
return loss
|
||||
|
||||
|
||||
def trainable_modules(self):
|
||||
return self.lora_encoder.parameters()
|
||||
|
||||
|
||||
class ModelLogger:
|
||||
def __init__(self, output_path, remove_prefix_in_ckpt=None):
|
||||
self.output_path = output_path
|
||||
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
||||
|
||||
|
||||
def on_step_end(self, loss):
|
||||
pass
|
||||
|
||||
|
||||
def on_epoch_end(self, accelerator, model, epoch_id):
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
state_dict = accelerator.unwrap_model(model).lora_encoder.state_dict()
|
||||
os.makedirs(self.output_path, exist_ok=True)
|
||||
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
||||
accelerator.save(state_dict, path, safe_serialization=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = LoRARetrieverTrainingModel()
|
||||
dataset = LoraDataset("data/lora/models/", "data/lora/lora_dataset_1000.csv", steps_per_epoch=100, loras_per_item=32)
|
||||
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1, num_workers=1, collate_fn=lambda x: x[0])
|
||||
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=1e-4)
|
||||
model_logger = ModelLogger("models/lora_retriever")
|
||||
accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
|
||||
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
||||
|
||||
for epoch_id in range(1000000):
|
||||
for data in tqdm(dataloader):
|
||||
with accelerator.accumulate(model):
|
||||
optimizer.zero_grad()
|
||||
loss = model(data)
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
print(loss)
|
||||
model_logger.on_epoch_end(accelerator, model, epoch_id)
|
||||
12
lora/utils.py
Normal file
12
lora/utils.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from diffsynth import load_state_dict
|
||||
import math, torch
|
||||
|
||||
|
||||
def load_lora(file_path, device):
|
||||
sd = load_state_dict(file_path, torch_dtype=torch.bfloat16, device=device)
|
||||
scale = math.sqrt(sd["lora_unet_single_blocks_9_modulation_lin.alpha"] / sd["lora_unet_single_blocks_9_modulation_lin.lora_down.weight"].shape[0])
|
||||
if scale != 1:
|
||||
sd = {i: sd[i] * scale for i in sd}
|
||||
return sd
|
||||
|
||||
|
||||
Reference in New Issue
Block a user