From 6a999e112755cb05b6f8d8ecf444757246478e2d Mon Sep 17 00:00:00 2001 From: Artiprocher Date: Thu, 26 Dec 2024 10:13:46 +0800 Subject: [PATCH] hunyuanvideo step_processor --- diffsynth/pipelines/hunyuan_video.py | 15 +++++++++++++++ diffsynth/schedulers/flow_match.py | 7 ++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/diffsynth/pipelines/hunyuan_video.py b/diffsynth/pipelines/hunyuan_video.py index ad81ca4..2beba5a 100644 --- a/diffsynth/pipelines/hunyuan_video.py +++ b/diffsynth/pipelines/hunyuan_video.py @@ -96,6 +96,7 @@ class HunyuanVideoPipeline(BasePipeline): num_inference_steps=30, tile_size=(17, 30, 30), tile_stride=(12, 20, 20), + step_processor=None, progress_bar_cmd=lambda x: x, progress_bar_st=None, ): @@ -140,6 +141,20 @@ class HunyuanVideoPipeline(BasePipeline): else: noise_pred = noise_pred_posi + # (Experimental feature, may be removed in the future) + if step_processor is not None: + self.load_models_to_device(['vae_decoder']) + rendered_frames = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents, to_final=True) + rendered_frames = self.vae_decoder.decode_video(rendered_frames, **tiler_kwargs) + rendered_frames = self.tensor2video(rendered_frames[0]) + rendered_frames = step_processor(rendered_frames, original_frames=input_video) + self.load_models_to_device(['vae_encoder']) + rendered_frames = self.preprocess_images(rendered_frames) + rendered_frames = torch.stack(rendered_frames, dim=2) + target_latents = self.encode_video(rendered_frames).to(dtype=self.torch_dtype, device=self.device) + noise_pred = self.scheduler.return_to_timestep(self.scheduler.timesteps[progress_id], latents, target_latents) + self.load_models_to_device([] if self.vram_management else ["dit"]) + # Scheduler latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents) diff --git a/diffsynth/schedulers/flow_match.py b/diffsynth/schedulers/flow_match.py index 949bfd7..399539b 100644 --- a/diffsynth/schedulers/flow_match.py +++ b/diffsynth/schedulers/flow_match.py @@ -47,7 +47,12 @@ class FlowMatchScheduler(): def return_to_timestep(self, timestep, sample, sample_stablized): # This scheduler doesn't support this function. - pass + if isinstance(timestep, torch.Tensor): + timestep = timestep.cpu() + timestep_id = torch.argmin((self.timesteps - timestep).abs()) + sigma = self.sigmas[timestep_id] + model_output = (sample - sample_stablized) / sigma + return model_output def add_noise(self, original_samples, noise, timestep):