This commit is contained in:
xuyixuan.xyx
2025-05-07 11:22:13 +08:00
parent 290ec469ca
commit f17558a4c4
4 changed files with 47 additions and 21 deletions

View File

@@ -1181,13 +1181,18 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
next_decoder_cache = None next_decoder_cache = None
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for decoder_layer in self.layers: for decoder_layer in self.layers:
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func( layer_outputs = torch.utils.checkpoint.checkpoint(
decoder_layer.__call__, create_custom_forward(decoder_layer),
hidden_states, hidden_states,
causal_mask, causal_mask,
position_ids, position_ids,
@@ -1196,7 +1201,19 @@ class Qwen2_5_VLModel(Qwen2_5_VLPreTrainedModel):
use_cache, use_cache,
cache_position, cache_position,
position_embeddings, position_embeddings,
use_reentrant=False,
) )
# layer_outputs = self._gradient_checkpointing_func(
# decoder_layer.__call__,
# hidden_states,
# causal_mask,
# position_ids,
# past_key_values,
# output_attentions,
# use_cache,
# cache_position,
# position_embeddings,
# )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,

View File

@@ -1,7 +1,7 @@
torch>=2.0.0 torch>=2.0.0
torchvision torchvision
cupy-cuda12x cupy-cuda12x
transformers==4.46.2 transformers==4.49.0
controlnet-aux==0.0.7 controlnet-aux==0.0.7
imageio imageio
imageio[ffmpeg] imageio[ffmpeg]
@@ -11,3 +11,4 @@ sentencepiece
protobuf protobuf
modelscope modelscope
ftfy ftfy
qwen_vl_utils

32
test.py
View File

@@ -102,31 +102,35 @@ model_manager.load_models([
]) ])
pipe = FluxImagePipeline.from_model_manager(model_manager) pipe = FluxImagePipeline.from_model_manager(model_manager)
state_dict = load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16) # state_dict = load_state_dict("models/DiffSynth-Studio/Nexus-Gen/decoder_81_512.bin", torch_dtype=torch.bfloat16)
pipe.dit.load_state_dict(state_dict, strict=False) # pipe.dit.load_state_dict(state_dict, strict=False)
adapter = torch.nn.Sequential(torch.nn.Linear(3584, 4096), torch.nn.LayerNorm(4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.LayerNorm(4096)).to(dtype=torch.bfloat16, device="cuda") adapter = torch.nn.Sequential(torch.nn.Linear(3584, 4096), torch.nn.LayerNorm(4096), torch.nn.ReLU(), torch.nn.Linear(4096, 4096), torch.nn.LayerNorm(4096)).to(dtype=torch.bfloat16, device="cuda")
adapter.load_state_dict(state_dict, strict=False) # adapter.load_state_dict(state_dict, strict=False)
qwenvl = NexusGenQwenVLEncoder.from_pretrained('models/DiffSynth-Studio/Nexus-Gen').to("cuda") qwenvl = NexusGenQwenVLEncoder.from_pretrained('models/DiffSynth-Studio/Nexus-Gen').to("cuda")
sd = {}
for i in range(1, 6):
print(i)
sd.update(load_state_dict(f"models/nexus_v1/epoch-8/model-0000{i}-of-00005.safetensors", torch_dtype=torch.bfloat16))
pipe.dit.load_state_dict({i.replace("pipe.dit.", ""): sd[i] for i in sd if i.startswith("pipe.dit.")})
qwenvl.load_state_dict({i.replace("qwenvl.", ""): sd[i] for i in sd if i.startswith("qwenvl.")})
adapter.load_state_dict({i.replace("adapter.", ""): sd[i] for i in sd if i.startswith("adapter.")})
for i in sd:
if (not i.startswith("pipe.dit")) and (not i.startswith("qwenvl.")) and (not i.startswith("adapter.")):
print(i)
with torch.no_grad(): with torch.no_grad():
instruction = "Generate an image according to the following description: a beautiful Asian girl" instruction = "Generate an image according to the following description: hyper-realistic and detailed 2010s movie still portrait of Josip Broz Tito, by Paolo Sorrentino, Leica SL2 50mm, clear color, high quality, high textured, dramatic light, cinematic"
emb = qwenvl(instruction, images=None) emb = qwenvl(instruction, images=None)
emb = adapter(emb) emb = adapter(emb)
image = pipe("", image_emb=emb) image = pipe("", image_emb=emb, height=512, width=512)
image.save("image_1.jpg") image.save("image_1.jpg")
with torch.no_grad(): with torch.no_grad():
instruction = "<|vision_start|><|image_pad|><|vision_end|> Add sunglasses." instruction = "<|vision_start|><|image_pad|><|vision_end|> transform the image into a cartoon style with vibrant colors and a confident expression."
emb = qwenvl(instruction, images=[Image.open("image_1.jpg")]) emb = qwenvl(instruction, images=[Image.open("image_1.jpg")])
emb = adapter(emb) emb = adapter(emb)
image = pipe("", image_emb=emb) image = pipe("", image_emb=emb, height=512, width=512)
image.save("image_2.jpg") image.save("image_2.jpg")
with torch.no_grad():
instruction = "<|vision_start|><|image_pad|><|vision_end|> Let her smile."
emb = qwenvl(instruction, images=[Image.open("image_2.jpg")])
emb = adapter(emb)
image = pipe("", image_emb=emb)
image.save("image_3.jpg")

View File

@@ -89,8 +89,6 @@ class SingleTaskDataset(torch.utils.data.Dataset):
def load_image(self, image_path, skip_process=False): def load_image(self, image_path, skip_process=False):
image_path = os.path.join(self.base_path, image_path) image_path = os.path.join(self.base_path, image_path)
image = Image.open(image_path).convert("RGB") image = Image.open(image_path).convert("RGB")
if skip_process:
return image
width, height = image.size width, height = image.size
scale = max(self.width / width, self.height / height) scale = max(self.width / width, self.height / height)
image = torchvision.transforms.functional.resize( image = torchvision.transforms.functional.resize(
@@ -98,6 +96,8 @@ class SingleTaskDataset(torch.utils.data.Dataset):
(round(height*scale), round(width*scale)), (round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR interpolation=torchvision.transforms.InterpolationMode.BILINEAR
) )
if skip_process:
return image
image = self.image_process(image) image = self.image_process(image)
return image return image
@@ -254,6 +254,10 @@ class UnifiedModel(pl.LightningModule):
self.pipe.vae_decoder.requires_grad_(False) self.pipe.vae_decoder.requires_grad_(False)
self.pipe.vae_encoder.requires_grad_(False) self.pipe.vae_encoder.requires_grad_(False)
self.pipe.text_encoder_1.requires_grad_(False) self.pipe.text_encoder_1.requires_grad_(False)
self.pipe.train()
self.adapter.train()
self.qwenvl.train()
# self.qwenvl.model.model.gradient_checkpointing = True
self.pipe.scheduler.set_timesteps(1000, training=True) self.pipe.scheduler.set_timesteps(1000, training=True)
@@ -289,7 +293,7 @@ class UnifiedModel(pl.LightningModule):
self.pipe.denoising_model(), self.pipe.denoising_model(),
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input, hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
image_emb=emb, image_emb=emb,
use_gradient_checkpointing=True use_gradient_checkpointing=False
) )
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float()) loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep) loss = loss * self.pipe.scheduler.training_weight(timestep)
@@ -331,7 +335,7 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--steps_per_epoch", "--steps_per_epoch",
type=int, type=int,
default=100, default=1000,
help="steps_per_epoch", help="steps_per_epoch",
) )
parser.add_argument( parser.add_argument(