diff --git a/train_flux_reference.py b/train_flux_reference.py index a6ec6c7..4cc779b 100644 --- a/train_flux_reference.py +++ b/train_flux_reference.py @@ -5,6 +5,7 @@ import torch, os, argparse import lightning as pl from diffsynth.data.image_pulse import SingleTaskDataset, MultiTaskDataset from diffsynth.pipelines.flux_image import lets_dance_flux +from diffsynth.models.flux_reference_embedder import FluxReferenceEmbedder os.environ["TOKENIZERS_PARALLELISM"] = "True" @@ -30,6 +31,7 @@ class LightningModel(LightningModelForT2ILoRA): model_manager.load_lora(path) self.pipe = FluxImagePipeline.from_model_manager(model_manager) + self.pipe.reference_embedder = FluxReferenceEmbedder() if quantize is not None: self.pipe.dit.quantize() @@ -37,6 +39,8 @@ class LightningModel(LightningModelForT2ILoRA): self.pipe.scheduler.set_timesteps(1000, training=True) self.freeze_parameters() + self.pipe.reference_embedder.requires_grad_(True) + self.pipe.reference_embedder.train() self.add_lora_to_model( self.pipe.denoising_model(), lora_rank=lora_rank, @@ -47,6 +51,7 @@ class LightningModel(LightningModelForT2ILoRA): state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format ) + def training_step(self, batch, batch_idx): # Data text, image = batch["text"], batch["image_2"] @@ -82,6 +87,26 @@ class LightningModel(LightningModelForT2ILoRA): # Record log self.log("train_loss", loss, prog_bar=True) return loss + + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + + def on_save_checkpoint(self, checkpoint): + checkpoint.clear() + trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.named_parameters())) + trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) + state_dict = self.pipe.state_dict() + lora_state_dict = {} + for name, param in state_dict.items(): + if name in trainable_param_names: + lora_state_dict[name] = param + if self.state_dict_converter is not None: + lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha) + checkpoint.update(lora_state_dict) def parse_args():