support ernie-image-turbo (#1391)

* support ernie-image-turbo

* pr review fix

* fix modelname
This commit is contained in:
Hong Zhang
2026-04-14 11:35:43 +08:00
committed by GitHub
parent 960d8c62c0
commit b5d04ceb30
18 changed files with 142 additions and 82 deletions

View File

@@ -131,11 +131,14 @@ class FlowMatchScheduler():
return sigmas, timesteps
@staticmethod
def set_timesteps_ernie_image(num_inference_steps=50, denoising_strength=1.0):
"""ERNIE-Image scheduler: pure linear sigmas from 1.0 to 0.0, no shift."""
def set_timesteps_ernie_image(num_inference_steps=50, denoising_strength=1.0, shift=3.0):
sigma_min = 0.0
sigma_max = 1.0
num_train_timesteps = 1000
sigma_start = denoising_strength
sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps + 1)[:-1]
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
if shift is not None and shift != 1.0:
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@@ -185,9 +188,6 @@ class FlowMatchScheduler():
return sigmas, timesteps
def set_training_weight(self):
if self.set_timesteps_fn == FlowMatchScheduler.set_timesteps_ernie_image:
self.set_uniform_training_weight()
return
steps = 1000
x = self.timesteps
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
@@ -199,13 +199,6 @@ class FlowMatchScheduler():
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
self.linear_timesteps_weights = bsmntw_weighing
def set_uniform_training_weight(self):
"""Assign equal weight to every timestep, suitable for linear schedulers like ERNIE-Image."""
steps = 1000
num_steps = len(self.timesteps)
uniform_weight = torch.full((num_steps,), steps / num_steps, dtype=self.timesteps.dtype)
self.linear_timesteps_weights = uniform_weight
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
self.sigmas, self.timesteps = self.set_timesteps_fn(
num_inference_steps=num_inference_steps,