mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support ltx2 distilled pipeline
This commit is contained in:
@@ -61,7 +61,6 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
LTX2AudioVideoUnit_InputVideoEmbedder(),
|
||||
]
|
||||
self.model_fn = model_fn_ltx2
|
||||
# self.lora_loader = LTX2LoRALoader
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
@@ -89,12 +88,12 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
pipe.video_vae_decoder = model_pool.fetch_model("ltx2_video_vae_decoder")
|
||||
pipe.audio_vae_decoder = model_pool.fetch_model("ltx2_audio_vae_decoder")
|
||||
pipe.audio_vocoder = model_pool.fetch_model("ltx2_audio_vocoder")
|
||||
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
|
||||
|
||||
# Stage 2
|
||||
if stage2_lora_config is not None:
|
||||
stage2_lora_config.download_if_necessary()
|
||||
pipe.stage2_lora_path = stage2_lora_config.path
|
||||
pipe.upsampler = model_pool.fetch_model("ltx2_latent_upsampler")
|
||||
# Optional, currently not used
|
||||
# pipe.audio_vae_encoder = model_pool.fetch_model("ltx2_audio_vae_encoder")
|
||||
|
||||
@@ -122,8 +121,8 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
noise_pred = noise_pred_posi
|
||||
return noise_pred
|
||||
|
||||
def stage2_denoise(self, cfg_scale, inputs_shared, inputs_posi, inputs_nega, use_two_stage_pipeline=True, progress_bar_cmd=tqdm):
|
||||
if use_two_stage_pipeline:
|
||||
def stage2_denoise(self, inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd=tqdm):
|
||||
if inputs_shared["use_two_stage_pipeline"]:
|
||||
latent = self.video_vae_encoder.per_channel_statistics.un_normalize(inputs_shared["video_latents"])
|
||||
self.load_models_to_device(self.in_iteration_models + ('upsampler',))
|
||||
latent = self.upsampler(latent)
|
||||
@@ -135,12 +134,13 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
inputs_shared["audio_latents"] = self.audio_patchifier.patchify(inputs_shared["audio_latents"])
|
||||
inputs_shared["audio_latents"] = self.scheduler.sigmas[0] * inputs_shared["audio_noise"] + (1 - self.scheduler.sigmas[0]) * inputs_shared["audio_latents"]
|
||||
self.load_models_to_device(self.in_iteration_models)
|
||||
self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8)
|
||||
if not inputs_shared["use_distilled_pipeline"]:
|
||||
self.load_lora(self.dit, self.stage2_lora_path, alpha=0.8)
|
||||
models = {name: getattr(self, name) for name in self.in_iteration_models}
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = timestep.unsqueeze(0).to(dtype=self.torch_dtype, device=self.device)
|
||||
noise_pred_video, noise_pred_audio = self.cfg_guided_model_fn(
|
||||
self.model_fn, cfg_scale, inputs_shared, inputs_posi, inputs_nega,
|
||||
self.model_fn, 1.0, inputs_shared, inputs_posi, inputs_nega,
|
||||
**models, timestep=timestep, progress_id=progress_id
|
||||
)
|
||||
inputs_shared["video_latents"] = self.step(self.scheduler, inputs_shared["video_latents"], progress_id=progress_id,
|
||||
@@ -185,7 +185,8 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
progress_bar_cmd=tqdm,
|
||||
):
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength)
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength=denoising_strength,
|
||||
special_case="ditilled_stage1" if use_distilled_pipeline else None)
|
||||
# self.load_lora(self.dit, self.stage2_lora_path)
|
||||
# Inputs
|
||||
inputs_posi = {
|
||||
@@ -223,7 +224,7 @@ class LTX2AudioVideoPipeline(BasePipeline):
|
||||
inputs_shared["audio_latents"] = self.audio_patchifier.unpatchify(inputs_shared["audio_latents"], inputs_shared["audio_latent_shape"])
|
||||
|
||||
# Denoise Stage 2
|
||||
inputs_shared = self.stage2_denoise(cfg_scale, inputs_shared, inputs_posi, inputs_nega, use_two_stage_pipeline, progress_bar_cmd)
|
||||
inputs_shared = self.stage2_denoise(inputs_shared, inputs_posi, inputs_nega, progress_bar_cmd)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['video_vae_decoder'])
|
||||
@@ -241,11 +242,17 @@ class LTX2AudioVideoUnit_PipelineChecker(PipelineUnit):
|
||||
super().__init__(take_over=True)
|
||||
|
||||
def process(self, pipe: LTX2AudioVideoPipeline, inputs_shared, inputs_posi, inputs_nega):
|
||||
if inputs_shared.get("use_distilled_pipeline", False):
|
||||
inputs_shared["use_two_stage_pipeline"] = True
|
||||
inputs_shared["cfg_scale"] = 1.0
|
||||
print(f"Distilled pipeline requested, setting use_two_stage_pipeline to True, disable CFG by setting cfg_scale to 1.0.")
|
||||
if inputs_shared.get("use_two_stage_pipeline", False):
|
||||
if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None):
|
||||
raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.")
|
||||
# distill pipeline also uses two-stage, but it does not needs lora
|
||||
if not inputs_shared.get("use_distilled_pipeline", False):
|
||||
if not (hasattr(pipe, "stage2_lora_path") and pipe.stage2_lora_path is not None):
|
||||
raise ValueError("Two-stage pipeline requested, but stage2_lora_path is not set in the pipeline.")
|
||||
if not (hasattr(pipe, "upsampler") and pipe.upsampler is not None):
|
||||
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
|
||||
raise ValueError("Two-stage pipeline requested, but upsampler model is not loaded in the pipeline.")
|
||||
return inputs_shared, inputs_posi, inputs_nega
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user