mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
hunyuanvideo text encoder offload
This commit is contained in:
@@ -111,7 +111,7 @@ huggingface_model_loader_configs = [
|
||||
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
||||
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
||||
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
||||
("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "hunyuan_video_text_encoder_2", "LlamaModel")
|
||||
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder")
|
||||
]
|
||||
patch_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
|
||||
55
diffsynth/models/hunyuan_video_text_encoder.py
Normal file
55
diffsynth/models/hunyuan_video_text_encoder.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from transformers import LlamaModel, LlamaConfig, DynamicCache
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
|
||||
|
||||
class HunyuanVideoLLMEncoder(LlamaModel):
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.auto_offload = False
|
||||
|
||||
|
||||
def enable_auto_offload(self, **kwargs):
|
||||
self.auto_offload = True
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
hidden_state_skip_layer=2
|
||||
):
|
||||
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
|
||||
inputs_embeds = embed_tokens(input_ids)
|
||||
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, None, False)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
rotary_emb = deepcopy(self.rotary_emb).to(input_ids.device) if self.auto_offload else self.rotary_emb
|
||||
position_embeddings = rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
for layer_id, decoder_layer in enumerate(self.layers):
|
||||
if self.auto_offload:
|
||||
decoder_layer = deepcopy(decoder_layer).to(hidden_states.device)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
if layer_id + hidden_state_skip_layer + 1 >= len(self.layers):
|
||||
break
|
||||
|
||||
return hidden_states
|
||||
@@ -1,10 +1,10 @@
|
||||
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder
|
||||
from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from ..prompters import HunyuanVideoPrompter
|
||||
import torch
|
||||
from transformers import LlamaModel
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@@ -18,7 +18,7 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
self.scheduler = FlowMatchScheduler(shift=7.0, sigma_min=0.0, extra_one_step=True)
|
||||
self.prompter = HunyuanVideoPrompter()
|
||||
self.text_encoder_1: SD3TextEncoder1 = None
|
||||
self.text_encoder_2: LlamaModel = None
|
||||
self.text_encoder_2: HunyuanVideoLLMEncoder = None
|
||||
self.dit: HunyuanVideoDiT = None
|
||||
self.vae_decoder: HunyuanVideoVAEDecoder = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder']
|
||||
@@ -28,6 +28,7 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
def enable_vram_management(self):
|
||||
self.vram_management = True
|
||||
self.enable_cpu_offload()
|
||||
self.text_encoder_2.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
|
||||
self.dit.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
|
||||
@@ -91,7 +92,7 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
latents = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(["text_encoder_1", "text_encoder_2"])
|
||||
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.sd3_text_encoder import SD3TextEncoder1
|
||||
from transformers import CLIPTokenizer, LlamaTokenizerFast, LlamaModel
|
||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
|
||||
from transformers import CLIPTokenizer, LlamaTokenizerFast
|
||||
import os, torch
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE = (
|
||||
@@ -50,12 +51,12 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
|
||||
self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right')
|
||||
self.text_encoder_1: SD3TextEncoder1 = None
|
||||
self.text_encoder_2: LlamaModel = None
|
||||
self.text_encoder_2: HunyuanVideoLLMEncoder = None
|
||||
|
||||
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
|
||||
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
|
||||
|
||||
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: LlamaModel = None):
|
||||
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: HunyuanVideoLLMEncoder = None):
|
||||
self.text_encoder_1 = text_encoder_1
|
||||
self.text_encoder_2 = text_encoder_2
|
||||
|
||||
@@ -88,7 +89,6 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
device,
|
||||
crop_start,
|
||||
hidden_state_skip_layer=2,
|
||||
apply_final_norm=False,
|
||||
use_attention_mask=True):
|
||||
max_length += crop_start
|
||||
inputs = self.tokenizer_2(prompt,
|
||||
@@ -98,18 +98,8 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
truncation=True)
|
||||
input_ids = inputs.input_ids.to(device)
|
||||
attention_mask = inputs.attention_mask.to(device)
|
||||
output_hidden_states = hidden_state_skip_layer is not None
|
||||
outputs = self.text_encoder_2(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=output_hidden_states)
|
||||
last_hidden_state = self.text_encoder_2(input_ids, attention_mask, hidden_state_skip_layer)
|
||||
|
||||
if hidden_state_skip_layer is not None:
|
||||
last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
|
||||
if hidden_state_skip_layer > 0 and apply_final_norm:
|
||||
last_hidden_state = self.text_encoder_2.norm(last_hidden_state)
|
||||
else:
|
||||
last_hidden_state = outputs['last_hidden_state']
|
||||
# crop out
|
||||
if crop_start > 0:
|
||||
last_hidden_state = last_hidden_state[:, crop_start:]
|
||||
@@ -126,7 +116,6 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
data_type='video',
|
||||
use_template=True,
|
||||
hidden_state_skip_layer=2,
|
||||
apply_final_norm=False,
|
||||
use_attention_mask=True):
|
||||
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
@@ -149,6 +138,6 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
# LLM
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_llm(
|
||||
prompt_formated, llm_sequence_length, device, crop_start,
|
||||
hidden_state_skip_layer, apply_final_norm, use_attention_mask)
|
||||
hidden_state_skip_layer, use_attention_mask)
|
||||
|
||||
return prompt_emb, pooled_prompt_emb, attention_mask
|
||||
|
||||
Reference in New Issue
Block a user