sd and sdxl training

This commit is contained in:
mi804
2026-04-24 10:19:58 +08:00
parent a8a0f082bb
commit 5cdab9ed01
17 changed files with 543 additions and 26 deletions

View File

@@ -124,6 +124,8 @@ class DDIMScheduler:
else:
raise ValueError(f"Unsupported timestep_spacing: {self.timestep_spacing}")
# Clamp timesteps to valid range [0, num_train_timesteps - 1]
timesteps = np.clip(timesteps, 0, self.num_train_timesteps - 1)
self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.int64)
self.num_inference_steps = num_inference_steps
@@ -222,6 +224,8 @@ class DDIMScheduler:
timestep = timestep[0].item()
timestep = int(timestep)
# Defensive clamp: ensure timestep is within valid range
timestep = max(0, min(timestep, self.num_train_timesteps - 1))
sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt()
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timestep]).sqrt()
@@ -238,6 +242,14 @@ class DDIMScheduler:
def training_target(self, sample, noise, timestep):
"""Return the training target for the given prediction type."""
if isinstance(timestep, torch.Tensor):
timestep = timestep.cpu()
if timestep.dim() == 0:
timestep = timestep.item()
elif timestep.dim() == 1:
timestep = timestep[0].item()
timestep = int(timestep)
timestep = max(0, min(timestep, self.num_train_timesteps - 1))
if self.prediction_type == "epsilon":
return noise
elif self.prediction_type == "v_prediction":
@@ -251,5 +263,7 @@ class DDIMScheduler:
def training_weight(self, timestep):
"""Return training weight for the given timestep."""
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
timestep = max(0, min(int(timestep), self.num_train_timesteps - 1))
timestep_tensor = torch.tensor(timestep, device=self.timesteps.device)
timestep_id = torch.argmin((self.timesteps - timestep_tensor).abs())
return self.linear_timesteps_weights[timestep_id]