mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 08:40:47 +00:00
load hunyuani2v model
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.sd3_text_encoder import SD3TextEncoder1
|
||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
|
||||
from transformers import CLIPTokenizer, LlamaTokenizerFast
|
||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder
|
||||
from transformers import CLIPTokenizer, LlamaTokenizerFast, CLIPImageProcessor
|
||||
import os, torch
|
||||
from typing import Union
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
|
||||
@@ -18,6 +19,24 @@ PROMPT_TEMPLATE_ENCODE_VIDEO = (
|
||||
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE_I2V = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the image by detailing the color, shape, size, texture, "
|
||||
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
|
||||
"<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
|
||||
"1. The main content and theme of the video."
|
||||
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
||||
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
||||
"4. background environment, light, style and atmosphere."
|
||||
"5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
|
||||
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
||||
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
|
||||
PROMPT_TEMPLATE = {
|
||||
"dit-llm-encode": {
|
||||
"template": PROMPT_TEMPLATE_ENCODE,
|
||||
@@ -27,6 +46,22 @@ PROMPT_TEMPLATE = {
|
||||
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
|
||||
"crop_start": 95,
|
||||
},
|
||||
"dit-llm-encode-i2v": {
|
||||
"template": PROMPT_TEMPLATE_ENCODE_I2V,
|
||||
"crop_start": 36,
|
||||
"image_emb_start": 5,
|
||||
"image_emb_end": 581,
|
||||
"image_emb_len": 576,
|
||||
"double_return_token_id": 271
|
||||
},
|
||||
"dit-llm-encode-video-i2v": {
|
||||
"template": PROMPT_TEMPLATE_ENCODE_VIDEO_I2V,
|
||||
"crop_start": 103,
|
||||
"image_emb_start": 5,
|
||||
"image_emb_end": 581,
|
||||
"image_emb_len": 576,
|
||||
"double_return_token_id": 271
|
||||
},
|
||||
}
|
||||
|
||||
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
|
||||
@@ -52,13 +87,27 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right')
|
||||
self.text_encoder_1: SD3TextEncoder1 = None
|
||||
self.text_encoder_2: HunyuanVideoLLMEncoder = None
|
||||
self.i2v_mode = False
|
||||
|
||||
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
|
||||
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
|
||||
|
||||
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: HunyuanVideoLLMEncoder = None):
|
||||
def fetch_models(self,
|
||||
text_encoder_1: SD3TextEncoder1 = None,
|
||||
text_encoder_2: Union[HunyuanVideoLLMEncoder, HunyuanVideoMLLMEncoder] = None):
|
||||
self.text_encoder_1 = text_encoder_1
|
||||
self.text_encoder_2 = text_encoder_2
|
||||
if isinstance(text_encoder_2, HunyuanVideoMLLMEncoder):
|
||||
# processor
|
||||
# TODO: may need to replace processor with local implementation
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_2_path = os.path.join(base_path, "tokenizer_configs/hunyuan_video/tokenizer_2")
|
||||
self.processor = CLIPImageProcessor.from_pretrained(tokenizer_2_path)
|
||||
# template
|
||||
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode-i2v']
|
||||
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video-i2v']
|
||||
# mode setting
|
||||
self.i2v_mode = True
|
||||
|
||||
def apply_text_to_template(self, text, template):
|
||||
assert isinstance(template, str)
|
||||
@@ -107,8 +156,91 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
|
||||
return last_hidden_state, attention_mask
|
||||
|
||||
def encode_prompt_using_mllm(self,
|
||||
prompt,
|
||||
images,
|
||||
max_length,
|
||||
device,
|
||||
crop_start,
|
||||
hidden_state_skip_layer=2,
|
||||
use_attention_mask=True,
|
||||
image_embed_interleave=2):
|
||||
image_outputs = self.processor(images, return_tensors="pt")[
|
||||
"pixel_values"
|
||||
].to(device)
|
||||
max_length += crop_start
|
||||
inputs = self.tokenizer_2(prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True)
|
||||
input_ids = inputs.input_ids.to(device)
|
||||
attention_mask = inputs.attention_mask.to(device)
|
||||
last_hidden_state = self.text_encoder_2(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
hidden_state_skip_layer=hidden_state_skip_layer,
|
||||
pixel_values=image_outputs)
|
||||
|
||||
text_crop_start = (crop_start - 1 + self.prompt_template_video.get("image_emb_len", 576))
|
||||
image_crop_start = self.prompt_template_video.get("image_emb_start", 5)
|
||||
image_crop_end = self.prompt_template_video.get("image_emb_end", 581)
|
||||
batch_indices, last_double_return_token_indices = torch.where(
|
||||
input_ids == self.prompt_template_video.get("double_return_token_id", 271))
|
||||
if last_double_return_token_indices.shape[0] == 3:
|
||||
# in case the prompt is too long
|
||||
last_double_return_token_indices = torch.cat((
|
||||
last_double_return_token_indices,
|
||||
torch.tensor([input_ids.shape[-1]]),
|
||||
))
|
||||
batch_indices = torch.cat((batch_indices, torch.tensor([0])))
|
||||
last_double_return_token_indices = (last_double_return_token_indices.reshape(input_ids.shape[0], -1)[:, -1])
|
||||
batch_indices = batch_indices.reshape(input_ids.shape[0], -1)[:, -1]
|
||||
assistant_crop_start = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576) - 4)
|
||||
assistant_crop_end = (last_double_return_token_indices - 1 + self.prompt_template_video.get("image_emb_len", 576))
|
||||
attention_mask_assistant_crop_start = (last_double_return_token_indices - 4)
|
||||
attention_mask_assistant_crop_end = last_double_return_token_indices
|
||||
text_last_hidden_state = []
|
||||
text_attention_mask = []
|
||||
image_last_hidden_state = []
|
||||
image_attention_mask = []
|
||||
for i in range(input_ids.shape[0]):
|
||||
text_last_hidden_state.append(
|
||||
torch.cat([
|
||||
last_hidden_state[i, text_crop_start:assistant_crop_start[i].item()],
|
||||
last_hidden_state[i, assistant_crop_end[i].item():],
|
||||
]))
|
||||
text_attention_mask.append(
|
||||
torch.cat([
|
||||
attention_mask[
|
||||
i,
|
||||
crop_start:attention_mask_assistant_crop_start[i].item(),
|
||||
],
|
||||
attention_mask[i, attention_mask_assistant_crop_end[i].item():],
|
||||
]) if use_attention_mask else None)
|
||||
image_last_hidden_state.append(last_hidden_state[i, image_crop_start:image_crop_end])
|
||||
image_attention_mask.append(
|
||||
torch.ones(image_last_hidden_state[-1].shape[0]).to(last_hidden_state.device).
|
||||
to(attention_mask.dtype) if use_attention_mask else None)
|
||||
|
||||
text_last_hidden_state = torch.stack(text_last_hidden_state)
|
||||
text_attention_mask = torch.stack(text_attention_mask)
|
||||
image_last_hidden_state = torch.stack(image_last_hidden_state)
|
||||
image_attention_mask = torch.stack(image_attention_mask)
|
||||
|
||||
image_last_hidden_state = image_last_hidden_state[:, ::image_embed_interleave, :]
|
||||
image_attention_mask = image_attention_mask[:, ::image_embed_interleave]
|
||||
|
||||
assert (text_last_hidden_state.shape[0] == text_attention_mask.shape[0] and
|
||||
image_last_hidden_state.shape[0] == image_attention_mask.shape[0])
|
||||
|
||||
last_hidden_state = torch.cat([image_last_hidden_state, text_last_hidden_state], dim=1)
|
||||
attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1)
|
||||
|
||||
return last_hidden_state, attention_mask
|
||||
|
||||
def encode_prompt(self,
|
||||
prompt,
|
||||
images=None,
|
||||
positive=True,
|
||||
device="cuda",
|
||||
clip_sequence_length=77,
|
||||
@@ -136,8 +268,11 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
|
||||
|
||||
# LLM
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_llm(
|
||||
prompt_formated, llm_sequence_length, device, crop_start,
|
||||
hidden_state_skip_layer, use_attention_mask)
|
||||
if images is None:
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_llm(prompt_formated, llm_sequence_length, device, crop_start,
|
||||
hidden_state_skip_layer, use_attention_mask)
|
||||
else:
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_mllm(prompt_formated, images, llm_sequence_length, device,
|
||||
crop_start, hidden_state_skip_layer, use_attention_mask)
|
||||
|
||||
return prompt_emb, pooled_prompt_emb, attention_mask
|
||||
|
||||
Reference in New Issue
Block a user