Files
DiffSynth-Studio/lora/train_merger.py
2025-06-23 17:34:30 +08:00

120 lines
4.8 KiB
Python

from diffsynth import FluxImagePipeline, ModelManager
from diffsynth.models.lora import FluxLoRAConverter
from diffsynth.pipelines.flux_image import lets_dance_flux
from lora.dataset import LoraDataset
from lora.merger import LoraPatcher
from lora.utils import load_lora
import torch, os
from accelerate import Accelerator, DistributedDataParallelKwargs
from tqdm import tqdm
class LoRAMergerTrainingModel(torch.nn.Module):
def __init__(self):
super().__init__()
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu", model_id_list=["FLUX.1-dev"])
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
self.lora_patcher = LoraPatcher()
self.pipe.enable_auto_lora()
self.freeze_parameters()
self.switch_to_training_mode()
self.use_gradient_checkpointing = True
self.state_dict_converter = FluxLoRAConverter.align_to_diffsynth_format
self.device = "cuda"
def to(self, *args, **kwargs):
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
if device is not None:
self.device = device
if dtype is not None:
self.torch_dtype = dtype
super().to(*args, **kwargs)
return self
def switch_to_training_mode(self):
self.pipe.scheduler.set_timesteps(1000, training=True)
def freeze_parameters(self):
self.pipe.requires_grad_(False)
self.pipe.eval()
self.pipe.denoising_model().train()
self.lora_patcher.requires_grad_(True)
def forward(self, batch):
# Data
text, image = batch[0]["text"], batch[0]["image"].unsqueeze(0)
num_lora = torch.randint(1, len(batch), (1,))[0]
lora_state_dicts = [
self.state_dict_converter(load_lora(batch[i]["model_file"], device=self.device)) for i in range(num_lora)
]
lora_alphas = None
# Prepare input parameters
self.pipe.device = self.device
prompt_emb = self.pipe.encode_prompt(text, positive=True)
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
noise = torch.randn_like(latents)
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
extra_input = self.pipe.prepare_extra_input(latents)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
# Compute loss
noise_pred = lets_dance_flux(
self.pipe.dit,
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
lora_state_dicts=lora_state_dicts, lora_alphas=lora_alphas, lora_patcher=self.lora_patcher,
use_gradient_checkpointing=self.use_gradient_checkpointing
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
return loss
def trainable_modules(self):
return self.lora_patcher.parameters()
class ModelLogger:
def __init__(self, output_path, remove_prefix_in_ckpt=None):
self.output_path = output_path
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
def on_step_end(self, loss):
pass
def on_epoch_end(self, accelerator, model, epoch_id):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
state_dict = accelerator.unwrap_model(model).lora_patcher.state_dict()
os.makedirs(self.output_path, exist_ok=True)
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
accelerator.save(state_dict, path, safe_serialization=True)
if __name__ == '__main__':
model = LoRAMergerTrainingModel()
dataset = LoraDataset("data/lora/models/", "data/lora/lora_dataset_1000.csv", steps_per_epoch=800, loras_per_item=4)
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1, num_workers=1, collate_fn=lambda x: x[0])
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=1e-4)
model_logger = ModelLogger("models/lora_merger")
accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
for epoch_id in range(1000000):
for data in tqdm(dataloader):
with accelerator.accumulate(model):
optimizer.zero_grad()
loss = model(data)
accelerator.backward(loss)
optimizer.step()
model_logger.on_epoch_end(accelerator, model, epoch_id)