load hunyuani2v model

This commit is contained in:
mi804
2025-03-07 17:43:30 +08:00
parent 6e316fd825
commit 945b43492e
6 changed files with 327 additions and 18 deletions

View File

@@ -112,6 +112,7 @@ model_loader_configs = [
(None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"), (None, "aeb82dce778a03dcb4d726cb03f3c43f", ["hunyuan_video_vae_decoder", "hunyuan_video_vae_encoder"], [HunyuanVideoVAEDecoder, HunyuanVideoVAEEncoder], "diffusers"),
(None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"), (None, "b9588f02e78f5ccafc9d7c0294e46308", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
(None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"), (None, "84ef4bd4757f60e906b54aa6a7815dc6", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
(None, "ae3c22aaa28bfae6f3688f796c9814ae", ["hunyuan_video_dit"], [HunyuanVideoDiT], "civitai"),
(None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"), (None, "68beaf8429b7c11aa8ca05b1bd0058bd", ["stepvideo_vae"], [StepVideoVAE], "civitai"),
(None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"), (None, "5c0216a2132b082c10cb7a0e0377e681", ["stepvideo_dit"], [StepVideoModel], "civitai"),
(None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"), (None, "9269f8db9040a9d860eaca435be61814", ["wan_video_dit"], [WanModel], "civitai"),
@@ -135,6 +136,7 @@ huggingface_model_loader_configs = [
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"), ("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"), ("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"), ("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder"),
("LlavaForConditionalGeneration", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoMLLMEncoder"),
("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"), ("Step1Model", "diffsynth.models.stepvideo_text_encoder", "stepvideo_text_encoder_2", "STEP1TextEncoder"),
] ]
patch_model_loader_configs = [ patch_model_loader_configs = [
@@ -677,6 +679,25 @@ preset_models_on_modelscope = {
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt" "models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
], ],
}, },
"HunyuanVideoI2V":{
"file_list": [
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideoI2V/text_encoder"),
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00001-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00002-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00003-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model-00004-of-00004.safetensors", "models/HunyuanVideoI2V/text_encoder_2"),
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "config.json", "models/HunyuanVideoI2V/text_encoder_2"),
("AI-ModelScope/llava-llama-3-8b-v1_1-transformers", "model.safetensors.index.json", "models/HunyuanVideoI2V/text_encoder_2"),
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/vae/pytorch_model.pt", "models/HunyuanVideoI2V/vae"),
("AI-ModelScope/HunyuanVideo-I2V", "hunyuan-video-i2v-720p/transformers/mp_rank_00_model_states.pt", "models/HunyuanVideoI2V/transformers")
],
"load_path": [
"models/HunyuanVideoI2V/text_encoder/model.safetensors",
"models/HunyuanVideoI2V/text_encoder_2",
"models/HunyuanVideoI2V/vae/pytorch_model.pt",
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
],
},
"HunyuanVideo-fp8":{ "HunyuanVideo-fp8":{
"file_list": [ "file_list": [
("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"), ("AI-ModelScope/clip-vit-large-patch14", "model.safetensors", "models/HunyuanVideo/text_encoder"),
@@ -753,4 +774,5 @@ Preset_model_id: TypeAlias = Literal[
"StableDiffusion3.5-medium", "StableDiffusion3.5-medium",
"HunyuanVideo", "HunyuanVideo",
"HunyuanVideo-fp8", "HunyuanVideo-fp8",
"HunyuanVideoI2V",
] ]

View File

@@ -4,6 +4,7 @@ from .utils import init_weights_on_device
from einops import rearrange, repeat from einops import rearrange, repeat
from tqdm import tqdm from tqdm import tqdm
from typing import Union, Tuple, List from typing import Union, Tuple, List
from .utils import hash_state_dict_keys
def HunyuanVideoRope(latents): def HunyuanVideoRope(latents):
@@ -555,7 +556,7 @@ class FinalLayer(torch.nn.Module):
class HunyuanVideoDiT(torch.nn.Module): class HunyuanVideoDiT(torch.nn.Module):
def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40): def __init__(self, in_channels=16, hidden_size=3072, text_dim=4096, num_double_blocks=20, num_single_blocks=40, guidance_embed=True):
super().__init__() super().__init__()
self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size) self.img_in = PatchEmbed(in_channels=in_channels, embed_dim=hidden_size)
self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size) self.txt_in = SingleTokenRefiner(in_channels=text_dim, hidden_size=hidden_size)
@@ -565,7 +566,7 @@ class HunyuanVideoDiT(torch.nn.Module):
torch.nn.SiLU(), torch.nn.SiLU(),
torch.nn.Linear(hidden_size, hidden_size) torch.nn.Linear(hidden_size, hidden_size)
) )
self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") self.guidance_in = TimestepEmbeddings(256, hidden_size, computation_device="cpu") if guidance_embed else None
self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)]) self.double_blocks = torch.nn.ModuleList([MMDoubleStreamBlock(hidden_size) for _ in range(num_double_blocks)])
self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)]) self.single_blocks = torch.nn.ModuleList([MMSingleStreamBlock(hidden_size) for _ in range(num_single_blocks)])
self.final_layer = FinalLayer(hidden_size) self.final_layer = FinalLayer(hidden_size)
@@ -610,7 +611,9 @@ class HunyuanVideoDiT(torch.nn.Module):
): ):
B, C, T, H, W = x.shape B, C, T, H, W = x.shape
vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb) + self.guidance_in(guidance * 1000, dtype=torch.float32) vec = self.time_in(t, dtype=torch.float32) + self.vector_in(pooled_prompt_emb)
if self.guidance_in is not None:
vec += self.guidance_in(guidance * 1000, dtype=torch.float32)
img = self.img_in(x) img = self.img_in(x)
txt = self.txt_in(prompt_emb, t, text_mask) txt = self.txt_in(prompt_emb, t, text_mask)
@@ -783,6 +786,7 @@ class HunyuanVideoDiTStateDictConverter:
pass pass
def from_civitai(self, state_dict): def from_civitai(self, state_dict):
origin_hash_key = hash_state_dict_keys(state_dict, with_shape=True)
if "module" in state_dict: if "module" in state_dict:
state_dict = state_dict["module"] state_dict = state_dict["module"]
direct_dict = { direct_dict = {
@@ -882,4 +886,6 @@ class HunyuanVideoDiTStateDictConverter:
state_dict_[name_] = param state_dict_[name_] = param
else: else:
pass pass
if origin_hash_key == "ae3c22aaa28bfae6f3688f796c9814ae":
return state_dict_, {"in_channels": 33, "guidance_embed":False}
return state_dict_ return state_dict_

View File

@@ -1,24 +1,18 @@
from transformers import LlamaModel, LlamaConfig, DynamicCache from transformers import LlamaModel, LlamaConfig, DynamicCache, LlavaForConditionalGeneration
from copy import deepcopy from copy import deepcopy
import torch import torch
class HunyuanVideoLLMEncoder(LlamaModel): class HunyuanVideoLLMEncoder(LlamaModel):
def __init__(self, config: LlamaConfig): def __init__(self, config: LlamaConfig):
super().__init__(config) super().__init__(config)
self.auto_offload = False self.auto_offload = False
def enable_auto_offload(self, **kwargs): def enable_auto_offload(self, **kwargs):
self.auto_offload = True self.auto_offload = True
def forward( def forward(self, input_ids, attention_mask, hidden_state_skip_layer=2):
self,
input_ids,
attention_mask,
hidden_state_skip_layer=2
):
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
inputs_embeds = embed_tokens(input_ids) inputs_embeds = embed_tokens(input_ids)
@@ -53,3 +47,22 @@ class HunyuanVideoLLMEncoder(LlamaModel):
break break
return hidden_states return hidden_states
class HunyuanVideoMLLMEncoder(LlavaForConditionalGeneration):
def __init__(self, config):
super().__init__(config)
self.auto_offload = False
def enable_auto_offload(self, **kwargs):
self.auto_offload = True
# TODO: implement the low VRAM inference for MLLM.
def forward(self, input_ids, pixel_values, attention_mask, hidden_state_skip_layer=2):
outputs = super().forward(input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
pixel_values=pixel_values)
hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
return hidden_state

View File

@@ -1,8 +1,9 @@
from .base_prompter import BasePrompter from .base_prompter import BasePrompter
from ..models.sd3_text_encoder import SD3TextEncoder1 from ..models.sd3_text_encoder import SD3TextEncoder1
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder
from transformers import CLIPTokenizer, LlamaTokenizerFast from transformers import CLIPTokenizer, LlamaTokenizerFast, CLIPImageProcessor
import os, torch import os, torch
from typing import Union
PROMPT_TEMPLATE_ENCODE = ( PROMPT_TEMPLATE_ENCODE = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, " "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
@@ -18,6 +19,24 @@ PROMPT_TEMPLATE_ENCODE_VIDEO = (
"5. camera angles, movements, and transitions used in the video:<|eot_id|>" "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>") "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
PROMPT_TEMPLATE_ENCODE_I2V = (
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)
PROMPT_TEMPLATE = { PROMPT_TEMPLATE = {
"dit-llm-encode": { "dit-llm-encode": {
"template": PROMPT_TEMPLATE_ENCODE, "template": PROMPT_TEMPLATE_ENCODE,
@@ -27,6 +46,22 @@ PROMPT_TEMPLATE = {
"template": PROMPT_TEMPLATE_ENCODE_VIDEO, "template": PROMPT_TEMPLATE_ENCODE_VIDEO,
"crop_start": 95, "crop_start": 95,
}, },
"dit-llm-encode-i2v": {
"template": PROMPT_TEMPLATE_ENCODE_I2V,
"crop_start": 36,
"image_emb_start": 5,
"image_emb_end": 581,
"image_emb_len": 576,
"double_return_token_id": 271
},
"dit-llm-encode-video-i2v": {
"template": PROMPT_TEMPLATE_ENCODE_VIDEO_I2V,
"crop_start": 103,
"image_emb_start": 5,
"image_emb_end": 581,
"image_emb_len": 576,
"double_return_token_id": 271
},
} }
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion" NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
@@ -52,13 +87,27 @@ class HunyuanVideoPrompter(BasePrompter):
self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right') self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right')
self.text_encoder_1: SD3TextEncoder1 = None self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: HunyuanVideoLLMEncoder = None self.text_encoder_2: HunyuanVideoLLMEncoder = None
self.i2v_mode = False
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode'] self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video'] self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: HunyuanVideoLLMEncoder = None): def fetch_models(self,
text_encoder_1: SD3TextEncoder1 = None,
text_encoder_2: Union[HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder] = None):
self.text_encoder_1 = text_encoder_1 self.text_encoder_1 = text_encoder_1
self.text_encoder_2 = text_encoder_2 self.text_encoder_2 = text_encoder_2
if isinstance(text_encoder_2, HunyuanVideoMLLMEncoder):
# processor
# TODO: may need to replace processor with local implementation
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/hunyuan_video/tokenizer_2")
self.processor = CLIPImageProcessor.from_pretrained(tokenizer_2_path)
# template
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode-i2v']
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video-i2v']
# mode setting
self.i2v_mode = True
def apply_text_to_template(self, text, template): def apply_text_to_template(self, text, template):
assert isinstance(template, str) assert isinstance(template, str)
@@ -107,8 +156,91 @@ class HunyuanVideoPrompter(BasePrompter):
return last_hidden_state, attention_mask return last_hidden_state, attention_mask
def encode_prompt_using_mllm(self,
prompt,
images,
max_length,
device,
crop_start,
hidden_state_skip_layer=2,
use_attention_mask=True,
image_embed_interleave=2):
image_outputs = self.processor(images, return_tensors="pt")[
"pixel_values"
].to(device)
max_length += crop_start
inputs = self.tokenizer_2(prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True)
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
last_hidden_state = self.text_encoder_2(input_ids=input_ids,
attention_mask=attention_mask,
hidden_state_skip_layer=hidden_state_skip_layer,
pixel_values=image_outputs)
text_crop_start = (crop_start - 1 + self.prompt_template_video.get("image_emb_len", 576))
image_crop_start = self.prompt_template_video.get("image_emb_start", 5)
image_crop_end = self.prompt_template_video.get("image_emb_end", 581)
batch_indices, last_double_return_token_indices = torch.where(
input_ids == self.prompt_template_video.get("double_return_token_id", 271))
if last_double_return_token_indices.shape[0] == 3:
# in case the prompt is too long
last_double_return_token_indices = torch.cat((
last_double_return_token_indices,
torch.tensor([input_ids.shape[-1]]),
))
batch_indices = torch.cat((batch_indices, torch.tensor([0])))
last_double_return_token_indices = (last_double_return_token_indices.reshape(input_ids.shape[0], -1)[:, -1])
batch_indices = batch_indices.reshape(input_ids.shape[0], -1)[:, -1]
assistant_crop_start = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576) - 4)
assistant_crop_end = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576))
attention_mask_assistant_crop_start = (last_double_return_token_indices - 4)
attention_mask_assistant_crop_end = last_double_return_token_indices
text_last_hidden_state = []
text_attention_mask = []
image_last_hidden_state = []
image_attention_mask = []
for i in range(input_ids.shape[0]):
text_last_hidden_state.append(
torch.cat([
last_hidden_state[i, text_crop_start:assistant_crop_start[i].item()],
last_hidden_state[i, assistant_crop_end[i].item():],
]))
text_attention_mask.append(
torch.cat([
attention_mask[
i,
crop_start:attention_mask_assistant_crop_start[i].item(),
],
attention_mask[i, attention_mask_assistant_crop_end[i].item():],
]) if use_attention_mask else None)
image_last_hidden_state.append(last_hidden_state[i, image_crop_start:image_crop_end])
image_attention_mask.append(
torch.ones(image_last_hidden_state[-1].shape[0]).to(last_hidden_state.device).
to(attention_mask.dtype) if use_attention_mask else None)
text_last_hidden_state = torch.stack(text_last_hidden_state)
text_attention_mask = torch.stack(text_attention_mask)
image_last_hidden_state = torch.stack(image_last_hidden_state)
image_attention_mask = torch.stack(image_attention_mask)
image_last_hidden_state = image_last_hidden_state[:, ::image_embed_interleave, :]
image_attention_mask = image_attention_mask[:, ::image_embed_interleave]
assert (text_last_hidden_state.shape[0] == text_attention_mask.shape[0] and
image_last_hidden_state.shape[0] == image_attention_mask.shape[0])
last_hidden_state = torch.cat([image_last_hidden_state, text_last_hidden_state], dim=1)
attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1)
return last_hidden_state, attention_mask
def encode_prompt(self, def encode_prompt(self,
prompt, prompt,
images=None,
positive=True, positive=True,
device="cuda", device="cuda",
clip_sequence_length=77, clip_sequence_length=77,
@@ -136,8 +268,11 @@ class HunyuanVideoPrompter(BasePrompter):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device) pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
# LLM # LLM
prompt_emb, attention_mask = self.encode_prompt_using_llm( if images is None:
prompt_formated, llm_sequence_length, device, crop_start, prompt_emb, attention_mask = self.encode_prompt_using_llm(prompt_formated, llm_sequence_length, device, crop_start,
hidden_state_skip_layer, use_attention_mask) hidden_state_skip_layer, use_attention_mask)
else:
prompt_emb, attention_mask = self.encode_prompt_using_mllm(prompt_formated, images, llm_sequence_length, device,
crop_start, hidden_state_skip_layer, use_attention_mask)
return prompt_emb, pooled_prompt_emb, attention_mask return prompt_emb, pooled_prompt_emb, attention_mask

View File

@@ -0,0 +1,45 @@
{
"_valid_processor_keys": [
"images",
"do_resize",
"size",
"resample",
"do_center_crop",
"crop_size",
"do_rescale",
"rescale_factor",
"do_normalize",
"image_mean",
"image_std",
"do_convert_rgb",
"return_tensors",
"data_format",
"input_data_format"
],
"crop_size": {
"height": 336,
"width": 336
},
"do_center_crop": true,
"do_convert_rgb": true,
"do_normalize": true,
"do_rescale": true,
"do_resize": true,
"image_mean": [
0.48145466,
0.4578275,
0.40821073
],
"image_processor_type": "CLIPImageProcessor",
"image_std": [
0.26862954,
0.26130258,
0.27577711
],
"processor_class": "LlavaProcessor",
"resample": 3,
"rescale_factor": 0.00392156862745098,
"size": {
"shortest_edge": 336
}
}

View File

@@ -0,0 +1,88 @@
import torch
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
from diffsynth.prompters.hunyuan_video_prompter import HunyuanVideoPrompter
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0):
num_patches = round((base_size / patch_size)**2)
assert max_ratio >= 1.0
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
aspect_ratio = float(height) / float(width)
closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio))
return buckets[closest_ratio_id], float(closest_ratio)
def prepare_vae_inputs(semantic_images, i2v_resolution="720p"):
if i2v_resolution == "720p":
bucket_hw_base_size = 960
elif i2v_resolution == "540p":
bucket_hw_base_size = 720
elif i2v_resolution == "360p":
bucket_hw_base_size = 480
else:
raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
origin_size = semantic_images[0].size
crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32)
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
ref_image_transform = transforms.Compose([
transforms.Resize(closest_size),
transforms.CenterCrop(closest_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2)
return semantic_image_pixel_values
model_manager = ModelManager()
# The other modules are loaded in float16.
model_manager.load_models(
[
"models/HunyuanVideoI2V/transformers/mp_rank_00_model_states.pt"
],
torch_dtype=torch.bfloat16, # you can use torch_dtype=torch.float8_e4m3fn to enable quantization.
device="cuda"
)
model_manager.load_models(
[
"models/HunyuanVideo/text_encoder/model.safetensors",
"models/HunyuanVideoI2V/text_encoder_2",
'models/HunyuanVideoI2V/vae/pytorch_model.pt'
],
torch_dtype=torch.float16,
device="cuda"
)
# The computation device is "cuda".
pipe = HunyuanVideoPipeline.from_model_manager(
model_manager,
torch_dtype=torch.bfloat16,
device="cuda",
enable_vram_management=False
)
# Although you have enough VRAM, we still recommend you to enable offload.
pipe.enable_cpu_offload()
print()