support video-to-video-translation

This commit is contained in:
Artiprocher
2023-12-21 17:11:58 +08:00
parent f7f4c1038e
commit c1453281df
20 changed files with 1659 additions and 427 deletions

View File

@@ -3,9 +3,14 @@ import torch, math
class EnhancedDDIMScheduler():
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012):
def __init__(self, num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"):
self.num_train_timesteps = num_train_timesteps
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
if beta_schedule == "scaled_linear":
betas = torch.square(torch.linspace(math.sqrt(beta_start), math.sqrt(beta_end), num_train_timesteps, dtype=torch.float32))
elif beta_schedule == "linear":
betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
else:
raise NotImplementedError(f"{beta_schedule} is not implemented")
self.alphas_cumprod = torch.cumprod(1.0 - betas, dim=0).tolist()
self.set_timesteps(10)
@@ -34,14 +39,14 @@ class EnhancedDDIMScheduler():
return prev_sample
def step(self, model_output, timestep, sample):
def step(self, model_output, timestep, sample, to_final=False):
alpha_prod_t = self.alphas_cumprod[timestep]
timestep_id = self.timesteps.index(timestep)
if timestep_id + 1 < len(self.timesteps):
if to_final or timestep_id + 1 >= len(self.timesteps):
alpha_prod_t_prev = 1.0
else:
timestep_prev = self.timesteps[timestep_id + 1]
alpha_prod_t_prev = self.alphas_cumprod[timestep_prev]
else:
alpha_prod_t_prev = 1.0
return self.denoise(model_output, sample, alpha_prod_t, alpha_prod_t_prev)