mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
hunyuanvideo pipeline
This commit is contained in:
@@ -70,12 +70,17 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
raise TypeError(f"Unsupported prompt type: {type(text)}")
|
||||
|
||||
def encode_prompt_using_clip(self, prompt, max_length, device):
|
||||
input_ids = self.tokenizer_1(prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True).input_ids.to(device)
|
||||
return self.text_encoder_1(input_ids=input_ids)[0]
|
||||
tokenized_result = self.tokenizer_1(
|
||||
prompt,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True
|
||||
)
|
||||
input_ids = tokenized_result.input_ids.to(device)
|
||||
attention_mask = tokenized_result.attention_mask.to(device)
|
||||
return self.text_encoder_1(input_ids=input_ids, extra_mask=attention_mask)[0]
|
||||
|
||||
def encode_prompt_using_llm(self,
|
||||
prompt,
|
||||
@@ -110,7 +115,7 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
last_hidden_state = last_hidden_state[:, crop_start:]
|
||||
attention_mask = (attention_mask[:, crop_start:] if use_attention_mask else None)
|
||||
|
||||
return last_hidden_state
|
||||
return last_hidden_state, attention_mask
|
||||
|
||||
def encode_prompt(self,
|
||||
prompt,
|
||||
@@ -142,8 +147,8 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
pooled_prompt_emb = self.encode_prompt_using_clip(prompt, clip_sequence_length, device)
|
||||
|
||||
# LLM
|
||||
prompt_emb = self.encode_prompt_using_llm(
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_llm(
|
||||
prompt_formated, llm_sequence_length, device, crop_start,
|
||||
hidden_state_skip_layer, apply_final_norm, use_attention_mask)
|
||||
|
||||
return prompt_emb, pooled_prompt_emb
|
||||
return prompt_emb, pooled_prompt_emb, attention_mask
|
||||
|
||||
Reference in New Issue
Block a user