diff --git a/README.md b/README.md index 071e2a4..fbe42e1 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ DiffSynth Studio is a Diffusion engine. We have restructured architectures inclu Until now, DiffSynth Studio has supported the following models: * [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1) +* [Kolors](https://huggingface.co/Kwai-Kolors/Kolors) * [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium) * [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt) * [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT) @@ -85,11 +86,13 @@ Generate high-resolution images, by breaking the limitation of diffusion models! LoRA fine-tuning is supported in [`examples/train`](./examples/train/). -|Stable Diffusion|Stable Diffusion XL| +|Model|Example| |-|-| -|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/584186bc-9855-4140-878e-99541f9a757f)| -|Stable Diffusion 3|Hunyuan-DiT| -|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/60b022c8-df3f-4541-95ab-bf39f2fa8bb5)| +|Stable Diffusion|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)| +|Stable Diffusion XL|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)| +|Stable Diffusion 3|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)| +|Kolors|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/53ef6f41-da11-4701-8665-9f64392607bf)| +|Hunyuan-DiT|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/60b022c8-df3f-4541-95ab-bf39f2fa8bb5)| ### Toon Shading diff --git a/diffsynth/models/__init__.py b/diffsynth/models/__init__.py index ea5e9da..057ed89 100644 --- a/diffsynth/models/__init__.py +++ b/diffsynth/models/__init__.py @@ -1,4 +1,4 @@ -import torch, os +import torch, os, json from safetensors import safe_open from typing_extensions import Literal, TypeAlias from typing import List @@ -36,6 +36,7 @@ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder from .hunyuan_dit import HunyuanDiT +from .kolors_text_encoder import ChatGLMModel preset_models_on_huggingface = { @@ -159,6 +160,20 @@ preset_models_on_modelscope = { ("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"), ("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"), ], + # Kolors + "Kolors": [ + ("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"), + ("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"), + ("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"), + ], } Preset_model_id: TypeAlias = Literal[ "HunyuanDiT", @@ -184,7 +199,8 @@ Preset_model_id: TypeAlias = Literal[ "IP-Adapter-SD", "IP-Adapter-SDXL", "StableDiffusion3", - "StableDiffusion3_without_T5" + "StableDiffusion3_without_T5", + "Kolors", ] Preset_model_website: TypeAlias = Literal[ "HuggingFace", @@ -272,8 +288,7 @@ class ModelManager: def is_controlnet(self, state_dict): param_name = "control_model.time_embed.0.weight" - param_name_2 = "mid_block.resnets.1.time_emb_proj.weight" # For controlnets in diffusers format - return param_name in state_dict or param_name_2 in state_dict + return param_name in state_dict def is_animatediff(self, state_dict): param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight" @@ -343,6 +358,21 @@ class ModelManager: param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight" return param_name in state_dict + def is_kolors_text_encoder(self, file_path): + file_list = os.listdir(file_path) + if "config.json" in file_list: + try: + with open(os.path.join(file_path, "config.json"), "r") as f: + config = json.load(f) + if config.get("model_type") == "chatglm": + return True + except: + pass + return False + + def is_kolors_unet(self, state_dict): + return "up_blocks.2.resnets.2.time_emb_proj.weight" in state_dict and "encoder_hid_proj.weight" in state_dict + def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None): component_dict = { "image_encoder": SVDImageEncoder, @@ -532,13 +562,13 @@ class ModelManager: component = "vae_encoder" model = SDXLVAEEncoder() model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict)) - model.to(self.torch_dtype).to(self.device) + model.to(torch.float32).to(self.device) self.model[component] = model self.model_path[component] = file_path component = "vae_decoder" model = SDXLVAEDecoder() model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict)) - model.to(self.torch_dtype).to(self.device) + model.to(torch.float32).to(self.device) self.model[component] = model self.model_path[component] = file_path @@ -592,6 +622,21 @@ class ModelManager: self.model[component] = model self.model_path[component] = file_path + def load_kolors_text_encoder(self, state_dict=None, file_path=""): + component = "kolors_text_encoder" + model = ChatGLMModel.from_pretrained(file_path, torch_dtype=self.torch_dtype) + model = model.to(dtype=self.torch_dtype, device=self.device) + self.model[component] = model + self.model_path[component] = file_path + + def load_kolors_unet(self, state_dict, file_path=""): + component = "kolors_unet" + model = SDXLUNet(is_kolors=True) + model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict)) + model.to(self.torch_dtype).to(self.device) + self.model[component] = model + self.model_path[component] = file_path + def search_for_embeddings(self, state_dict): embeddings = [] for k in state_dict: @@ -607,7 +652,11 @@ class ModelManager: # Load every textual inversion file for file_name in os.listdir(folder): - if file_name.endswith(".txt"): + if os.path.isdir(os.path.join(folder, file_name)) or \ + not (file_name.endswith(".bin") or \ + file_name.endswith(".safetensors") or \ + file_name.endswith(".pth") or \ + file_name.endswith(".pt")): continue keyword = os.path.splitext(file_name)[0] state_dict = load_state_dict(os.path.join(folder, file_name)) @@ -620,6 +669,10 @@ class ModelManager: break def load_model(self, file_path, components=None, lora_alphas=[]): + if os.path.isdir(file_path): + if self.is_kolors_text_encoder(file_path): + self.load_kolors_text_encoder(file_path=file_path) + return state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype) if self.is_stable_video_diffusion(state_dict): self.load_stable_video_diffusion(state_dict, file_path=file_path) @@ -663,6 +716,8 @@ class ModelManager: self.load_stable_diffusion_3(state_dict, components=components, file_path=file_path) elif self.is_stable_diffusion_3_t5(state_dict): self.load_stable_diffusion_3_t5(state_dict, file_path=file_path) + elif self.is_kolors_unet(state_dict): + self.load_kolors_unet(state_dict, file_path=file_path) def load_models(self, file_path_list, lora_alphas=[]): for file_path in file_path_list: diff --git a/diffsynth/models/downloader.py b/diffsynth/models/downloader.py index 603dea0..33d3da7 100644 --- a/diffsynth/models/downloader.py +++ b/diffsynth/models/downloader.py @@ -1,251 +1,6 @@ from huggingface_hub import hf_hub_download -from http.cookiejar import CookieJar -from pathlib import Path -from typing import Dict, Optional, List, Union -import copy, uuid, requests, io, platform, pickle, os, urllib -from requests.adapters import Retry -from tqdm import tqdm - - -def _get_sep(path): - if isinstance(path, bytes): - return b'/' - else: - return '/' - - -def expanduser(path): - """Expand ~ and ~user constructions. If user or $HOME is unknown, - do nothing.""" - path = os.fspath(path) - if isinstance(path, bytes): - tilde = b'~' - else: - tilde = '~' - if not path.startswith(tilde): - return path - sep = _get_sep(path) - i = path.find(sep, 1) - if i < 0: - i = len(path) - if i == 1: - if 'HOME' not in os.environ: - import pwd - try: - userhome = pwd.getpwuid(os.getuid()).pw_dir - except KeyError: - # bpo-10496: if the current user identifier doesn't exist in the - # password database, return the path unchanged - return path - else: - userhome = os.environ['HOME'] - else: - import pwd - name = path[1:i] - if isinstance(name, bytes): - name = str(name, 'ASCII') - try: - pwent = pwd.getpwnam(name) - except KeyError: - # bpo-10496: if the user name from the path doesn't exist in the - # password database, return the path unchanged - return path - userhome = pwent.pw_dir - if isinstance(path, bytes): - userhome = os.fsencode(userhome) - root = b'/' - else: - root = '/' - userhome = userhome.rstrip(root) - return (userhome + path[i:]) or root - - - -class ModelScopeConfig: - DEFAULT_CREDENTIALS_PATH = Path.home().joinpath('.modelscope', 'credentials') - path_credential = expanduser(DEFAULT_CREDENTIALS_PATH) - COOKIES_FILE_NAME = 'cookies' - GIT_TOKEN_FILE_NAME = 'git_token' - USER_INFO_FILE_NAME = 'user' - USER_SESSION_ID_FILE_NAME = 'session' - - @staticmethod - def make_sure_credential_path_exist(): - os.makedirs(ModelScopeConfig.path_credential, exist_ok=True) - - @staticmethod - def get_user_session_id(): - session_path = os.path.join(ModelScopeConfig.path_credential, - ModelScopeConfig.USER_SESSION_ID_FILE_NAME) - session_id = '' - if os.path.exists(session_path): - with open(session_path, 'rb') as f: - session_id = str(f.readline().strip(), encoding='utf-8') - return session_id - if session_id == '' or len(session_id) != 32: - session_id = str(uuid.uuid4().hex) - ModelScopeConfig.make_sure_credential_path_exist() - with open(session_path, 'w+') as wf: - wf.write(session_id) - - return session_id - - @staticmethod - def get_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str: - """Formats a user-agent string with basic info about a request. - - Args: - user_agent (`str`, `dict`, *optional*): - The user agent info in the form of a dictionary or a single string. - - Returns: - The formatted user-agent string. - """ - - # include some more telemetrics when executing in dedicated - # cloud containers - MODELSCOPE_CLOUD_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT' - MODELSCOPE_CLOUD_USERNAME = 'MODELSCOPE_USERNAME' - env = 'custom' - if MODELSCOPE_CLOUD_ENVIRONMENT in os.environ: - env = os.environ[MODELSCOPE_CLOUD_ENVIRONMENT] - user_name = 'unknown' - if MODELSCOPE_CLOUD_USERNAME in os.environ: - user_name = os.environ[MODELSCOPE_CLOUD_USERNAME] - - ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s; user/%s' % ( - "1.15.0", - platform.python_version(), - ModelScopeConfig.get_user_session_id(), - platform.platform(), - platform.processor(), - env, - user_name, - ) - if isinstance(user_agent, dict): - ua += '; ' + '; '.join(f'{k}/{v}' for k, v in user_agent.items()) - elif isinstance(user_agent, str): - ua += '; ' + user_agent - return ua - - @staticmethod - def get_cookies(): - cookies_path = os.path.join(ModelScopeConfig.path_credential, - ModelScopeConfig.COOKIES_FILE_NAME) - if os.path.exists(cookies_path): - with open(cookies_path, 'rb') as f: - cookies = pickle.load(f) - return cookies - return None - - - -def modelscope_http_get_model_file( - url: str, - local_dir: str, - file_name: str, - file_size: int, - cookies: CookieJar, - headers: Optional[Dict[str, str]] = None, -): - """Download remote file, will retry 5 times before giving up on errors. - - Args: - url(str): - actual download url of the file - local_dir(str): - local directory where the downloaded file stores - file_name(str): - name of the file stored in `local_dir` - file_size(int): - The file size. - cookies(CookieJar): - cookies used to authentication the user, which is used for downloading private repos - headers(Dict[str, str], optional): - http headers to carry necessary info when requesting the remote file - - Raises: - FileDownloadError: File download failed. - - """ - get_headers = {} if headers is None else copy.deepcopy(headers) - get_headers['X-Request-ID'] = str(uuid.uuid4().hex) - temp_file_path = os.path.join(local_dir, file_name) - # retry sleep 0.5s, 1s, 2s, 4s - retry = Retry( - total=5, - backoff_factor=1, - allowed_methods=['GET']) - while True: - try: - progress = tqdm( - unit='B', - unit_scale=True, - unit_divisor=1024, - total=file_size, - initial=0, - desc='Downloading', - ) - partial_length = 0 - if os.path.exists( - temp_file_path): # download partial, continue download - with open(temp_file_path, 'rb') as f: - partial_length = f.seek(0, io.SEEK_END) - progress.update(partial_length) - if partial_length > file_size: - break - get_headers['Range'] = 'bytes=%s-%s' % (partial_length, - file_size - 1) - with open(temp_file_path, 'ab') as f: - r = requests.get( - url, - stream=True, - headers=get_headers, - cookies=cookies, - timeout=60) - r.raise_for_status() - for chunk in r.iter_content( - chunk_size=1024 * 1024 * 1): - if chunk: # filter out keep-alive new chunks - progress.update(len(chunk)) - f.write(chunk) - progress.close() - break - except (Exception) as e: # no matter what happen, we will retry. - retry = retry.increment('GET', url, error=e) - retry.sleep() - - -def get_endpoint(): - MODELSCOPE_URL_SCHEME = 'https://' - DEFAULT_MODELSCOPE_DOMAIN = 'www.modelscope.cn' - modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', - DEFAULT_MODELSCOPE_DOMAIN) - return MODELSCOPE_URL_SCHEME + modelscope_domain - - -def get_file_download_url(model_id: str, file_path: str, revision: str): - """Format file download url according to `model_id`, `revision` and `file_path`. - e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`, - the resulted download url is: https://modelscope.cn/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md - - Args: - model_id (str): The model_id. - file_path (str): File path - revision (str): File revision. - - Returns: - str: The file url. - """ - file_path = urllib.parse.quote_plus(file_path) - revision = urllib.parse.quote_plus(revision) - download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}' - return download_url_template.format( - endpoint=get_endpoint(), - model_id=model_id, - revision=revision, - file_path=file_path, - ) +from modelscope import snapshot_download +import os, shutil def download_from_modelscope(model_id, origin_file_path, local_dir): @@ -255,17 +10,12 @@ def download_from_modelscope(model_id, origin_file_path, local_dir): return else: print(f"Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}") - headers = {'user-agent': ModelScopeConfig.get_user_agent(user_agent=None)} - cookies = ModelScopeConfig.get_cookies() - url = get_file_download_url(model_id=model_id, file_path=origin_file_path, revision="master") - modelscope_http_get_model_file( - url, - local_dir, - os.path.basename(origin_file_path), - file_size=0, - headers=headers, - cookies=cookies - ) + snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir) + downloaded_file_path = os.path.join(local_dir, origin_file_path) + target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1]) + if downloaded_file_path != target_file_path: + shutil.move(downloaded_file_path, target_file_path) + shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0])) def download_from_huggingface(model_id, origin_file_path, local_dir): diff --git a/diffsynth/models/kolors_text_encoder.py b/diffsynth/models/kolors_text_encoder.py new file mode 100644 index 0000000..3dd3213 --- /dev/null +++ b/diffsynth/models/kolors_text_encoder.py @@ -0,0 +1,1363 @@ +""" +This model is copied from https://github.com/Kwai-Kolors/Kolors/tree/master/kolors/models. +We didn't modify this model. +The tensor operation is performed in the prompter. +""" + + +""" PyTorch ChatGLM model. """ + +import math +import copy +import warnings +import re +import sys + +import torch +import torch.utils.checkpoint +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss, LayerNorm +from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss +from torch.nn.utils import skip_init +from typing import Optional, Tuple, Union, List, Callable, Dict, Any +from copy import deepcopy + +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import logging +from transformers.generation.logits_process import LogitsProcessor +from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput +from transformers import PretrainedConfig + + + +class ChatGLMConfig(PretrainedConfig): + model_type = "chatglm" + def __init__( + self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + classifier_dropout=None, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs + ): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + self.hidden_dropout = hidden_dropout + self.classifier_dropout = classifier_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + super().__init__(**kwargs) + + + +# flags required to enable jit fusion kernels + +if sys.platform != 'darwin': + torch._C._jit_set_profiling_mode(False) + torch._C._jit_set_profiling_executor(False) + torch._C._jit_override_can_fuse_on_cpu(True) + torch._C._jit_override_can_fuse_on_gpu(True) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM" +_CONFIG_FOR_DOC = "ChatGLM6BConfig" + +CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "THUDM/chatglm3-6b-base", + # See all ChatGLM models at https://huggingface.co/models?filter=chatglm +] + + +def default_init(cls, *args, **kwargs): + return cls(*args, **kwargs) + + +class InvalidScoreLogitsProcessor(LogitsProcessor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + if torch.isnan(scores).any() or torch.isinf(scores).any(): + scores.zero_() + scores[..., 5] = 5e4 + return scores + + +class PrefixEncoder(torch.nn.Module): + """ + The torch.nn model to encode the prefix + Input shape: (batch-size, prefix-length) + Output shape: (batch-size, prefix-length, 2*layers*hidden) + """ + + def __init__(self, config: ChatGLMConfig): + super().__init__() + self.prefix_projection = config.prefix_projection + if self.prefix_projection: + # Use a two-layer MLP to encode the prefix + kv_size = config.num_layers * config.kv_channels * config.multi_query_group_num * 2 + self.embedding = torch.nn.Embedding(config.pre_seq_len, kv_size) + self.trans = torch.nn.Sequential( + torch.nn.Linear(kv_size, config.hidden_size), + torch.nn.Tanh(), + torch.nn.Linear(config.hidden_size, kv_size) + ) + else: + self.embedding = torch.nn.Embedding(config.pre_seq_len, + config.num_layers * config.kv_channels * config.multi_query_group_num * 2) + + def forward(self, prefix: torch.Tensor): + if self.prefix_projection: + prefix_tokens = self.embedding(prefix) + past_key_values = self.trans(prefix_tokens) + else: + past_key_values = self.embedding(prefix) + return past_key_values + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> List[torch.Tensor]: + """Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = tensor.size()[last_dim] // num_partitions + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # Note: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list + + +class RotaryEmbedding(nn.Module): + def __init__(self, dim, original_impl=False, device=None, dtype=None): + super().__init__() + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim)) + self.register_buffer("inv_freq", inv_freq) + self.dim = dim + self.original_impl = original_impl + + def forward_impl( + self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000 + ): + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=torch.float, device=device) / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, dtype=torch.float, device=device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).float() + + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half() + return cache + + def forward(self, max_seq_len, offset=0): + return self.forward_impl( + max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device + ) + + +@torch.jit.script +def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: + # x: [sq, b, np, hn] + sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3) + rot_dim = rope_cache.shape[-2] * 2 + x, x_pass = x[..., :rot_dim], x[..., rot_dim:] + # truncate to support variable sizes + rope_cache = rope_cache[:sq] + xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) + rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + x_out2 = x_out2.flatten(3) + return torch.cat((x_out2, x_pass), dim=-1) + + +class RMSNorm(torch.nn.Module): + def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs): + super().__init__() + self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype)) + self.eps = eps + + def forward(self, hidden_states: torch.Tensor): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + return (self.weight * hidden_states).to(input_dtype) + + +class CoreAttention(torch.nn.Module): + def __init__(self, config: ChatGLMConfig, layer_number): + super(CoreAttention, self).__init__() + + self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 + if self.apply_query_key_layer_scaling: + self.attention_softmax_in_fp32 = True + self.layer_number = max(1, layer_number) + + projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_partition = projection_size + self.hidden_size_per_attention_head = projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + coeff = None + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + if self.apply_query_key_layer_scaling: + coeff = self.layer_number + self.norm_factor *= coeff + self.coeff = coeff + + self.attention_dropout = torch.nn.Dropout(config.attention_dropout) + + def forward(self, query_layer, key_layer, value_layer, attention_mask): + pytorch_major_version = int(torch.__version__.split('.')[0]) + if pytorch_major_version >= 2: + query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] + if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]: + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + is_causal=True) + else: + if attention_mask is not None: + attention_mask = ~attention_mask + context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, + attention_mask) + context_layer = context_layer.permute(2, 0, 1, 3) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.reshape(*new_context_layer_shape) + else: + # Raw attention scores + + # [b, np, sq, sk] + output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) + + # preallocting input tensor: [b * np, sq, sk] + matmul_input_buffer = torch.empty( + output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype, + device=query_layer.device + ) + + # Raw attention scores. [b * np, sq, sk] + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer.transpose(0, 1), # [b * np, sq, hn] + key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] + beta=0.0, + alpha=(1.0 / self.norm_factor), + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(*output_size) + + # =========================== + # Attention probs and dropout + # =========================== + + # attention scores and attention mask [b, np, sq, sk] + if self.attention_softmax_in_fp32: + attention_scores = attention_scores.float() + if self.coeff is not None: + attention_scores = attention_scores * self.coeff + if attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]: + attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3], + device=attention_scores.device, dtype=torch.bool) + attention_mask.tril_() + attention_mask = ~attention_mask + if attention_mask is not None: + attention_scores = attention_scores.masked_fill(attention_mask, float("-inf")) + attention_probs = F.softmax(attention_scores, dim=-1) + attention_probs = attention_probs.type_as(value_layer) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.attention_dropout(attention_probs) + # ========================= + # Context layer. [sq, b, hp] + # ========================= + + # value_layer -> context layer. + # [sk, b, np, hn] --> [b, np, sq, hn] + + # context layer shape: [b, np, sq, hn] + output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) + # change view [sk, b * np, hn] + value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) + # change view [b * np, sq, sk] + attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) + # change view [b, np, sq, hn] + context_layer = context_layer.view(*output_size) + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + return context_layer + + +class SelfAttention(torch.nn.Module): + """Parallel self-attention layer abstract class. + + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(SelfAttention, self).__init__() + self.layer_number = max(1, layer_number) + + self.projection_size = config.kv_channels * config.num_attention_heads + + # Per attention head and per partition values. + self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads + self.num_attention_heads_per_partition = config.num_attention_heads + + self.multi_query_attention = config.multi_query_attention + self.qkv_hidden_size = 3 * self.projection_size + if self.multi_query_attention: + self.num_multi_query_groups_per_partition = config.multi_query_group_num + self.qkv_hidden_size = ( + self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num + ) + self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size, + bias=config.add_bias_linear or config.add_qkv_bias, + device=device, **_config_to_kwargs(config) + ) + + self.core_attention = CoreAttention(config, self.layer_number) + + # Output. + self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, + device=device, **_config_to_kwargs(config) + ) + + def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None): + if self.multi_query_attention: + num_attention_heads = self.num_multi_query_groups_per_partition + else: + num_attention_heads = self.num_attention_heads_per_partition + return torch.empty( + inference_max_sequence_len, + batch_size, + num_attention_heads, + self.hidden_size_per_attention_head, + dtype=dtype, + device=device, + ) + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True + ): + # hidden_states: [sq, b, h] + + # ================================================= + # Pre-allocate memory for key-values for inference. + # ================================================= + # ===================== + # Query, Key, and Value + # ===================== + + # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)] + mixed_x_layer = self.query_key_value(hidden_states) + + if self.multi_query_attention: + (query_layer, key_layer, value_layer) = mixed_x_layer.split( + [ + self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, + ], + dim=-1, + ) + query_layer = query_layer.view( + query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + key_layer = key_layer.view( + key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + ) + value_layer = value_layer.view( + value_layer.size()[:-1] + + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head) + ) + else: + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) + + # apply relative positional encoding (rotary embedding) + if rotary_pos_emb is not None: + query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb) + key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb) + + # adjust key and value for inference + if kv_cache is not None: + cache_k, cache_v = kv_cache + key_layer = torch.cat((cache_k, key_layer), dim=0) + value_layer = torch.cat((cache_v, value_layer), dim=0) + if use_cache: + kv_cache = (key_layer, value_layer) + else: + kv_cache = None + + if self.multi_query_attention: + key_layer = key_layer.unsqueeze(-2) + key_layer = key_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 + ) + key_layer = key_layer.contiguous().view( + key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + value_layer = value_layer.unsqueeze(-2) + value_layer = value_layer.expand( + -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 + ) + value_layer = value_layer.contiguous().view( + value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) + ) + + # ================================== + # core attention computation + # ================================== + + context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) + + # ================= + # Output. [sq, b, h] + # ================= + + output = self.dense(context_layer) + + return output, kv_cache + + +def _config_to_kwargs(args): + common_kwargs = { + "dtype": args.torch_dtype, + } + return common_kwargs + + +class MLP(torch.nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__(self, config: ChatGLMConfig, device=None): + super(MLP, self).__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf + self.dense_h_to_4h = nn.Linear( + config.hidden_size, + config.ffn_hidden_size * 2, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config) + ) + + def swiglu(x): + x = torch.chunk(x, 2, dim=-1) + return F.silu(x[0]) * x[1] + + self.activation_func = swiglu + + # Project back to h. + self.dense_4h_to_h = nn.Linear( + config.ffn_hidden_size, + config.hidden_size, + bias=self.add_bias, + device=device, + **_config_to_kwargs(config) + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(torch.nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__(self, config: ChatGLMConfig, layer_number, device=None): + super(GLMBlock, self).__init__() + self.layer_number = layer_number + + self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm + + self.fp32_residual_connection = config.fp32_residual_connection + + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + # Self attention. + self.self_attention = SelfAttention(config, layer_number, device=device) + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + # MLP + self.mlp = MLP(config, device=device) + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True, + ): + # hidden_states: [s, b, h] + + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output, kv_cache = self.self_attention( + layernorm_output, + attention_mask, + rotary_pos_emb, + kv_cache=kv_cache, + use_cache=use_cache + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training) + layernorm_input = residual + layernorm_input + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # MLP. + mlp_output = self.mlp(layernorm_output) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training) + output = residual + output + + return output, kv_cache + + +class GLMTransformer(torch.nn.Module): + """Transformer class.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(GLMTransformer, self).__init__() + + self.fp32_residual_connection = config.fp32_residual_connection + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + def build_layer(layer_number): + return GLMBlock(config, layer_number, device=device) + + self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)]) + + if self.post_layer_norm: + LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon, device=device, + dtype=config.torch_dtype) + + self.gradient_checkpointing = False + + def _get_layer(self, layer_number): + return self.layers[layer_number] + + def forward( + self, hidden_states, attention_mask, rotary_pos_emb, kv_caches=None, + use_cache: Optional[bool] = True, + output_hidden_states: Optional[bool] = False, + ): + if not kv_caches: + kv_caches = [None for _ in range(self.num_layers)] + presents = () if use_cache else None + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_self_attentions = None + all_hidden_states = () if output_hidden_states else None + for index in range(self.num_layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer = self._get_layer(index) + if self.gradient_checkpointing and self.training: + layer_ret = torch.utils.checkpoint.checkpoint( + layer, + hidden_states, + attention_mask, + rotary_pos_emb, + kv_caches[index], + use_cache + ) + else: + layer_ret = layer( + hidden_states, + attention_mask, + rotary_pos_emb, + kv_cache=kv_caches[index], + use_cache=use_cache + ) + hidden_states, kv_cache = layer_ret + if use_cache: + presents = presents + (kv_cache,) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states, presents, all_hidden_states, all_self_attentions + + +class ChatGLMPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and + a simple interface for downloading and loading pretrained models. + """ + + is_parallelizable = False + supports_gradient_checkpointing = True + config_class = ChatGLMConfig + base_model_prefix = "transformer" + _no_split_modules = ["GLMBlock"] + + def _init_weights(self, module: nn.Module): + """Initialize the weights.""" + return + + def get_masks(self, input_ids, past_key_values, padding_mask=None): + batch_size, seq_length = input_ids.shape + full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device) + full_attention_mask.tril_() + past_length = 0 + if past_key_values: + past_length = past_key_values[0][0].shape[0] + if past_length: + full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length, + device=input_ids.device), full_attention_mask), dim=-1) + if padding_mask is not None: + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + if not past_length and padding_mask is not None: + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + def get_position_ids(self, input_ids, device): + batch_size, seq_length = input_ids.shape + position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) + return position_ids + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, GLMTransformer): + module.gradient_checkpointing = value + + +class Embedding(torch.nn.Module): + """Language model embeddings.""" + + def __init__(self, config: ChatGLMConfig, device=None): + super(Embedding, self).__init__() + + self.hidden_size = config.hidden_size + # Word embeddings (parallel). + self.word_embeddings = nn.Embedding( + config.padded_vocab_size, + self.hidden_size, + dtype=config.torch_dtype, + device=device + ) + self.fp32_residual_connection = config.fp32_residual_connection + + def forward(self, input_ids): + # Embeddings. + words_embeddings = self.word_embeddings(input_ids) + embeddings = words_embeddings + # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. + embeddings = embeddings.transpose(0, 1).contiguous() + # If the input flag for fp32 residual connection is set, convert for float. + if self.fp32_residual_connection: + embeddings = embeddings.float() + return embeddings + + +class ChatGLMModel(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, device=None, empty_init=True): + super().__init__(config) + if empty_init: + init_method = skip_init + else: + init_method = default_init + init_kwargs = {} + if device is not None: + init_kwargs["device"] = device + self.embedding = init_method(Embedding, config, **init_kwargs) + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + + # Rotary positional embeddings + self.seq_length = config.seq_length + rotary_dim = ( + config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels + ) + + self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device, + dtype=config.torch_dtype) + self.encoder = init_method(GLMTransformer, config, **init_kwargs) + self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False, + dtype=config.torch_dtype, **init_kwargs) + self.pre_seq_len = config.pre_seq_len + self.prefix_projection = config.prefix_projection + if self.pre_seq_len is not None: + for param in self.parameters(): + param.requires_grad = False + self.prefix_tokens = torch.arange(self.pre_seq_len).long() + self.prefix_encoder = PrefixEncoder(config) + self.dropout = torch.nn.Dropout(0.1) + + def get_input_embeddings(self): + return self.embedding.word_embeddings + + def get_prompt(self, batch_size, device, dtype=torch.half): + prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device) + past_key_values = self.prefix_encoder(prefix_tokens).type(dtype) + past_key_values = past_key_values.view( + batch_size, + self.pre_seq_len, + self.num_layers * 2, + self.multi_query_group_num, + self.kv_channels + ) + # seq_len, b, nh, hidden_size + past_key_values = self.dropout(past_key_values) + past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2) + return past_key_values + + def forward( + self, + input_ids, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.BoolTensor] = None, + full_attention_mask: Optional[torch.BoolTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, seq_length = input_ids.shape + + if inputs_embeds is None: + inputs_embeds = self.embedding(input_ids) + + if self.pre_seq_len is not None: + if past_key_values is None: + past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device, + dtype=inputs_embeds.dtype) + if attention_mask is not None: + attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)), + attention_mask], dim=-1) + + if full_attention_mask is None: + if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): + full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Rotary positional embeddings + rotary_pos_emb = self.rotary_pos_emb(self.seq_length) + if position_ids is not None: + rotary_pos_emb = rotary_pos_emb[position_ids] + else: + rotary_pos_emb = rotary_pos_emb[None, :seq_length] + rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous() + + # Run encoder. + hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder( + inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb, + kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states + ) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + def quantize(self, weight_bit_width: int): + from .quantization import quantize + quantize(self.encoder, weight_bit_width) + return self + + +class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): + super().__init__(config) + + self.max_sequence_length = config.max_length + self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) + self.config = config + self.quantized = False + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def _update_model_kwargs_for_generation( + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) + + # update attention mask + if "attention_mask" in model_kwargs: + attention_mask = model_kwargs["attention_mask"] + model_kwargs["attention_mask"] = torch.cat( + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + ) + + # update position ids + if "position_ids" in model_kwargs: + position_ids = model_kwargs["position_ids"] + new_position_id = position_ids[..., -1:].clone() + new_position_id += 1 + model_kwargs["position_ids"] = torch.cat( + [position_ids, new_position_id], dim=-1 + ) + + model_kwargs["is_first_forward"] = False + return model_kwargs + + def prepare_inputs_for_generation( + self, + input_ids: torch.LongTensor, + past_key_values: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + is_first_forward: bool = True, + **kwargs + ) -> dict: + # only last token for input_ids if past is not None + if position_ids is None: + position_ids = self.get_position_ids(input_ids, device=input_ids.device) + if not is_first_forward: + if past_key_values is not None: + position_ids = position_ids[..., -1:] + input_ids = input_ids[:, -1:] + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "position_ids": position_ids, + "attention_mask": attention_mask, + "return_last_logit": True, + "use_cache": use_cache + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + return_last_logit: Optional[bool] = False, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + if return_last_logit: + hidden_states = hidden_states[-1:] + lm_logits = self.transformer.output_layer(hidden_states) + lm_logits = lm_logits.transpose(0, 1).contiguous() + + loss = None + if labels is not None: + lm_logits = lm_logits.to(torch.float32) + + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + lm_logits = lm_logits.to(hidden_states.dtype) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor + ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + + Output shares the same memory storage as `past`. + """ + return tuple( + ( + layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)), + layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)), + ) + for layer_past in past + ) + + def process_response(self, output, history): + content = "" + history = deepcopy(history) + for response in output.split("<|assistant|>"): + metadata, content = response.split("\n", maxsplit=1) + if not metadata.strip(): + content = content.strip() + history.append({"role": "assistant", "metadata": metadata, "content": content}) + content = content.replace("[[训练时间]]", "2023年") + else: + history.append({"role": "assistant", "metadata": metadata, "content": content}) + if history[0]["role"] == "system" and "tools" in history[0]: + content = "\n".join(content.split("\n")[1:-1]) + def tool_call(**kwargs): + return kwargs + parameters = eval(content) + content = {"name": metadata.strip(), "parameters": parameters} + else: + content = {"name": metadata.strip(), "content": content} + return content, history + + @torch.inference_mode() + def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user", + max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, + **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + inputs = tokenizer.build_chat_input(query, history=history, role=role) + inputs = inputs.to(self.device) + eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), + tokenizer.get_command("<|observation|>")] + outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id) + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] + response = tokenizer.decode(outputs) + history.append({"role": role, "content": query}) + response, history = self.process_response(response, history) + return response, history + + @torch.inference_mode() + def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, role: str = "user", + past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8, + logits_processor=None, return_past_key_values=False, **kwargs): + if history is None: + history = [] + if logits_processor is None: + logits_processor = LogitsProcessorList() + logits_processor.append(InvalidScoreLogitsProcessor()) + eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), + tokenizer.get_command("<|observation|>")] + gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p, + "temperature": temperature, "logits_processor": logits_processor, **kwargs} + if past_key_values is None: + inputs = tokenizer.build_chat_input(query, history=history, role=role) + else: + inputs = tokenizer.build_chat_input(query, role=role) + inputs = inputs.to(self.device) + if past_key_values is not None: + past_length = past_key_values[0][0].shape[0] + if self.transformer.pre_seq_len is not None: + past_length -= self.transformer.pre_seq_len + inputs.position_ids += past_length + attention_mask = inputs.attention_mask + attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) + inputs['attention_mask'] = attention_mask + history.append({"role": role, "content": query}) + for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, + eos_token_id=eos_token_id, return_past_key_values=return_past_key_values, + **gen_kwargs): + if return_past_key_values: + outputs, past_key_values = outputs + outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] + response = tokenizer.decode(outputs) + if response and response[-1] != "�": + response, new_history = self.process_response(response, history) + if return_past_key_values: + yield response, new_history, past_key_values + else: + yield response, new_history + + @torch.inference_mode() + def stream_generate( + self, + input_ids, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, + return_past_key_values=False, + **kwargs, + ): + batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] + + if generation_config is None: + generation_config = self.generation_config + generation_config = copy.deepcopy(generation_config) + model_kwargs = generation_config.update(**kwargs) + model_kwargs["use_cache"] = generation_config.use_cache + bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + if has_default_max_length and generation_config.max_new_tokens is None: + warnings.warn( + f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " + "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" + " recommend using `max_new_tokens` to control the maximum length of the generation.", + UserWarning, + ) + elif generation_config.max_new_tokens is not None: + generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length + if not has_default_max_length: + logger.warn( + f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" + f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " + "Please refer to the documentation for more information. " + "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", + UserWarning, + ) + + if input_ids_seq_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + logger.warning( + f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`." + ) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_seq_length, + encoder_input_ids=input_ids, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + logits_processor=logits_processor, + ) + + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria + ) + logits_warper = self._get_logits_warper(generation_config) + + unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) + scores = None + while True: + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=False, + output_hidden_states=False, + ) + + next_token_logits = outputs.logits[:, -1, :] + + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + + # sample + probs = nn.functional.softmax(next_token_scores, dim=-1) + if generation_config.do_sample: + next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) + else: + next_tokens = torch.argmax(probs, dim=-1) + # update generated ids, model inputs, and length for next step + input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) + ) + if return_past_key_values: + yield input_ids, outputs.past_key_values + else: + yield input_ids + # stop when each sentence is finished, or if we exceed the maximum length + if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores): + break + + def quantize(self, bits: int, empty_init=False, device=None, **kwargs): + if bits == 0: + return + + from .quantization import quantize + + if self.quantized: + logger.info("Already quantized.") + return self + + self.quantized = True + + self.config.quantization_bit = bits + + self.transformer.encoder = quantize(self.transformer.encoder, bits, empty_init=empty_init, device=device, + **kwargs) + return self + + +class ChatGLMForSequenceClassification(ChatGLMPreTrainedModel): + def __init__(self, config: ChatGLMConfig, empty_init=True, device=None): + super().__init__(config) + + self.num_labels = config.num_labels + self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device) + + self.classifier_head = nn.Linear(config.hidden_size, config.num_labels, bias=True, dtype=torch.half) + if config.classifier_dropout is not None: + self.dropout = nn.Dropout(config.classifier_dropout) + else: + self.dropout = None + self.config = config + + if self.config.quantization_bit: + self.quantize(self.config.quantization_bit, empty_init=True) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + full_attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + inputs_embeds: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor, ...], SequenceClassifierOutputWithPast]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + full_attention_mask=full_attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + pooled_hidden_states = hidden_states[-1] + if self.dropout is not None: + pooled_hidden_states = self.dropout(pooled_hidden_states) + logits = self.classifier_head(pooled_hidden_states) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(logits.squeeze().float(), labels.squeeze()) + else: + loss = loss_fct(logits.float(), labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels).float(), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits.float(), labels.view(-1, self.num_labels)) + + if not return_dict: + output = (logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/diffsynth/models/sdxl_unet.py b/diffsynth/models/sdxl_unet.py index a336259..b70a8df 100644 --- a/diffsynth/models/sdxl_unet.py +++ b/diffsynth/models/sdxl_unet.py @@ -3,7 +3,7 @@ from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock class SDXLUNet(torch.nn.Module): - def __init__(self): + def __init__(self, is_kolors=False): super().__init__() self.time_proj = Timesteps(320) self.time_embedding = torch.nn.Sequential( @@ -13,11 +13,12 @@ class SDXLUNet(torch.nn.Module): ) self.add_time_proj = Timesteps(256) self.add_time_embedding = torch.nn.Sequential( - torch.nn.Linear(2816, 1280), + torch.nn.Linear(5632 if is_kolors else 2816, 1280), torch.nn.SiLU(), torch.nn.Linear(1280, 1280) ) self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1) + self.text_intermediate_proj = torch.nn.Linear(4096, 2048) if is_kolors else None self.blocks = torch.nn.ModuleList([ # DownBlock2D @@ -85,7 +86,9 @@ class SDXLUNet(torch.nn.Module): def forward( self, sample, timestep, encoder_hidden_states, add_time_id, add_text_embeds, - tiled=False, tile_size=64, tile_stride=8, **kwargs + tiled=False, tile_size=64, tile_stride=8, + use_gradient_checkpointing=False, + **kwargs ): # 1. time t_emb = self.time_proj(timestep[None]).to(sample.dtype) @@ -102,15 +105,26 @@ class SDXLUNet(torch.nn.Module): # 2. pre-process height, width = sample.shape[2], sample.shape[3] hidden_states = self.conv_in(sample) - text_emb = encoder_hidden_states + text_emb = encoder_hidden_states if self.text_intermediate_proj is None else self.text_intermediate_proj(encoder_hidden_states) res_stack = [hidden_states] # 3. blocks + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + return custom_forward for i, block in enumerate(self.blocks): - hidden_states, time_emb, text_emb, res_stack = block( - hidden_states, time_emb, text_emb, res_stack, - tiled=tiled, tile_size=tile_size, tile_stride=tile_stride - ) + if self.training and use_gradient_checkpointing and not (isinstance(block, PushBlock) or isinstance(block, PopBlock)): + hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, time_emb, text_emb, res_stack, + use_reentrant=False, + ) + else: + hidden_states, time_emb, text_emb, res_stack = block( + hidden_states, time_emb, text_emb, res_stack, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride + ) # 4. output hidden_states = self.conv_norm_out(hidden_states) @@ -148,6 +162,8 @@ class SDXLUNetStateDictConverter: names = name.split(".") if names[0] in ["conv_in", "conv_norm_out", "conv_out"]: pass + elif names[0] in ["encoder_hid_proj"]: + names[0] = "text_intermediate_proj" elif names[0] in ["time_embedding", "add_embedding"]: if names[0] == "add_embedding": names[0] = "add_time_embedding" diff --git a/diffsynth/pipelines/__init__.py b/diffsynth/pipelines/__init__.py index e2b61ad..c475288 100644 --- a/diffsynth/pipelines/__init__.py +++ b/diffsynth/pipelines/__init__.py @@ -5,3 +5,4 @@ from .stable_diffusion_xl_video import SDXLVideoPipeline from .stable_video_diffusion import SVDVideoPipeline from .hunyuan_dit import HunyuanDiTImagePipeline from .stable_diffusion_3 import SD3ImagePipeline +from .kwai_kolors import KolorsImagePipeline diff --git a/diffsynth/pipelines/dancer.py b/diffsynth/pipelines/dancer.py index d01a7f3..986f4ca 100644 --- a/diffsynth/pipelines/dancer.py +++ b/diffsynth/pipelines/dancer.py @@ -147,7 +147,7 @@ def lets_dance_xl( # 3. pre-process height, width = sample.shape[2], sample.shape[3] hidden_states = unet.conv_in(sample) - text_emb = encoder_hidden_states + text_emb = encoder_hidden_states if unet.text_intermediate_proj is None else unet.text_intermediate_proj(encoder_hidden_states) res_stack = [hidden_states] # 4. blocks diff --git a/diffsynth/pipelines/hunyuan_dit.py b/diffsynth/pipelines/hunyuan_dit.py index 6d592e4..1076727 100644 --- a/diffsynth/pipelines/hunyuan_dit.py +++ b/diffsynth/pipelines/hunyuan_dit.py @@ -216,7 +216,7 @@ class HunyuanDiTImagePipeline(torch.nn.Module): # Prepare latent tensors noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) if input_image is not None: - image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) + image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32) latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype) latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) else: @@ -293,6 +293,6 @@ class HunyuanDiTImagePipeline(torch.nn.Module): progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) # Decode image - image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) return image diff --git a/diffsynth/pipelines/kwai_kolors.py b/diffsynth/pipelines/kwai_kolors.py new file mode 100644 index 0000000..10cdd69 --- /dev/null +++ b/diffsynth/pipelines/kwai_kolors.py @@ -0,0 +1,168 @@ +from ..models import ModelManager, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder +from ..models.kolors_text_encoder import ChatGLMModel +from ..prompts import KolorsPrompter +from ..schedulers import EnhancedDDIMScheduler +from .dancer import lets_dance_xl +import torch +from tqdm import tqdm +from PIL import Image +import numpy as np + + +class KolorsImagePipeline(torch.nn.Module): + + def __init__(self, device="cuda", torch_dtype=torch.float16): + super().__init__() + self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100) + self.prompter = KolorsPrompter() + self.device = device + self.torch_dtype = torch_dtype + # models + self.text_encoder: ChatGLMModel = None + self.unet: SDXLUNet = None + self.vae_decoder: SDXLVAEDecoder = None + self.vae_encoder: SDXLVAEEncoder = None + self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None + self.ipadapter: SDXLIpAdapter = None + + + def fetch_main_models(self, model_manager: ModelManager): + self.text_encoder = model_manager.kolors_text_encoder + self.unet = model_manager.kolors_unet + self.vae_decoder = model_manager.vae_decoder + self.vae_encoder = model_manager.vae_encoder + + + def fetch_ipadapter(self, model_manager: ModelManager): + if "ipadapter_xl" in model_manager.model: + self.ipadapter = model_manager.ipadapter_xl + if "ipadapter_xl_image_encoder" in model_manager.model: + self.ipadapter_image_encoder = model_manager.ipadapter_xl_image_encoder + + + def fetch_prompter(self, model_manager: ModelManager): + self.prompter.load_from_model_manager(model_manager) + + + @staticmethod + def from_model_manager(model_manager: ModelManager): + pipe = KolorsImagePipeline( + device=model_manager.device, + torch_dtype=model_manager.torch_dtype, + ) + pipe.fetch_main_models(model_manager) + pipe.fetch_prompter(model_manager) + pipe.fetch_ipadapter(model_manager) + return pipe + + + def preprocess_image(self, image): + image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0) + return image + + + def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32): + image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0] + image = image.cpu().permute(1, 2, 0).numpy() + image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8")) + return image + + + @torch.no_grad() + def __call__( + self, + prompt, + negative_prompt="", + cfg_scale=7.5, + clip_skip=2, + input_image=None, + ipadapter_images=None, + ipadapter_scale=1.0, + ipadapter_use_instant_style=False, + denoising_strength=1.0, + height=1024, + width=1024, + num_inference_steps=20, + tiled=False, + tile_size=64, + tile_stride=32, + progress_bar_cmd=tqdm, + progress_bar_st=None, + ): + # Prepare scheduler + self.scheduler.set_timesteps(num_inference_steps, denoising_strength) + + # Prepare latent tensors + if input_image is not None: + image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype) + latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype) + noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) + latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0]) + else: + latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype) + + # Encode prompts + add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt( + self.text_encoder, + prompt, + clip_skip=clip_skip, + device=self.device, + positive=True, + ) + if cfg_scale != 1.0: + add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt( + self.text_encoder, + negative_prompt, + clip_skip=clip_skip, + device=self.device, + positive=False, + ) + + # Prepare positional id + add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device) + + # IP-Adapter + if ipadapter_images is not None: + if ipadapter_use_instant_style: + self.ipadapter.set_less_adapter() + else: + self.ipadapter.set_full_adapter() + ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images) + ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale) + ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding)) + else: + ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {} + + # Denoise + for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)): + timestep = torch.IntTensor((timestep,))[0].to(self.device) + + # Classifier-free guidance + noise_pred_posi = lets_dance_xl( + self.unet, + sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi, + add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, + ipadapter_kwargs_list=ipadapter_kwargs_list_posi, + ) + if cfg_scale != 1.0: + noise_pred_nega = lets_dance_xl( + self.unet, + sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega, + add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega, + tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, + ipadapter_kwargs_list=ipadapter_kwargs_list_nega, + ) + noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) + else: + noise_pred = noise_pred_posi + + latents = self.scheduler.step(noise_pred, timestep, latents) + + if progress_bar_st is not None: + progress_bar_st.progress(progress_id / len(self.scheduler.timesteps)) + + # Decode image + image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) + + return image diff --git a/diffsynth/prompts/__init__.py b/diffsynth/prompts/__init__.py index 5c15008..5309bb2 100644 --- a/diffsynth/prompts/__init__.py +++ b/diffsynth/prompts/__init__.py @@ -2,3 +2,4 @@ from .sd_prompter import SDPrompter from .sdxl_prompter import SDXLPrompter from .sd3_prompter import SD3Prompter from .hunyuan_dit_prompter import HunyuanDiTPrompter +from .kolors_prompter import KolorsPrompter diff --git a/diffsynth/prompts/kolors_prompter.py b/diffsynth/prompts/kolors_prompter.py new file mode 100644 index 0000000..51fdcc8 --- /dev/null +++ b/diffsynth/prompts/kolors_prompter.py @@ -0,0 +1,347 @@ +from .utils import Prompter +import json, os, re +from typing import List, Optional, Union, Dict +from sentencepiece import SentencePieceProcessor +from transformers import PreTrainedTokenizer +from transformers.utils import PaddingStrategy +from transformers.tokenization_utils_base import EncodedInput, BatchEncoding +from ..models.kolors_text_encoder import ChatGLMModel + + +class SPTokenizer: + def __init__(self, model_path: str): + # reload tokenizer + assert os.path.isfile(model_path), model_path + self.sp_model = SentencePieceProcessor(model_file=model_path) + + # BOS / EOS token IDs + self.n_words: int = self.sp_model.vocab_size() + self.bos_id: int = self.sp_model.bos_id() + self.eos_id: int = self.sp_model.eos_id() + self.pad_id: int = self.sp_model.unk_id() + assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() + + role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"] + special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens + self.special_tokens = {} + self.index_special_tokens = {} + for token in special_tokens: + self.special_tokens[token] = self.n_words + self.index_special_tokens[self.n_words] = token + self.n_words += 1 + self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens]) + + def tokenize(self, s: str, encode_special_tokens=False): + if encode_special_tokens: + last_index = 0 + t = [] + for match in re.finditer(self.role_special_token_expression, s): + if last_index < match.start(): + t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()])) + t.append(s[match.start():match.end()]) + last_index = match.end() + if last_index < len(s): + t.extend(self.sp_model.EncodeAsPieces(s[last_index:])) + return t + else: + return self.sp_model.EncodeAsPieces(s) + + def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]: + assert type(s) is str + t = self.sp_model.encode(s) + if bos: + t = [self.bos_id] + t + if eos: + t = t + [self.eos_id] + return t + + def decode(self, t: List[int]) -> str: + text, buffer = "", [] + for token in t: + if token in self.index_special_tokens: + if buffer: + text += self.sp_model.decode(buffer) + buffer = [] + text += self.index_special_tokens[token] + else: + buffer.append(token) + if buffer: + text += self.sp_model.decode(buffer) + return text + + def decode_tokens(self, tokens: List[str]) -> str: + text = self.sp_model.DecodePieces(tokens) + return text + + def convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + if token in self.special_tokens: + return self.special_tokens[token] + return self.sp_model.PieceToId(token) + + def convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + if index in self.index_special_tokens: + return self.index_special_tokens[index] + if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0: + return "" + return self.sp_model.IdToPiece(index) + + + +class ChatGLMTokenizer(PreTrainedTokenizer): + vocab_files_names = {"vocab_file": "tokenizer.model"} + + model_input_names = ["input_ids", "attention_mask", "position_ids"] + + def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False, + **kwargs): + self.name = "GLMTokenizer" + + self.vocab_file = vocab_file + self.tokenizer = SPTokenizer(vocab_file) + self.special_tokens = { + "": self.tokenizer.bos_id, + "": self.tokenizer.eos_id, + "": self.tokenizer.pad_id + } + self.encode_special_tokens = encode_special_tokens + super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces, + encode_special_tokens=encode_special_tokens, + **kwargs) + + def get_command(self, token): + if token in self.special_tokens: + return self.special_tokens[token] + assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}" + return self.tokenizer.special_tokens[token] + + @property + def unk_token(self) -> str: + return "" + + @property + def pad_token(self) -> str: + return "" + + @property + def pad_token_id(self): + return self.get_command("") + + @property + def eos_token(self) -> str: + return "" + + @property + def eos_token_id(self): + return self.get_command("") + + @property + def vocab_size(self): + return self.tokenizer.n_words + + def get_vocab(self): + """ Returns vocab as a dict """ + vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text, **kwargs): + return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens) + + def _convert_token_to_id(self, token): + """ Converts a token (str) in an id using the vocab. """ + return self.tokenizer.convert_token_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + return self.tokenizer.convert_id_to_token(index) + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + return self.tokenizer.decode_tokens(tokens) + + def save_vocabulary(self, save_directory, filename_prefix=None): + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + filename_prefix (`str`, *optional*): + An optional prefix to add to the named of the saved files. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if os.path.isdir(save_directory): + vocab_file = os.path.join( + save_directory, self.vocab_files_names["vocab_file"] + ) + else: + vocab_file = save_directory + + with open(self.vocab_file, 'rb') as fin: + proto_str = fin.read() + + with open(vocab_file, "wb") as writer: + writer.write(proto_str) + + return (vocab_file,) + + def get_prefix_tokens(self): + prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")] + return prefix_tokens + + def build_single_message(self, role, metadata, message): + assert role in ["system", "user", "assistant", "observation"], role + role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n") + message_tokens = self.tokenizer.encode(message) + tokens = role_tokens + message_tokens + return tokens + + def build_chat_input(self, query, history=None, role="user"): + if history is None: + history = [] + input_ids = [] + for item in history: + content = item["content"] + if item["role"] == "system" and "tools" in item: + content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False) + input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content)) + input_ids.extend(self.build_single_message(role, "", query)) + input_ids.extend([self.get_command("<|assistant|>")]) + return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and + adding special tokens. A BERT sequence has the following format: + + - single sequence: `[CLS] X [SEP]` + - pair of sequences: `[CLS] A [SEP] B [SEP]` + + Args: + token_ids_0 (`List[int]`): + List of IDs to which the special tokens will be added. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. + """ + prefix_tokens = self.get_prefix_tokens() + token_ids_0 = prefix_tokens + token_ids_0 + if token_ids_1 is not None: + token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("")] + return token_ids_0 + + def _pad( + self, + encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding], + max_length: Optional[int] = None, + padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD, + pad_to_multiple_of: Optional[int] = None, + return_attention_mask: Optional[bool] = None, + ) -> dict: + """ + Pad encoded inputs (on left/right and up to predefined length or max length in the batch) + + Args: + encoded_inputs: + Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`). + max_length: maximum length of the returned list and optionally padding length (see below). + Will truncate by taking into account the special tokens. + padding_strategy: PaddingStrategy to use for padding. + + - PaddingStrategy.LONGEST Pad to the longest sequence in the batch + - PaddingStrategy.MAX_LENGTH: Pad to the max length (default) + - PaddingStrategy.DO_NOT_PAD: Do not pad + The tokenizer padding sides are defined in self.padding_side: + + - 'left': pads on the left of the sequences + - 'right': pads on the right of the sequences + pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value. + This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability + `>= 7.5` (Volta). + return_attention_mask: + (optional) Set to False to avoid returning attention mask (default: set to model specifics) + """ + # Load from model defaults + assert self.padding_side == "left" + + required_input = encoded_inputs[self.model_input_names[0]] + seq_length = len(required_input) + + if padding_strategy == PaddingStrategy.LONGEST: + max_length = len(required_input) + + if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): + max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of + + needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length + + # Initialize attention mask if not present. + if "attention_mask" not in encoded_inputs: + encoded_inputs["attention_mask"] = [1] * seq_length + + if "position_ids" not in encoded_inputs: + encoded_inputs["position_ids"] = list(range(seq_length)) + + if needs_to_be_padded: + difference = max_length - len(required_input) + + if "attention_mask" in encoded_inputs: + encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"] + if "position_ids" in encoded_inputs: + encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"] + encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input + + return encoded_inputs + + + +class KolorsPrompter(Prompter): + def __init__( + self, + tokenizer_path=None + ): + if tokenizer_path is None: + base_path = os.path.dirname(os.path.dirname(__file__)) + tokenizer_path = os.path.join(base_path, "tokenizer_configs/kolors/tokenizer") + super().__init__() + self.tokenizer = ChatGLMTokenizer.from_pretrained(tokenizer_path) + + + def encode_prompt_using_ChatGLM(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ).to(device) + output = text_encoder( + input_ids=text_inputs['input_ids'] , + attention_mask=text_inputs['attention_mask'], + position_ids=text_inputs['position_ids'], + output_hidden_states=True + ) + prompt_emb = output.hidden_states[-clip_skip].permute(1, 0, 2).clone() + pooled_prompt_emb = output.hidden_states[-1][-1, :, :].clone() + return prompt_emb, pooled_prompt_emb + + + def encode_prompt( + self, + text_encoder: ChatGLMModel, + prompt, + clip_skip=2, + positive=True, + device="cuda" + ): + prompt = self.process_prompt(prompt, positive=positive) + prompt_emb, pooled_prompt_emb = self.encode_prompt_using_ChatGLM(prompt, text_encoder, self.tokenizer, 256, clip_skip, device) + + return pooled_prompt_emb, prompt_emb diff --git a/diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model b/diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model new file mode 100644 index 0000000..c8336ad Binary files /dev/null and b/diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model differ diff --git a/diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json b/diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json new file mode 100644 index 0000000..f6f13c8 --- /dev/null +++ b/diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer_config.json @@ -0,0 +1,12 @@ +{ + "name_or_path": "THUDM/chatglm3-6b-base", + "remove_space": false, + "do_lower_case": false, + "tokenizer_class": "ChatGLMTokenizer", + "auto_map": { + "AutoTokenizer": [ + "tokenization_chatglm.ChatGLMTokenizer", + null + ] + } +} diff --git a/diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt b/diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt new file mode 100644 index 0000000..c8336ad Binary files /dev/null and b/diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt differ diff --git a/examples/image_synthesis/README.md b/examples/image_synthesis/README.md index 0aeab86..800952e 100644 --- a/examples/image_synthesis/README.md +++ b/examples/image_synthesis/README.md @@ -28,6 +28,16 @@ LoRA Training: [`../train/stable_diffusion_3/`](../train/stable_diffusion_3/) |-|-| |![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)|![image_2048](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/1386c802-e580-4101-939d-f1596802df9d)| +### Example: Kolors + +Example script: [`kolors_text_to_image.py`](./kolors_text_to_image.py) + +LoRA Training: [`../train/kolors/`](../train/kolors/) + +|1024*1024|2048*2048| +|-|-| +|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/53ef6f41-da11-4701-8665-9f64392607bf)|![image_2048](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/66bb7a75-fe31-44e5-90eb-d3140ee4686d)| + ### Example: Hunyuan-DiT Example script: [`hunyuan_dit_text_to_image.py`](./hunyuan_dit_text_to_image.py) diff --git a/examples/image_synthesis/kolors_text_to_image.py b/examples/image_synthesis/kolors_text_to_image.py new file mode 100644 index 0000000..b9cb583 --- /dev/null +++ b/examples/image_synthesis/kolors_text_to_image.py @@ -0,0 +1,34 @@ +from diffsynth import ModelManager, KolorsImagePipeline, download_models +import torch + +# Download models +# https://huggingface.co/Kwai-Kolors/Kolors +download_models(["Kolors"]) +model_manager = ModelManager(torch_dtype=torch.float16, device="cuda", + file_path_list=[ + "models/kolors/Kolors/text_encoder", + "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors", + "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors" + ]) +pipe = KolorsImagePipeline.from_model_manager(model_manager) + +prompt = "一幅充满诗意美感的全身画,泛红的肤色,画中一位银色长发、蓝色眼睛、肤色红润、身穿蓝色吊带连衣裙的少女漂浮在水下,面向镜头,周围是光彩的气泡,和煦的阳光透过水面折射进水下" +negative_prompt = "半身,苍白的肤色,蜡黄的肤色,尸体,错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,错误的手指,口红,腮红" + +torch.manual_seed(7) +image = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_inference_steps=50, + cfg_scale=4, +) +image.save(f"image_1024.jpg") + +image = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + input_image=image.resize((2048, 2048)), denoising_strength=0.4, height=2048, width=2048, + num_inference_steps=50, + cfg_scale=4, +) +image.save("image_2048.jpg") diff --git a/examples/train/kolors/README.md b/examples/train/kolors/README.md new file mode 100644 index 0000000..be50561 --- /dev/null +++ b/examples/train/kolors/README.md @@ -0,0 +1,175 @@ +# Kolors + +Kolors is a Chinese diffusion model, which is based on ChatGLM and Stable Diffusion XL. We provide training scripts here. + +## Download models + +The following files will be used for constructing Kolors. You can download them from [huggingface](https://huggingface.co/Kwai-Kolors/Kolors) or [modelscope](https://modelscope.cn/models/Kwai-Kolors/Kolors). + +``` +models/kolors/Kolors +├── text_encoder +│ ├── config.json +│ ├── pytorch_model-00001-of-00007.bin +│ ├── pytorch_model-00002-of-00007.bin +│ ├── pytorch_model-00003-of-00007.bin +│ ├── pytorch_model-00004-of-00007.bin +│ ├── pytorch_model-00005-of-00007.bin +│ ├── pytorch_model-00006-of-00007.bin +│ ├── pytorch_model-00007-of-00007.bin +│ └── pytorch_model.bin.index.json +├── unet +│ └── diffusion_pytorch_model.safetensors +└── vae + └── diffusion_pytorch_model.safetensors +``` + +You can use the following code to download these files: + +```python +from diffsynth import download_models + +download_models(["Kolors"]) +``` + +## Train + +### Install training dependency + +``` +pip install peft lightning pandas torchvision +``` + +### Prepare your dataset + +We provide an example dataset [here](https://modelscope.cn/datasets/buptwq/lora-stable-diffusion-finetune/files). You need to manage the training images as follows: + +``` +data/dog/ +└── train + ├── 00.jpg + ├── 01.jpg + ├── 02.jpg + ├── 03.jpg + ├── 04.jpg + └── metadata.csv +``` + +`metadata.csv`: + +``` +file_name,text +00.jpg,一只小狗 +01.jpg,一只小狗 +02.jpg,一只小狗 +03.jpg,一只小狗 +04.jpg,一只小狗 +``` + +### Train a LoRA model + +We provide a training script `train_kolors_lora.py`. Before you run this training script, please copy it to the root directory of this project. + +The following settings are recommended. **We found the UNet model suffers from precision overflow issues, thus the training script doesn't support float16. 40GB VRAM is required. We are working on overcoming this pitfall.** + +``` +CUDA_VISIBLE_DEVICES="0" python examples/train/kolors/train_kolors_lora.py \ + --pretrained_path models/kolors/Kolors \ + --dataset_path data/dog \ + --output_path ./models \ + --max_epochs 10 \ + --center_crop \ + --use_gradient_checkpointing \ + --precision 32 +``` + +Optional arguments: +``` + -h, --help show this help message and exit + --pretrained_path PRETRAINED_PATH + Path to pretrained model. For example, `models/kolors/Kolors`. + --dataset_path DATASET_PATH + The path of the Dataset. + --output_path OUTPUT_PATH + Path to save the model. + --steps_per_epoch STEPS_PER_EPOCH + Number of steps per epoch. + --height HEIGHT Image height. + --width WIDTH Image width. + --center_crop Whether to center crop the input images to the resolution. If not set, the images will be randomly cropped. The images will be resized to the resolution first before cropping. + --random_flip Whether to randomly flip images horizontally + --batch_size BATCH_SIZE + Batch size (per device) for the training dataloader. + --dataloader_num_workers DATALOADER_NUM_WORKERS + Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. + --precision {32,16,16-mixed} + Training precision + --learning_rate LEARNING_RATE + Learning rate. + --lora_rank LORA_RANK + The dimension of the LoRA update matrices. + --lora_alpha LORA_ALPHA + The weight of the LoRA update matrices. + --use_gradient_checkpointing + Whether to use gradient checkpointing. + --accumulate_grad_batches ACCUMULATE_GRAD_BATCHES + The number of batches in gradient accumulation. + --training_strategy {auto,deepspeed_stage_1,deepspeed_stage_2,deepspeed_stage_3} + Training strategy + --max_epochs MAX_EPOCHS + Number of epochs. +``` + +### Inference with your own LoRA model + +After training, you can use your own LoRA model to generate new images. Here are some examples. + +```python +from diffsynth import ModelManager, KolorsImagePipeline +from peft import LoraConfig, inject_adapter_in_model +import torch + + +def load_lora(model, lora_rank, lora_alpha, lora_path): + lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + init_lora_weights="gaussian", + target_modules=["to_q", "to_k", "to_v", "to_out"], + ) + model = inject_adapter_in_model(lora_config, model) + state_dict = torch.load(lora_path, map_location="cpu") + model.load_state_dict(state_dict, strict=False) + return model + + +# Load models +model_manager = ModelManager(torch_dtype=torch.float16, device="cuda", + file_path_list=[ + "models/kolors/Kolors/text_encoder", + "models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors", + "models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors" + ]) +pipe = KolorsImagePipeline.from_model_manager(model_manager) + +# Generate an image with lora +pipe.unet = load_lora( + pipe.unet, + lora_rank=4, lora_alpha=4.0, # The two parameters should be consistent with those in your training script. + lora_path="path/to/your/lora/model/lightning_logs/version_x/checkpoints/epoch=x-step=xxx.ckpt" +) +torch.manual_seed(0) +image = pipe( + prompt="一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉", + negative_prompt="", + cfg_scale=4, + num_inference_steps=50, height=1024, width=1024, +) +image.save("image_with_lora.jpg") +``` + +Prompt: 一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉 + +|Without LoRA|With LoRA| +|-|-| +|![image_without_lora](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/9d79ed7a-e8cf-4d98-800a-f182809db318)|![image_with_lora](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/02f62323-6ee5-4788-97a1-549732dbe4f0)| diff --git a/examples/train/kolors/train_kolors_lora.py b/examples/train/kolors/train_kolors_lora.py new file mode 100644 index 0000000..0a0af9d --- /dev/null +++ b/examples/train/kolors/train_kolors_lora.py @@ -0,0 +1,293 @@ +from diffsynth import ModelManager, KolorsImagePipeline +from peft import LoraConfig, inject_adapter_in_model +from torchvision import transforms +from PIL import Image +import lightning as pl +import pandas as pd +import torch, os, argparse +os.environ["TOKENIZERS_PARALLELISM"] = "True" + + + +class TextImageDataset(torch.utils.data.Dataset): + def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False): + self.steps_per_epoch = steps_per_epoch + metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv")) + self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]] + self.text = metadata["text"].to_list() + self.image_processor = transforms.Compose( + [ + transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)), + transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + + def __getitem__(self, index): + data_id = torch.randint(0, len(self.path), (1,))[0] + data_id = (data_id + index) % len(self.path) # For fixed seed. + text = self.text[data_id] + image = Image.open(self.path[data_id]).convert("RGB") + image = self.image_processor(image) + return {"text": text, "image": image} + + + def __len__(self): + return self.steps_per_epoch + + + +class LightningModel(pl.LightningModule): + def __init__(self, torch_dtype=torch.float16, learning_rate=1e-4, pretrained_weights=[], lora_rank=4, lora_alpha=4, use_gradient_checkpointing=True): + super().__init__() + + # Load models + model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device) + model_manager.load_models(pretrained_weights) + self.pipe = KolorsImagePipeline.from_model_manager(model_manager) + + # Freeze parameters + self.pipe.text_encoder.requires_grad_(False) + self.pipe.unet.requires_grad_(False) + self.pipe.vae_decoder.requires_grad_(False) + self.pipe.vae_encoder.requires_grad_(False) + self.pipe.text_encoder.eval() + self.pipe.unet.train() + self.pipe.vae_decoder.eval() + self.pipe.vae_encoder.eval() + + # Add LoRA to UNet + lora_config = LoraConfig( + r=lora_rank, + lora_alpha=lora_alpha, + init_lora_weights="gaussian", + target_modules=["to_q", "to_k", "to_v", "to_out"], + ) + self.pipe.unet = inject_adapter_in_model(lora_config, self.pipe.unet) + for param in self.pipe.unet.parameters(): + # Upcast LoRA parameters into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + + # Set other parameters + self.learning_rate = learning_rate + self.use_gradient_checkpointing = use_gradient_checkpointing + self.pipe.scheduler.set_timesteps(1100) + + + def training_step(self, batch, batch_idx): + # Data + text, image = batch["text"], batch["image"] + + # Prepare input parameters + self.pipe.device = self.device + add_prompt_emb, prompt_emb = self.pipe.prompter.encode_prompt( + self.pipe.text_encoder, text, clip_skip=2, device=self.device, positive=True, + ) + height, width = image.shape[-2:] + latents = self.pipe.vae_encoder(image.to(dtype=torch.float32, device=self.device)).to(self.pipe.torch_dtype) + noise = torch.randn_like(latents) + timestep = torch.randint(0, 1100, (1,), device=self.device)[0] + add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device) + noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep) + + # Compute loss + noise_pred = self.pipe.unet( + noisy_latents, timestep, prompt_emb, add_time_id, add_prompt_emb, + use_gradient_checkpointing=self.use_gradient_checkpointing + ) + loss = torch.nn.functional.mse_loss(noise_pred, noise) + + # Record log + self.log("train_loss", loss, prog_bar=True) + return loss + + + def configure_optimizers(self): + trainable_modules = filter(lambda p: p.requires_grad, self.pipe.unet.parameters()) + optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate) + return optimizer + + + def on_save_checkpoint(self, checkpoint): + checkpoint.clear() + trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.unet.named_parameters())) + trainable_param_names = set([named_param[0] for named_param in trainable_param_names]) + state_dict = self.pipe.unet.state_dict() + for name, param in state_dict.items(): + if name in trainable_param_names: + checkpoint[name] = param + + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_path", + type=str, + default=None, + required=True, + help="Path to pretrained model. For example, `models/kolors/Kolors`.", + ) + parser.add_argument( + "--dataset_path", + type=str, + default=None, + required=True, + help="The path of the Dataset.", + ) + parser.add_argument( + "--output_path", + type=str, + default="./", + help="Path to save the model.", + ) + parser.add_argument( + "--steps_per_epoch", + type=int, + default=500, + help="Number of steps per epoch.", + ) + parser.add_argument( + "--height", + type=int, + default=1024, + help="Image height.", + ) + parser.add_argument( + "--width", + type=int, + default=1024, + help="Image width.", + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + default=False, + action="store_true", + help="Whether to randomly flip images horizontally", + ) + parser.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader.", + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.", + ) + parser.add_argument( + "--precision", + type=str, + default="16-mixed", + choices=["32", "16", "16-mixed"], + help="Training precision", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Learning rate.", + ) + parser.add_argument( + "--lora_rank", + type=int, + default=4, + help="The dimension of the LoRA update matrices.", + ) + parser.add_argument( + "--lora_alpha", + type=float, + default=4.0, + help="The weight of the LoRA update matrices.", + ) + parser.add_argument( + "--use_gradient_checkpointing", + default=False, + action="store_true", + help="Whether to use gradient checkpointing.", + ) + parser.add_argument( + "--accumulate_grad_batches", + type=int, + default=1, + help="The number of batches in gradient accumulation.", + ) + parser.add_argument( + "--training_strategy", + type=str, + default="auto", + choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"], + help="Training strategy", + ) + parser.add_argument( + "--max_epochs", + type=int, + default=1, + help="Number of epochs.", + ) + args = parser.parse_args() + return args + + + +if __name__ == '__main__': + # args + args = parse_args() + + # dataset and data loader + dataset = TextImageDataset( + args.dataset_path, + steps_per_epoch=args.steps_per_epoch * args.batch_size, + height=args.height, + width=args.width, + center_crop=args.center_crop, + random_flip=args.random_flip + ) + train_loader = torch.utils.data.DataLoader( + dataset, + shuffle=True, + batch_size=args.batch_size, + num_workers=args.dataloader_num_workers + ) + + # model + model = LightningModel( + pretrained_weights=[ + os.path.join(args.pretrained_path, "text_encoder"), + os.path.join(args.pretrained_path, "unet/diffusion_pytorch_model.safetensors"), + os.path.join(args.pretrained_path, "vae/diffusion_pytorch_model.safetensors"), + ], + torch_dtype=torch.float32 if args.precision == "32" else torch.float16, + learning_rate=args.learning_rate, + lora_rank=args.lora_rank, + lora_alpha=args.lora_alpha, + use_gradient_checkpointing=args.use_gradient_checkpointing + ) + + # train + trainer = pl.Trainer( + max_epochs=args.max_epochs, + accelerator="gpu", + devices="auto", + precision=args.precision, + strategy=args.training_strategy, + default_root_dir=args.output_path, + accumulate_grad_batches=args.accumulate_grad_batches, + callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)] + ) + trainer.fit(model=model, train_dataloaders=train_loader) diff --git a/examples/train/stable_diffusion_3/README.md b/examples/train/stable_diffusion_3/README.md index 517cf7d..1163e59 100644 --- a/examples/train/stable_diffusion_3/README.md +++ b/examples/train/stable_diffusion_3/README.md @@ -153,7 +153,7 @@ image = pipe( image.save("image_with_lora.jpg") ``` -Prompt: +Prompt: a dog is jumping, flowers around the dog, the background is mountains and clouds |Without LoRA|With LoRA| |-|-| diff --git a/requirements.txt b/requirements.txt index a300a2c..4934741 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ imageio[ffmpeg] safetensors einops sentencepiece +modelscope