mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
train
This commit is contained in:
12
train.py
12
train.py
@@ -89,8 +89,6 @@ class SingleTaskDataset(torch.utils.data.Dataset):
|
||||
def load_image(self, image_path, skip_process=False):
|
||||
image_path = os.path.join(self.base_path, image_path)
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
if skip_process:
|
||||
return image
|
||||
width, height = image.size
|
||||
scale = max(self.width / width, self.height / height)
|
||||
image = torchvision.transforms.functional.resize(
|
||||
@@ -98,6 +96,8 @@ class SingleTaskDataset(torch.utils.data.Dataset):
|
||||
(round(height*scale), round(width*scale)),
|
||||
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
|
||||
)
|
||||
if skip_process:
|
||||
return image
|
||||
image = self.image_process(image)
|
||||
return image
|
||||
|
||||
@@ -254,6 +254,10 @@ class UnifiedModel(pl.LightningModule):
|
||||
self.pipe.vae_decoder.requires_grad_(False)
|
||||
self.pipe.vae_encoder.requires_grad_(False)
|
||||
self.pipe.text_encoder_1.requires_grad_(False)
|
||||
self.pipe.train()
|
||||
self.adapter.train()
|
||||
self.qwenvl.train()
|
||||
# self.qwenvl.model.model.gradient_checkpointing = True
|
||||
|
||||
self.pipe.scheduler.set_timesteps(1000, training=True)
|
||||
|
||||
@@ -289,7 +293,7 @@ class UnifiedModel(pl.LightningModule):
|
||||
self.pipe.denoising_model(),
|
||||
hidden_states=noisy_latents, timestep=timestep, **prompt_emb, **extra_input,
|
||||
image_emb=emb,
|
||||
use_gradient_checkpointing=True
|
||||
use_gradient_checkpointing=False
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * self.pipe.scheduler.training_weight(timestep)
|
||||
@@ -331,7 +335,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--steps_per_epoch",
|
||||
type=int,
|
||||
default=100,
|
||||
default=1000,
|
||||
help="steps_per_epoch",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user