mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-04-16 07:08:19 +00:00
Support ERNIE-Image (#1389)
* ernie-image pipeline * ernie-image inference and training * style fix * ernie docs * lowvram * final style fix * pr-review * pr-fix round2 * set uniform training weight * fix * update lowvram docs
This commit is contained in:
@@ -4,7 +4,7 @@ from typing_extensions import Literal
|
||||
|
||||
class FlowMatchScheduler():
|
||||
|
||||
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"):
|
||||
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning", "ERNIE-Image"] = "FLUX.1"):
|
||||
self.set_timesteps_fn = {
|
||||
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
|
||||
"Wan": FlowMatchScheduler.set_timesteps_wan,
|
||||
@@ -13,6 +13,7 @@ class FlowMatchScheduler():
|
||||
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
||||
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
|
||||
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
|
||||
"ERNIE-Image": FlowMatchScheduler.set_timesteps_ernie_image,
|
||||
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
||||
self.num_train_timesteps = 1000
|
||||
|
||||
@@ -129,6 +130,15 @@ class FlowMatchScheduler():
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_ernie_image(num_inference_steps=50, denoising_strength=1.0):
|
||||
"""ERNIE-Image scheduler: pure linear sigmas from 1.0 to 0.0, no shift."""
|
||||
num_train_timesteps = 1000
|
||||
sigma_start = denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps + 1)[:-1]
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
|
||||
sigma_min = 0.0
|
||||
@@ -175,6 +185,9 @@ class FlowMatchScheduler():
|
||||
return sigmas, timesteps
|
||||
|
||||
def set_training_weight(self):
|
||||
if self.set_timesteps_fn == FlowMatchScheduler.set_timesteps_ernie_image:
|
||||
self.set_uniform_training_weight()
|
||||
return
|
||||
steps = 1000
|
||||
x = self.timesteps
|
||||
y = torch.exp(-2 * ((x - steps / 2) / steps) ** 2)
|
||||
@@ -185,6 +198,13 @@ class FlowMatchScheduler():
|
||||
bsmntw_weighing = bsmntw_weighing * (len(self.timesteps) / steps)
|
||||
bsmntw_weighing = bsmntw_weighing + bsmntw_weighing[1]
|
||||
self.linear_timesteps_weights = bsmntw_weighing
|
||||
|
||||
def set_uniform_training_weight(self):
|
||||
"""Assign equal weight to every timestep, suitable for linear schedulers like ERNIE-Image."""
|
||||
steps = 1000
|
||||
num_steps = len(self.timesteps)
|
||||
uniform_weight = torch.full((num_steps,), steps / num_steps, dtype=self.timesteps.dtype)
|
||||
self.linear_timesteps_weights = uniform_weight
|
||||
|
||||
def set_timesteps(self, num_inference_steps=100, denoising_strength=1.0, training=False, **kwargs):
|
||||
self.sigmas, self.timesteps = self.set_timesteps_fn(
|
||||
|
||||
Reference in New Issue
Block a user