mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support reference image
This commit is contained in:
@@ -5,6 +5,7 @@ import torch, os, argparse
|
||||
import lightning as pl
|
||||
from diffsynth.data.image_pulse import SingleTaskDataset, MultiTaskDataset
|
||||
from diffsynth.pipelines.flux_image import lets_dance_flux
|
||||
from diffsynth.models.flux_reference_embedder import FluxReferenceEmbedder
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "True"
|
||||
|
||||
|
||||
@@ -30,6 +31,7 @@ class LightningModel(LightningModelForT2ILoRA):
|
||||
model_manager.load_lora(path)
|
||||
|
||||
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||
self.pipe.reference_embedder = FluxReferenceEmbedder()
|
||||
|
||||
if quantize is not None:
|
||||
self.pipe.dit.quantize()
|
||||
@@ -37,6 +39,8 @@ class LightningModel(LightningModelForT2ILoRA):
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
self.freeze_parameters()
|
||||
self.pipe.reference_embedder.requires_grad_(True)
|
||||
self.pipe.reference_embedder.train()
|
||||
self.add_lora_to_model(
|
||||
self.pipe.denoising_model(),
|
||||
lora_rank=lora_rank,
|
||||
@@ -47,6 +51,7 @@ class LightningModel(LightningModelForT2ILoRA):
|
||||
state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
|
||||
)
|
||||
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# Data
|
||||
text, image = batch["text"], batch["image_2"]
|
||||
@@ -82,6 +87,26 @@ class LightningModel(LightningModelForT2ILoRA):
|
||||
# Record log
|
||||
self.log("train_loss", loss, prog_bar=True)
|
||||
return loss
|
||||
|
||||
|
||||
def configure_optimizers(self):
|
||||
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters())
|
||||
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
checkpoint.clear()
|
||||
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.named_parameters()))
|
||||
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||
state_dict = self.pipe.state_dict()
|
||||
lora_state_dict = {}
|
||||
for name, param in state_dict.items():
|
||||
if name in trainable_param_names:
|
||||
lora_state_dict[name] = param
|
||||
if self.state_dict_converter is not None:
|
||||
lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha)
|
||||
checkpoint.update(lora_state_dict)
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
||||
Reference in New Issue
Block a user