mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
106 lines
4.1 KiB
Python
106 lines
4.1 KiB
Python
from diffsynth import FluxImagePipeline, ModelManager, load_state_dict
|
|
from diffsynth.models.lora import FluxLoRAConverter
|
|
from diffsynth.pipelines.flux_image import lets_dance_flux
|
|
from lora.dataset import LoraDataset
|
|
from lora.retriever import TextEncoder, LoRAEncoder
|
|
from lora.utils import load_lora
|
|
import torch, os
|
|
from accelerate import Accelerator, DistributedDataParallelKwargs
|
|
from tqdm import tqdm
|
|
from transformers import CLIPTokenizer, CLIPModel
|
|
|
|
|
|
|
|
class LoRARetrieverTrainingModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
self.text_encoder = TextEncoder().to(torch.bfloat16)
|
|
state_dict = load_state_dict("models/FLUX/FLUX.1-dev/text_encoder/model.safetensors")
|
|
self.text_encoder.load_state_dict(TextEncoder.state_dict_converter().from_civitai(state_dict))
|
|
self.text_encoder.requires_grad_(False)
|
|
self.text_encoder.eval()
|
|
|
|
self.lora_encoder = LoRAEncoder().to(torch.bfloat16)
|
|
|
|
self.tokenizer = CLIPTokenizer.from_pretrained("diffsynth/tokenizer_configs/stable_diffusion_3/tokenizer_1")
|
|
|
|
|
|
def to(self, *args, **kwargs):
|
|
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
|
if device is not None:
|
|
self.device = device
|
|
if dtype is not None:
|
|
self.torch_dtype = dtype
|
|
super().to(*args, **kwargs)
|
|
return self
|
|
|
|
|
|
def forward(self, batch):
|
|
text = [data["text"] for data in batch]
|
|
input_ids = self.tokenizer(
|
|
text,
|
|
return_tensors="pt",
|
|
padding="max_length",
|
|
max_length=77,
|
|
truncation=True
|
|
).input_ids.to(self.device)
|
|
text_emb = self.text_encoder(input_ids)
|
|
text_emb = text_emb / text_emb.norm()
|
|
|
|
lora_emb = []
|
|
for data in batch:
|
|
lora = FluxLoRAConverter.align_to_diffsynth_format(load_lora(data["model_file"], device=self.device))
|
|
lora_emb.append(self.lora_encoder(lora))
|
|
lora_emb = torch.concat(lora_emb)
|
|
lora_emb = lora_emb / lora_emb.norm()
|
|
|
|
similarity = text_emb @ lora_emb.T
|
|
print(similarity)
|
|
loss = -torch.log(torch.softmax(similarity, dim=0).diag()) - torch.log(torch.softmax(similarity, dim=1).diag())
|
|
loss = 10 * loss.mean()
|
|
return loss
|
|
|
|
|
|
def trainable_modules(self):
|
|
return self.lora_encoder.parameters()
|
|
|
|
|
|
class ModelLogger:
|
|
def __init__(self, output_path, remove_prefix_in_ckpt=None):
|
|
self.output_path = output_path
|
|
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
|
|
|
|
|
def on_step_end(self, loss):
|
|
pass
|
|
|
|
|
|
def on_epoch_end(self, accelerator, model, epoch_id):
|
|
accelerator.wait_for_everyone()
|
|
if accelerator.is_main_process:
|
|
state_dict = accelerator.unwrap_model(model).lora_encoder.state_dict()
|
|
os.makedirs(self.output_path, exist_ok=True)
|
|
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
|
accelerator.save(state_dict, path, safe_serialization=True)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
model = LoRARetrieverTrainingModel()
|
|
dataset = LoraDataset("data/lora/models/", "data/lora/lora_dataset_1000.csv", steps_per_epoch=100, loras_per_item=32)
|
|
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1, num_workers=1, collate_fn=lambda x: x[0])
|
|
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=1e-4)
|
|
model_logger = ModelLogger("models/lora_retriever")
|
|
accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
|
|
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
|
|
|
for epoch_id in range(1000000):
|
|
for data in tqdm(dataloader):
|
|
with accelerator.accumulate(model):
|
|
optimizer.zero_grad()
|
|
loss = model(data)
|
|
accelerator.backward(loss)
|
|
optimizer.step()
|
|
print(loss)
|
|
model_logger.on_epoch_end(accelerator, model, epoch_id)
|