diff --git a/diffsynth/configs/model_config.py b/diffsynth/configs/model_config.py index b4f08ba..bc2b675 100644 --- a/diffsynth/configs/model_config.py +++ b/diffsynth/configs/model_config.py @@ -111,7 +111,7 @@ huggingface_model_loader_configs = [ ("T5EncoderModel", "diffsynth.models.flux_text_encoder", "flux_text_encoder_2", "FluxTextEncoder2"), ("CogVideoXTransformer3DModel", "diffsynth.models.cog_dit", "cog_dit", "CogDiT"), ("SiglipModel", "transformers.models.siglip.modeling_siglip", "siglip_vision_model", "SiglipVisionModel"), - ("LlamaForCausalLM", "transformers.models.llama.modeling_llama", "hunyuan_video_text_encoder_2", "LlamaModel") + ("LlamaForCausalLM", "diffsynth.models.hunyuan_video_text_encoder", "hunyuan_video_text_encoder_2", "HunyuanVideoLLMEncoder") ] patch_model_loader_configs = [ # These configs are provided for detecting model type automatically. diff --git a/diffsynth/models/hunyuan_video_text_encoder.py b/diffsynth/models/hunyuan_video_text_encoder.py new file mode 100644 index 0000000..df5755f --- /dev/null +++ b/diffsynth/models/hunyuan_video_text_encoder.py @@ -0,0 +1,55 @@ +from transformers import LlamaModel, LlamaConfig, DynamicCache +from copy import deepcopy +import torch + + +class HunyuanVideoLLMEncoder(LlamaModel): + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.auto_offload = False + + + def enable_auto_offload(self, **kwargs): + self.auto_offload = True + + + def forward( + self, + input_ids, + attention_mask, + hidden_state_skip_layer=2 + ): + embed_tokens = deepcopy(self.embed_tokens).to(input_ids.device) if self.auto_offload else self.embed_tokens + inputs_embeds = embed_tokens(input_ids) + + past_key_values = DynamicCache() + + cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, None, False) + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + rotary_emb = deepcopy(self.rotary_emb).to(input_ids.device) if self.auto_offload else self.rotary_emb + position_embeddings = rotary_emb(hidden_states, position_ids) + + # decoder layers + for layer_id, decoder_layer in enumerate(self.layers): + if self.auto_offload: + decoder_layer = deepcopy(decoder_layer).to(hidden_states.device) + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=False, + use_cache=True, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + hidden_states = layer_outputs[0] + if layer_id + hidden_state_skip_layer + 1 >= len(self.layers): + break + + return hidden_states diff --git a/diffsynth/pipelines/hunyuan_video.py b/diffsynth/pipelines/hunyuan_video.py index 7a8b297..20778c5 100644 --- a/diffsynth/pipelines/hunyuan_video.py +++ b/diffsynth/pipelines/hunyuan_video.py @@ -1,10 +1,10 @@ from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder from ..models.hunyuan_video_dit import HunyuanVideoDiT +from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder from ..schedulers.flow_match import FlowMatchScheduler from .base import BasePipeline from ..prompters import HunyuanVideoPrompter import torch -from transformers import LlamaModel from einops import rearrange import numpy as np from PIL import Image @@ -18,7 +18,7 @@ class HunyuanVideoPipeline(BasePipeline): self.scheduler = FlowMatchScheduler(shift=7.0, sigma_min=0.0, extra_one_step=True) self.prompter = HunyuanVideoPrompter() self.text_encoder_1: SD3TextEncoder1 = None - self.text_encoder_2: LlamaModel = None + self.text_encoder_2: HunyuanVideoLLMEncoder = None self.dit: HunyuanVideoDiT = None self.vae_decoder: HunyuanVideoVAEDecoder = None self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder'] @@ -28,6 +28,7 @@ class HunyuanVideoPipeline(BasePipeline): def enable_vram_management(self): self.vram_management = True self.enable_cpu_offload() + self.text_encoder_2.enable_auto_offload(dtype=self.torch_dtype, device=self.device) self.dit.enable_auto_offload(dtype=self.torch_dtype, device=self.device) @@ -91,7 +92,7 @@ class HunyuanVideoPipeline(BasePipeline): latents = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype) # Encode prompts - self.load_models_to_device(["text_encoder_1", "text_encoder_2"]) + self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"]) prompt_emb_posi = self.encode_prompt(prompt, positive=True) if cfg_scale != 1.0: prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False) diff --git a/diffsynth/prompters/hunyuan_video_prompter.py b/diffsynth/prompters/hunyuan_video_prompter.py index 70b035a..3b5a9fe 100644 --- a/diffsynth/prompters/hunyuan_video_prompter.py +++ b/diffsynth/prompters/hunyuan_video_prompter.py @@ -1,6 +1,7 @@ from .base_prompter import BasePrompter from ..models.sd3_text_encoder import SD3TextEncoder1 -from transformers import CLIPTokenizer, LlamaTokenizerFast, LlamaModel +from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder +from transformers import CLIPTokenizer, LlamaTokenizerFast import os, torch PROMPT_TEMPLATE_ENCODE = ( @@ -50,12 +51,12 @@ class HunyuanVideoPrompter(BasePrompter): self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path) self.tokenizer_2 = LlamaTokenizerFast.from_pretrained(tokenizer_2_path, padding_side='right') self.text_encoder_1: SD3TextEncoder1 = None - self.text_encoder_2: LlamaModel = None + self.text_encoder_2: HunyuanVideoLLMEncoder = None self.prompt_template = PROMPT_TEMPLATE['dit-llm-encode'] self.prompt_template_video = PROMPT_TEMPLATE['dit-llm-encode-video'] - def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: LlamaModel = None): + def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: HunyuanVideoLLMEncoder = None): self.text_encoder_1 = text_encoder_1 self.text_encoder_2 = text_encoder_2 @@ -88,7 +89,6 @@ class HunyuanVideoPrompter(BasePrompter): device, crop_start, hidden_state_skip_layer=2, - apply_final_norm=False, use_attention_mask=True): max_length += crop_start inputs = self.tokenizer_2(prompt, @@ -98,18 +98,8 @@ class HunyuanVideoPrompter(BasePrompter): truncation=True) input_ids = inputs.input_ids.to(device) attention_mask = inputs.attention_mask.to(device) - output_hidden_states = hidden_state_skip_layer is not None - outputs = self.text_encoder_2( - input_ids=input_ids, - attention_mask=attention_mask, - output_hidden_states=output_hidden_states) + last_hidden_state = self.text_encoder_2(input_ids, attention_mask, hidden_state_skip_layer) - if hidden_state_skip_layer is not None: - last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)] - if hidden_state_skip_layer > 0 and apply_final_norm: - last_hidden_state = self.text_encoder_2.norm(last_hidden_state) - else: - last_hidden_state = outputs['last_hidden_state'] # crop out if crop_start > 0: last_hidden_state = last_hidden_state[:, crop_start:] @@ -126,7 +116,6 @@ class HunyuanVideoPrompter(BasePrompter): data_type='video', use_template=True, hidden_state_skip_layer=2, - apply_final_norm=False, use_attention_mask=True): prompt = self.process_prompt(prompt, positive=positive) @@ -149,6 +138,6 @@ class HunyuanVideoPrompter(BasePrompter): # LLM prompt_emb, attention_mask = self.encode_prompt_using_llm( prompt_formated, llm_sequence_length, device, crop_start, - hidden_state_skip_layer, apply_final_norm, use_attention_mask) + hidden_state_skip_layer, use_attention_mask) return prompt_emb, pooled_prompt_emb, attention_mask diff --git a/examples/HunyuanVideo/hunyuanvideo_16G.py b/examples/HunyuanVideo/hunyuanvideo_16G.py new file mode 100644 index 0000000..860d575 --- /dev/null +++ b/examples/HunyuanVideo/hunyuanvideo_16G.py @@ -0,0 +1,42 @@ +import torch +torch.cuda.set_per_process_memory_fraction(1.0, 0) +from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video + + +download_models(["HunyuanVideo"]) +model_manager = ModelManager() + +# The DiT model is loaded in bfloat16. +model_manager.load_models( + [ + "models/HunyuanVideo/transformers/mp_rank_00_model_states.pt" + ], + torch_dtype=torch.bfloat16, + device="cpu" +) + +# The other modules are loaded in float16. +model_manager.load_models( + [ + "models/HunyuanVideo/text_encoder/model.safetensors", + "models/HunyuanVideo/text_encoder_2", + "models/HunyuanVideo/vae/pytorch_model.pt", + ], + torch_dtype=torch.float16, + device="cpu" +) + +# We support LoRA inference. You can use the following code to load your LoRA model. +# model_manager.load_lora("models/lora/xxx.safetensors", lora_alpha=1.0) + +# The computation device is "cuda". +pipe = HunyuanVideoPipeline.from_model_manager( + model_manager, + torch_dtype=torch.bfloat16, + device="cuda" +) + +# Enjoy! +prompt = "CG, masterpiece, best quality, solo, long hair, wavy hair, silver hair, blue eyes, blue dress, medium breasts, dress, underwater, air bubble, floating hair, refraction, portrait. The girl's flowing silver hair shimmers with every color of the rainbow and cascades down, merging with the floating flora around her." +video = pipe(prompt, seed=0, height=720, width=960) +save_video(video, "video.mp4", fps=30, quality=5) diff --git a/examples/video_synthesis/hunyuanvideo.py b/examples/HunyuanVideo/hunyuanvideo_24G.py similarity index 100% rename from examples/video_synthesis/hunyuanvideo.py rename to examples/HunyuanVideo/hunyuanvideo_24G.py diff --git a/examples/HunyuanVideo/hunyuanvideo_8G.py b/examples/HunyuanVideo/hunyuanvideo_8G.py new file mode 100644 index 0000000..336034b --- /dev/null +++ b/examples/HunyuanVideo/hunyuanvideo_8G.py @@ -0,0 +1,42 @@ +import torch +torch.cuda.set_per_process_memory_fraction(1.0, 0) +from diffsynth import ModelManager, HunyuanVideoPipeline, download_models, save_video + + +download_models(["HunyuanVideo"]) +model_manager = ModelManager() + +# The DiT model is loaded in bfloat16. +model_manager.load_models( + [ + "models/HunyuanVideo/transformers/mp_rank_00_model_states.pt" + ], + torch_dtype=torch.bfloat16, + device="cpu" +) + +# The other modules are loaded in float16. +model_manager.load_models( + [ + "models/HunyuanVideo/text_encoder/model.safetensors", + "models/HunyuanVideo/text_encoder_2", + "models/HunyuanVideo/vae/pytorch_model.pt", + ], + torch_dtype=torch.float16, + device="cpu" +) + +# We support LoRA inference. You can use the following code to load your LoRA model. +model_manager.load_lora("models/lora/Rem_hunyuan_video_v3.safetensors", lora_alpha=1.0) + +# The computation device is "cuda". +pipe = HunyuanVideoPipeline.from_model_manager( + model_manager, + torch_dtype=torch.bfloat16, + device="cuda" +) + +# Enjoy! +prompt = "a woman with blue hair wearing a white and black dress, sitting on a bed with a white wall in the background. she is wearing a re:zero starting life in another world rem cosplay costume, complete with a black and white dress, black gloves, and a black bow tie." +video = pipe(prompt, seed=0, height=512, width=512, tile_size=(17, 16, 16), tile_stride=(12, 12, 12)) +save_video(video, "video.mp4", fps=30, quality=5)