mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 23:08:13 +00:00
z-image
This commit is contained in:
@@ -12,6 +12,7 @@ class FlowMatchScheduler():
|
||||
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
|
||||
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
||||
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
||||
self.num_train_timesteps = 1000
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_flux(num_inference_steps=100, denoising_strength=1.0, shift=None):
|
||||
@@ -111,11 +112,12 @@ class FlowMatchScheduler():
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
def set_training_weight(self, num_inference_steps):
|
||||
def set_training_weight(self):
|
||||
steps = 1000
|
||||
x = self.timesteps
|
||||
y = torch.exp(-2 * ((x - num_inference_steps / 2) / num_inference_steps) ** 2)
|
||||
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
|
||||
y_shifted = y - y.min()
|
||||
bsmntw_weighing = y_shifted * (num_inference_steps / y_shifted.sum())
|
||||
bsmntw_weighing = y_shifted * (steps / y_shifted.sum())
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
|
||||
@@ -125,7 +127,7 @@ class FlowMatchScheduler():
|
||||
**kwargs,
|
||||
)
|
||||
if training:
|
||||
self.set_training_weight(num_inference_steps)
|
||||
self.set_training_weight()
|
||||
self.training = True
|
||||
else:
|
||||
self.training = False
|
||||
|
||||
@@ -3,8 +3,8 @@ import torch
|
||||
|
||||
|
||||
def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * pipe.scheduler.num_train_timesteps)
|
||||
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * pipe.scheduler.num_train_timesteps)
|
||||
max_timestep_boundary = int(inputs.get("max_timestep_boundary", 1) * len(pipe.scheduler.timesteps))
|
||||
min_timestep_boundary = int(inputs.get("min_timestep_boundary", 0) * len(pipe.scheduler.timesteps))
|
||||
|
||||
timestep_id = torch.randint(min_timestep_boundary, max_timestep_boundary, (1,))
|
||||
timestep = pipe.scheduler.timesteps[timestep_id].to(dtype=pipe.torch_dtype, device=pipe.device)
|
||||
|
||||
Reference in New Issue
Block a user