Files
DiffSynth-Studio/lora/train_retriever.py
2025-06-23 17:34:30 +08:00

106 lines
4.1 KiB
Python

from diffsynth import FluxImagePipeline, ModelManager, load_state_dict
from diffsynth.models.lora import FluxLoRAConverter
from diffsynth.pipelines.flux_image import lets_dance_flux
from lora.dataset import LoraDataset
from lora.retriever import TextEncoder, LoRAEncoder
from lora.utils import load_lora
import torch, os
from accelerate import Accelerator, DistributedDataParallelKwargs
from tqdm import tqdm
from transformers import CLIPTokenizer, CLIPModel
class LoRARetrieverTrainingModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.text_encoder = TextEncoder().to(torch.bfloat16)
state_dict = load_state_dict("models/FLUX/FLUX.1-dev/text_encoder/model.safetensors")
self.text_encoder.load_state_dict(TextEncoder.state_dict_converter().from_civitai(state_dict))
self.text_encoder.requires_grad_(False)
self.text_encoder.eval()
self.lora_encoder = LoRAEncoder().to(torch.bfloat16)
self.tokenizer = CLIPTokenizer.from_pretrained("diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1")
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None:
self.device = device
if dtype is not None:
self.torch_dtype = dtype
super().to(*args, **kwargs)
return self
def forward(self, batch):
text = [data["text"] for data in batch]
input_ids = self.tokenizer(
text,
return_tensors="pt",
padding="max_length",
max_length=77,
truncation=True
).input_ids.to(self.device)
text_emb = self.text_encoder(input_ids)
text_emb = text_emb / text_emb.norm()
lora_emb = []
for data in batch:
lora = FluxLoRAConverter.align_to_diffsynth_format(load_lora(data["model_file"], device=self.device))
lora_emb.append(self.lora_encoder(lora))
lora_emb = torch.concat(lora_emb)
lora_emb = lora_emb / lora_emb.norm()
similarity = text_emb @ lora_emb.T
print(similarity)
loss = -torch.log(torch.softmax(similarity, dim=0).diag()) - torch.log(torch.softmax(similarity, dim=1).diag())
loss = 10 * loss.mean()
return loss
def trainable_modules(self):
return self.lora_encoder.parameters()
class ModelLogger:
def __init__(self, output_path, remove_prefix_in_ckpt=None):
self.output_path = output_path
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
def on_step_end(self, loss):
pass
def on_epoch_end(self, accelerator, model, epoch_id):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
state_dict = accelerator.unwrap_model(model).lora_encoder.state_dict()
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
accelerator.save(state_dict, path, safe_serialization=True)
if __name__ == '__main__':
model = LoRARetrieverTrainingModel()
dataset = LoraDataset("data/lora/models/", "data/lora/lora_dataset_1000.csv", steps_per_epoch=100, loras_per_item=32)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1, num_workers=1, collate_fn=lambda x: x[0])
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=1e-4)
model_logger = ModelLogger("models/lora_retriever")
accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
for epoch_id in range(1000000):
for data in tqdm(dataloader):
with accelerator.accumulate(model):
optimizer.zero_grad()
loss = model(data)
accelerator.backward(loss)
optimizer.step()
print(loss)
model_logger.on_epoch_end(accelerator, model, epoch_id)