support reference image

This commit is contained in:
Artiprocher
2025-04-11 15:36:51 +08:00
parent d21676b4dc
commit 9e78bf5e89

View File

@@ -5,6 +5,7 @@ import torch, os, argparse
import lightning as pl
from diffsynth.data.image_pulse import SingleTaskDataset, MultiTaskDataset
from diffsynth.pipelines.flux_image import lets_dance_flux
from diffsynth.models.flux_reference_embedder import FluxReferenceEmbedder
os.environ["TOKENIZERS_PARALLELISM"] = "True"
@@ -30,6 +31,7 @@ class LightningModel(LightningModelForT2ILoRA):
model_manager.load_lora(path)
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
self.pipe.reference_embedder = FluxReferenceEmbedder()
if quantize is not None:
self.pipe.dit.quantize()
@@ -37,6 +39,8 @@ class LightningModel(LightningModelForT2ILoRA):
self.pipe.scheduler.set_timesteps(1000, training=True)
self.freeze_parameters()
self.pipe.reference_embedder.requires_grad_(True)
self.pipe.reference_embedder.train()
self.add_lora_to_model(
self.pipe.denoising_model(),
lora_rank=lora_rank,
@@ -47,6 +51,7 @@ class LightningModel(LightningModelForT2ILoRA):
state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
)
def training_step(self, batch, batch_idx):
# Data
text, image = batch["text"], batch["image_2"]
@@ -82,6 +87,26 @@ class LightningModel(LightningModelForT2ILoRA):
# Record log
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
return optimizer
def on_save_checkpoint(self, checkpoint):
checkpoint.clear()
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
state_dict = self.pipe.state_dict()
lora_state_dict = {}
for name, param in state_dict.items():
if name in trainable_param_names:
lora_state_dict[name] = param
if self.state_dict_converter is not None:
lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha)
checkpoint.update(lora_state_dict)
def parse_args():