From 91fbb24e176673898dbe1bd027f8af29652c05e9 Mon Sep 17 00:00:00 2001 From: "xuyixuan.xyx" Date: Mon, 12 May 2025 14:19:00 +0800 Subject: [PATCH] refine training --- run_single.sh | 4 + test.py | 210 ++++++++++++++++++++++++++++++++++++++++++++++---- train.py | 28 ++++--- 3 files changed, 216 insertions(+), 26 deletions(-) create mode 100644 run_single.sh diff --git a/run_single.sh b/run_single.sh new file mode 100644 index 0000000..1378b9b --- /dev/null +++ b/run_single.sh @@ -0,0 +1,4 @@ +accelerate launch \ + train.py \ + --output_path models/nexus_v3 \ + --steps_per_epoch 4000 \ No newline at end of file diff --git a/test.py b/test.py index 4418e53..3f6f6f7 100644 --- a/test.py +++ b/test.py @@ -1,11 +1,133 @@ from transformers import AutoConfig, AutoTokenizer -import torch +import torch, json, os, torchvision from modeling.ar.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration from modeling.ar.processing_qwen2_5_vl import Qwen2_5_VLProcessor from diffsynth import ModelManager, FluxImagePipeline, load_state_dict, hash_state_dict_keys from qwen_vl_utils import smart_resize from PIL import Image import numpy as np +from torchvision.transforms import v2 + + + +class SingleTaskDataset(torch.utils.data.Dataset): + def __init__( + self, + base_path, + keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")), + height=1024, width=1024, random=True, steps_per_epoch=1000, metadata_path=None + ): + self.base_path = base_path + self.keys = keys + self.metadata = [] + self.bad_data = [] + self.height = height + self.width = width + self.random = random + self.steps_per_epoch = steps_per_epoch + self.image_process = v2.Compose([ + v2.CenterCrop(size=(height, width)), + v2.ToImage(), + v2.ToDtype(torch.float32, scale=True), + v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + if metadata_path is None: + self.search_for_data("", report_data_log=True) + self.report_data_log() + else: + with open(metadata_path, "r", encoding="utf-8-sig") as f: + self.metadata = json.load(f) + + + def report_data_log(self): + print(f"{len(self.metadata)} valid data, {len(self.bad_data)} invalid data.") + + + def dump_metadata(self, path): + with open(path, "w", encoding="utf-8") as f: + json.dump(self.metadata, f, ensure_ascii=False, indent=4) + + + def parse_json_file(self, absolute_path, relative_path): + data_list = [] + with open(absolute_path, "r") as f: + metadata = json.load(f) + for image_1, image_2, instruction in self.keys: + image_1 = os.path.join(relative_path, metadata[image_1]) if image_1 is not None else None + image_2 = os.path.join(relative_path, metadata[image_2]) + instruction = metadata[instruction] + data_list.append((image_1, image_2, instruction)) + return data_list + + + def search_for_data(self, path, report_data_log=False): + now_path = os.path.join(self.base_path, path) + if os.path.isfile(now_path) and path.endswith(".json"): + try: + data_list = self.parse_json_file(now_path, os.path.dirname(path)) + self.metadata.extend(data_list) + except: + self.bad_data.append(now_path) + elif os.path.isdir(now_path): + for sub_path in os.listdir(now_path): + self.search_for_data(os.path.join(path, sub_path)) + if report_data_log and os.path.isdir(os.path.join(self.base_path, path, sub_path)): + self.report_data_log() + + + def load_image(self, image_path, skip_process=False): + image_path = os.path.join(self.base_path, image_path) + image = Image.open(image_path).convert("RGB") + width, height = image.size + scale = max(self.width / width, self.height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + if skip_process: + return image + image = self.image_process(image) + return image + + + def load_data(self, data_id): + image_1, image_2, instruction = self.metadata[data_id] + image_1 = self.load_image(image_1, skip_process=True) if image_1 is not None else None + image_2 = self.load_image(image_2) + return {"image_1": image_1, "image_2": image_2, "instruction": instruction} + + + def __getitem__(self, data_id): + if self.random: + data_id = (torch.randint(0, len(self.metadata), (1,))[0] + data_id) % len(self.metadata) + data = self.load_data(data_id) + return data + else: + return self.load_data(data_id) + + + def __len__(self): + return self.steps_per_epoch if self.random else len(self.metadata) + + + +class MultiTaskDataset(torch.utils.data.Dataset): + def __init__(self, dataset_list, dataset_weight, steps_per_epoch=1000): + self.dataset_list = dataset_list + self.dataset_weight = torch.tensor(dataset_weight, dtype=torch.float) + self.steps_per_epoch = steps_per_epoch + + + def __getitem__(self, data_id): + dataset_id = torch.multinomial(self.dataset_weight, 1).tolist()[0] + data_id = torch.randint(0, len(self.dataset_list[dataset_id]), (1,))[0] + data = self.dataset_list[dataset_id][data_id] + return data + + + def __len__(self): + return self.steps_per_epoch @@ -113,24 +235,78 @@ qwenvl = NexusGenQwenVLEncoder.from_pretrained('models/DiffSynth-Studio/Nexus-Ge sd = {} for i in range(1, 6): print(i) - sd.update(load_state_dict(f"models/nexus_v1/epoch-8/model-0000{i}-of-00005.safetensors", torch_dtype=torch.bfloat16)) + sd.update(load_state_dict(f"models/nexus_v3/epoch-19/model-0000{i}-of-00005.safetensors", torch_dtype=torch.bfloat16)) pipe.dit.load_state_dict({i.replace("pipe.dit.", ""): sd[i] for i in sd if i.startswith("pipe.dit.")}) qwenvl.load_state_dict({i.replace("qwenvl.", ""): sd[i] for i in sd if i.startswith("qwenvl.")}) adapter.load_state_dict({i.replace("adapter.", ""): sd[i] for i in sd if i.startswith("adapter.")}) -for i in sd: - if (not i.startswith("pipe.dit")) and (not i.startswith("qwenvl.")) and (not i.startswith("adapter.")): - print(i) -with torch.no_grad(): - instruction = "Generate an image according to the following description: hyper-realistic and detailed 2010s movie still portrait of Josip Broz Tito, by Paolo Sorrentino, Leica SL2 50mm, clear color, high quality, high textured, dramatic light, cinematic" - emb = qwenvl(instruction, images=None) - emb = adapter(emb) - image = pipe("", image_emb=emb, height=512, width=512) - image.save("image_1.jpg") + +dataset = MultiTaskDataset( + dataset_list=[ + SingleTaskDataset( + "/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove", + keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")), + height=1024, width=1024, + metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_change_add_remove.json", + ), + SingleTaskDataset( + "/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer", + keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")), + height=1024, width=1024, + metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_style_transfer.json", + ), + SingleTaskDataset( + "/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid", + keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")), + height=1024, width=1024, + metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_faceid.json", + ), + ], + dataset_weight=(4, 2, 1,), + steps_per_epoch=100000 +) + + +torch.manual_seed(0) +for data_id, data in enumerate(dataset): + image_1 = data["image_1"] + image_2 = data["image_2"].cpu().float().permute(1, 2, 0).numpy() + image_2 = Image.fromarray(((image_2 / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) + instruction = data["instruction"] + + print(instruction) + if image_1 is None: + with torch.no_grad(): + instruction = f"Generate an image according to the following description: {instruction}" + emb = qwenvl(instruction, images=None) + emb = adapter(emb) + image_3 = pipe("", image_emb=emb) + else: + with torch.no_grad(): + instruction = f"<|vision_start|><|image_pad|><|vision_end|> {instruction}" + emb = qwenvl(instruction, images=[image_1]) + emb = adapter(emb) + image_3 = pipe("", image_emb=emb) -with torch.no_grad(): - instruction = "<|vision_start|><|image_pad|><|vision_end|> transform the image into a cartoon style with vibrant colors and a confident expression." - emb = qwenvl(instruction, images=[Image.open("image_1.jpg")]) - emb = adapter(emb) - image = pipe("", image_emb=emb, height=512, width=512) - image.save("image_2.jpg") + if image_1 is not None: + image_1.save(f"data/output/{data_id}_1.jpg") + image_2.save(f"data/output/{data_id}_2.jpg") + image_3.save(f"data/output/{data_id}_3.jpg") + if data_id >= 100: + break + + + +# with torch.no_grad(): +# instruction = "Generate an image according to the following description: hyper-realistic and detailed 2010s movie still portrait of Josip Broz Tito, by Paolo Sorrentino, Leica SL2 50mm, clear color, high quality, high textured, dramatic light, cinematic" +# emb = qwenvl(instruction, images=None) +# emb = adapter(emb) +# image = pipe("", image_emb=emb) +# image.save("image_1.jpg") + +# with torch.no_grad(): +# instruction = "<|vision_start|><|image_pad|><|vision_end|> transform the image into a cartoon style with vibrant colors and a confident expression." +# emb = qwenvl(instruction, images=[Image.open("image_1.jpg")]) +# emb = adapter(emb) +# image = pipe("", image_emb=emb) +# image.save("image_2.jpg") diff --git a/train.py b/train.py index bec795b..edaaeea 100644 --- a/train.py +++ b/train.py @@ -254,9 +254,12 @@ class UnifiedModel(pl.LightningModule): self.pipe.vae_decoder.requires_grad_(False) self.pipe.vae_encoder.requires_grad_(False) self.pipe.text_encoder_1.requires_grad_(False) + self.qwenvl.requires_grad_(False) + self.qwenvl.model.visual.requires_grad_(False) self.pipe.train() self.adapter.train() self.qwenvl.train() + self.qwenvl.model.visual.eval() # self.qwenvl.model.model.gradient_checkpointing = True self.pipe.scheduler.set_timesteps(1000, training=True) @@ -302,12 +305,6 @@ class UnifiedModel(pl.LightningModule): def forward(self, batch): return self.training_step(batch, 0) - - - def configure_optimizers(self): - trainable_modules = filter(lambda p: p.requires_grad, self.pipe.parameters()) - optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) - return optimizer @@ -369,12 +366,25 @@ if __name__ == '__main__': dataset = MultiTaskDataset( dataset_list=[ SingleTaskDataset( - "data/example_dataset", + "/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_change_add_remove", + keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")), + height=1024, width=1024, + metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_change_add_remove.json", + ), + SingleTaskDataset( + "/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_style_transfer", keys=(("image_1", "image_4", "editing_instruction"), ("image_4", "image_1", "reverse_editing_instruction"), (None, "image_1", "prompt")), - height=512, width=512, + height=1024, width=1024, + metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_style_transfer.json", + ), + SingleTaskDataset( + "/shark/zhongjie/data/image_pulse_datasets/task1/data/dataset_faceid", + keys=(("image_1", "image_2", "editing_instruction"), ("image_2", "image_1", "reverse_editing_instruction")), + height=1024, width=1024, + metadata_path="/shark/zhongjie/data/image_pulse_datasets/task1/data/metadata/20250507_dataset_faceid.json", ), ], - dataset_weight=(1,), + dataset_weight=(4, 2, 1,), steps_per_epoch=args.steps_per_epoch * accelerator.num_processes, ) train_loader = torch.utils.data.DataLoader(