support HunyuanDiT

This commit is contained in:
Artiprocher
2024-06-05 13:00:39 +08:00
parent 83461d400c
commit 78f53a0754
23 changed files with 69988 additions and 185 deletions

View File

@@ -3,7 +3,7 @@ import torch, math
class EnhancedDDIMScheduler():
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"):
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", prediction_type="epsilon"):
self.num_train_timesteps = num_train_timesteps
if beta_schedule == "scaled_linear":
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
@@ -13,6 +13,7 @@ class EnhancedDDIMScheduler():
raise NotImplementedError(f"{beta_schedule} is not implemented")
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0).tolist()
self.set_timesteps(10)
self.prediction_type = prediction_type
def set_timesteps(self, num_inference_steps, denoising_strength=1.0):
@@ -28,9 +29,16 @@ class EnhancedDDIMScheduler():
def denoise(self, model_output, sample, alpha_prod_t, alpha_prod_t_prev):
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
prev_sample = sample * weight_x + model_output * weight_e
if self.prediction_type == "epsilon":
weight_e = math.sqrt(1 - alpha_prod_t_prev) - math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t) / alpha_prod_t)
weight_x = math.sqrt(alpha_prod_t_prev / alpha_prod_t)
prev_sample = sample * weight_x + model_output * weight_e
elif self.prediction_type == "v_prediction":
weight_e = -math.sqrt(alpha_prod_t_prev * (1 - alpha_prod_t)) + math.sqrt(alpha_prod_t * (1 - alpha_prod_t_prev))
weight_x = math.sqrt(alpha_prod_t * alpha_prod_t_prev) + math.sqrt((1 - alpha_prod_t) * (1 - alpha_prod_t_prev))
prev_sample = sample * weight_x + model_output * weight_e
else:
raise NotImplementedError(f"{self.prediction_type} is not implemented")
return prev_sample
@@ -57,4 +65,9 @@ class EnhancedDDIMScheduler():
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep])
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
def training_target(self, sample, noise, timestep):
sqrt_alpha_prod = math.sqrt(self.alphas_cumprod[timestep])
sqrt_one_minus_alpha_prod = math.sqrt(1 - self.alphas_cumprod[timestep])
target = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
return target