mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 15:48:20 +00:00
refactor scheduler
This commit is contained in:
@@ -23,7 +23,7 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler()
|
||||
self.scheduler = FlowMatchScheduler("FLUX.2")
|
||||
self.text_encoder: Flux2TextEncoder = None
|
||||
self.dit: Flux2DiT = None
|
||||
self.vae: Flux2VAE = None
|
||||
@@ -86,6 +86,8 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
# Progress bar
|
||||
progress_bar_cmd = tqdm,
|
||||
):
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, dynamic_shift_len=height//16*width//16)
|
||||
|
||||
# Parameters
|
||||
inputs_posi = {
|
||||
"prompt": prompt,
|
||||
@@ -103,12 +105,6 @@ class Flux2ImagePipeline(BasePipeline):
|
||||
for unit in self.units:
|
||||
inputs_shared, inputs_posi, inputs_nega = self.unit_runner(unit, self, inputs_shared, inputs_posi, inputs_nega)
|
||||
|
||||
# using dynamic shift Scheduler
|
||||
self.scheduler.exponential_shift = True
|
||||
self.scheduler.sigma_min = 1 / num_inference_steps
|
||||
mu = self.scheduler.compute_empirical_mu(inputs_shared["latents"].shape[1], num_inference_steps)
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength, exponential_shift_mu=mu)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
|
||||
@@ -60,7 +60,7 @@ class FluxImagePipeline(BasePipeline):
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler()
|
||||
self.scheduler = FlowMatchScheduler("FLUX.1")
|
||||
self.tokenizer_1: CLIPTokenizer = None
|
||||
self.tokenizer_2: T5TokenizerFast = None
|
||||
self.text_encoder_1: FluxTextEncoderClip = None
|
||||
|
||||
@@ -24,7 +24,7 @@ class QwenImagePipeline(BasePipeline):
|
||||
)
|
||||
from transformers import Qwen2Tokenizer, Qwen2VLProcessor
|
||||
|
||||
self.scheduler = FlowMatchScheduler(sigma_min=0, sigma_max=1, extra_one_step=True, exponential_shift=True, exponential_shift_mu=0.8, shift_terminal=0.02)
|
||||
self.scheduler = FlowMatchScheduler("Qwen-Image")
|
||||
self.text_encoder: QwenImageTextEncoder = None
|
||||
self.dit: QwenImageDiT = None
|
||||
self.vae: QwenImageVAE = None
|
||||
|
||||
@@ -35,7 +35,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16, time_division_factor=4, time_division_remainder=1
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler(shift=5, sigma_min=0.0, extra_one_step=True)
|
||||
self.scheduler = FlowMatchScheduler("Wan")
|
||||
self.tokenizer: HuggingfaceTokenizer = None
|
||||
self.audio_processor: Wav2Vec2Processor = None
|
||||
self.text_encoder: WanTextEncoder = None
|
||||
@@ -283,7 +283,7 @@ class WanVideoPipeline(BasePipeline):
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
# Switch DiT if necessary
|
||||
if timestep.item() < switch_DiT_boundary * self.scheduler.num_train_timesteps and self.dit2 is not None and not models["dit"] is self.dit2:
|
||||
if timestep.item() < switch_DiT_boundary * 1000 and self.dit2 is not None and not models["dit"] is self.dit2:
|
||||
self.load_models_to_device(self.in_iteration_models_2)
|
||||
models["dit"] = self.dit2
|
||||
models["vace"] = self.vace2
|
||||
|
||||
@@ -23,7 +23,7 @@ class ZImagePipeline(BasePipeline):
|
||||
device=device, torch_dtype=torch_dtype,
|
||||
height_division_factor=16, width_division_factor=16,
|
||||
)
|
||||
self.scheduler = FlowMatchScheduler()
|
||||
self.scheduler = FlowMatchScheduler("Z-Image")
|
||||
self.text_encoder: ZImageTextEncoder = None
|
||||
self.dit: ZImageDiT = None
|
||||
self.vae_encoder: FluxVAEEncoder = None
|
||||
|
||||
Reference in New Issue
Block a user