hunyuanvideo text encoder offload

This commit is contained in:
Artiprocher
2024-12-18 19:35:04 +08:00
parent e5099f4e74
commit ec7ac20def
7 changed files with 150 additions and 21 deletions

View File

@@ -111,7 +111,7 @@ huggingface_model_loader_configs = [
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "hunyuan_video_text_encoder_2", "LlamaModel")
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder")
]
patch_model_loader_configs = [
# These configs are provided for detecting model type automatically.

View File

@@ -0,0 +1,55 @@
from transformers import LlamaModel, LlamaConfig, DynamicCache
from copy import deepcopy
import torch
class HunyuanVideoLLMEncoder(LlamaModel):
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.auto_offload = False
def enable_auto_offload(self, **kwargs):
self.auto_offload = True
def forward(
self,
input_ids,
attention_mask,
hidden_state_skip_layer=2
):
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
inputs_embeds = embed_tokens(input_ids)
past_key_values = DynamicCache()
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, None, False)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
rotary_emb = deepcopy(self.rotary_emb).to(input_ids.device) if self.auto_offload else self.rotary_emb
position_embeddings = rotary_emb(hidden_states, position_ids)
# decoder layers
for layer_id, decoder_layer in enumerate(self.layers):
if self.auto_offload:
decoder_layer = deepcopy(decoder_layer).to(hidden_states.device)
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=False,
use_cache=True,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
hidden_states = layer_outputs[0]
if layer_id + hidden_state_skip_layer + 1 >= len(self.layers):
break
return hidden_states

View File

@@ -1,10 +1,10 @@
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder
from ..models.hunyuan_video_dit import HunyuanVideoDiT
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
from ..prompters import HunyuanVideoPrompter
import torch
from transformers import LlamaModel
from einops import rearrange
import numpy as np
from PIL import Image
@@ -18,7 +18,7 @@ class HunyuanVideoPipeline(BasePipeline):
self.scheduler = FlowMatchScheduler(shift=7.0, sigma_min=0.0, extra_one_step=True)
self.prompter = HunyuanVideoPrompter()
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: LlamaModel = None
self.text_encoder_2: HunyuanVideoLLMEncoder = None
self.dit: HunyuanVideoDiT = None
self.vae_decoder: HunyuanVideoVAEDecoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder']
@@ -28,6 +28,7 @@ class HunyuanVideoPipeline(BasePipeline):
def enable_vram_management(self):
self.vram_management = True
self.enable_cpu_offload()
self.text_encoder_2.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
self.dit.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
@@ -91,7 +92,7 @@ class HunyuanVideoPipeline(BasePipeline):
latents = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# Encode prompts
self.load_models_to_device(["text_encoder_1", "text_encoder_2"])
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"])
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)

View File

@@ -1,6 +1,7 @@
from .base_prompter import BasePrompter
from ..models.sd3_text_encoder import SD3TextEncoder1
from transformers import CLIPTokenizer, LlamaTokenizerFast, LlamaModel
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
from transformers import CLIPTokenizer, LlamaTokenizerFast
import os, torch
PROMPT_TEMPLATE_ENCODE = (
@@ -50,12 +51,12 @@ class HunyuanVideoPrompter(BasePrompter):
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right')
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: LlamaModel = None
self.text_encoder_2: HunyuanVideoLLMEncoder = None
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: LlamaModel = None):
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: HunyuanVideoLLMEncoder = None):
self.text_encoder_1 = text_encoder_1
self.text_encoder_2 = text_encoder_2
@@ -88,7 +89,6 @@ class HunyuanVideoPrompter(BasePrompter):
device,
crop_start,
hidden_state_skip_layer=2,
apply_final_norm=False,
use_attention_mask=True):
max_length += crop_start
inputs = self.tokenizer_2(prompt,
@@ -98,18 +98,8 @@ class HunyuanVideoPrompter(BasePrompter):
truncation=True)
input_ids = inputs.input_ids.to(device)
attention_mask = inputs.attention_mask.to(device)
output_hidden_states = hidden_state_skip_layer is not None
outputs = self.text_encoder_2(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=output_hidden_states)
last_hidden_state = self.text_encoder_2(input_ids, attention_mask, hidden_state_skip_layer)
if hidden_state_skip_layer is not None:
last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
if hidden_state_skip_layer > 0 and apply_final_norm:
last_hidden_state = self.text_encoder_2.norm(last_hidden_state)
else:
last_hidden_state = outputs['last_hidden_state']
# crop out
if crop_start > 0:
last_hidden_state = last_hidden_state[:, crop_start:]
@@ -126,7 +116,6 @@ class HunyuanVideoPrompter(BasePrompter):
data_type='video',
use_template=True,
hidden_state_skip_layer=2,
apply_final_norm=False,
use_attention_mask=True):
prompt = self.process_prompt(prompt, positive=positive)
@@ -149,6 +138,6 @@ class HunyuanVideoPrompter(BasePrompter):
# LLM
prompt_emb, attention_mask = self.encode_prompt_using_llm(
prompt_formated, llm_sequence_length, device, crop_start,
hidden_state_skip_layer, apply_final_norm, use_attention_mask)
hidden_state_skip_layer, use_attention_mask)
return prompt_emb, pooled_prompt_emb, attention_mask