model-code

This commit is contained in:
mi804
2026-04-17 17:06:26 +08:00
parent 079e51c9f3
commit 36c203da57
23 changed files with 4230 additions and 2 deletions

View File

@@ -4,7 +4,7 @@ from typing_extensions import Literal
class FlowMatchScheduler():
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning", "ERNIE-Image"] = "FLUX.1"):
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning", "ERNIE-Image", "ACE-Step"] = "FLUX.1"):
self.set_timesteps_fn = {
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
"Wan": FlowMatchScheduler.set_timesteps_wan,
@@ -14,6 +14,7 @@ class FlowMatchScheduler():
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
"ERNIE-Image": FlowMatchScheduler.set_timesteps_ernie_image,
"ACE-Step": FlowMatchScheduler.set_timesteps_ace_step,
}.get(template, FlowMatchScheduler.set_timesteps_flux)
self.num_train_timesteps = 1000
@@ -142,6 +143,26 @@ class FlowMatchScheduler():
timesteps = sigmas * num_train_timesteps
return sigmas, timesteps
@staticmethod
def set_timesteps_ace_step(num_inference_steps=8, denoising_strength=1.0, shift=3.0):
"""ACE-Step Flow Matching scheduler.
Timesteps range from 1.0 to 0.0 (not multiplied by 1000).
Shift transformation: t = shift * t / (1 + (shift - 1) * t)
Args:
num_inference_steps: Number of diffusion steps.
denoising_strength: Denoising strength (1.0 = full denoising).
shift: Timestep shift parameter (default 3.0 for turbo).
"""
num_train_timesteps = 1000
sigma_start = denoising_strength
sigmas = torch.linspace(sigma_start, 0.0, num_inference_steps)
if shift is not None and shift != 1.0:
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
timesteps = sigmas # ACE-Step uses [0, 1] range directly
return sigmas, timesteps
@staticmethod
def set_timesteps_z_image(num_inference_steps=100, denoising_strength=1.0, shift=None, target_timesteps=None):
sigma_min = 0.0