mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 16:18:13 +00:00
Merge pull request #436 from mi804/hunyuanvideo_i2v
support hunyuanvideo-i2v
This commit is contained in:
@@ -5,13 +5,13 @@ from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from ..prompters import HunyuanVideoPrompter
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
|
||||
class HunyuanVideoPipeline(BasePipeline):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
@@ -53,10 +53,58 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
pipe.enable_vram_management()
|
||||
return pipe
|
||||
|
||||
def generate_crop_size_list(self, base_size=256, patch_size=32, max_ratio=4.0):
|
||||
num_patches = round((base_size / patch_size)**2)
|
||||
assert max_ratio >= 1.0
|
||||
crop_size_list = []
|
||||
wp, hp = num_patches, 1
|
||||
while wp > 0:
|
||||
if max(wp, hp) / min(wp, hp) <= max_ratio:
|
||||
crop_size_list.append((wp * patch_size, hp * patch_size))
|
||||
if (hp + 1) * wp <= num_patches:
|
||||
hp += 1
|
||||
else:
|
||||
wp -= 1
|
||||
return crop_size_list
|
||||
|
||||
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256):
|
||||
|
||||
def get_closest_ratio(self, height: float, width: float, ratios: list, buckets: list):
|
||||
aspect_ratio = float(height) / float(width)
|
||||
closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
|
||||
closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio))
|
||||
return buckets[closest_ratio_id], float(closest_ratio)
|
||||
|
||||
|
||||
def prepare_vae_images_inputs(self, semantic_images, i2v_resolution="720p"):
|
||||
if i2v_resolution == "720p":
|
||||
bucket_hw_base_size = 960
|
||||
elif i2v_resolution == "540p":
|
||||
bucket_hw_base_size = 720
|
||||
elif i2v_resolution == "360p":
|
||||
bucket_hw_base_size = 480
|
||||
else:
|
||||
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
|
||||
origin_size = semantic_images[0].size
|
||||
|
||||
crop_size_list = self.generate_crop_size_list(bucket_hw_base_size, 32)
|
||||
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
|
||||
closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
|
||||
ref_image_transform = transforms.Compose([
|
||||
transforms.Resize(closest_size),
|
||||
transforms.CenterCrop(closest_size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5])
|
||||
])
|
||||
|
||||
semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
|
||||
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device)
|
||||
target_height, target_width = closest_size
|
||||
return semantic_image_pixel_values, target_height, target_width
|
||||
|
||||
|
||||
def encode_prompt(self, prompt, positive=True, clip_sequence_length=77, llm_sequence_length=256, input_images=None):
|
||||
prompt_emb, pooled_prompt_emb, text_mask = self.prompter.encode_prompt(
|
||||
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length
|
||||
prompt, device=self.device, positive=positive, clip_sequence_length=clip_sequence_length, llm_sequence_length=llm_sequence_length, images=input_images
|
||||
)
|
||||
return {"prompt_emb": prompt_emb, "pooled_prompt_emb": pooled_prompt_emb, "text_mask": text_mask}
|
||||
|
||||
@@ -87,6 +135,9 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
input_video=None,
|
||||
input_images=None,
|
||||
i2v_resolution="720p",
|
||||
i2v_stability=True,
|
||||
denoising_strength=1.0,
|
||||
seed=None,
|
||||
rand_device=None,
|
||||
@@ -105,10 +156,17 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
):
|
||||
# Tiler parameters
|
||||
tiler_kwargs = {"tile_size": tile_size, "tile_stride": tile_stride}
|
||||
|
||||
|
||||
# Scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# encoder input images
|
||||
if input_images is not None:
|
||||
self.load_models_to_device(['vae_encoder'])
|
||||
image_pixel_values, height, width = self.prepare_vae_images_inputs(input_images, i2v_resolution=i2v_resolution)
|
||||
with torch.autocast(device_type=self.device, dtype=torch.float16, enabled=True):
|
||||
image_latents = self.vae_encoder(image_pixel_values)
|
||||
|
||||
# Initialize noise
|
||||
rand_device = self.device if rand_device is None else rand_device
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=self.torch_dtype).to(self.device)
|
||||
@@ -118,12 +176,18 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
input_video = torch.stack(input_video, dim=2)
|
||||
latents = self.encode_video(input_video, **tiler_kwargs).to(dtype=self.torch_dtype, device=self.device)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
elif input_images is not None and i2v_stability:
|
||||
noise = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=rand_device, dtype=image_latents.dtype).to(self.device)
|
||||
t = torch.tensor([0.999]).to(device=self.device)
|
||||
latents = noise * t + image_latents.repeat(1, 1, (num_frames - 1) // 4 + 1, 1, 1) * (1 - t)
|
||||
latents = latents.to(dtype=image_latents.dtype)
|
||||
else:
|
||||
latents = noise
|
||||
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
# current mllm does not support vram_management
|
||||
self.load_models_to_device(["text_encoder_1"] if self.vram_management and input_images is None else ["text_encoder_1", "text_encoder_2"])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True, input_images=input_images)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
|
||||
@@ -139,11 +203,16 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
timestep = timestep.unsqueeze(0).to(self.device)
|
||||
print(f"Step {progress_id + 1} / {len(self.scheduler.timesteps)}")
|
||||
|
||||
forward_func = lets_dance_hunyuan_video
|
||||
if input_images is not None:
|
||||
latents = torch.concat([image_latents, latents[:, :, 1:, :, :]], dim=2)
|
||||
forward_func = lets_dance_hunyuan_video_i2v
|
||||
|
||||
# Inference
|
||||
with torch.autocast(device_type=self.device, dtype=self.torch_dtype):
|
||||
noise_pred_posi = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
|
||||
noise_pred_posi = forward_func(self.dit, latents, timestep, **prompt_emb_posi, **extra_input, **tea_cache_kwargs)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = lets_dance_hunyuan_video(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
|
||||
noise_pred_nega = forward_func(self.dit, latents, timestep, **prompt_emb_nega, **extra_input)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
@@ -163,7 +232,11 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
self.load_models_to_device([] if self.vram_management else ["dit"])
|
||||
|
||||
# Scheduler
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
if input_images is not None:
|
||||
latents = self.scheduler.step(noise_pred[:, :, 1:, :, :], self.scheduler.timesteps[progress_id], latents[:, :, 1:, :, :])
|
||||
latents = torch.concat([image_latents, latents], dim=2)
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, self.scheduler.timesteps[progress_id], latents)
|
||||
|
||||
# Decode
|
||||
self.load_models_to_device(['vae_decoder'])
|
||||
@@ -194,7 +267,7 @@ class TeaCache:
|
||||
if self.step == 0 or self.step == self.num_inference_steps - 1:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
else:
|
||||
else:
|
||||
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
|
||||
rescale_func = np.poly1d(coefficients)
|
||||
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
|
||||
@@ -203,14 +276,14 @@ class TeaCache:
|
||||
else:
|
||||
should_calc = True
|
||||
self.accumulated_rel_l1_distance = 0
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.previous_modulated_input = modulated_inp
|
||||
self.step += 1
|
||||
if self.step == self.num_inference_steps:
|
||||
self.step = 0
|
||||
if should_calc:
|
||||
self.previous_hidden_states = img.clone()
|
||||
return not should_calc
|
||||
|
||||
|
||||
def store(self, hidden_states):
|
||||
self.previous_residual = hidden_states - self.previous_hidden_states
|
||||
self.previous_hidden_states = None
|
||||
@@ -250,13 +323,70 @@ def lets_dance_hunyuan_video(
|
||||
print("TeaCache skip forward.")
|
||||
img = tea_cache.update(img)
|
||||
else:
|
||||
split_token = int(text_mask.sum(dim=1))
|
||||
txt_len = int(txt.shape[1])
|
||||
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin))
|
||||
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), split_token=split_token)
|
||||
|
||||
x = torch.concat([img, txt], dim=1)
|
||||
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
|
||||
x = block(x, vec, (freqs_cos, freqs_sin))
|
||||
img = x[:, :-256]
|
||||
x = block(x, vec, (freqs_cos, freqs_sin), txt_len=txt_len, split_token=split_token)
|
||||
img = x[:, :-txt_len]
|
||||
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(img)
|
||||
img = dit.final_layer(img, vec)
|
||||
img = dit.unpatchify(img, T=T//1, H=H//2, W=W//2)
|
||||
return img
|
||||
|
||||
|
||||
def lets_dance_hunyuan_video_i2v(
|
||||
dit: HunyuanVideoDiT,
|
||||
x: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
prompt_emb: torch.Tensor = None,
|
||||
text_mask: torch.Tensor = None,
|
||||
pooled_prompt_emb: torch.Tensor = None,
|
||||
freqs_cos: torch.Tensor = None,
|
||||
freqs_sin: torch.Tensor = None,
|
||||
guidance: torch.Tensor = None,
|
||||
tea_cache: TeaCache = None,
|
||||
**kwargs
|
||||
):
|
||||
B, C, T, H, W = x.shape
|
||||
# Uncomment below to keep same as official implementation
|
||||
# guidance = guidance.to(dtype=torch.float32).to(torch.bfloat16)
|
||||
vec = dit.time_in(t, dtype=torch.bfloat16)
|
||||
vec_2 = dit.vector_in(pooled_prompt_emb)
|
||||
vec = vec + vec_2
|
||||
vec = vec + dit.guidance_in(guidance * 1000., dtype=torch.bfloat16)
|
||||
|
||||
token_replace_vec = dit.time_in(torch.zeros_like(t), dtype=torch.bfloat16)
|
||||
tr_token = (H // 2) * (W // 2)
|
||||
token_replace_vec = token_replace_vec + vec_2
|
||||
|
||||
img = dit.img_in(x)
|
||||
txt = dit.txt_in(prompt_emb, t, text_mask)
|
||||
|
||||
# TeaCache
|
||||
if tea_cache is not None:
|
||||
tea_cache_update = tea_cache.check(dit, img, vec)
|
||||
else:
|
||||
tea_cache_update = False
|
||||
|
||||
if tea_cache_update:
|
||||
print("TeaCache skip forward.")
|
||||
img = tea_cache.update(img)
|
||||
else:
|
||||
split_token = int(text_mask.sum(dim=1))
|
||||
txt_len = int(txt.shape[1])
|
||||
for block in tqdm(dit.double_blocks, desc="Double stream blocks"):
|
||||
img, txt = block(img, txt, vec, (freqs_cos, freqs_sin), token_replace_vec, tr_token, split_token)
|
||||
|
||||
x = torch.concat([img, txt], dim=1)
|
||||
for block in tqdm(dit.single_blocks, desc="Single stream blocks"):
|
||||
x = block(x, vec, (freqs_cos, freqs_sin), txt_len, token_replace_vec, tr_token, split_token)
|
||||
img = x[:, :-txt_len]
|
||||
|
||||
if tea_cache is not None:
|
||||
tea_cache.store(img)
|
||||
|
||||
Reference in New Issue
Block a user