z image distill

This commit is contained in:
Artiprocher
2025-12-03 11:20:49 +08:00
parent 5065c9ef6a
commit 4a80e9c179
8 changed files with 185 additions and 26 deletions

View File

@@ -101,7 +101,7 @@ class FlowMatchScheduler():
return sigmas, timesteps
@staticmethod
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None):
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
sigma_min = 0.0
sigma_max = 1.0
shift = 3 if shift is None else shift
@@ -110,6 +110,11 @@ class FlowMatchScheduler():
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_train_timesteps
if target_timesteps is not None:
target_timesteps = target_timesteps.to(dtype=timesteps.dtype, device=timesteps.device)
for timestep in target_timesteps:
timestep_id = torch.argmin((timesteps - timestep).abs())
timesteps[timestep_id] = timestep
return sigmas, timesteps
def set_training_weight(self):
@@ -118,6 +123,10 @@ class FlowMatchScheduler():
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
y_shifted = y - y.min()
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
if len(self.timesteps) != 1000:
# This is an empirical formula.
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
self.linear_timesteps_weights = bsmntw_weighing
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):