hunyuanvideo text encoder offload

This commit is contained in:
Artiprocher
2024-12-18 19:35:04 +08:00
parent e5099f4e74
commit ec7ac20def
7 changed files with 150 additions and 21 deletions

View File

@@ -1,10 +1,10 @@
from ..models import ModelManager, SD3TextEncoder1, HunyuanVideoVAEDecoder
from ..models.hunyuan_video_dit import HunyuanVideoDiT
from ..models.hunyuan_video_text_encoder import HunyuanVideoLLMEncoder
from ..schedulers.flow_match import FlowMatchScheduler
from .base import BasePipeline
from ..prompters import HunyuanVideoPrompter
import torch
from transformers import LlamaModel
from einops import rearrange
import numpy as np
from PIL import Image
@@ -18,7 +18,7 @@ class HunyuanVideoPipeline(BasePipeline):
self.scheduler = FlowMatchScheduler(shift=7.0, sigma_min=0.0, extra_one_step=True)
self.prompter = HunyuanVideoPrompter()
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: LlamaModel = None
self.text_encoder_2: HunyuanVideoLLMEncoder = None
self.dit: HunyuanVideoDiT = None
self.vae_decoder: HunyuanVideoVAEDecoder = None
self.model_names = ['text_encoder_1', 'text_encoder_2', 'dit', 'vae_decoder']
@@ -28,6 +28,7 @@ class HunyuanVideoPipeline(BasePipeline):
def enable_vram_management(self):
self.vram_management = True
self.enable_cpu_offload()
self.text_encoder_2.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
self.dit.enable_auto_offload(dtype=self.torch_dtype, device=self.device)
@@ -91,7 +92,7 @@ class HunyuanVideoPipeline(BasePipeline):
latents = self.generate_noise((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), seed=seed, device=self.device, dtype=self.torch_dtype)
# Encode prompts
self.load_models_to_device(["text_encoder_1", "text_encoder_2"])
self.load_models_to_device(["text_encoder_1"] if self.vram_management else ["text_encoder_1", "text_encoder_2"])
prompt_emb_posi = self.encode_prompt(prompt, positive=True)
if cfg_scale != 1.0:
prompt_emb_nega = self.encode_prompt(negative_prompt, positive=False)