refactor scheduler

This commit is contained in:
Artiprocher
2025-11-30 15:22:39 +08:00
parent 1a6fd69e6b
commit 675ae5e91f
6 changed files with 150 additions and 127 deletions

View File

@@ -23,7 +23,7 @@ class Flux2ImagePipeline(BasePipeline):
device=device, torch_dtype=torch_dtype,
height_division_factor=16, width_division_factor=16,
)
self.scheduler = FlowMatchScheduler()
self.scheduler = FlowMatchScheduler("FLUX.2")
self.text_encoder: Flux2TextEncoder = None
self.dit: Flux2DiT = None
self.vae: Flux2VAE = None
@@ -86,6 +86,8 @@ class Flux2ImagePipeline(BasePipeline):
# Progress bar
progress_bar_cmd = tqdm,
):
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
# Parameters
inputs_posi = {
"prompt": prompt,
@@ -103,12 +105,6 @@ class Flux2ImagePipeline(BasePipeline):
for unit in self.units:
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
# using dynamic shift Scheduler
self.scheduler.exponential_shift = True
self.scheduler.sigma_min = 1 / num_inference_steps
mu = self.scheduler.compute_empirical_mu(inputs_shared["latents"].shape[1], num_inference_steps)
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, exponential_shift_mu=mu)
# Denoise
self.load_models_to_device(self.in_iteration_models)
models = {name: getattr(self, name) for name in self.in_iteration_models}