hunyuanvideo pipeline

This commit is contained in:
Artiprocher
2024-12-18 11:42:43 +08:00
parent 7a45b7efa7
commit b048f1b1de
5 changed files with 279 additions and 31 deletions

View File

@@ -70,12 +70,17 @@ class HunyuanVideoPrompter(BasePrompter):
raise TypeError(f"Unsupported prompt type: {type(text)}")
def encode_prompt_using_clip(self, prompt, max_length, device):
input_ids = self.tokenizer_1(prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True).input_ids.to(device)
return self.text_encoder_1(input_ids=input_ids)[0]
tokenized_result = self.tokenizer_1(
prompt,
return_tensors="pt",
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True
)
input_ids = tokenized_result.input_ids.to(device)
attention_mask = tokenized_result.attention_mask.to(device)
return self.text_encoder_1(input_ids=input_ids, extra_mask=attention_mask)[0]
def encode_prompt_using_llm(self,
prompt,
@@ -110,7 +115,7 @@ class HunyuanVideoPrompter(BasePrompter):
last_hidden_state = last_hidden_state[:, crop_start:]
attention_mask = (attention_mask[:, crop_start:] if use_attention_mask else None)
return last_hidden_state
return last_hidden_state, attention_mask
def encode_prompt(self,
prompt,
@@ -142,8 +147,8 @@ class HunyuanVideoPrompter(BasePrompter):
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
# LLM
prompt_emb = self.encode_prompt_using_llm(
prompt_emb, attention_mask = self.encode_prompt_using_llm(
prompt_formated, llm_sequence_length, device, crop_start,
hidden_state_skip_layer, apply_final_norm, use_attention_mask)
return prompt_emb, pooled_prompt_emb
return prompt_emb, pooled_prompt_emb, attention_mask