mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 01:48:13 +00:00
support reference image
This commit is contained in:
@@ -5,6 +5,7 @@ import torch, os, argparse
|
|||||||
import lightning as pl
|
import lightning as pl
|
||||||
from diffsynth.data.image_pulse import SingleTaskDataset, MultiTaskDataset
|
from diffsynth.data.image_pulse import SingleTaskDataset, MultiTaskDataset
|
||||||
from diffsynth.pipelines.flux_image import lets_dance_flux
|
from diffsynth.pipelines.flux_image import lets_dance_flux
|
||||||
|
from diffsynth.models.flux_reference_embedder import FluxReferenceEmbedder
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "True"
|
os.environ["TOKENIZERS_PARALLELISM"] = "True"
|
||||||
|
|
||||||
|
|
||||||
@@ -30,6 +31,7 @@ class LightningModel(LightningModelForT2ILoRA):
|
|||||||
model_manager.load_lora(path)
|
model_manager.load_lora(path)
|
||||||
|
|
||||||
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
|
self.pipe = FluxImagePipeline.from_model_manager(model_manager)
|
||||||
|
self.pipe.reference_embedder = FluxReferenceEmbedder()
|
||||||
|
|
||||||
if quantize is not None:
|
if quantize is not None:
|
||||||
self.pipe.dit.quantize()
|
self.pipe.dit.quantize()
|
||||||
@@ -37,6 +39,8 @@ class LightningModel(LightningModelForT2ILoRA):
|
|||||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||||
|
|
||||||
self.freeze_parameters()
|
self.freeze_parameters()
|
||||||
|
self.pipe.reference_embedder.requires_grad_(True)
|
||||||
|
self.pipe.reference_embedder.train()
|
||||||
self.add_lora_to_model(
|
self.add_lora_to_model(
|
||||||
self.pipe.denoising_model(),
|
self.pipe.denoising_model(),
|
||||||
lora_rank=lora_rank,
|
lora_rank=lora_rank,
|
||||||
@@ -47,6 +51,7 @@ class LightningModel(LightningModelForT2ILoRA):
|
|||||||
state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
|
state_dict_converter=FluxLoRAConverter.align_to_diffsynth_format
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
# Data
|
# Data
|
||||||
text, image = batch["text"], batch["image_2"]
|
text, image = batch["text"], batch["image_2"]
|
||||||
@@ -84,6 +89,26 @@ class LightningModel(LightningModelForT2ILoRA):
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters())
|
||||||
|
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
def on_save_checkpoint(self, checkpoint):
|
||||||
|
checkpoint.clear()
|
||||||
|
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.named_parameters()))
|
||||||
|
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||||
|
state_dict = self.pipe.state_dict()
|
||||||
|
lora_state_dict = {}
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
if name in trainable_param_names:
|
||||||
|
lora_state_dict[name] = param
|
||||||
|
if self.state_dict_converter is not None:
|
||||||
|
lora_state_dict = self.state_dict_converter(lora_state_dict, alpha=self.lora_alpha)
|
||||||
|
checkpoint.update(lora_state_dict)
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
Reference in New Issue
Block a user