mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-23 00:58:11 +00:00
hunyuanvideo text encoder offload
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.sd3_text_encoder import SD3TextEncoder1
|
||||
from transformers import CLIPTokenizer, LlamaTokenizerFast, LlamaModel
|
||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
|
||||
from transformers import CLIPTokenizer, LlamaTokenizerFast
|
||||
import os, torch
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE = (
|
||||
@@ -50,12 +51,12 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
|
||||
self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right')
|
||||
self.text_encoder_1: SD3TextEncoder1 = None
|
||||
self.text_encoder_2: LlamaModel = None
|
||||
self.text_encoder_2: HunyuanVideoLLMEncoder = None
|
||||
|
||||
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
|
||||
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
|
||||
|
||||
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: LlamaModel = None):
|
||||
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: HunyuanVideoLLMEncoder = None):
|
||||
self.text_encoder_1 = text_encoder_1
|
||||
self.text_encoder_2 = text_encoder_2
|
||||
|
||||
@@ -88,7 +89,6 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
device,
|
||||
crop_start,
|
||||
hidden_state_skip_layer=2,
|
||||
apply_final_norm=False,
|
||||
use_attention_mask=True):
|
||||
max_length += crop_start
|
||||
inputs = self.tokenizer_2(prompt,
|
||||
@@ -98,18 +98,8 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
truncation=True)
|
||||
input_ids = inputs.input_ids.to(device)
|
||||
attention_mask = inputs.attention_mask.to(device)
|
||||
output_hidden_states = hidden_state_skip_layer is not None
|
||||
outputs = self.text_encoder_2(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=output_hidden_states)
|
||||
last_hidden_state = self.text_encoder_2(input_ids, attention_mask, hidden_state_skip_layer)
|
||||
|
||||
if hidden_state_skip_layer is not None:
|
||||
last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
|
||||
if hidden_state_skip_layer > 0 and apply_final_norm:
|
||||
last_hidden_state = self.text_encoder_2.norm(last_hidden_state)
|
||||
else:
|
||||
last_hidden_state = outputs['last_hidden_state']
|
||||
# crop out
|
||||
if crop_start > 0:
|
||||
last_hidden_state = last_hidden_state[:, crop_start:]
|
||||
@@ -126,7 +116,6 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
data_type='video',
|
||||
use_template=True,
|
||||
hidden_state_skip_layer=2,
|
||||
apply_final_norm=False,
|
||||
use_attention_mask=True):
|
||||
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
@@ -149,6 +138,6 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
# LLM
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_llm(
|
||||
prompt_formated, llm_sequence_length, device, crop_start,
|
||||
hidden_state_skip_layer, apply_final_norm, use_attention_mask)
|
||||
hidden_state_skip_layer, use_attention_mask)
|
||||
|
||||
return prompt_emb, pooled_prompt_emb, attention_mask
|
||||
|
||||
Reference in New Issue
Block a user