mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:39:43 +00:00
69 lines
2.7 KiB
Python
69 lines
2.7 KiB
Python
from transformers import LlamaModel, LlamaConfig, DynamicCache, LlavaForConditionalGeneration
|
|
from copy import deepcopy
|
|
import torch
|
|
|
|
|
|
class HunyuanVideoLLMEncoder(LlamaModel):
|
|
|
|
def __init__(self, config: LlamaConfig):
|
|
super().__init__(config)
|
|
self.auto_offload = False
|
|
|
|
def enable_auto_offload(self, **kwargs):
|
|
self.auto_offload = True
|
|
|
|
def forward(self, input_ids, attention_mask, hidden_state_skip_layer=2):
|
|
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
|
|
inputs_embeds = embed_tokens(input_ids)
|
|
|
|
past_key_values = DynamicCache()
|
|
|
|
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
|
|
position_ids = cache_position.unsqueeze(0)
|
|
|
|
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, None, False)
|
|
hidden_states = inputs_embeds
|
|
|
|
# create position embeddings to be shared across the decoder layers
|
|
rotary_emb = deepcopy(self.rotary_emb).to(input_ids.device) if self.auto_offload else self.rotary_emb
|
|
position_embeddings = rotary_emb(hidden_states, position_ids)
|
|
|
|
# decoder layers
|
|
for layer_id, decoder_layer in enumerate(self.layers):
|
|
if self.auto_offload:
|
|
decoder_layer = deepcopy(decoder_layer).to(hidden_states.device)
|
|
layer_outputs = decoder_layer(
|
|
hidden_states,
|
|
attention_mask=causal_mask,
|
|
position_ids=position_ids,
|
|
past_key_value=past_key_values,
|
|
output_attentions=False,
|
|
use_cache=True,
|
|
cache_position=cache_position,
|
|
position_embeddings=position_embeddings,
|
|
)
|
|
hidden_states = layer_outputs[0]
|
|
if layer_id + hidden_state_skip_layer + 1 >= len(self.layers):
|
|
break
|
|
|
|
return hidden_states
|
|
|
|
|
|
class HunyuanVideoMLLMEncoder(LlavaForConditionalGeneration):
|
|
|
|
def __init__(self, config):
|
|
super().__init__(config)
|
|
self.auto_offload = False
|
|
|
|
def enable_auto_offload(self, **kwargs):
|
|
self.auto_offload = True
|
|
|
|
# TODO: implement the low VRAM inference for MLLM.
|
|
def forward(self, input_ids, pixel_values, attention_mask, hidden_state_skip_layer=2):
|
|
outputs = super().forward(input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
output_hidden_states=True,
|
|
pixel_values=pixel_values)
|
|
hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
|
|
return hidden_state
|