This commit is contained in:
Artiprocher
2024-11-12 10:17:01 +08:00
parent a671070a28
commit d46b8b8fd7
6 changed files with 26 additions and 81 deletions

View File

@@ -1,5 +1,6 @@
from .base_prompter import BasePrompter
from ..models.flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
from ..models.flux_text_encoder import FluxTextEncoder2
from ..models.sd3_text_encoder import SD3TextEncoder1
from transformers import CLIPTokenizer, T5TokenizerFast
import os, torch
@@ -19,11 +20,11 @@ class FluxPrompter(BasePrompter):
super().__init__()
self.tokenizer_1 = CLIPTokenizer.from_pretrained(tokenizer_1_path)
self.tokenizer_2 = T5TokenizerFast.from_pretrained(tokenizer_2_path)
self.text_encoder_1: FluxTextEncoder1 = None
self.text_encoder_1: SD3TextEncoder1 = None
self.text_encoder_2: FluxTextEncoder2 = None
def fetch_models(self, text_encoder_1: FluxTextEncoder1 = None, text_encoder_2: FluxTextEncoder2 = None):
def fetch_models(self, text_encoder_1: SD3TextEncoder1 = None, text_encoder_2: FluxTextEncoder2 = None):
self.text_encoder_1 = text_encoder_1
self.text_encoder_2 = text_encoder_2
@@ -36,7 +37,7 @@ class FluxPrompter(BasePrompter):
max_length=max_length,
truncation=True
).input_ids.to(device)
_, pooled_prompt_emb = text_encoder(input_ids)
pooled_prompt_emb, _ = text_encoder(input_ids)
return pooled_prompt_emb