mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 00:58:11 +00:00
rebuild base modules
This commit is contained in:
@@ -22,10 +22,10 @@ class EnhancedDDIMScheduler():
|
||||
max_timestep = max(round(self.num_train_timesteps * denoising_strength) - 1, 0)
|
||||
num_inference_steps = min(num_inference_steps, max_timestep + 1)
|
||||
if num_inference_steps == 1:
|
||||
self.timesteps = [max_timestep]
|
||||
self.timesteps = torch.Tensor([max_timestep])
|
||||
else:
|
||||
step_length = max_timestep / (num_inference_steps - 1)
|
||||
self.timesteps = [round(max_timestep - i*step_length) for i in range(num_inference_steps)]
|
||||
self.timesteps = torch.Tensor([round(max_timestep - i*step_length) for i in range(num_inference_steps)])
|
||||
|
||||
|
||||
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
|
||||
@@ -43,31 +43,37 @@ class EnhancedDDIMScheduler():
|
||||
|
||||
|
||||
def step(self, model_output, timestep, sample, to_final=False):
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
timestep_id = self.timesteps.index(timestep)
|
||||
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||
alpha_prod_t_prev = 1.0
|
||||
else:
|
||||
timestep_prev = self.timesteps[timestep_id + 1]
|
||||
timestep_prev = int(self.timesteps[timestep_id + 1])
|
||||
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
|
||||
|
||||
return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)
|
||||
|
||||
|
||||
def return_to_timestep(self, timestep, sample, sample_stablized):
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t = self.alphas_cumprod[int(timestep.flatten().tolist()[0])]
|
||||
noise_pred = (sample - math.sqrt(alpha_prod_t) * sample_stablized) / math.sqrt(1 - alpha_prod_t)
|
||||
return noise_pred
|
||||
|
||||
|
||||
def add_noise(self, original_samples, noise, timestep):
|
||||
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep])
|
||||
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep])
|
||||
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
|
||||
|
||||
def training_target(self, sample, noise, timestep):
|
||||
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep])
|
||||
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep])
|
||||
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return target
|
||||
if self.prediction_type == "epsilon":
|
||||
return noise
|
||||
else:
|
||||
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[int(timestep.flatten().tolist()[0])])
|
||||
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
|
||||
return target
|
||||
|
||||
@@ -20,6 +20,8 @@ class FlowMatchScheduler():
|
||||
|
||||
|
||||
def step(self, model_output, timestep, sample, to_final=False):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
if to_final or timestep_id + 1 >= len(self.timesteps):
|
||||
@@ -36,6 +38,8 @@ class FlowMatchScheduler():
|
||||
|
||||
|
||||
def add_noise(self, original_samples, noise, timestep):
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.cpu()
|
||||
timestep_id = torch.argmin((self.timesteps - timestep).abs())
|
||||
sigma = self.sigmas[timestep_id]
|
||||
sample = (1 - sigma) * original_samples + sigma * noise
|
||||
|
||||
Reference in New Issue
Block a user