support kolors! (#106)

This commit is contained in:
Zhongjie Duan
2024-07-11 21:43:45 +08:00
committed by GitHub
parent 2a4709e572
commit 9c6607f78d
20 changed files with 2510 additions and 281 deletions

View File

@@ -1,4 +1,4 @@
import torch, os
import torch, os, json
from safetensors import safe_open
from typing_extensions import Literal, TypeAlias
from typing import List
@@ -36,6 +36,7 @@ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from .hunyuan_dit import HunyuanDiT
from .kolors_text_encoder import ChatGLMModel
preset_models_on_huggingface = {
@@ -159,6 +160,20 @@ preset_models_on_modelscope = {
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
],
# Kolors
"Kolors": [
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
],
}
Preset_model_id: TypeAlias = Literal[
"HunyuanDiT",
@@ -184,7 +199,8 @@ Preset_model_id: TypeAlias = Literal[
"IP-Adapter-SD",
"IP-Adapter-SDXL",
"StableDiffusion3",
"StableDiffusion3_without_T5"
"StableDiffusion3_without_T5",
"Kolors",
]
Preset_model_website: TypeAlias = Literal[
"HuggingFace",
@@ -272,8 +288,7 @@ class ModelManager:
def is_controlnet(self, state_dict):
param_name = "control_model.time_embed.0.weight"
param_name_2 = "mid_block.resnets.1.time_emb_proj.weight" # For controlnets in diffusers format
return param_name in state_dict or param_name_2 in state_dict
return param_name in state_dict
def is_animatediff(self, state_dict):
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
@@ -343,6 +358,21 @@ class ModelManager:
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
return param_name in state_dict
def is_kolors_text_encoder(self, file_path):
file_list = os.listdir(file_path)
if "config.json" in file_list:
try:
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
if config.get("model_type") == "chatglm":
return True
except:
pass
return False
def is_kolors_unet(self, state_dict):
return "up_blocks.2.resnets.2.time_emb_proj.weight" in state_dict and "encoder_hid_proj.weight" in state_dict
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
component_dict = {
"image_encoder": SVDImageEncoder,
@@ -532,13 +562,13 @@ class ModelManager:
component = "vae_encoder"
model = SDXLVAEEncoder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
model.to(torch.float32).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
component = "vae_decoder"
model = SDXLVAEDecoder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
model.to(torch.float32).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
@@ -592,6 +622,21 @@ class ModelManager:
self.model[component] = model
self.model_path[component] = file_path
def load_kolors_text_encoder(self, state_dict=None, file_path=""):
component = "kolors_text_encoder"
model = ChatGLMModel.from_pretrained(file_path, torch_dtype=self.torch_dtype)
model = model.to(dtype=self.torch_dtype, device=self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_kolors_unet(self, state_dict, file_path=""):
component = "kolors_unet"
model = SDXLUNet(is_kolors=True)
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def search_for_embeddings(self, state_dict):
embeddings = []
for k in state_dict:
@@ -607,7 +652,11 @@ class ModelManager:
# Load every textual inversion file
for file_name in os.listdir(folder):
if file_name.endswith(".txt"):
if os.path.isdir(os.path.join(folder, file_name)) or \
not (file_name.endswith(".bin") or \
file_name.endswith(".safetensors") or \
file_name.endswith(".pth") or \
file_name.endswith(".pt")):
continue
keyword = os.path.splitext(file_name)[0]
state_dict = load_state_dict(os.path.join(folder, file_name))
@@ -620,6 +669,10 @@ class ModelManager:
break
def load_model(self, file_path, components=None, lora_alphas=[]):
if os.path.isdir(file_path):
if self.is_kolors_text_encoder(file_path):
self.load_kolors_text_encoder(file_path=file_path)
return
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
if self.is_stable_video_diffusion(state_dict):
self.load_stable_video_diffusion(state_dict, file_path=file_path)
@@ -663,6 +716,8 @@ class ModelManager:
self.load_stable_diffusion_3(state_dict, components=components, file_path=file_path)
elif self.is_stable_diffusion_3_t5(state_dict):
self.load_stable_diffusion_3_t5(state_dict, file_path=file_path)
elif self.is_kolors_unet(state_dict):
self.load_kolors_unet(state_dict, file_path=file_path)
def load_models(self, file_path_list, lora_alphas=[]):
for file_path in file_path_list:

View File

@@ -1,251 +1,6 @@
from huggingface_hub import hf_hub_download
from http.cookiejar import CookieJar
from pathlib import Path
from typing import Dict, Optional, List, Union
import copy, uuid, requests, io, platform, pickle, os, urllib
from requests.adapters import Retry
from tqdm import tqdm
def _get_sep(path):
if isinstance(path, bytes):
return b'/'
else:
return '/'
def expanduser(path):
"""Expand ~ and ~user constructions. If user or $HOME is unknown,
do nothing."""
path = os.fspath(path)
if isinstance(path, bytes):
tilde = b'~'
else:
tilde = '~'
if not path.startswith(tilde):
return path
sep = _get_sep(path)
i = path.find(sep, 1)
if i < 0:
i = len(path)
if i == 1:
if 'HOME' not in os.environ:
import pwd
try:
userhome = pwd.getpwuid(os.getuid()).pw_dir
except KeyError:
# bpo-10496: if the current user identifier doesn't exist in the
# password database, return the path unchanged
return path
else:
userhome = os.environ['HOME']
else:
import pwd
name = path[1:i]
if isinstance(name, bytes):
name = str(name, 'ASCII')
try:
pwent = pwd.getpwnam(name)
except KeyError:
# bpo-10496: if the user name from the path doesn't exist in the
# password database, return the path unchanged
return path
userhome = pwent.pw_dir
if isinstance(path, bytes):
userhome = os.fsencode(userhome)
root = b'/'
else:
root = '/'
userhome = userhome.rstrip(root)
return (userhome + path[i:]) or root
class ModelScopeConfig:
DEFAULT_CREDENTIALS_PATH = Path.home().joinpath('.modelscope', 'credentials')
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
COOKIES_FILE_NAME = 'cookies'
GIT_TOKEN_FILE_NAME = 'git_token'
USER_INFO_FILE_NAME = 'user'
USER_SESSION_ID_FILE_NAME = 'session'
@staticmethod
def make_sure_credential_path_exist():
os.makedirs(ModelScopeConfig.path_credential, exist_ok=True)
@staticmethod
def get_user_session_id():
session_path = os.path.join(ModelScopeConfig.path_credential,
ModelScopeConfig.USER_SESSION_ID_FILE_NAME)
session_id = ''
if os.path.exists(session_path):
with open(session_path, 'rb') as f:
session_id = str(f.readline().strip(), encoding='utf-8')
return session_id
if session_id == '' or len(session_id) != 32:
session_id = str(uuid.uuid4().hex)
ModelScopeConfig.make_sure_credential_path_exist()
with open(session_path, 'w+') as wf:
wf.write(session_id)
return session_id
@staticmethod
def get_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str:
"""Formats a user-agent string with basic info about a request.
Args:
user_agent (`str`, `dict`, *optional*):
The user agent info in the form of a dictionary or a single string.
Returns:
The formatted user-agent string.
"""
# include some more telemetrics when executing in dedicated
# cloud containers
MODELSCOPE_CLOUD_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT'
MODELSCOPE_CLOUD_USERNAME = 'MODELSCOPE_USERNAME'
env = 'custom'
if MODELSCOPE_CLOUD_ENVIRONMENT in os.environ:
env = os.environ[MODELSCOPE_CLOUD_ENVIRONMENT]
user_name = 'unknown'
if MODELSCOPE_CLOUD_USERNAME in os.environ:
user_name = os.environ[MODELSCOPE_CLOUD_USERNAME]
ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s; user/%s' % (
"1.15.0",
platform.python_version(),
ModelScopeConfig.get_user_session_id(),
platform.platform(),
platform.processor(),
env,
user_name,
)
if isinstance(user_agent, dict):
ua += '; ' + '; '.join(f'{k}/{v}' for k, v in user_agent.items())
elif isinstance(user_agent, str):
ua += '; ' + user_agent
return ua
@staticmethod
def get_cookies():
cookies_path = os.path.join(ModelScopeConfig.path_credential,
ModelScopeConfig.COOKIES_FILE_NAME)
if os.path.exists(cookies_path):
with open(cookies_path, 'rb') as f:
cookies = pickle.load(f)
return cookies
return None
def modelscope_http_get_model_file(
url: str,
local_dir: str,
file_name: str,
file_size: int,
cookies: CookieJar,
headers: Optional[Dict[str, str]] = None,
):
"""Download remote file, will retry 5 times before giving up on errors.
Args:
url(str):
actual download url of the file
local_dir(str):
local directory where the downloaded file stores
file_name(str):
name of the file stored in `local_dir`
file_size(int):
The file size.
cookies(CookieJar):
cookies used to authentication the user, which is used for downloading private repos
headers(Dict[str, str], optional):
http headers to carry necessary info when requesting the remote file
Raises:
FileDownloadError: File download failed.
"""
get_headers = {} if headers is None else copy.deepcopy(headers)
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
temp_file_path = os.path.join(local_dir, file_name)
# retry sleep 0.5s, 1s, 2s, 4s
retry = Retry(
total=5,
backoff_factor=1,
allowed_methods=['GET'])
while True:
try:
progress = tqdm(
unit='B',
unit_scale=True,
unit_divisor=1024,
total=file_size,
initial=0,
desc='Downloading',
)
partial_length = 0
if os.path.exists(
temp_file_path): # download partial, continue download
with open(temp_file_path, 'rb') as f:
partial_length = f.seek(0, io.SEEK_END)
progress.update(partial_length)
if partial_length > file_size:
break
get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
file_size - 1)
with open(temp_file_path, 'ab') as f:
r = requests.get(
url,
stream=True,
headers=get_headers,
cookies=cookies,
timeout=60)
r.raise_for_status()
for chunk in r.iter_content(
chunk_size=1024 * 1024 * 1):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
f.write(chunk)
progress.close()
break
except (Exception) as e: # no matter what happen, we will retry.
retry = retry.increment('GET', url, error=e)
retry.sleep()
def get_endpoint():
MODELSCOPE_URL_SCHEME = 'https://'
DEFAULT_MODELSCOPE_DOMAIN = 'www.modelscope.cn'
modelscope_domain = os.getenv('MODELSCOPE_DOMAIN',
DEFAULT_MODELSCOPE_DOMAIN)
return MODELSCOPE_URL_SCHEME + modelscope_domain
def get_file_download_url(model_id: str, file_path: str, revision: str):
"""Format file download url according to `model_id`, `revision` and `file_path`.
e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`,
the resulted download url is: https://modelscope.cn/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md
Args:
model_id (str): The model_id.
file_path (str): File path
revision (str): File revision.
Returns:
str: The file url.
"""
file_path = urllib.parse.quote_plus(file_path)
revision = urllib.parse.quote_plus(revision)
download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}'
return download_url_template.format(
endpoint=get_endpoint(),
model_id=model_id,
revision=revision,
file_path=file_path,
)
from modelscope import snapshot_download
import os, shutil
def download_from_modelscope(model_id, origin_file_path, local_dir):
@@ -255,17 +10,12 @@ def download_from_modelscope(model_id, origin_file_path, local_dir):
return
else:
print(f"Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
headers = {'user-agent': ModelScopeConfig.get_user_agent(user_agent=None)}
cookies = ModelScopeConfig.get_cookies()
url = get_file_download_url(model_id=model_id, file_path=origin_file_path, revision="master")
modelscope_http_get_model_file(
url,
local_dir,
os.path.basename(origin_file_path),
file_size=0,
headers=headers,
cookies=cookies
)
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
downloaded_file_path = os.path.join(local_dir, origin_file_path)
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
if downloaded_file_path != target_file_path:
shutil.move(downloaded_file_path, target_file_path)
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
def download_from_huggingface(model_id, origin_file_path, local_dir):

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,7 @@ from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock
class SDXLUNet(torch.nn.Module):
def __init__(self):
def __init__(self, is_kolors=False):
super().__init__()
self.time_proj = Timesteps(320)
self.time_embedding = torch.nn.Sequential(
@@ -13,11 +13,12 @@ class SDXLUNet(torch.nn.Module):
)
self.add_time_proj = Timesteps(256)
self.add_time_embedding = torch.nn.Sequential(
torch.nn.Linear(2816, 1280),
torch.nn.Linear(5632 if is_kolors else 2816, 1280),
torch.nn.SiLU(),
torch.nn.Linear(1280, 1280)
)
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
self.text_intermediate_proj = torch.nn.Linear(4096, 2048) if is_kolors else None
self.blocks = torch.nn.ModuleList([
# DownBlock2D
@@ -85,7 +86,9 @@ class SDXLUNet(torch.nn.Module):
def forward(
self,
sample, timestep, encoder_hidden_states, add_time_id, add_text_embeds,
tiled=False, tile_size=64, tile_stride=8, **kwargs
tiled=False, tile_size=64, tile_stride=8,
use_gradient_checkpointing=False,
**kwargs
):
# 1. time
t_emb = self.time_proj(timestep[None]).to(sample.dtype)
@@ -102,15 +105,26 @@ class SDXLUNet(torch.nn.Module):
# 2. pre-process
height, width = sample.shape[2], sample.shape[3]
hidden_states = self.conv_in(sample)
text_emb = encoder_hidden_states
text_emb = encoder_hidden_states if self.text_intermediate_proj is None else self.text_intermediate_proj(encoder_hidden_states)
res_stack = [hidden_states]
# 3. blocks
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for i, block in enumerate(self.blocks):
hidden_states, time_emb, text_emb, res_stack = block(
hidden_states, time_emb, text_emb, res_stack,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
if self.training and use_gradient_checkpointing and not (isinstance(block, PushBlock) or isinstance(block, PopBlock)):
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, time_emb, text_emb, res_stack,
use_reentrant=False,
)
else:
hidden_states, time_emb, text_emb, res_stack = block(
hidden_states, time_emb, text_emb, res_stack,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
# 4. output
hidden_states = self.conv_norm_out(hidden_states)
@@ -148,6 +162,8 @@ class SDXLUNetStateDictConverter:
names = name.split(".")
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
pass
elif names[0] in ["encoder_hid_proj"]:
names[0] = "text_intermediate_proj"
elif names[0] in ["time_embedding", "add_embedding"]:
if names[0] == "add_embedding":
names[0] = "add_time_embedding"

View File

@@ -5,3 +5,4 @@ from .stable_diffusion_xl_video import SDXLVideoPipeline
from .stable_video_diffusion import SVDVideoPipeline
from .hunyuan_dit import HunyuanDiTImagePipeline
from .stable_diffusion_3 import SD3ImagePipeline
from .kwai_kolors import KolorsImagePipeline

View File

@@ -147,7 +147,7 @@ def lets_dance_xl(
# 3. pre-process
height, width = sample.shape[2], sample.shape[3]
hidden_states = unet.conv_in(sample)
text_emb = encoder_hidden_states
text_emb = encoder_hidden_states if unet.text_intermediate_proj is None else unet.text_intermediate_proj(encoder_hidden_states)
res_stack = [hidden_states]
# 4. blocks

View File

@@ -216,7 +216,7 @@ class HunyuanDiTImagePipeline(torch.nn.Module):
# Prepare latent tensors
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
if input_image is not None:
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32)
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
@@ -293,6 +293,6 @@ class HunyuanDiTImagePipeline(torch.nn.Module):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return image

View File

@@ -0,0 +1,168 @@
from ..models import ModelManager, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from ..models.kolors_text_encoder import ChatGLMModel
from ..prompts import KolorsPrompter
from ..schedulers import EnhancedDDIMScheduler
from .dancer import lets_dance_xl
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
class KolorsImagePipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__()
self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
self.prompter = KolorsPrompter()
self.device = device
self.torch_dtype = torch_dtype
# models
self.text_encoder: ChatGLMModel = None
self.unet: SDXLUNet = None
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
self.ipadapter: SDXLIpAdapter = None
def fetch_main_models(self, model_manager: ModelManager):
self.text_encoder = model_manager.kolors_text_encoder
self.unet = model_manager.kolors_unet
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
def fetch_ipadapter(self, model_manager: ModelManager):
if "ipadapter_xl" in model_manager.model:
self.ipadapter = model_manager.ipadapter_xl
if "ipadapter_xl_image_encoder" in model_manager.model:
self.ipadapter_image_encoder = model_manager.ipadapter_xl_image_encoder
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
@staticmethod
def from_model_manager(model_manager: ModelManager):
pipe = KolorsImagePipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_main_models(model_manager)
pipe.fetch_prompter(model_manager)
pipe.fetch_ipadapter(model_manager)
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=2,
input_image=None,
ipadapter_images=None,
ipadapter_scale=1.0,
ipadapter_use_instant_style=False,
denoising_strength=1.0,
height=1024,
width=1024,
num_inference_steps=20,
tiled=False,
tile_size=64,
tile_stride=32,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
self.text_encoder,
prompt,
clip_skip=clip_skip,
device=self.device,
positive=True,
)
if cfg_scale != 1.0:
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
self.text_encoder,
negative_prompt,
clip_skip=clip_skip,
device=self.device,
positive=False,
)
# Prepare positional id
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
# IP-Adapter
if ipadapter_images is not None:
if ipadapter_use_instant_style:
self.ipadapter.set_less_adapter()
else:
self.ipadapter.set_full_adapter()
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.IntTensor((timestep,))[0].to(self.device)
# Classifier-free guidance
noise_pred_posi = lets_dance_xl(
self.unet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi,
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
)
if cfg_scale != 1.0:
noise_pred_nega = lets_dance_xl(
self.unet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega,
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
latents = self.scheduler.step(noise_pred, timestep, latents)
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return image

View File

@@ -2,3 +2,4 @@ from .sd_prompter import SDPrompter
from .sdxl_prompter import SDXLPrompter
from .sd3_prompter import SD3Prompter
from .hunyuan_dit_prompter import HunyuanDiTPrompter
from .kolors_prompter import KolorsPrompter

View File

@@ -0,0 +1,347 @@
from .utils import Prompter
import json, os, re
from typing import List, Optional, Union, Dict
from sentencepiece import SentencePieceProcessor
from transformers import PreTrainedTokenizer
from transformers.utils import PaddingStrategy
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
from ..models.kolors_text_encoder import ChatGLMModel
class SPTokenizer:
def __init__(self, model_path: str):
# reload tokenizer
assert os.path.isfile(model_path), model_path
self.sp_model = SentencePieceProcessor(model_file=model_path)
# BOS / EOS token IDs
self.n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.unk_id()
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
self.special_tokens = {}
self.index_special_tokens = {}
for token in special_tokens:
self.special_tokens[token] = self.n_words
self.index_special_tokens[self.n_words] = token
self.n_words += 1
self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
def tokenize(self, s: str, encode_special_tokens=False):
if encode_special_tokens:
last_index = 0
t = []
for match in re.finditer(self.role_special_token_expression, s):
if last_index < match.start():
t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
t.append(s[match.start():match.end()])
last_index = match.end()
if last_index < len(s):
t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
return t
else:
return self.sp_model.EncodeAsPieces(s)
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
assert type(s) is str
t = self.sp_model.encode(s)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int]) -> str:
text, buffer = "", []
for token in t:
if token in self.index_special_tokens:
if buffer:
text += self.sp_model.decode(buffer)
buffer = []
text += self.index_special_tokens[token]
else:
buffer.append(token)
if buffer:
text += self.sp_model.decode(buffer)
return text
def decode_tokens(self, tokens: List[str]) -> str:
text = self.sp_model.DecodePieces(tokens)
return text
def convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
if token in self.special_tokens:
return self.special_tokens[token]
return self.sp_model.PieceToId(token)
def convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
if index in self.index_special_tokens:
return self.index_special_tokens[index]
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
return ""
return self.sp_model.IdToPiece(index)
class ChatGLMTokenizer(PreTrainedTokenizer):
vocab_files_names = {"vocab_file": "tokenizer.model"}
model_input_names = ["input_ids", "attention_mask", "position_ids"]
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
**kwargs):
self.name = "GLMTokenizer"
self.vocab_file = vocab_file
self.tokenizer = SPTokenizer(vocab_file)
self.special_tokens = {
"<bos>": self.tokenizer.bos_id,
"<eos>": self.tokenizer.eos_id,
"<pad>": self.tokenizer.pad_id
}
self.encode_special_tokens = encode_special_tokens
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
encode_special_tokens=encode_special_tokens,
**kwargs)
def get_command(self, token):
if token in self.special_tokens:
return self.special_tokens[token]
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
return self.tokenizer.special_tokens[token]
@property
def unk_token(self) -> str:
return "<unk>"
@property
def pad_token(self) -> str:
return "<unk>"
@property
def pad_token_id(self):
return self.get_command("<pad>")
@property
def eos_token(self) -> str:
return "</s>"
@property
def eos_token_id(self):
return self.get_command("<eos>")
@property
def vocab_size(self):
return self.tokenizer.n_words
def get_vocab(self):
""" Returns vocab as a dict """
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text, **kwargs):
return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
return self.tokenizer.convert_token_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.tokenizer.convert_id_to_token(index)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
return self.tokenizer.decode_tokens(tokens)
def save_vocabulary(self, save_directory, filename_prefix=None):
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
filename_prefix (`str`, *optional*):
An optional prefix to add to the named of the saved files.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, self.vocab_files_names["vocab_file"]
)
else:
vocab_file = save_directory
with open(self.vocab_file, 'rb') as fin:
proto_str = fin.read()
with open(vocab_file, "wb") as writer:
writer.write(proto_str)
return (vocab_file,)
def get_prefix_tokens(self):
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
return prefix_tokens
def build_single_message(self, role, metadata, message):
assert role in ["system", "user", "assistant", "observation"], role
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
message_tokens = self.tokenizer.encode(message)
tokens = role_tokens + message_tokens
return tokens
def build_chat_input(self, query, history=None, role="user"):
if history is None:
history = []
input_ids = []
for item in history:
content = item["content"]
if item["role"] == "system" and "tools" in item:
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
input_ids.extend(self.build_single_message(role, "", query))
input_ids.extend([self.get_command("<|assistant|>")])
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A BERT sequence has the following format:
- single sequence: `[CLS] X [SEP]`
- pair of sequences: `[CLS] A [SEP] B [SEP]`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
prefix_tokens = self.get_prefix_tokens()
token_ids_0 = prefix_tokens + token_ids_0
if token_ids_1 is not None:
token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
return token_ids_0
def _pad(
self,
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
Args:
encoded_inputs:
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
max_length: maximum length of the returned list and optionally padding length (see below).
Will truncate by taking into account the special tokens.
padding_strategy: PaddingStrategy to use for padding.
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The tokenizer padding sides are defined in self.padding_side:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
`>= 7.5` (Volta).
return_attention_mask:
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
# Load from model defaults
assert self.padding_side == "left"
required_input = encoded_inputs[self.model_input_names[0]]
seq_length = len(required_input)
if padding_strategy == PaddingStrategy.LONGEST:
max_length = len(required_input)
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
# Initialize attention mask if not present.
if "attention_mask" not in encoded_inputs:
encoded_inputs["attention_mask"] = [1] * seq_length
if "position_ids" not in encoded_inputs:
encoded_inputs["position_ids"] = list(range(seq_length))
if needs_to_be_padded:
difference = max_length - len(required_input)
if "attention_mask" in encoded_inputs:
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
if "position_ids" in encoded_inputs:
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
return encoded_inputs
class KolorsPrompter(Prompter):
def __init__(
self,
tokenizer_path=None
):
if tokenizer_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_path = os.path.join(base_path, "tokenizer_configs/kolors/tokenizer")
super().__init__()
self.tokenizer = ChatGLMTokenizer.from_pretrained(tokenizer_path)
def encode_prompt_using_ChatGLM(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
).to(device)
output = text_encoder(
input_ids=text_inputs['input_ids'] ,
attention_mask=text_inputs['attention_mask'],
position_ids=text_inputs['position_ids'],
output_hidden_states=True
)
prompt_emb = output.hidden_states[-clip_skip].permute(1, 0, 2).clone()
pooled_prompt_emb = output.hidden_states[-1][-1, :, :].clone()
return prompt_emb, pooled_prompt_emb
def encode_prompt(
self,
text_encoder: ChatGLMModel,
prompt,
clip_skip=2,
positive=True,
device="cuda"
):
prompt = self.process_prompt(prompt, positive=positive)
prompt_emb, pooled_prompt_emb = self.encode_prompt_using_ChatGLM(prompt, text_encoder, self.tokenizer, 256, clip_skip, device)
return pooled_prompt_emb, prompt_emb

View File

@@ -0,0 +1,12 @@
{
"name_or_path": "THUDM/chatglm3-6b-base",
"remove_space": false,
"do_lower_case": false,
"tokenizer_class": "ChatGLMTokenizer",
"auto_map": {
"AutoTokenizer": [
"tokenization_chatglm.ChatGLMTokenizer",
null
]
}
}

Binary file not shown.