refine training

This commit is contained in:
xuyixuan.xyx
2025-05-12 14:19:00 +08:00
parent f17558a4c4
commit 91fbb24e17
3 changed files with 216 additions and 26 deletions

View File

@@ -254,9 +254,12 @@ class UnifiedModel(pl.LightningModule):
self.pipe.vae_decoder.requires_grad_(False)
self.pipe.vae_encoder.requires_grad_(False)
self.pipe.text_encoder_1.requires_grad_(False)
self.qwenvl.requires_grad_(False)
self.qwenvl.model.visual.requires_grad_(False)
self.pipe.train()
self.adapter.train()
self.qwenvl.train()
self.qwenvl.model.visual.eval()
# self.qwenvl.model.model.gradient_checkpointing = True
self.pipe.scheduler.set_timesteps(1000, training=True)
@@ -302,12 +305,6 @@ class UnifiedModel(pl.LightningModule):
def forward(self, batch):
return self.training_step(batch, 0)
def configure_optimizers(self):
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
return optimizer
@@ -369,12 +366,25 @@ if __name__ == '__main__':
dataset = MultiTaskDataset(
dataset_list=[
SingleTaskDataset(
"data/example_dataset",
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
height=1024, width=1024,
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_change_add_remove.json",
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")),
height=512, width=512,
height=1024, width=1024,
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_style_transfer.json",
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
height=1024, width=1024,
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_faceid.json",
),
],
dataset_weight=(1,),
dataset_weight=(4, 2, 1,),
steps_per_epoch=args.steps_per_epoch * accelerator.num_processes,
)
train_loader = torch.utils.data.DataLoader(