training framework

This commit is contained in:
Artiprocher
2025-05-12 17:48:28 +08:00
parent dbef6122e9
commit 675eefa07e
20 changed files with 939 additions and 174 deletions

View File

@@ -35,6 +35,9 @@ class FlowMatchScheduler():
y_shifted = y - y.min()
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
self.linear_timesteps_weights = bsmntw_weighing
self.training = True
else:
self.training = False
def step(self, model_output, timestep, sample, to_final=False, **kwargs):