mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
support reference image
This commit is contained in:
@@ -54,7 +54,7 @@ class LightningModel(LightningModelForT2ILoRA):
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# Data
|
||||
text, image = batch["text"], batch["image_2"]
|
||||
text, image = batch["instruction"], batch["image_2"]
|
||||
image_ref = batch["image_1"]
|
||||
|
||||
# Prepare input parameters
|
||||
@@ -77,8 +77,9 @@ class LightningModel(LightningModelForT2ILoRA):
|
||||
# Compute loss
|
||||
noise_pred = lets_dance_flux(
|
||||
self.pipe.denoising_model(),
|
||||
reference_embedder=self.pipe.reference_embedder,
|
||||
hidden_states_ref=hidden_states_ref,
|
||||
latents=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||
use_gradient_checkpointing=self.use_gradient_checkpointing
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
@@ -191,19 +192,23 @@ if __name__ == '__main__':
|
||||
SingleTaskDataset(
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_change_add_remove.json",
|
||||
height=512, width=512,
|
||||
),
|
||||
SingleTaskDataset(
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_zoomin_zoomout",
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_zoomin_zoomout.json",
|
||||
height=512, width=512,
|
||||
),
|
||||
SingleTaskDataset(
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
|
||||
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction")),
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_style_transfer.json",
|
||||
height=512, width=512,
|
||||
),
|
||||
SingleTaskDataset(
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_faceid.json",
|
||||
height=512, width=512,
|
||||
),
|
||||
],
|
||||
dataset_weight=(4, 2, 2, 1),
|
||||
|
||||
Reference in New Issue
Block a user