mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-14 21:58:17 +00:00
support ernie-image-turbo (#1391)
* support ernie-image-turbo * pr review fix * fix modelname
This commit is contained in:
@@ -131,11 +131,14 @@ class FlowMatchScheduler():
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_ernie_image(num_inference_steps=50, denoising_strength=1.0):
|
||||
"""ERNIE-Image scheduler: pure linear sigmas from 1.0 to 0.0, no shift."""
|
||||
def set_timesteps_ernie_image(num_inference_steps=50, denoising_strength=1.0, shift=3.0):
|
||||
sigma_min = 0.0
|
||||
sigma_max = 1.0
|
||||
num_train_timesteps = 1000
|
||||
sigma_start = denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps + 1)[:-1]
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||
if shift is not None and shift != 1.0:
|
||||
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
@@ -185,9 +188,6 @@ class FlowMatchScheduler():
|
||||
return sigmas, timesteps
|
||||
|
||||
def set_training_weight(self):
|
||||
if self.set_timesteps_fn == FlowMatchScheduler.set_timesteps_ernie_image:
|
||||
self.set_uniform_training_weight()
|
||||
return
|
||||
steps = 1000
|
||||
x = self.timesteps
|
||||
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
|
||||
@@ -199,13 +199,6 @@ class FlowMatchScheduler():
|
||||
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
|
||||
def set_uniform_training_weight(self):
|
||||
"""Assign equal weight to every timestep, suitable for linear schedulers like ERNIE-Image."""
|
||||
steps = 1000
|
||||
num_steps = len(self.timesteps)
|
||||
uniform_weight = torch.full((num_steps,), steps / num_steps, dtype=self.timesteps.dtype)
|
||||
self.linear_timesteps_weights = uniform_weight
|
||||
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
|
||||
self.sigmas, self.timesteps = self.set_timesteps_fn(
|
||||
num_inference_steps=num_inference_steps,
|
||||
|
||||
Reference in New Issue
Block a user