mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-24 06:46:13 +00:00
sd and sdxl training
This commit is contained in:
@@ -124,6 +124,8 @@ class DDIMScheduler:
|
||||
else:
|
||||
raise ValueError(f"Unsupported timestep_spacing: {self.timestep_spacing}")
|
||||
|
||||
# Clamp timesteps to valid range [0, num_train_timesteps - 1]
|
||||
timesteps = np.clip(timesteps, 0, self.num_train_timesteps - 1)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.int64)
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
@@ -222,6 +224,8 @@ class DDIMScheduler:
|
||||
timestep = timestep[0].item()
|
||||
|
||||
timestep = int(timestep)
|
||||
# Defensive clamp: ensure timestep is within valid range
|
||||
timestep = max(0, min(timestep, self.num_train_timesteps - 1))
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt()
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timestep]).sqrt()
|
||||
|
||||
@@ -238,6 +242,14 @@ class DDIMScheduler:
|
||||
|
||||
def training_target(self, sample, noise, timestep):
|
||||
"""Return the training target for the given prediction type."""
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
if timestep.dim() == 0:
|
||||
timestep = timestep.item()
|
||||
elif timestep.dim() == 1:
|
||||
timestep = timestep[0].item()
|
||||
timestep = int(timestep)
|
||||
timestep = max(0, min(timestep, self.num_train_timesteps - 1))
|
||||
if self.prediction_type == "epsilon":
|
||||
return noise
|
||||
elif self.prediction_type == "v_prediction":
|
||||
@@ -251,5 +263,7 @@ class DDIMScheduler:
|
||||
|
||||
def training_weight(self, timestep):
|
||||
"""Return training weight for the given timestep."""
|
||||
timestep_id = torch.argmin((self.timesteps - timestep.to(self.timesteps.device)).abs())
|
||||
timestep = max(0, min(int(timestep), self.num_train_timesteps - 1))
|
||||
timestep_tensor = torch.tensor(timestep, device=self.timesteps.device)
|
||||
timestep_id = torch.argmin((self.timesteps - timestep_tensor).abs())
|
||||
return self.linear_timesteps_weights[timestep_id]
|
||||
|
||||
Reference in New Issue
Block a user