mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
120 lines
4.8 KiB
Python
120 lines
4.8 KiB
Python
from diffsynth import FluxImagePipeline, ModelManager
|
|
from diffsynth.models.lora import FluxLoRAConverter
|
|
from diffsynth.pipelines.flux_image import lets_dance_flux
|
|
from lora.dataset import LoraDataset
|
|
from lora.merger import LoraPatcher
|
|
from lora.utils import load_lora
|
|
import torch, os
|
|
from accelerate import Accelerator, DistributedDataParallelKwargs
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
class LoRAMergerTrainingModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu", model_id_list=["FLUX.1-dev"])
|
|
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
|
|
self.lora_patcher = LoraPatcher()
|
|
self.pipe.enable_auto_lora()
|
|
self.freeze_parameters()
|
|
self.switch_to_training_mode()
|
|
self.use_gradient_checkpointing = True
|
|
self.state_dict_converter = FluxLoRAConverter.align_to_diffsynth_format
|
|
self.device = "cuda"
|
|
|
|
|
|
def to(self, *args, **kwargs):
|
|
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
|
if device is not None:
|
|
self.device = device
|
|
if dtype is not None:
|
|
self.torch_dtype = dtype
|
|
super().to(*args, **kwargs)
|
|
return self
|
|
|
|
|
|
def switch_to_training_mode(self):
|
|
self.pipe.scheduler.set_timesteps(1000, training=True)
|
|
|
|
|
|
def freeze_parameters(self):
|
|
self.pipe.requires_grad_(False)
|
|
self.pipe.eval()
|
|
self.pipe.denoising_model().train()
|
|
self.lora_patcher.requires_grad_(True)
|
|
|
|
|
|
def forward(self, batch):
|
|
# Data
|
|
text, image = batch[0]["text"], batch[0]["image"].unsqueeze(0)
|
|
num_lora = torch.randint(1, len(batch), (1,))[0]
|
|
lora_state_dicts = [
|
|
self.state_dict_converter(load_lora(batch[i]["model_file"], device=self.device)) for i in range(num_lora)
|
|
]
|
|
lora_alphas = None
|
|
|
|
# Prepare input parameters
|
|
self.pipe.device = self.device
|
|
prompt_emb = self.pipe.encode_prompt(text, positive=True)
|
|
latents = self.pipe.vae_encoder(image.to(dtype=self.pipe.torch_dtype, device=self.device))
|
|
noise = torch.randn_like(latents)
|
|
timestep_id = torch.randint(0, self.pipe.scheduler.num_train_timesteps, (1,))
|
|
timestep = self.pipe.scheduler.timesteps[timestep_id].to(self.device)
|
|
extra_input = self.pipe.prepare_extra_input(latents)
|
|
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
|
training_target = self.pipe.scheduler.training_target(latents, noise, timestep)
|
|
|
|
# Compute loss
|
|
noise_pred = lets_dance_flux(
|
|
self.pipe.dit,
|
|
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
|
lora_state_dicts=lora_state_dicts, lora_alphas=lora_alphas, lora_patcher=self.lora_patcher,
|
|
use_gradient_checkpointing=self.use_gradient_checkpointing
|
|
)
|
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
|
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
|
return loss
|
|
|
|
|
|
def trainable_modules(self):
|
|
return self.lora_patcher.parameters()
|
|
|
|
|
|
class ModelLogger:
|
|
def __init__(self, output_path, remove_prefix_in_ckpt=None):
|
|
self.output_path = output_path
|
|
self.remove_prefix_in_ckpt = remove_prefix_in_ckpt
|
|
|
|
|
|
def on_step_end(self, loss):
|
|
pass
|
|
|
|
|
|
def on_epoch_end(self, accelerator, model, epoch_id):
|
|
accelerator.wait_for_everyone()
|
|
if accelerator.is_main_process:
|
|
state_dict = accelerator.unwrap_model(model).lora_patcher.state_dict()
|
|
os.makedirs(self.output_path, exist_ok=True)
|
|
path = os.path.join(self.output_path, f"epoch-{epoch_id}.safetensors")
|
|
accelerator.save(state_dict, path, safe_serialization=True)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
model = LoRAMergerTrainingModel()
|
|
dataset = LoraDataset("data/lora/models/", "data/lora/lora_dataset_1000.csv", steps_per_epoch=800, loras_per_item=4)
|
|
dataloader = torch.utils.data.DataLoader(dataset, shuffle=True, batch_size=1, num_workers=1, collate_fn=lambda x: x[0])
|
|
optimizer = torch.optim.AdamW(model.trainable_modules(), lr=1e-4)
|
|
model_logger = ModelLogger("models/lora_merger")
|
|
accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)])
|
|
model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
|
|
|
|
for epoch_id in range(1000000):
|
|
for data in tqdm(dataloader):
|
|
with accelerator.accumulate(model):
|
|
optimizer.zero_grad()
|
|
loss = model(data)
|
|
accelerator.backward(loss)
|
|
optimizer.step()
|
|
model_logger.on_epoch_end(accelerator, model, epoch_id)
|