mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 08:08:13 +00:00
support hunyuanvideo v2v
This commit is contained in:
@@ -72,16 +72,21 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
frames = [Image.fromarray(frame) for frame in frames]
|
||||
return frames
|
||||
|
||||
def encode_video(self, frames):
|
||||
# frames : (B, C, T, H, W)
|
||||
latents = self.vae_encoder(frames)
|
||||
|
||||
def encode_video(self, frames, tile_size=(17, 30, 30), tile_stride=(12, 20, 20)):
|
||||
tile_size = ((tile_size[0] - 1) * 4 + 1, tile_size[1] * 8, tile_size[2] * 8)
|
||||
tile_stride = (tile_stride[0] * 4, tile_stride[1] * 8, tile_stride[2] * 8)
|
||||
latents = self.vae_encoder.encode_video(frames, tile_size=tile_size, tile_stride=tile_stride)
|
||||
return latents
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
input_video=None,
|
||||
denoising_strength=1.0,
|
||||
seed=None,
|
||||
height=720,
|
||||
width=1280,
|
||||
@@ -94,8 +99,22 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
progress_bar_cmd=lambda x: x,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Initialize noise
|
||||
latents = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
if input_video is not None:
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
input_video = self.preprocess_images(input_video)
|
||||
input_video = torch.stack(input_video, dim=2)
|
||||
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = noise
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"])
|
||||
@@ -106,9 +125,6 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
# Extra input
|
||||
extra_input = self.prepare_extra_input(latents, guidance=embedded_guidance)
|
||||
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# Denoise
|
||||
self.load_models_to_device([] if self.vram_management else ["dit"])
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
@@ -126,9 +142,6 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
|
||||
# Scheduler
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
|
||||
Reference in New Issue
Block a user