mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 00:38:11 +00:00
support kolors! (#106)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import torch, os
|
||||
import torch, os, json
|
||||
from safetensors import safe_open
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
from typing import List
|
||||
@@ -36,6 +36,7 @@ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
|
||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
from .kolors_text_encoder import ChatGLMModel
|
||||
|
||||
|
||||
preset_models_on_huggingface = {
|
||||
@@ -159,6 +160,20 @@ preset_models_on_modelscope = {
|
||||
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
||||
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
||||
],
|
||||
# Kolors
|
||||
"Kolors": [
|
||||
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
||||
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
||||
],
|
||||
}
|
||||
Preset_model_id: TypeAlias = Literal[
|
||||
"HunyuanDiT",
|
||||
@@ -184,7 +199,8 @@ Preset_model_id: TypeAlias = Literal[
|
||||
"IP-Adapter-SD",
|
||||
"IP-Adapter-SDXL",
|
||||
"StableDiffusion3",
|
||||
"StableDiffusion3_without_T5"
|
||||
"StableDiffusion3_without_T5",
|
||||
"Kolors",
|
||||
]
|
||||
Preset_model_website: TypeAlias = Literal[
|
||||
"HuggingFace",
|
||||
@@ -272,8 +288,7 @@ class ModelManager:
|
||||
|
||||
def is_controlnet(self, state_dict):
|
||||
param_name = "control_model.time_embed.0.weight"
|
||||
param_name_2 = "mid_block.resnets.1.time_emb_proj.weight" # For controlnets in diffusers format
|
||||
return param_name in state_dict or param_name_2 in state_dict
|
||||
return param_name in state_dict
|
||||
|
||||
def is_animatediff(self, state_dict):
|
||||
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
|
||||
@@ -343,6 +358,21 @@ class ModelManager:
|
||||
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_kolors_text_encoder(self, file_path):
|
||||
file_list = os.listdir(file_path)
|
||||
if "config.json" in file_list:
|
||||
try:
|
||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
if config.get("model_type") == "chatglm":
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
def is_kolors_unet(self, state_dict):
|
||||
return "up_blocks.2.resnets.2.time_emb_proj.weight" in state_dict and "encoder_hid_proj.weight" in state_dict
|
||||
|
||||
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
|
||||
component_dict = {
|
||||
"image_encoder": SVDImageEncoder,
|
||||
@@ -532,13 +562,13 @@ class ModelManager:
|
||||
component = "vae_encoder"
|
||||
model = SDXLVAEEncoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
model.to(torch.float32).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
component = "vae_decoder"
|
||||
model = SDXLVAEDecoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
model.to(torch.float32).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
@@ -592,6 +622,21 @@ class ModelManager:
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_kolors_text_encoder(self, state_dict=None, file_path=""):
|
||||
component = "kolors_text_encoder"
|
||||
model = ChatGLMModel.from_pretrained(file_path, torch_dtype=self.torch_dtype)
|
||||
model = model.to(dtype=self.torch_dtype, device=self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_kolors_unet(self, state_dict, file_path=""):
|
||||
component = "kolors_unet"
|
||||
model = SDXLUNet(is_kolors=True)
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def search_for_embeddings(self, state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
@@ -607,7 +652,11 @@ class ModelManager:
|
||||
|
||||
# Load every textual inversion file
|
||||
for file_name in os.listdir(folder):
|
||||
if file_name.endswith(".txt"):
|
||||
if os.path.isdir(os.path.join(folder, file_name)) or \
|
||||
not (file_name.endswith(".bin") or \
|
||||
file_name.endswith(".safetensors") or \
|
||||
file_name.endswith(".pth") or \
|
||||
file_name.endswith(".pt")):
|
||||
continue
|
||||
keyword = os.path.splitext(file_name)[0]
|
||||
state_dict = load_state_dict(os.path.join(folder, file_name))
|
||||
@@ -620,6 +669,10 @@ class ModelManager:
|
||||
break
|
||||
|
||||
def load_model(self, file_path, components=None, lora_alphas=[]):
|
||||
if os.path.isdir(file_path):
|
||||
if self.is_kolors_text_encoder(file_path):
|
||||
self.load_kolors_text_encoder(file_path=file_path)
|
||||
return
|
||||
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
|
||||
if self.is_stable_video_diffusion(state_dict):
|
||||
self.load_stable_video_diffusion(state_dict, file_path=file_path)
|
||||
@@ -663,6 +716,8 @@ class ModelManager:
|
||||
self.load_stable_diffusion_3(state_dict, components=components, file_path=file_path)
|
||||
elif self.is_stable_diffusion_3_t5(state_dict):
|
||||
self.load_stable_diffusion_3_t5(state_dict, file_path=file_path)
|
||||
elif self.is_kolors_unet(state_dict):
|
||||
self.load_kolors_unet(state_dict, file_path=file_path)
|
||||
|
||||
def load_models(self, file_path_list, lora_alphas=[]):
|
||||
for file_path in file_path_list:
|
||||
|
||||
@@ -1,251 +1,6 @@
|
||||
from huggingface_hub import hf_hub_download
|
||||
from http.cookiejar import CookieJar
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, List, Union
|
||||
import copy, uuid, requests, io, platform, pickle, os, urllib
|
||||
from requests.adapters import Retry
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def _get_sep(path):
|
||||
if isinstance(path, bytes):
|
||||
return b'/'
|
||||
else:
|
||||
return '/'
|
||||
|
||||
|
||||
def expanduser(path):
|
||||
"""Expand ~ and ~user constructions. If user or $HOME is unknown,
|
||||
do nothing."""
|
||||
path = os.fspath(path)
|
||||
if isinstance(path, bytes):
|
||||
tilde = b'~'
|
||||
else:
|
||||
tilde = '~'
|
||||
if not path.startswith(tilde):
|
||||
return path
|
||||
sep = _get_sep(path)
|
||||
i = path.find(sep, 1)
|
||||
if i < 0:
|
||||
i = len(path)
|
||||
if i == 1:
|
||||
if 'HOME' not in os.environ:
|
||||
import pwd
|
||||
try:
|
||||
userhome = pwd.getpwuid(os.getuid()).pw_dir
|
||||
except KeyError:
|
||||
# bpo-10496: if the current user identifier doesn't exist in the
|
||||
# password database, return the path unchanged
|
||||
return path
|
||||
else:
|
||||
userhome = os.environ['HOME']
|
||||
else:
|
||||
import pwd
|
||||
name = path[1:i]
|
||||
if isinstance(name, bytes):
|
||||
name = str(name, 'ASCII')
|
||||
try:
|
||||
pwent = pwd.getpwnam(name)
|
||||
except KeyError:
|
||||
# bpo-10496: if the user name from the path doesn't exist in the
|
||||
# password database, return the path unchanged
|
||||
return path
|
||||
userhome = pwent.pw_dir
|
||||
if isinstance(path, bytes):
|
||||
userhome = os.fsencode(userhome)
|
||||
root = b'/'
|
||||
else:
|
||||
root = '/'
|
||||
userhome = userhome.rstrip(root)
|
||||
return (userhome + path[i:]) or root
|
||||
|
||||
|
||||
|
||||
class ModelScopeConfig:
|
||||
DEFAULT_CREDENTIALS_PATH = Path.home().joinpath('.modelscope', 'credentials')
|
||||
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
|
||||
COOKIES_FILE_NAME = 'cookies'
|
||||
GIT_TOKEN_FILE_NAME = 'git_token'
|
||||
USER_INFO_FILE_NAME = 'user'
|
||||
USER_SESSION_ID_FILE_NAME = 'session'
|
||||
|
||||
@staticmethod
|
||||
def make_sure_credential_path_exist():
|
||||
os.makedirs(ModelScopeConfig.path_credential, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def get_user_session_id():
|
||||
session_path = os.path.join(ModelScopeConfig.path_credential,
|
||||
ModelScopeConfig.USER_SESSION_ID_FILE_NAME)
|
||||
session_id = ''
|
||||
if os.path.exists(session_path):
|
||||
with open(session_path, 'rb') as f:
|
||||
session_id = str(f.readline().strip(), encoding='utf-8')
|
||||
return session_id
|
||||
if session_id == '' or len(session_id) != 32:
|
||||
session_id = str(uuid.uuid4().hex)
|
||||
ModelScopeConfig.make_sure_credential_path_exist()
|
||||
with open(session_path, 'w+') as wf:
|
||||
wf.write(session_id)
|
||||
|
||||
return session_id
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str:
|
||||
"""Formats a user-agent string with basic info about a request.
|
||||
|
||||
Args:
|
||||
user_agent (`str`, `dict`, *optional*):
|
||||
The user agent info in the form of a dictionary or a single string.
|
||||
|
||||
Returns:
|
||||
The formatted user-agent string.
|
||||
"""
|
||||
|
||||
# include some more telemetrics when executing in dedicated
|
||||
# cloud containers
|
||||
MODELSCOPE_CLOUD_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT'
|
||||
MODELSCOPE_CLOUD_USERNAME = 'MODELSCOPE_USERNAME'
|
||||
env = 'custom'
|
||||
if MODELSCOPE_CLOUD_ENVIRONMENT in os.environ:
|
||||
env = os.environ[MODELSCOPE_CLOUD_ENVIRONMENT]
|
||||
user_name = 'unknown'
|
||||
if MODELSCOPE_CLOUD_USERNAME in os.environ:
|
||||
user_name = os.environ[MODELSCOPE_CLOUD_USERNAME]
|
||||
|
||||
ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s; user/%s' % (
|
||||
"1.15.0",
|
||||
platform.python_version(),
|
||||
ModelScopeConfig.get_user_session_id(),
|
||||
platform.platform(),
|
||||
platform.processor(),
|
||||
env,
|
||||
user_name,
|
||||
)
|
||||
if isinstance(user_agent, dict):
|
||||
ua += '; ' + '; '.join(f'{k}/{v}' for k, v in user_agent.items())
|
||||
elif isinstance(user_agent, str):
|
||||
ua += '; ' + user_agent
|
||||
return ua
|
||||
|
||||
@staticmethod
|
||||
def get_cookies():
|
||||
cookies_path = os.path.join(ModelScopeConfig.path_credential,
|
||||
ModelScopeConfig.COOKIES_FILE_NAME)
|
||||
if os.path.exists(cookies_path):
|
||||
with open(cookies_path, 'rb') as f:
|
||||
cookies = pickle.load(f)
|
||||
return cookies
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def modelscope_http_get_model_file(
|
||||
url: str,
|
||||
local_dir: str,
|
||||
file_name: str,
|
||||
file_size: int,
|
||||
cookies: CookieJar,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Download remote file, will retry 5 times before giving up on errors.
|
||||
|
||||
Args:
|
||||
url(str):
|
||||
actual download url of the file
|
||||
local_dir(str):
|
||||
local directory where the downloaded file stores
|
||||
file_name(str):
|
||||
name of the file stored in `local_dir`
|
||||
file_size(int):
|
||||
The file size.
|
||||
cookies(CookieJar):
|
||||
cookies used to authentication the user, which is used for downloading private repos
|
||||
headers(Dict[str, str], optional):
|
||||
http headers to carry necessary info when requesting the remote file
|
||||
|
||||
Raises:
|
||||
FileDownloadError: File download failed.
|
||||
|
||||
"""
|
||||
get_headers = {} if headers is None else copy.deepcopy(headers)
|
||||
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
||||
temp_file_path = os.path.join(local_dir, file_name)
|
||||
# retry sleep 0.5s, 1s, 2s, 4s
|
||||
retry = Retry(
|
||||
total=5,
|
||||
backoff_factor=1,
|
||||
allowed_methods=['GET'])
|
||||
while True:
|
||||
try:
|
||||
progress = tqdm(
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=file_size,
|
||||
initial=0,
|
||||
desc='Downloading',
|
||||
)
|
||||
partial_length = 0
|
||||
if os.path.exists(
|
||||
temp_file_path): # download partial, continue download
|
||||
with open(temp_file_path, 'rb') as f:
|
||||
partial_length = f.seek(0, io.SEEK_END)
|
||||
progress.update(partial_length)
|
||||
if partial_length > file_size:
|
||||
break
|
||||
get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
|
||||
file_size - 1)
|
||||
with open(temp_file_path, 'ab') as f:
|
||||
r = requests.get(
|
||||
url,
|
||||
stream=True,
|
||||
headers=get_headers,
|
||||
cookies=cookies,
|
||||
timeout=60)
|
||||
r.raise_for_status()
|
||||
for chunk in r.iter_content(
|
||||
chunk_size=1024 * 1024 * 1):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
progress.update(len(chunk))
|
||||
f.write(chunk)
|
||||
progress.close()
|
||||
break
|
||||
except (Exception) as e: # no matter what happen, we will retry.
|
||||
retry = retry.increment('GET', url, error=e)
|
||||
retry.sleep()
|
||||
|
||||
|
||||
def get_endpoint():
|
||||
MODELSCOPE_URL_SCHEME = 'https://'
|
||||
DEFAULT_MODELSCOPE_DOMAIN = 'www.modelscope.cn'
|
||||
modelscope_domain = os.getenv('MODELSCOPE_DOMAIN',
|
||||
DEFAULT_MODELSCOPE_DOMAIN)
|
||||
return MODELSCOPE_URL_SCHEME + modelscope_domain
|
||||
|
||||
|
||||
def get_file_download_url(model_id: str, file_path: str, revision: str):
|
||||
"""Format file download url according to `model_id`, `revision` and `file_path`.
|
||||
e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`,
|
||||
the resulted download url is: https://modelscope.cn/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md
|
||||
|
||||
Args:
|
||||
model_id (str): The model_id.
|
||||
file_path (str): File path
|
||||
revision (str): File revision.
|
||||
|
||||
Returns:
|
||||
str: The file url.
|
||||
"""
|
||||
file_path = urllib.parse.quote_plus(file_path)
|
||||
revision = urllib.parse.quote_plus(revision)
|
||||
download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}'
|
||||
return download_url_template.format(
|
||||
endpoint=get_endpoint(),
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
file_path=file_path,
|
||||
)
|
||||
from modelscope import snapshot_download
|
||||
import os, shutil
|
||||
|
||||
|
||||
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
||||
@@ -255,17 +10,12 @@ def download_from_modelscope(model_id, origin_file_path, local_dir):
|
||||
return
|
||||
else:
|
||||
print(f"Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
||||
headers = {'user-agent': ModelScopeConfig.get_user_agent(user_agent=None)}
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
url = get_file_download_url(model_id=model_id, file_path=origin_file_path, revision="master")
|
||||
modelscope_http_get_model_file(
|
||||
url,
|
||||
local_dir,
|
||||
os.path.basename(origin_file_path),
|
||||
file_size=0,
|
||||
headers=headers,
|
||||
cookies=cookies
|
||||
)
|
||||
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
||||
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
||||
if downloaded_file_path != target_file_path:
|
||||
shutil.move(downloaded_file_path, target_file_path)
|
||||
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
||||
|
||||
|
||||
def download_from_huggingface(model_id, origin_file_path, local_dir):
|
||||
|
||||
1363
diffsynth/models/kolors_text_encoder.py
Normal file
1363
diffsynth/models/kolors_text_encoder.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,7 +3,7 @@ from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock
|
||||
|
||||
|
||||
class SDXLUNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, is_kolors=False):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(320)
|
||||
self.time_embedding = torch.nn.Sequential(
|
||||
@@ -13,11 +13,12 @@ class SDXLUNet(torch.nn.Module):
|
||||
)
|
||||
self.add_time_proj = Timesteps(256)
|
||||
self.add_time_embedding = torch.nn.Sequential(
|
||||
torch.nn.Linear(2816, 1280),
|
||||
torch.nn.Linear(5632 if is_kolors else 2816, 1280),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(1280, 1280)
|
||||
)
|
||||
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
||||
self.text_intermediate_proj = torch.nn.Linear(4096, 2048) if is_kolors else None
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# DownBlock2D
|
||||
@@ -85,7 +86,9 @@ class SDXLUNet(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
sample, timestep, encoder_hidden_states, add_time_id, add_text_embeds,
|
||||
tiled=False, tile_size=64, tile_stride=8, **kwargs
|
||||
tiled=False, tile_size=64, tile_stride=8,
|
||||
use_gradient_checkpointing=False,
|
||||
**kwargs
|
||||
):
|
||||
# 1. time
|
||||
t_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
@@ -102,15 +105,26 @@ class SDXLUNet(torch.nn.Module):
|
||||
# 2. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = self.conv_in(sample)
|
||||
text_emb = encoder_hidden_states
|
||||
text_emb = encoder_hidden_states if self.text_intermediate_proj is None else self.text_intermediate_proj(encoder_hidden_states)
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 3. blocks
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
if self.training and use_gradient_checkpointing and not (isinstance(block, PushBlock) or isinstance(block, PopBlock)):
|
||||
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, time_emb, text_emb, res_stack = block(
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
|
||||
# 4. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
@@ -148,6 +162,8 @@ class SDXLUNetStateDictConverter:
|
||||
names = name.split(".")
|
||||
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
|
||||
pass
|
||||
elif names[0] in ["encoder_hid_proj"]:
|
||||
names[0] = "text_intermediate_proj"
|
||||
elif names[0] in ["time_embedding", "add_embedding"]:
|
||||
if names[0] == "add_embedding":
|
||||
names[0] = "add_time_embedding"
|
||||
|
||||
@@ -5,3 +5,4 @@ from .stable_diffusion_xl_video import SDXLVideoPipeline
|
||||
from .stable_video_diffusion import SVDVideoPipeline
|
||||
from .hunyuan_dit import HunyuanDiTImagePipeline
|
||||
from .stable_diffusion_3 import SD3ImagePipeline
|
||||
from .kwai_kolors import KolorsImagePipeline
|
||||
|
||||
@@ -147,7 +147,7 @@ def lets_dance_xl(
|
||||
# 3. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = unet.conv_in(sample)
|
||||
text_emb = encoder_hidden_states
|
||||
text_emb = encoder_hidden_states if unet.text_intermediate_proj is None else unet.text_intermediate_proj(encoder_hidden_states)
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 4. blocks
|
||||
|
||||
@@ -216,7 +216,7 @@ class HunyuanDiTImagePipeline(torch.nn.Module):
|
||||
# Prepare latent tensors
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32)
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
@@ -293,6 +293,6 @@ class HunyuanDiTImagePipeline(torch.nn.Module):
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
return image
|
||||
|
||||
168
diffsynth/pipelines/kwai_kolors.py
Normal file
168
diffsynth/pipelines/kwai_kolors.py
Normal file
@@ -0,0 +1,168 @@
|
||||
from ..models import ModelManager, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
from ..models.kolors_text_encoder import ChatGLMModel
|
||||
from ..prompts import KolorsPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from .dancer import lets_dance_xl
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class KolorsImagePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
|
||||
self.prompter = KolorsPrompter()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# models
|
||||
self.text_encoder: ChatGLMModel = None
|
||||
self.unet: SDXLUNet = None
|
||||
self.vae_decoder: SDXLVAEDecoder = None
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
|
||||
self.ipadapter: SDXLIpAdapter = None
|
||||
|
||||
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.text_encoder = model_manager.kolors_text_encoder
|
||||
self.unet = model_manager.kolors_unet
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
|
||||
|
||||
def fetch_ipadapter(self, model_manager: ModelManager):
|
||||
if "ipadapter_xl" in model_manager.model:
|
||||
self.ipadapter = model_manager.ipadapter_xl
|
||||
if "ipadapter_xl_image_encoder" in model_manager.model:
|
||||
self.ipadapter_image_encoder = model_manager.ipadapter_xl_image_encoder
|
||||
|
||||
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager):
|
||||
pipe = KolorsImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
pipe.fetch_ipadapter(model_manager)
|
||||
return pipe
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=2,
|
||||
input_image=None,
|
||||
ipadapter_images=None,
|
||||
ipadapter_scale=1.0,
|
||||
ipadapter_use_instant_style=False,
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=20,
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
prompt,
|
||||
clip_skip=clip_skip,
|
||||
device=self.device,
|
||||
positive=True,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
negative_prompt,
|
||||
clip_skip=clip_skip,
|
||||
device=self.device,
|
||||
positive=False,
|
||||
)
|
||||
|
||||
# Prepare positional id
|
||||
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
|
||||
|
||||
# IP-Adapter
|
||||
if ipadapter_images is not None:
|
||||
if ipadapter_use_instant_style:
|
||||
self.ipadapter.set_less_adapter()
|
||||
else:
|
||||
self.ipadapter.set_full_adapter()
|
||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
||||
ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
|
||||
ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
|
||||
else:
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance_xl(
|
||||
self.unet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi,
|
||||
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = lets_dance_xl(
|
||||
self.unet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega,
|
||||
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
return image
|
||||
@@ -2,3 +2,4 @@ from .sd_prompter import SDPrompter
|
||||
from .sdxl_prompter import SDXLPrompter
|
||||
from .sd3_prompter import SD3Prompter
|
||||
from .hunyuan_dit_prompter import HunyuanDiTPrompter
|
||||
from .kolors_prompter import KolorsPrompter
|
||||
|
||||
347
diffsynth/prompts/kolors_prompter.py
Normal file
347
diffsynth/prompts/kolors_prompter.py
Normal file
@@ -0,0 +1,347 @@
|
||||
from .utils import Prompter
|
||||
import json, os, re
|
||||
from typing import List, Optional, Union, Dict
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers.utils import PaddingStrategy
|
||||
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
||||
from ..models.kolors_text_encoder import ChatGLMModel
|
||||
|
||||
|
||||
class SPTokenizer:
|
||||
def __init__(self, model_path: str):
|
||||
# reload tokenizer
|
||||
assert os.path.isfile(model_path), model_path
|
||||
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
||||
|
||||
# BOS / EOS token IDs
|
||||
self.n_words: int = self.sp_model.vocab_size()
|
||||
self.bos_id: int = self.sp_model.bos_id()
|
||||
self.eos_id: int = self.sp_model.eos_id()
|
||||
self.pad_id: int = self.sp_model.unk_id()
|
||||
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
||||
|
||||
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
|
||||
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
|
||||
self.special_tokens = {}
|
||||
self.index_special_tokens = {}
|
||||
for token in special_tokens:
|
||||
self.special_tokens[token] = self.n_words
|
||||
self.index_special_tokens[self.n_words] = token
|
||||
self.n_words += 1
|
||||
self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
|
||||
|
||||
def tokenize(self, s: str, encode_special_tokens=False):
|
||||
if encode_special_tokens:
|
||||
last_index = 0
|
||||
t = []
|
||||
for match in re.finditer(self.role_special_token_expression, s):
|
||||
if last_index < match.start():
|
||||
t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
|
||||
t.append(s[match.start():match.end()])
|
||||
last_index = match.end()
|
||||
if last_index < len(s):
|
||||
t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
|
||||
return t
|
||||
else:
|
||||
return self.sp_model.EncodeAsPieces(s)
|
||||
|
||||
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
|
||||
assert type(s) is str
|
||||
t = self.sp_model.encode(s)
|
||||
if bos:
|
||||
t = [self.bos_id] + t
|
||||
if eos:
|
||||
t = t + [self.eos_id]
|
||||
return t
|
||||
|
||||
def decode(self, t: List[int]) -> str:
|
||||
text, buffer = "", []
|
||||
for token in t:
|
||||
if token in self.index_special_tokens:
|
||||
if buffer:
|
||||
text += self.sp_model.decode(buffer)
|
||||
buffer = []
|
||||
text += self.index_special_tokens[token]
|
||||
else:
|
||||
buffer.append(token)
|
||||
if buffer:
|
||||
text += self.sp_model.decode(buffer)
|
||||
return text
|
||||
|
||||
def decode_tokens(self, tokens: List[str]) -> str:
|
||||
text = self.sp_model.DecodePieces(tokens)
|
||||
return text
|
||||
|
||||
def convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
if token in self.special_tokens:
|
||||
return self.special_tokens[token]
|
||||
return self.sp_model.PieceToId(token)
|
||||
|
||||
def convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
if index in self.index_special_tokens:
|
||||
return self.index_special_tokens[index]
|
||||
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
|
||||
return ""
|
||||
return self.sp_model.IdToPiece(index)
|
||||
|
||||
|
||||
|
||||
class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = {"vocab_file": "tokenizer.model"}
|
||||
|
||||
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
||||
|
||||
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
|
||||
**kwargs):
|
||||
self.name = "GLMTokenizer"
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
self.tokenizer = SPTokenizer(vocab_file)
|
||||
self.special_tokens = {
|
||||
"<bos>": self.tokenizer.bos_id,
|
||||
"<eos>": self.tokenizer.eos_id,
|
||||
"<pad>": self.tokenizer.pad_id
|
||||
}
|
||||
self.encode_special_tokens = encode_special_tokens
|
||||
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
encode_special_tokens=encode_special_tokens,
|
||||
**kwargs)
|
||||
|
||||
def get_command(self, token):
|
||||
if token in self.special_tokens:
|
||||
return self.special_tokens[token]
|
||||
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
|
||||
return self.tokenizer.special_tokens[token]
|
||||
|
||||
@property
|
||||
def unk_token(self) -> str:
|
||||
return "<unk>"
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
return "<unk>"
|
||||
|
||||
@property
|
||||
def pad_token_id(self):
|
||||
return self.get_command("<pad>")
|
||||
|
||||
@property
|
||||
def eos_token(self) -> str:
|
||||
return "</s>"
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
return self.get_command("<eos>")
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.tokenizer.n_words
|
||||
|
||||
def get_vocab(self):
|
||||
""" Returns vocab as a dict """
|
||||
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def _tokenize(self, text, **kwargs):
|
||||
return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
return self.tokenizer.convert_token_to_id(token)
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
return self.tokenizer.convert_id_to_token(index)
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
return self.tokenizer.decode_tokens(tokens)
|
||||
|
||||
def save_vocabulary(self, save_directory, filename_prefix=None):
|
||||
"""
|
||||
Save the vocabulary and special tokens file to a directory.
|
||||
|
||||
Args:
|
||||
save_directory (`str`):
|
||||
The directory in which to save the vocabulary.
|
||||
filename_prefix (`str`, *optional*):
|
||||
An optional prefix to add to the named of the saved files.
|
||||
|
||||
Returns:
|
||||
`Tuple(str)`: Paths to the files saved.
|
||||
"""
|
||||
if os.path.isdir(save_directory):
|
||||
vocab_file = os.path.join(
|
||||
save_directory, self.vocab_files_names["vocab_file"]
|
||||
)
|
||||
else:
|
||||
vocab_file = save_directory
|
||||
|
||||
with open(self.vocab_file, 'rb') as fin:
|
||||
proto_str = fin.read()
|
||||
|
||||
with open(vocab_file, "wb") as writer:
|
||||
writer.write(proto_str)
|
||||
|
||||
return (vocab_file,)
|
||||
|
||||
def get_prefix_tokens(self):
|
||||
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
||||
return prefix_tokens
|
||||
|
||||
def build_single_message(self, role, metadata, message):
|
||||
assert role in ["system", "user", "assistant", "observation"], role
|
||||
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
|
||||
message_tokens = self.tokenizer.encode(message)
|
||||
tokens = role_tokens + message_tokens
|
||||
return tokens
|
||||
|
||||
def build_chat_input(self, query, history=None, role="user"):
|
||||
if history is None:
|
||||
history = []
|
||||
input_ids = []
|
||||
for item in history:
|
||||
content = item["content"]
|
||||
if item["role"] == "system" and "tools" in item:
|
||||
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
|
||||
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
|
||||
input_ids.extend(self.build_single_message(role, "", query))
|
||||
input_ids.extend([self.get_command("<|assistant|>")])
|
||||
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
adding special tokens. A BERT sequence has the following format:
|
||||
|
||||
- single sequence: `[CLS] X [SEP]`
|
||||
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
||||
"""
|
||||
prefix_tokens = self.get_prefix_tokens()
|
||||
token_ids_0 = prefix_tokens + token_ids_0
|
||||
if token_ids_1 is not None:
|
||||
token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
|
||||
return token_ids_0
|
||||
|
||||
def _pad(
|
||||
self,
|
||||
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
||||
max_length: Optional[int] = None,
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
||||
|
||||
Args:
|
||||
encoded_inputs:
|
||||
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
||||
max_length: maximum length of the returned list and optionally padding length (see below).
|
||||
Will truncate by taking into account the special tokens.
|
||||
padding_strategy: PaddingStrategy to use for padding.
|
||||
|
||||
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
||||
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
||||
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
||||
The tokenizer padding sides are defined in self.padding_side:
|
||||
|
||||
- 'left': pads on the left of the sequences
|
||||
- 'right': pads on the right of the sequences
|
||||
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
||||
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
||||
`>= 7.5` (Volta).
|
||||
return_attention_mask:
|
||||
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
||||
"""
|
||||
# Load from model defaults
|
||||
assert self.padding_side == "left"
|
||||
|
||||
required_input = encoded_inputs[self.model_input_names[0]]
|
||||
seq_length = len(required_input)
|
||||
|
||||
if padding_strategy == PaddingStrategy.LONGEST:
|
||||
max_length = len(required_input)
|
||||
|
||||
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
||||
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
||||
|
||||
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
||||
|
||||
# Initialize attention mask if not present.
|
||||
if "attention_mask" not in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = [1] * seq_length
|
||||
|
||||
if "position_ids" not in encoded_inputs:
|
||||
encoded_inputs["position_ids"] = list(range(seq_length))
|
||||
|
||||
if needs_to_be_padded:
|
||||
difference = max_length - len(required_input)
|
||||
|
||||
if "attention_mask" in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
|
||||
if "position_ids" in encoded_inputs:
|
||||
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
|
||||
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
|
||||
|
||||
class KolorsPrompter(Prompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path=None
|
||||
):
|
||||
if tokenizer_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_path = os.path.join(base_path, "tokenizer_configs/kolors/tokenizer")
|
||||
super().__init__()
|
||||
self.tokenizer = ChatGLMTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
|
||||
def encode_prompt_using_ChatGLM(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
output = text_encoder(
|
||||
input_ids=text_inputs['input_ids'] ,
|
||||
attention_mask=text_inputs['attention_mask'],
|
||||
position_ids=text_inputs['position_ids'],
|
||||
output_hidden_states=True
|
||||
)
|
||||
prompt_emb = output.hidden_states[-clip_skip].permute(1, 0, 2).clone()
|
||||
pooled_prompt_emb = output.hidden_states[-1][-1, :, :].clone()
|
||||
return prompt_emb, pooled_prompt_emb
|
||||
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
text_encoder: ChatGLMModel,
|
||||
prompt,
|
||||
clip_skip=2,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
prompt_emb, pooled_prompt_emb = self.encode_prompt_using_ChatGLM(prompt, text_encoder, self.tokenizer, 256, clip_skip, device)
|
||||
|
||||
return pooled_prompt_emb, prompt_emb
|
||||
BIN
diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model
Normal file
BIN
diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model
Normal file
Binary file not shown.
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"name_or_path": "THUDM/chatglm3-6b-base",
|
||||
"remove_space": false,
|
||||
"do_lower_case": false,
|
||||
"tokenizer_class": "ChatGLMTokenizer",
|
||||
"auto_map": {
|
||||
"AutoTokenizer": [
|
||||
"tokenization_chatglm.ChatGLMTokenizer",
|
||||
null
|
||||
]
|
||||
}
|
||||
}
|
||||
BIN
diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt
Normal file
BIN
diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt
Normal file
Binary file not shown.
Reference in New Issue
Block a user