hunyuanvideo text encoder offload

This commit is contained in:
Artiprocher
2024-12-18 19:35:04 +08:00
parent e5099f4e74
commit ec7ac20def
7 changed files with 150 additions and 21 deletions

View File

@@ -1,6 +1,7 @@
from .base_prompter import BasePrompter
from ..models.sd3_text_encoder import SD3TextEncoder1
from transformers import CLIPTokenizer, LlamaTokenizerFast, LlamaModel
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
from transformers import CLIPTokenizer, LlamaTokenizerFast
import os, torch
PROMPT_TEMPLATE_ENCODE = (
@@ -50,12 +51,12 @@ class HunyuanVideoPrompter(BasePrompter):
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right')
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: LlamaModel = None
self.text_encoder_2: HunyuanVideoLLMEncoder = None
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: LlamaModel = None):
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: HunyuanVideoLLMEncoder = None):
self.text_encoder_1 = text_encoder_1
self.text_encoder_2 = text_encoder_2
@@ -88,7 +89,6 @@ class HunyuanVideoPrompter(BasePrompter):
device,
crop_start,
hidden_state_skip_layer=2,
apply_final_norm=False,
use_attention_mask=True):
max_length += crop_start
inputs = self.tokenizer_2(prompt,
@@ -98,18 +98,8 @@ class HunyuanVideoPrompter(BasePrompter):
truncation=True)
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
output_hidden_states = hidden_state_skip_layer is not None
outputs = self.text_encoder_2(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=output_hidden_states)
last_hidden_state = self.text_encoder_2(input_ids, attention_mask, hidden_state_skip_layer)
if hidden_state_skip_layer is not None:
last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
if hidden_state_skip_layer > 0 and apply_final_norm:
last_hidden_state = self.text_encoder_2.norm(last_hidden_state)
else:
last_hidden_state = outputs['last_hidden_state']
# crop out
if crop_start > 0:
last_hidden_state = last_hidden_state[:, crop_start:]
@@ -126,7 +116,6 @@ class HunyuanVideoPrompter(BasePrompter):
data_type='video',
use_template=True,
hidden_state_skip_layer=2,
apply_final_norm=False,
use_attention_mask=True):
prompt = self.process_prompt(prompt, positive=positive)
@@ -149,6 +138,6 @@ class HunyuanVideoPrompter(BasePrompter):
# LLM
prompt_emb, attention_mask = self.encode_prompt_using_llm(
prompt_formated, llm_sequence_length, device, crop_start,
hidden_state_skip_layer, apply_final_norm, use_attention_mask)
hidden_state_skip_layer, use_attention_mask)
return prompt_emb, pooled_prompt_emb, attention_mask