From f17558a4c490f22a0515b5dcc9a52787ec36d7ba Mon Sep 17 00:00:00 2001 From: "xuyixuan.xyx" Date: Wed, 7 May 2025 11:22:13 +0800 Subject: [PATCH] train --- modeling/ar/modeling_qwen2_5_vl.py | 21 ++++++++++++++++++-- requirements.txt | 3 ++- test.py | 32 +++++++++++++++++------------- train.py | 12 +++++++---- 4 files changed, 47 insertions(+), 21 deletions(-) diff --git a/modeling/ar/modeling_qwen2_5_vl.py b/modeling/ar/modeling_qwen2_5_vl.py index 3cf6c37..7187d09 100644 --- a/modeling/ar/modeling_qwen2_5_vl.py +++ b/modeling/ar/modeling_qwen2_5_vl.py @@ -1181,13 +1181,18 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): all_self_attns = () if output_attentions else None next_decoder_cache = None + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward + for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), hidden_states, causal_mask, position_ids, @@ -1196,7 +1201,19 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel): use_cache, cache_position, position_embeddings, + use_reentrant=False, ) + # layer_outputs = self._gradient_checkpointing_func( + # decoder_layer.__call__, + # hidden_states, + # causal_mask, + # position_ids, + # past_key_values, + # output_attentions, + # use_cache, + # cache_position, + # position_embeddings, + # ) else: layer_outputs = decoder_layer( hidden_states, diff --git a/requirements.txt b/requirements.txt index 63a871b..0beb6ed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torch>=2.0.0 torchvision cupy-cuda12x -transformers==4.46.2 +transformers==4.49.0 controlnet-aux==0.0.7 imageio imageio[ffmpeg] @@ -11,3 +11,4 @@ sentencepiece protobuf modelscope ftfy +qwen_vl_utils diff --git a/test.py b/test.py index 5e0a073..4418e53 100644 --- a/test.py +++ b/test.py @@ -102,31 +102,35 @@ model_manager.load_models([ ]) pipe = FluxImagePipeline.from_model_manager(model_manager) -state_dict = load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16) -pipe.dit.load_state_dict(state_dict, strict=False) +# state_dict = load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16) +# pipe.dit.load_state_dict(state_dict, strict=False) adapter = torch.nn.Sequential(torch.nn.Linear(3584, 4096), torch.nn.LayerNorm(4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.LayerNorm(4096)).to(dtype=torch.bfloat16, device="cuda") -adapter.load_state_dict(state_dict, strict=False) +# adapter.load_state_dict(state_dict, strict=False) qwenvl = NexusGenQwenVLEncoder.from_pretrained('models/DiffSynth-Studio/Nexus-Gen').to("cuda") +sd = {} +for i in range(1, 6): + print(i) + sd.update(load_state_dict(f"models/nexus_v1/epoch-8/model-0000{i}-of-00005.safetensors", torch_dtype=torch.bfloat16)) +pipe.dit.load_state_dict({i.replace("pipe.dit.", ""): sd[i] for i in sd if i.startswith("pipe.dit.")}) +qwenvl.load_state_dict({i.replace("qwenvl.", ""): sd[i] for i in sd if i.startswith("qwenvl.")}) +adapter.load_state_dict({i.replace("adapter.", ""): sd[i] for i in sd if i.startswith("adapter.")}) +for i in sd: + if (not i.startswith("pipe.dit")) and (not i.startswith("qwenvl.")) and (not i.startswith("adapter.")): + print(i) + with torch.no_grad(): - instruction = "Generate an image according to the following description: a beautiful Asian girl" + instruction = "Generate an image according to the following description: hyper-realistic and detailed 2010s movie still portrait of Josip Broz Tito, by Paolo Sorrentino, Leica SL2 50mm, clear color, high quality, high textured, dramatic light, cinematic" emb = qwenvl(instruction, images=None) emb = adapter(emb) - image = pipe("", image_emb=emb) + image = pipe("", image_emb=emb, height=512, width=512) image.save("image_1.jpg") with torch.no_grad(): - instruction = "<|vision_start|><|image_pad|><|vision_end|> Add sunglasses." + instruction = "<|vision_start|><|image_pad|><|vision_end|> transform the image into a cartoon style with vibrant colors and a confident expression." emb = qwenvl(instruction, images=[Image.open("image_1.jpg")]) emb = adapter(emb) - image = pipe("", image_emb=emb) + image = pipe("", image_emb=emb, height=512, width=512) image.save("image_2.jpg") - -with torch.no_grad(): - instruction = "<|vision_start|><|image_pad|><|vision_end|> Let her smile." - emb = qwenvl(instruction, images=[Image.open("image_2.jpg")]) - emb = adapter(emb) - image = pipe("", image_emb=emb) - image.save("image_3.jpg") diff --git a/train.py b/train.py index 5aa81d8..bec795b 100644 --- a/train.py +++ b/train.py @@ -89,8 +89,6 @@ class SingleTaskDataset(torch.utils.data.Dataset): def load_image(self, image_path, skip_process=False): image_path = os.path.join(self.base_path, image_path) image = Image.open(image_path).convert("RGB") - if skip_process: - return image width, height = image.size scale = max(self.width / width, self.height / height) image = torchvision.transforms.functional.resize( @@ -98,6 +96,8 @@ class SingleTaskDataset(torch.utils.data.Dataset): (round(height*scale), round(width*scale)), interpolation=torchvision.transforms.InterpolationMode.BILINEAR ) + if skip_process: + return image image = self.image_process(image) return image @@ -254,6 +254,10 @@ class UnifiedModel(pl.LightningModule): self.pipe.vae_decoder.requires_grad_(False) self.pipe.vae_encoder.requires_grad_(False) self.pipe.text_encoder_1.requires_grad_(False) + self.pipe.train() + self.adapter.train() + self.qwenvl.train() + # self.qwenvl.model.model.gradient_checkpointing = True self.pipe.scheduler.set_timesteps(1000, training=True) @@ -289,7 +293,7 @@ class UnifiedModel(pl.LightningModule): self.pipe.denoising_model(), hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input, image_emb=emb, - use_gradient_checkpointing=True + use_gradient_checkpointing=False ) loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) loss = loss * self.pipe.scheduler.training_weight(timestep) @@ -331,7 +335,7 @@ def parse_args(): parser.add_argument( "--steps_per_epoch", type=int, - default=100, + default=1000, help="steps_per_epoch", ) parser.add_argument(