mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
refine training
This commit is contained in:
28
train.py
28
train.py
@@ -254,9 +254,12 @@ class UnifiedModel(pl.LightningModule):
|
||||
self.pipe.vae_decoder.requires_grad_(False)
|
||||
self.pipe.vae_encoder.requires_grad_(False)
|
||||
self.pipe.text_encoder_1.requires_grad_(False)
|
||||
self.qwenvl.requires_grad_(False)
|
||||
self.qwenvl.model.visual.requires_grad_(False)
|
||||
self.pipe.train()
|
||||
self.adapter.train()
|
||||
self.qwenvl.train()
|
||||
self.qwenvl.model.visual.eval()
|
||||
# self.qwenvl.model.model.gradient_checkpointing = True
|
||||
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
@@ -302,12 +305,6 @@ class UnifiedModel(pl.LightningModule):
|
||||
|
||||
def forward(self, batch):
|
||||
return self.training_step(batch, 0)
|
||||
|
||||
|
||||
def configure_optimizers(self):
|
||||
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters())
|
||||
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
|
||||
|
||||
@@ -369,12 +366,25 @@ if __name__ == '__main__':
|
||||
dataset = MultiTaskDataset(
|
||||
dataset_list=[
|
||||
SingleTaskDataset(
|
||||
"data/example_dataset",
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
|
||||
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
|
||||
height=1024, width=1024,
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_change_add_remove.json",
|
||||
),
|
||||
SingleTaskDataset(
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
|
||||
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
|
||||
height=512, width=512,
|
||||
height=1024, width=1024,
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_style_transfer.json",
|
||||
),
|
||||
SingleTaskDataset(
|
||||
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
|
||||
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
|
||||
height=1024, width=1024,
|
||||
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_faceid.json",
|
||||
),
|
||||
],
|
||||
dataset_weight=(1,),
|
||||
dataset_weight=(4, 2, 1,),
|
||||
steps_per_epoch=args.steps_per_epoch * accelerator.num_processes,
|
||||
)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
|
||||
Reference in New Issue
Block a user