multi-node

This commit is contained in:
xuyixuan.xyx
2025-04-22 10:40:40 +08:00
parent 9ea29a769d
commit a10635818a
3 changed files with 258 additions and 31 deletions

View File

@@ -42,6 +42,8 @@ class LightningModel(LightningModelForT2ILoRA):
self.freeze_parameters()
self.pipe.reference_embedder.requires_grad_(True)
self.pipe.reference_embedder.train()
self.pipe.dit.requires_grad_(True)
self.pipe.dit.train()
# self.add_lora_to_model(
# self.pipe.denoising_model(),
# lora_rank=lora_rank,
@@ -192,27 +194,30 @@ if __name__ == '__main__':
dataset_list=[
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove",
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_change_add_remove.json",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_change_add_remove.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_zoomin_zoomout",
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_zoomin_zoomout.json",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_zoomin_zoomout.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer",
keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_style_transfer.json",
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_style_transfer.json",
height=512, width=512,
),
SingleTaskDataset(
"/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid",
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250411_dataset_faceid.json",
keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")),
metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250418_dataset_faceid.json",
height=512, width=512,
),
],
dataset_weight=(4, 2, 2, 1),
dataset_weight=(4, 1, 4, 1),
steps_per_epoch=args.steps_per_epoch,
)
train_loader = torch.utils.data.DataLoader(