This commit is contained in:
Artiprocher
2025-11-30 19:04:21 +08:00
parent b106458eac
commit 20cf2317e0
5 changed files with 12 additions and 11 deletions

View File

@@ -12,6 +12,7 @@ class FlowMatchScheduler():
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
}.get(template, FlowMatchScheduler.set_timesteps_flux)
self.num_train_timesteps = 1000
@staticmethod
def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None):
@@ -111,11 +112,12 @@ class FlowMatchScheduler():
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
def set_training_weight(self, num_inference_steps):
def set_training_weight(self):
steps = 1000
x = self.timesteps
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
y_shifted = y - y.min()
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
self.linear_timesteps_weights = bsmntw_weighing
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
@@ -125,7 +127,7 @@ class FlowMatchScheduler():
**kwargs,
)
if training:
self.set_training_weight(num_inference_steps)
self.set_training_weight()
self.training = True
else:
self.training = False