This commit is contained in:
xuyixuan.xyx
2025-05-07 11:22:13 +08:00
parent 290ec469ca
commit f17558a4c4
4 changed files with 47 additions and 21 deletions

View File

@@ -89,8 +89,6 @@ class SingleTaskDataset(torch.utils.data.Dataset):
def load_image(self, image_path, skip_process=False):
image_path = os.path.join(self.base_path, image_path)
image = Image.open(image_path).convert("RGB")
if skip_process:
return image
width, height = image.size
scale = max(self.width / width, self.height / height)
image = torchvision.transforms.functional.resize(
@@ -98,6 +96,8 @@ class SingleTaskDataset(torch.utils.data.Dataset):
(round(height*scale), round(width*scale)),
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
)
if skip_process:
return image
image = self.image_process(image)
return image
@@ -254,6 +254,10 @@ class UnifiedModel(pl.LightningModule):
self.pipe.vae_decoder.requires_grad_(False)
self.pipe.vae_encoder.requires_grad_(False)
self.pipe.text_encoder_1.requires_grad_(False)
self.pipe.train()
self.adapter.train()
self.qwenvl.train()
# self.qwenvl.model.model.gradient_checkpointing = True
self.pipe.scheduler.set_timesteps(1000, training=True)
@@ -289,7 +293,7 @@ class UnifiedModel(pl.LightningModule):
self.pipe.denoising_model(),
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
image_emb=emb,
use_gradient_checkpointing=True
use_gradient_checkpointing=False
)
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
loss = loss * self.pipe.scheduler.training_weight(timestep)
@@ -331,7 +335,7 @@ def parse_args():
parser.add_argument(
"--steps_per_epoch",
type=int,
default=100,
default=1000,
help="steps_per_epoch",
)
parser.add_argument(