mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
hunyuanvideo text encoder offload
This commit is contained in:
@@ -111,7 +111,7 @@ huggingface_model_loader_configs = [
|
||||
("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"),
|
||||
("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"),
|
||||
("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"),
|
||||
("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "hunyuan_video_text_encoder_2", "LlamaModel")
|
||||
("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder")
|
||||
]
|
||||
patch_model_loader_configs = [
|
||||
# These configs are provided for detecting model type automatically.
|
||||
|
||||
55
diffsynth/models/hunyuan_video_text_encoder.py
Normal file
55
diffsynth/models/hunyuan_video_text_encoder.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from transformers import LlamaModel, LlamaConfig, DynamicCache
|
||||
from copy import deepcopy
|
||||
import torch
|
||||
|
||||
|
||||
class HunyuanVideoLLMEncoder(LlamaModel):
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super().__init__(config)
|
||||
self.auto_offload = False
|
||||
|
||||
|
||||
def enable_auto_offload(self, **kwargs):
|
||||
self.auto_offload = True
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
hidden_state_skip_layer=2
|
||||
):
|
||||
embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens
|
||||
inputs_embeds = embed_tokens(input_ids)
|
||||
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, None, False)
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
rotary_emb = deepcopy(self.rotary_emb).to(input_ids.device) if self.auto_offload else self.rotary_emb
|
||||
position_embeddings = rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
for layer_id, decoder_layer in enumerate(self.layers):
|
||||
if self.auto_offload:
|
||||
decoder_layer = deepcopy(decoder_layer).to(hidden_states.device)
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
if layer_id + hidden_state_skip_layer + 1 >= len(self.layers):
|
||||
break
|
||||
|
||||
return hidden_states
|
||||
@@ -1,10 +1,10 @@
|
||||
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder
|
||||
from ..models.hunyuan_video_dit import HunyuanVideoDiT
|
||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
|
||||
from ..schedulers.flow_match import FlowMatchScheduler
|
||||
from .base import BasePipeline
|
||||
from ..prompters import HunyuanVideoPrompter
|
||||
import torch
|
||||
from transformers import LlamaModel
|
||||
from einops import rearrange
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@@ -18,7 +18,7 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
self.scheduler = FlowMatchScheduler(shift=7.0, sigma_min=0.0, extra_one_step=True)
|
||||
self.prompter = HunyuanVideoPrompter()
|
||||
self.text_encoder_1: SD3TextEncoder1 = None
|
||||
self.text_encoder_2: LlamaModel = None
|
||||
self.text_encoder_2: HunyuanVideoLLMEncoder = None
|
||||
self.dit: HunyuanVideoDiT = None
|
||||
self.vae_decoder: HunyuanVideoVAEDecoder = None
|
||||
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder']
|
||||
@@ -28,6 +28,7 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
def enable_vram_management(self):
|
||||
self.vram_management = True
|
||||
self.enable_cpu_offload()
|
||||
self.text_encoder_2.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
|
||||
self.dit.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
|
||||
|
||||
|
||||
@@ -91,7 +92,7 @@ class HunyuanVideoPipeline(BasePipeline):
|
||||
latents = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
self.load_models_to_device(["text_encoder_1", "text_encoder_2"])
|
||||
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"])
|
||||
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
|
||||
if cfg_scale != 1.0:
|
||||
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from .base_prompter import BasePrompter
|
||||
from ..models.sd3_text_encoder import SD3TextEncoder1
|
||||
from transformers import CLIPTokenizer, LlamaTokenizerFast, LlamaModel
|
||||
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
|
||||
from transformers import CLIPTokenizer, LlamaTokenizerFast
|
||||
import os, torch
|
||||
|
||||
PROMPT_TEMPLATE_ENCODE = (
|
||||
@@ -50,12 +51,12 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
|
||||
self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right')
|
||||
self.text_encoder_1: SD3TextEncoder1 = None
|
||||
self.text_encoder_2: LlamaModel = None
|
||||
self.text_encoder_2: HunyuanVideoLLMEncoder = None
|
||||
|
||||
self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode']
|
||||
self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video']
|
||||
|
||||
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: LlamaModel = None):
|
||||
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: HunyuanVideoLLMEncoder = None):
|
||||
self.text_encoder_1 = text_encoder_1
|
||||
self.text_encoder_2 = text_encoder_2
|
||||
|
||||
@@ -88,7 +89,6 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
device,
|
||||
crop_start,
|
||||
hidden_state_skip_layer=2,
|
||||
apply_final_norm=False,
|
||||
use_attention_mask=True):
|
||||
max_length += crop_start
|
||||
inputs = self.tokenizer_2(prompt,
|
||||
@@ -98,18 +98,8 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
truncation=True)
|
||||
input_ids = inputs.input_ids.to(device)
|
||||
attention_mask = inputs.attention_mask.to(device)
|
||||
output_hidden_states = hidden_state_skip_layer is not None
|
||||
outputs = self.text_encoder_2(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=output_hidden_states)
|
||||
last_hidden_state = self.text_encoder_2(input_ids, attention_mask, hidden_state_skip_layer)
|
||||
|
||||
if hidden_state_skip_layer is not None:
|
||||
last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
|
||||
if hidden_state_skip_layer > 0 and apply_final_norm:
|
||||
last_hidden_state = self.text_encoder_2.norm(last_hidden_state)
|
||||
else:
|
||||
last_hidden_state = outputs['last_hidden_state']
|
||||
# crop out
|
||||
if crop_start > 0:
|
||||
last_hidden_state = last_hidden_state[:, crop_start:]
|
||||
@@ -126,7 +116,6 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
data_type='video',
|
||||
use_template=True,
|
||||
hidden_state_skip_layer=2,
|
||||
apply_final_norm=False,
|
||||
use_attention_mask=True):
|
||||
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
@@ -149,6 +138,6 @@ class HunyuanVideoPrompter(BasePrompter):
|
||||
# LLM
|
||||
prompt_emb, attention_mask = self.encode_prompt_using_llm(
|
||||
prompt_formated, llm_sequence_length, device, crop_start,
|
||||
hidden_state_skip_layer, apply_final_norm, use_attention_mask)
|
||||
hidden_state_skip_layer, use_attention_mask)
|
||||
|
||||
return prompt_emb, pooled_prompt_emb, attention_mask
|
||||
|
||||
42
examples/HunyuanVideo/hunyuanvideo_16G.py
Normal file
42
examples/HunyuanVideo/hunyuanvideo_16G.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
torch.cuda.set_per_process_memory_fraction(1.0, 0)
|
||||
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||
|
||||
|
||||
download_models(["HunyuanVideo"])
|
||||
model_manager = ModelManager()
|
||||
|
||||
# The DiT model is loaded in bfloat16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# The other modules are loaded in float16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideo/text_encoder_2",
|
||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||
],
|
||||
torch_dtype=torch.float16,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# We support LoRA inference. You can use the following code to load your LoRA model.
|
||||
# model_manager.load_lora("models/lora/xxx.safetensors", lora_alpha=1.0)
|
||||
|
||||
# The computation device is "cuda".
|
||||
pipe = HunyuanVideoPipeline.from_model_manager(
|
||||
model_manager,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Enjoy!
|
||||
prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her."
|
||||
video = pipe(prompt, seed=0, height=720, width=960)
|
||||
save_video(video, "video.mp4", fps=30, quality=5)
|
||||
42
examples/HunyuanVideo/hunyuanvideo_8G.py
Normal file
42
examples/HunyuanVideo/hunyuanvideo_8G.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
torch.cuda.set_per_process_memory_fraction(1.0, 0)
|
||||
from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video
|
||||
|
||||
|
||||
download_models(["HunyuanVideo"])
|
||||
model_manager = ModelManager()
|
||||
|
||||
# The DiT model is loaded in bfloat16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/transformers/mp_rank_00_model_states.pt"
|
||||
],
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# The other modules are loaded in float16.
|
||||
model_manager.load_models(
|
||||
[
|
||||
"models/HunyuanVideo/text_encoder/model.safetensors",
|
||||
"models/HunyuanVideo/text_encoder_2",
|
||||
"models/HunyuanVideo/vae/pytorch_model.pt",
|
||||
],
|
||||
torch_dtype=torch.float16,
|
||||
device="cpu"
|
||||
)
|
||||
|
||||
# We support LoRA inference. You can use the following code to load your LoRA model.
|
||||
model_manager.load_lora("models/lora/Rem_hunyuan_video_v3.safetensors", lora_alpha=1.0)
|
||||
|
||||
# The computation device is "cuda".
|
||||
pipe = HunyuanVideoPipeline.from_model_manager(
|
||||
model_manager,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device="cuda"
|
||||
)
|
||||
|
||||
# Enjoy!
|
||||
prompt = "a woman with blue hair wearing a white and black dress, sitting on a bed with a white wall in the background. she is wearing a re:zero starting life in another world rem cosplay costume, complete with a black and white dress, black gloves, and a black bow tie."
|
||||
video = pipe(prompt, seed=0, height=512, width=512, tile_size=(17, 16, 16), tile_stride=(12, 12, 12))
|
||||
save_video(video, "video.mp4", fps=30, quality=5)
|
||||
Reference in New Issue
Block a user