mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
Merge branch 'main' into ltx-2
This commit is contained in:
@@ -296,6 +296,7 @@ class BasePipeline(torch.nn.Module):
|
||||
vram_config=vram_config,
|
||||
vram_limit=vram_limit,
|
||||
clear_parameters=model_config.clear_parameters,
|
||||
state_dict=model_config.state_dict,
|
||||
)
|
||||
return model_pool
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing_extensions import Literal
|
||||
|
||||
class FlowMatchScheduler():
|
||||
|
||||
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2"] = "FLUX.1"):
|
||||
def __init__(self, template: Literal["FLUX.1", "Wan", "Qwen-Image", "FLUX.2", "Z-Image", "LTX-2", "Qwen-Image-Lightning"] = "FLUX.1"):
|
||||
self.set_timesteps_fn = {
|
||||
"FLUX.1": FlowMatchScheduler.set_timesteps_flux,
|
||||
"Wan": FlowMatchScheduler.set_timesteps_wan,
|
||||
@@ -12,6 +12,7 @@ class FlowMatchScheduler():
|
||||
"FLUX.2": FlowMatchScheduler.set_timesteps_flux2,
|
||||
"Z-Image": FlowMatchScheduler.set_timesteps_z_image,
|
||||
"LTX-2": FlowMatchScheduler.set_timesteps_ltx2,
|
||||
"Qwen-Image-Lightning": FlowMatchScheduler.set_timesteps_qwen_image_lightning,
|
||||
}.get(template, FlowMatchScheduler.set_timesteps_flux)
|
||||
self.num_train_timesteps = 1000
|
||||
|
||||
@@ -71,6 +72,28 @@ class FlowMatchScheduler():
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def set_timesteps_qwen_image_lightning(num_inference_steps=100, denoising_strength=1.0, exponential_shift_mu=None, dynamic_shift_len=None):
|
||||
sigma_min = 0.0
|
||||
sigma_max = 1.0
|
||||
num_train_timesteps = 1000
|
||||
base_shift = math.log(3)
|
||||
max_shift = math.log(3)
|
||||
# Sigmas
|
||||
sigma_start = sigma_min + (sigma_max - sigma_min) * denoising_strength
|
||||
sigmas = torch.linspace(sigma_start, sigma_min, num_inference_steps + 1)[:-1]
|
||||
# Mu
|
||||
if exponential_shift_mu is not None:
|
||||
mu = exponential_shift_mu
|
||||
elif dynamic_shift_len is not None:
|
||||
mu = FlowMatchScheduler._calculate_shift_qwen_image(dynamic_shift_len, base_shift=base_shift, max_shift=max_shift)
|
||||
else:
|
||||
mu = 0.8
|
||||
sigmas = math.exp(mu) / (math.exp(mu) + (1 / sigmas - 1))
|
||||
# Timesteps
|
||||
timesteps = sigmas * num_train_timesteps
|
||||
return sigmas, timesteps
|
||||
|
||||
@staticmethod
|
||||
def compute_empirical_mu(image_seq_len, num_steps):
|
||||
a1, b1 = 8.73809524e-05, 1.89833333
|
||||
|
||||
@@ -13,9 +13,16 @@ def FlowMatchSFTLoss(pipe: BasePipeline, **inputs):
|
||||
inputs["latents"] = pipe.scheduler.add_noise(inputs["input_latents"], noise, timestep)
|
||||
training_target = pipe.scheduler.training_target(inputs["input_latents"], noise, timestep)
|
||||
|
||||
if "first_frame_latents" in inputs:
|
||||
inputs["latents"][:, :, 0:1] = inputs["first_frame_latents"]
|
||||
|
||||
models = {name: getattr(pipe, name) for name in pipe.in_iteration_models}
|
||||
noise_pred = pipe.model_fn(**models, **inputs, timestep=timestep)
|
||||
|
||||
if "first_frame_latents" in inputs:
|
||||
noise_pred = noise_pred[:, :, 1:]
|
||||
training_target = training_target[:, :, 1:]
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), training_target.float())
|
||||
loss = loss * pipe.scheduler.training_weight(timestep)
|
||||
return loss
|
||||
|
||||
Reference in New Issue
Block a user