From 4e3a184199d8c164e975081757755635d0eb5b80 Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Mon, 14 Oct 2024 10:00:32 +0800 Subject: [PATCH] update flux training --- diffsynth/schedulers/flow_match.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/diffsynth/schedulers/flow_match.py b/diffsynth/schedulers/flow_match.py index 6678dc5..ab965ee 100644 --- a/diffsynth/schedulers/flow_match.py +++ b/diffsynth/schedulers/flow_match.py @@ -16,17 +16,13 @@ class FlowMatchScheduler(): sigma_start = self.sigma_min + (self.sigma_max - self.sigma_min) * denoising_strength self.sigmas = torch.linspace(sigma_start, self.sigma_min, num_inference_steps) self.sigmas = self.shift * self.sigmas / (1 + (self.shift - 1) * self.sigmas) + self.timesteps = self.sigmas * self.num_train_timesteps if training: - self.timesteps = torch.linspace(1000, 0, num_inference_steps) - - # prepare timestep weights - x = torch.arange(num_inference_steps, dtype=torch.float32) + x = self.timesteps y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2) y_shifted = y - y.min() bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum()) self.linear_timesteps_weights = bsmntw_weighing - else: - self.timesteps = self.sigmas * self.num_train_timesteps def step(self, model_output, timestep, sample, to_final=False):