mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 08:08:13 +00:00
support kolors! (#106)
This commit is contained in:
11
README.md
11
README.md
@@ -8,6 +8,7 @@ DiffSynth Studio is a Diffusion engine. We have restructured architectures inclu
|
||||
Until now, DiffSynth Studio has supported the following models:
|
||||
|
||||
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
|
||||
* [Kolors](https://huggingface.co/Kwai-Kolors/Kolors)
|
||||
* [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
|
||||
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
|
||||
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
|
||||
@@ -85,11 +86,13 @@ Generate high-resolution images, by breaking the limitation of diffusion models!
|
||||
|
||||
LoRA fine-tuning is supported in [`examples/train`](./examples/train/).
|
||||
|
||||
|Stable Diffusion|Stable Diffusion XL|
|
||||
|Model|Example|
|
||||
|-|-|
|
||||
||||
|
||||
|Stable Diffusion 3|Hunyuan-DiT|
|
||||
|||
|
||||
|Stable Diffusion||
|
||||
|Stable Diffusion XL||
|
||||
|Stable Diffusion 3||
|
||||
|Kolors||
|
||||
|Hunyuan-DiT||
|
||||
|
||||
### Toon Shading
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import torch, os
|
||||
import torch, os, json
|
||||
from safetensors import safe_open
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
from typing import List
|
||||
@@ -36,6 +36,7 @@ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
|
||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
from .kolors_text_encoder import ChatGLMModel
|
||||
|
||||
|
||||
preset_models_on_huggingface = {
|
||||
@@ -159,6 +160,20 @@ preset_models_on_modelscope = {
|
||||
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
||||
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
||||
],
|
||||
# Kolors
|
||||
"Kolors": [
|
||||
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
||||
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
||||
],
|
||||
}
|
||||
Preset_model_id: TypeAlias = Literal[
|
||||
"HunyuanDiT",
|
||||
@@ -184,7 +199,8 @@ Preset_model_id: TypeAlias = Literal[
|
||||
"IP-Adapter-SD",
|
||||
"IP-Adapter-SDXL",
|
||||
"StableDiffusion3",
|
||||
"StableDiffusion3_without_T5"
|
||||
"StableDiffusion3_without_T5",
|
||||
"Kolors",
|
||||
]
|
||||
Preset_model_website: TypeAlias = Literal[
|
||||
"HuggingFace",
|
||||
@@ -272,8 +288,7 @@ class ModelManager:
|
||||
|
||||
def is_controlnet(self, state_dict):
|
||||
param_name = "control_model.time_embed.0.weight"
|
||||
param_name_2 = "mid_block.resnets.1.time_emb_proj.weight" # For controlnets in diffusers format
|
||||
return param_name in state_dict or param_name_2 in state_dict
|
||||
return param_name in state_dict
|
||||
|
||||
def is_animatediff(self, state_dict):
|
||||
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
|
||||
@@ -343,6 +358,21 @@ class ModelManager:
|
||||
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_kolors_text_encoder(self, file_path):
|
||||
file_list = os.listdir(file_path)
|
||||
if "config.json" in file_list:
|
||||
try:
|
||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
if config.get("model_type") == "chatglm":
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
def is_kolors_unet(self, state_dict):
|
||||
return "up_blocks.2.resnets.2.time_emb_proj.weight" in state_dict and "encoder_hid_proj.weight" in state_dict
|
||||
|
||||
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
|
||||
component_dict = {
|
||||
"image_encoder": SVDImageEncoder,
|
||||
@@ -532,13 +562,13 @@ class ModelManager:
|
||||
component = "vae_encoder"
|
||||
model = SDXLVAEEncoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
model.to(torch.float32).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
component = "vae_decoder"
|
||||
model = SDXLVAEDecoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
model.to(torch.float32).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
@@ -592,6 +622,21 @@ class ModelManager:
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_kolors_text_encoder(self, state_dict=None, file_path=""):
|
||||
component = "kolors_text_encoder"
|
||||
model = ChatGLMModel.from_pretrained(file_path, torch_dtype=self.torch_dtype)
|
||||
model = model.to(dtype=self.torch_dtype, device=self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_kolors_unet(self, state_dict, file_path=""):
|
||||
component = "kolors_unet"
|
||||
model = SDXLUNet(is_kolors=True)
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def search_for_embeddings(self, state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
@@ -607,7 +652,11 @@ class ModelManager:
|
||||
|
||||
# Load every textual inversion file
|
||||
for file_name in os.listdir(folder):
|
||||
if file_name.endswith(".txt"):
|
||||
if os.path.isdir(os.path.join(folder, file_name)) or \
|
||||
not (file_name.endswith(".bin") or \
|
||||
file_name.endswith(".safetensors") or \
|
||||
file_name.endswith(".pth") or \
|
||||
file_name.endswith(".pt")):
|
||||
continue
|
||||
keyword = os.path.splitext(file_name)[0]
|
||||
state_dict = load_state_dict(os.path.join(folder, file_name))
|
||||
@@ -620,6 +669,10 @@ class ModelManager:
|
||||
break
|
||||
|
||||
def load_model(self, file_path, components=None, lora_alphas=[]):
|
||||
if os.path.isdir(file_path):
|
||||
if self.is_kolors_text_encoder(file_path):
|
||||
self.load_kolors_text_encoder(file_path=file_path)
|
||||
return
|
||||
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
|
||||
if self.is_stable_video_diffusion(state_dict):
|
||||
self.load_stable_video_diffusion(state_dict, file_path=file_path)
|
||||
@@ -663,6 +716,8 @@ class ModelManager:
|
||||
self.load_stable_diffusion_3(state_dict, components=components, file_path=file_path)
|
||||
elif self.is_stable_diffusion_3_t5(state_dict):
|
||||
self.load_stable_diffusion_3_t5(state_dict, file_path=file_path)
|
||||
elif self.is_kolors_unet(state_dict):
|
||||
self.load_kolors_unet(state_dict, file_path=file_path)
|
||||
|
||||
def load_models(self, file_path_list, lora_alphas=[]):
|
||||
for file_path in file_path_list:
|
||||
|
||||
@@ -1,251 +1,6 @@
|
||||
from huggingface_hub import hf_hub_download
|
||||
from http.cookiejar import CookieJar
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, List, Union
|
||||
import copy, uuid, requests, io, platform, pickle, os, urllib
|
||||
from requests.adapters import Retry
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def _get_sep(path):
|
||||
if isinstance(path, bytes):
|
||||
return b'/'
|
||||
else:
|
||||
return '/'
|
||||
|
||||
|
||||
def expanduser(path):
|
||||
"""Expand ~ and ~user constructions. If user or $HOME is unknown,
|
||||
do nothing."""
|
||||
path = os.fspath(path)
|
||||
if isinstance(path, bytes):
|
||||
tilde = b'~'
|
||||
else:
|
||||
tilde = '~'
|
||||
if not path.startswith(tilde):
|
||||
return path
|
||||
sep = _get_sep(path)
|
||||
i = path.find(sep, 1)
|
||||
if i < 0:
|
||||
i = len(path)
|
||||
if i == 1:
|
||||
if 'HOME' not in os.environ:
|
||||
import pwd
|
||||
try:
|
||||
userhome = pwd.getpwuid(os.getuid()).pw_dir
|
||||
except KeyError:
|
||||
# bpo-10496: if the current user identifier doesn't exist in the
|
||||
# password database, return the path unchanged
|
||||
return path
|
||||
else:
|
||||
userhome = os.environ['HOME']
|
||||
else:
|
||||
import pwd
|
||||
name = path[1:i]
|
||||
if isinstance(name, bytes):
|
||||
name = str(name, 'ASCII')
|
||||
try:
|
||||
pwent = pwd.getpwnam(name)
|
||||
except KeyError:
|
||||
# bpo-10496: if the user name from the path doesn't exist in the
|
||||
# password database, return the path unchanged
|
||||
return path
|
||||
userhome = pwent.pw_dir
|
||||
if isinstance(path, bytes):
|
||||
userhome = os.fsencode(userhome)
|
||||
root = b'/'
|
||||
else:
|
||||
root = '/'
|
||||
userhome = userhome.rstrip(root)
|
||||
return (userhome + path[i:]) or root
|
||||
|
||||
|
||||
|
||||
class ModelScopeConfig:
|
||||
DEFAULT_CREDENTIALS_PATH = Path.home().joinpath('.modelscope', 'credentials')
|
||||
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
|
||||
COOKIES_FILE_NAME = 'cookies'
|
||||
GIT_TOKEN_FILE_NAME = 'git_token'
|
||||
USER_INFO_FILE_NAME = 'user'
|
||||
USER_SESSION_ID_FILE_NAME = 'session'
|
||||
|
||||
@staticmethod
|
||||
def make_sure_credential_path_exist():
|
||||
os.makedirs(ModelScopeConfig.path_credential, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def get_user_session_id():
|
||||
session_path = os.path.join(ModelScopeConfig.path_credential,
|
||||
ModelScopeConfig.USER_SESSION_ID_FILE_NAME)
|
||||
session_id = ''
|
||||
if os.path.exists(session_path):
|
||||
with open(session_path, 'rb') as f:
|
||||
session_id = str(f.readline().strip(), encoding='utf-8')
|
||||
return session_id
|
||||
if session_id == '' or len(session_id) != 32:
|
||||
session_id = str(uuid.uuid4().hex)
|
||||
ModelScopeConfig.make_sure_credential_path_exist()
|
||||
with open(session_path, 'w+') as wf:
|
||||
wf.write(session_id)
|
||||
|
||||
return session_id
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str:
|
||||
"""Formats a user-agent string with basic info about a request.
|
||||
|
||||
Args:
|
||||
user_agent (`str`, `dict`, *optional*):
|
||||
The user agent info in the form of a dictionary or a single string.
|
||||
|
||||
Returns:
|
||||
The formatted user-agent string.
|
||||
"""
|
||||
|
||||
# include some more telemetrics when executing in dedicated
|
||||
# cloud containers
|
||||
MODELSCOPE_CLOUD_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT'
|
||||
MODELSCOPE_CLOUD_USERNAME = 'MODELSCOPE_USERNAME'
|
||||
env = 'custom'
|
||||
if MODELSCOPE_CLOUD_ENVIRONMENT in os.environ:
|
||||
env = os.environ[MODELSCOPE_CLOUD_ENVIRONMENT]
|
||||
user_name = 'unknown'
|
||||
if MODELSCOPE_CLOUD_USERNAME in os.environ:
|
||||
user_name = os.environ[MODELSCOPE_CLOUD_USERNAME]
|
||||
|
||||
ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s; user/%s' % (
|
||||
"1.15.0",
|
||||
platform.python_version(),
|
||||
ModelScopeConfig.get_user_session_id(),
|
||||
platform.platform(),
|
||||
platform.processor(),
|
||||
env,
|
||||
user_name,
|
||||
)
|
||||
if isinstance(user_agent, dict):
|
||||
ua += '; ' + '; '.join(f'{k}/{v}' for k, v in user_agent.items())
|
||||
elif isinstance(user_agent, str):
|
||||
ua += '; ' + user_agent
|
||||
return ua
|
||||
|
||||
@staticmethod
|
||||
def get_cookies():
|
||||
cookies_path = os.path.join(ModelScopeConfig.path_credential,
|
||||
ModelScopeConfig.COOKIES_FILE_NAME)
|
||||
if os.path.exists(cookies_path):
|
||||
with open(cookies_path, 'rb') as f:
|
||||
cookies = pickle.load(f)
|
||||
return cookies
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def modelscope_http_get_model_file(
|
||||
url: str,
|
||||
local_dir: str,
|
||||
file_name: str,
|
||||
file_size: int,
|
||||
cookies: CookieJar,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Download remote file, will retry 5 times before giving up on errors.
|
||||
|
||||
Args:
|
||||
url(str):
|
||||
actual download url of the file
|
||||
local_dir(str):
|
||||
local directory where the downloaded file stores
|
||||
file_name(str):
|
||||
name of the file stored in `local_dir`
|
||||
file_size(int):
|
||||
The file size.
|
||||
cookies(CookieJar):
|
||||
cookies used to authentication the user, which is used for downloading private repos
|
||||
headers(Dict[str, str], optional):
|
||||
http headers to carry necessary info when requesting the remote file
|
||||
|
||||
Raises:
|
||||
FileDownloadError: File download failed.
|
||||
|
||||
"""
|
||||
get_headers = {} if headers is None else copy.deepcopy(headers)
|
||||
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
||||
temp_file_path = os.path.join(local_dir, file_name)
|
||||
# retry sleep 0.5s, 1s, 2s, 4s
|
||||
retry = Retry(
|
||||
total=5,
|
||||
backoff_factor=1,
|
||||
allowed_methods=['GET'])
|
||||
while True:
|
||||
try:
|
||||
progress = tqdm(
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=file_size,
|
||||
initial=0,
|
||||
desc='Downloading',
|
||||
)
|
||||
partial_length = 0
|
||||
if os.path.exists(
|
||||
temp_file_path): # download partial, continue download
|
||||
with open(temp_file_path, 'rb') as f:
|
||||
partial_length = f.seek(0, io.SEEK_END)
|
||||
progress.update(partial_length)
|
||||
if partial_length > file_size:
|
||||
break
|
||||
get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
|
||||
file_size - 1)
|
||||
with open(temp_file_path, 'ab') as f:
|
||||
r = requests.get(
|
||||
url,
|
||||
stream=True,
|
||||
headers=get_headers,
|
||||
cookies=cookies,
|
||||
timeout=60)
|
||||
r.raise_for_status()
|
||||
for chunk in r.iter_content(
|
||||
chunk_size=1024 * 1024 * 1):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
progress.update(len(chunk))
|
||||
f.write(chunk)
|
||||
progress.close()
|
||||
break
|
||||
except (Exception) as e: # no matter what happen, we will retry.
|
||||
retry = retry.increment('GET', url, error=e)
|
||||
retry.sleep()
|
||||
|
||||
|
||||
def get_endpoint():
|
||||
MODELSCOPE_URL_SCHEME = 'https://'
|
||||
DEFAULT_MODELSCOPE_DOMAIN = 'www.modelscope.cn'
|
||||
modelscope_domain = os.getenv('MODELSCOPE_DOMAIN',
|
||||
DEFAULT_MODELSCOPE_DOMAIN)
|
||||
return MODELSCOPE_URL_SCHEME + modelscope_domain
|
||||
|
||||
|
||||
def get_file_download_url(model_id: str, file_path: str, revision: str):
|
||||
"""Format file download url according to `model_id`, `revision` and `file_path`.
|
||||
e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`,
|
||||
the resulted download url is: https://modelscope.cn/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md
|
||||
|
||||
Args:
|
||||
model_id (str): The model_id.
|
||||
file_path (str): File path
|
||||
revision (str): File revision.
|
||||
|
||||
Returns:
|
||||
str: The file url.
|
||||
"""
|
||||
file_path = urllib.parse.quote_plus(file_path)
|
||||
revision = urllib.parse.quote_plus(revision)
|
||||
download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}'
|
||||
return download_url_template.format(
|
||||
endpoint=get_endpoint(),
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
file_path=file_path,
|
||||
)
|
||||
from modelscope import snapshot_download
|
||||
import os, shutil
|
||||
|
||||
|
||||
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
||||
@@ -255,17 +10,12 @@ def download_from_modelscope(model_id, origin_file_path, local_dir):
|
||||
return
|
||||
else:
|
||||
print(f"Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
||||
headers = {'user-agent': ModelScopeConfig.get_user_agent(user_agent=None)}
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
url = get_file_download_url(model_id=model_id, file_path=origin_file_path, revision="master")
|
||||
modelscope_http_get_model_file(
|
||||
url,
|
||||
local_dir,
|
||||
os.path.basename(origin_file_path),
|
||||
file_size=0,
|
||||
headers=headers,
|
||||
cookies=cookies
|
||||
)
|
||||
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
||||
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
||||
if downloaded_file_path != target_file_path:
|
||||
shutil.move(downloaded_file_path, target_file_path)
|
||||
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
||||
|
||||
|
||||
def download_from_huggingface(model_id, origin_file_path, local_dir):
|
||||
|
||||
1363
diffsynth/models/kolors_text_encoder.py
Normal file
1363
diffsynth/models/kolors_text_encoder.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,7 +3,7 @@ from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock
|
||||
|
||||
|
||||
class SDXLUNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, is_kolors=False):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(320)
|
||||
self.time_embedding = torch.nn.Sequential(
|
||||
@@ -13,11 +13,12 @@ class SDXLUNet(torch.nn.Module):
|
||||
)
|
||||
self.add_time_proj = Timesteps(256)
|
||||
self.add_time_embedding = torch.nn.Sequential(
|
||||
torch.nn.Linear(2816, 1280),
|
||||
torch.nn.Linear(5632 if is_kolors else 2816, 1280),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(1280, 1280)
|
||||
)
|
||||
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
||||
self.text_intermediate_proj = torch.nn.Linear(4096, 2048) if is_kolors else None
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# DownBlock2D
|
||||
@@ -85,7 +86,9 @@ class SDXLUNet(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
sample, timestep, encoder_hidden_states, add_time_id, add_text_embeds,
|
||||
tiled=False, tile_size=64, tile_stride=8, **kwargs
|
||||
tiled=False, tile_size=64, tile_stride=8,
|
||||
use_gradient_checkpointing=False,
|
||||
**kwargs
|
||||
):
|
||||
# 1. time
|
||||
t_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
@@ -102,15 +105,26 @@ class SDXLUNet(torch.nn.Module):
|
||||
# 2. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = self.conv_in(sample)
|
||||
text_emb = encoder_hidden_states
|
||||
text_emb = encoder_hidden_states if self.text_intermediate_proj is None else self.text_intermediate_proj(encoder_hidden_states)
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 3. blocks
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
if self.training and use_gradient_checkpointing and not (isinstance(block, PushBlock) or isinstance(block, PopBlock)):
|
||||
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, time_emb, text_emb, res_stack = block(
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
|
||||
# 4. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
@@ -148,6 +162,8 @@ class SDXLUNetStateDictConverter:
|
||||
names = name.split(".")
|
||||
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
|
||||
pass
|
||||
elif names[0] in ["encoder_hid_proj"]:
|
||||
names[0] = "text_intermediate_proj"
|
||||
elif names[0] in ["time_embedding", "add_embedding"]:
|
||||
if names[0] == "add_embedding":
|
||||
names[0] = "add_time_embedding"
|
||||
|
||||
@@ -5,3 +5,4 @@ from .stable_diffusion_xl_video import SDXLVideoPipeline
|
||||
from .stable_video_diffusion import SVDVideoPipeline
|
||||
from .hunyuan_dit import HunyuanDiTImagePipeline
|
||||
from .stable_diffusion_3 import SD3ImagePipeline
|
||||
from .kwai_kolors import KolorsImagePipeline
|
||||
|
||||
@@ -147,7 +147,7 @@ def lets_dance_xl(
|
||||
# 3. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = unet.conv_in(sample)
|
||||
text_emb = encoder_hidden_states
|
||||
text_emb = encoder_hidden_states if unet.text_intermediate_proj is None else unet.text_intermediate_proj(encoder_hidden_states)
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 4. blocks
|
||||
|
||||
@@ -216,7 +216,7 @@ class HunyuanDiTImagePipeline(torch.nn.Module):
|
||||
# Prepare latent tensors
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32)
|
||||
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
@@ -293,6 +293,6 @@ class HunyuanDiTImagePipeline(torch.nn.Module):
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
return image
|
||||
|
||||
168
diffsynth/pipelines/kwai_kolors.py
Normal file
168
diffsynth/pipelines/kwai_kolors.py
Normal file
@@ -0,0 +1,168 @@
|
||||
from ..models import ModelManager, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
from ..models.kolors_text_encoder import ChatGLMModel
|
||||
from ..prompts import KolorsPrompter
|
||||
from ..schedulers import EnhancedDDIMScheduler
|
||||
from .dancer import lets_dance_xl
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
class KolorsImagePipeline(torch.nn.Module):
|
||||
|
||||
def __init__(self, device="cuda", torch_dtype=torch.float16):
|
||||
super().__init__()
|
||||
self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
|
||||
self.prompter = KolorsPrompter()
|
||||
self.device = device
|
||||
self.torch_dtype = torch_dtype
|
||||
# models
|
||||
self.text_encoder: ChatGLMModel = None
|
||||
self.unet: SDXLUNet = None
|
||||
self.vae_decoder: SDXLVAEDecoder = None
|
||||
self.vae_encoder: SDXLVAEEncoder = None
|
||||
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
|
||||
self.ipadapter: SDXLIpAdapter = None
|
||||
|
||||
|
||||
def fetch_main_models(self, model_manager: ModelManager):
|
||||
self.text_encoder = model_manager.kolors_text_encoder
|
||||
self.unet = model_manager.kolors_unet
|
||||
self.vae_decoder = model_manager.vae_decoder
|
||||
self.vae_encoder = model_manager.vae_encoder
|
||||
|
||||
|
||||
def fetch_ipadapter(self, model_manager: ModelManager):
|
||||
if "ipadapter_xl" in model_manager.model:
|
||||
self.ipadapter = model_manager.ipadapter_xl
|
||||
if "ipadapter_xl_image_encoder" in model_manager.model:
|
||||
self.ipadapter_image_encoder = model_manager.ipadapter_xl_image_encoder
|
||||
|
||||
|
||||
def fetch_prompter(self, model_manager: ModelManager):
|
||||
self.prompter.load_from_model_manager(model_manager)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_model_manager(model_manager: ModelManager):
|
||||
pipe = KolorsImagePipeline(
|
||||
device=model_manager.device,
|
||||
torch_dtype=model_manager.torch_dtype,
|
||||
)
|
||||
pipe.fetch_main_models(model_manager)
|
||||
pipe.fetch_prompter(model_manager)
|
||||
pipe.fetch_ipadapter(model_manager)
|
||||
return pipe
|
||||
|
||||
|
||||
def preprocess_image(self, image):
|
||||
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
|
||||
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
|
||||
image = image.cpu().permute(1, 2, 0).numpy()
|
||||
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
|
||||
return image
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
negative_prompt="",
|
||||
cfg_scale=7.5,
|
||||
clip_skip=2,
|
||||
input_image=None,
|
||||
ipadapter_images=None,
|
||||
ipadapter_scale=1.0,
|
||||
ipadapter_use_instant_style=False,
|
||||
denoising_strength=1.0,
|
||||
height=1024,
|
||||
width=1024,
|
||||
num_inference_steps=20,
|
||||
tiled=False,
|
||||
tile_size=64,
|
||||
tile_stride=32,
|
||||
progress_bar_cmd=tqdm,
|
||||
progress_bar_st=None,
|
||||
):
|
||||
# Prepare scheduler
|
||||
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
|
||||
|
||||
# Prepare latent tensors
|
||||
if input_image is not None:
|
||||
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
|
||||
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
|
||||
else:
|
||||
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
|
||||
|
||||
# Encode prompts
|
||||
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
prompt,
|
||||
clip_skip=clip_skip,
|
||||
device=self.device,
|
||||
positive=True,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
|
||||
self.text_encoder,
|
||||
negative_prompt,
|
||||
clip_skip=clip_skip,
|
||||
device=self.device,
|
||||
positive=False,
|
||||
)
|
||||
|
||||
# Prepare positional id
|
||||
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
|
||||
|
||||
# IP-Adapter
|
||||
if ipadapter_images is not None:
|
||||
if ipadapter_use_instant_style:
|
||||
self.ipadapter.set_less_adapter()
|
||||
else:
|
||||
self.ipadapter.set_full_adapter()
|
||||
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
|
||||
ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
|
||||
ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
|
||||
else:
|
||||
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
|
||||
|
||||
# Denoise
|
||||
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
|
||||
timestep = torch.IntTensor((timestep,))[0].to(self.device)
|
||||
|
||||
# Classifier-free guidance
|
||||
noise_pred_posi = lets_dance_xl(
|
||||
self.unet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi,
|
||||
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
|
||||
)
|
||||
if cfg_scale != 1.0:
|
||||
noise_pred_nega = lets_dance_xl(
|
||||
self.unet,
|
||||
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega,
|
||||
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
|
||||
ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
|
||||
)
|
||||
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
|
||||
else:
|
||||
noise_pred = noise_pred_posi
|
||||
|
||||
latents = self.scheduler.step(noise_pred, timestep, latents)
|
||||
|
||||
if progress_bar_st is not None:
|
||||
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
|
||||
|
||||
# Decode image
|
||||
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
|
||||
|
||||
return image
|
||||
@@ -2,3 +2,4 @@ from .sd_prompter import SDPrompter
|
||||
from .sdxl_prompter import SDXLPrompter
|
||||
from .sd3_prompter import SD3Prompter
|
||||
from .hunyuan_dit_prompter import HunyuanDiTPrompter
|
||||
from .kolors_prompter import KolorsPrompter
|
||||
|
||||
347
diffsynth/prompts/kolors_prompter.py
Normal file
347
diffsynth/prompts/kolors_prompter.py
Normal file
@@ -0,0 +1,347 @@
|
||||
from .utils import Prompter
|
||||
import json, os, re
|
||||
from typing import List, Optional, Union, Dict
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers.utils import PaddingStrategy
|
||||
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
||||
from ..models.kolors_text_encoder import ChatGLMModel
|
||||
|
||||
|
||||
class SPTokenizer:
|
||||
def __init__(self, model_path: str):
|
||||
# reload tokenizer
|
||||
assert os.path.isfile(model_path), model_path
|
||||
self.sp_model = SentencePieceProcessor(model_file=model_path)
|
||||
|
||||
# BOS / EOS token IDs
|
||||
self.n_words: int = self.sp_model.vocab_size()
|
||||
self.bos_id: int = self.sp_model.bos_id()
|
||||
self.eos_id: int = self.sp_model.eos_id()
|
||||
self.pad_id: int = self.sp_model.unk_id()
|
||||
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
|
||||
|
||||
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
|
||||
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
|
||||
self.special_tokens = {}
|
||||
self.index_special_tokens = {}
|
||||
for token in special_tokens:
|
||||
self.special_tokens[token] = self.n_words
|
||||
self.index_special_tokens[self.n_words] = token
|
||||
self.n_words += 1
|
||||
self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
|
||||
|
||||
def tokenize(self, s: str, encode_special_tokens=False):
|
||||
if encode_special_tokens:
|
||||
last_index = 0
|
||||
t = []
|
||||
for match in re.finditer(self.role_special_token_expression, s):
|
||||
if last_index < match.start():
|
||||
t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
|
||||
t.append(s[match.start():match.end()])
|
||||
last_index = match.end()
|
||||
if last_index < len(s):
|
||||
t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
|
||||
return t
|
||||
else:
|
||||
return self.sp_model.EncodeAsPieces(s)
|
||||
|
||||
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
|
||||
assert type(s) is str
|
||||
t = self.sp_model.encode(s)
|
||||
if bos:
|
||||
t = [self.bos_id] + t
|
||||
if eos:
|
||||
t = t + [self.eos_id]
|
||||
return t
|
||||
|
||||
def decode(self, t: List[int]) -> str:
|
||||
text, buffer = "", []
|
||||
for token in t:
|
||||
if token in self.index_special_tokens:
|
||||
if buffer:
|
||||
text += self.sp_model.decode(buffer)
|
||||
buffer = []
|
||||
text += self.index_special_tokens[token]
|
||||
else:
|
||||
buffer.append(token)
|
||||
if buffer:
|
||||
text += self.sp_model.decode(buffer)
|
||||
return text
|
||||
|
||||
def decode_tokens(self, tokens: List[str]) -> str:
|
||||
text = self.sp_model.DecodePieces(tokens)
|
||||
return text
|
||||
|
||||
def convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
if token in self.special_tokens:
|
||||
return self.special_tokens[token]
|
||||
return self.sp_model.PieceToId(token)
|
||||
|
||||
def convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
if index in self.index_special_tokens:
|
||||
return self.index_special_tokens[index]
|
||||
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
|
||||
return ""
|
||||
return self.sp_model.IdToPiece(index)
|
||||
|
||||
|
||||
|
||||
class ChatGLMTokenizer(PreTrainedTokenizer):
|
||||
vocab_files_names = {"vocab_file": "tokenizer.model"}
|
||||
|
||||
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
||||
|
||||
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
|
||||
**kwargs):
|
||||
self.name = "GLMTokenizer"
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
self.tokenizer = SPTokenizer(vocab_file)
|
||||
self.special_tokens = {
|
||||
"<bos>": self.tokenizer.bos_id,
|
||||
"<eos>": self.tokenizer.eos_id,
|
||||
"<pad>": self.tokenizer.pad_id
|
||||
}
|
||||
self.encode_special_tokens = encode_special_tokens
|
||||
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
encode_special_tokens=encode_special_tokens,
|
||||
**kwargs)
|
||||
|
||||
def get_command(self, token):
|
||||
if token in self.special_tokens:
|
||||
return self.special_tokens[token]
|
||||
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
|
||||
return self.tokenizer.special_tokens[token]
|
||||
|
||||
@property
|
||||
def unk_token(self) -> str:
|
||||
return "<unk>"
|
||||
|
||||
@property
|
||||
def pad_token(self) -> str:
|
||||
return "<unk>"
|
||||
|
||||
@property
|
||||
def pad_token_id(self):
|
||||
return self.get_command("<pad>")
|
||||
|
||||
@property
|
||||
def eos_token(self) -> str:
|
||||
return "</s>"
|
||||
|
||||
@property
|
||||
def eos_token_id(self):
|
||||
return self.get_command("<eos>")
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.tokenizer.n_words
|
||||
|
||||
def get_vocab(self):
|
||||
""" Returns vocab as a dict """
|
||||
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def _tokenize(self, text, **kwargs):
|
||||
return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
""" Converts a token (str) in an id using the vocab. """
|
||||
return self.tokenizer.convert_token_to_id(token)
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
return self.tokenizer.convert_id_to_token(index)
|
||||
|
||||
def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
||||
return self.tokenizer.decode_tokens(tokens)
|
||||
|
||||
def save_vocabulary(self, save_directory, filename_prefix=None):
|
||||
"""
|
||||
Save the vocabulary and special tokens file to a directory.
|
||||
|
||||
Args:
|
||||
save_directory (`str`):
|
||||
The directory in which to save the vocabulary.
|
||||
filename_prefix (`str`, *optional*):
|
||||
An optional prefix to add to the named of the saved files.
|
||||
|
||||
Returns:
|
||||
`Tuple(str)`: Paths to the files saved.
|
||||
"""
|
||||
if os.path.isdir(save_directory):
|
||||
vocab_file = os.path.join(
|
||||
save_directory, self.vocab_files_names["vocab_file"]
|
||||
)
|
||||
else:
|
||||
vocab_file = save_directory
|
||||
|
||||
with open(self.vocab_file, 'rb') as fin:
|
||||
proto_str = fin.read()
|
||||
|
||||
with open(vocab_file, "wb") as writer:
|
||||
writer.write(proto_str)
|
||||
|
||||
return (vocab_file,)
|
||||
|
||||
def get_prefix_tokens(self):
|
||||
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
|
||||
return prefix_tokens
|
||||
|
||||
def build_single_message(self, role, metadata, message):
|
||||
assert role in ["system", "user", "assistant", "observation"], role
|
||||
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
|
||||
message_tokens = self.tokenizer.encode(message)
|
||||
tokens = role_tokens + message_tokens
|
||||
return tokens
|
||||
|
||||
def build_chat_input(self, query, history=None, role="user"):
|
||||
if history is None:
|
||||
history = []
|
||||
input_ids = []
|
||||
for item in history:
|
||||
content = item["content"]
|
||||
if item["role"] == "system" and "tools" in item:
|
||||
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
|
||||
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
|
||||
input_ids.extend(self.build_single_message(role, "", query))
|
||||
input_ids.extend([self.get_command("<|assistant|>")])
|
||||
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
||||
adding special tokens. A BERT sequence has the following format:
|
||||
|
||||
- single sequence: `[CLS] X [SEP]`
|
||||
- pair of sequences: `[CLS] A [SEP] B [SEP]`
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs to which the special tokens will be added.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
|
||||
"""
|
||||
prefix_tokens = self.get_prefix_tokens()
|
||||
token_ids_0 = prefix_tokens + token_ids_0
|
||||
if token_ids_1 is not None:
|
||||
token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
|
||||
return token_ids_0
|
||||
|
||||
def _pad(
|
||||
self,
|
||||
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
||||
max_length: Optional[int] = None,
|
||||
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
||||
|
||||
Args:
|
||||
encoded_inputs:
|
||||
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
|
||||
max_length: maximum length of the returned list and optionally padding length (see below).
|
||||
Will truncate by taking into account the special tokens.
|
||||
padding_strategy: PaddingStrategy to use for padding.
|
||||
|
||||
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
|
||||
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
|
||||
- PaddingStrategy.DO_NOT_PAD: Do not pad
|
||||
The tokenizer padding sides are defined in self.padding_side:
|
||||
|
||||
- 'left': pads on the left of the sequences
|
||||
- 'right': pads on the right of the sequences
|
||||
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
||||
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
||||
`>= 7.5` (Volta).
|
||||
return_attention_mask:
|
||||
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
||||
"""
|
||||
# Load from model defaults
|
||||
assert self.padding_side == "left"
|
||||
|
||||
required_input = encoded_inputs[self.model_input_names[0]]
|
||||
seq_length = len(required_input)
|
||||
|
||||
if padding_strategy == PaddingStrategy.LONGEST:
|
||||
max_length = len(required_input)
|
||||
|
||||
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
||||
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
||||
|
||||
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
||||
|
||||
# Initialize attention mask if not present.
|
||||
if "attention_mask" not in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = [1] * seq_length
|
||||
|
||||
if "position_ids" not in encoded_inputs:
|
||||
encoded_inputs["position_ids"] = list(range(seq_length))
|
||||
|
||||
if needs_to_be_padded:
|
||||
difference = max_length - len(required_input)
|
||||
|
||||
if "attention_mask" in encoded_inputs:
|
||||
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
|
||||
if "position_ids" in encoded_inputs:
|
||||
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
|
||||
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
|
||||
|
||||
return encoded_inputs
|
||||
|
||||
|
||||
|
||||
class KolorsPrompter(Prompter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_path=None
|
||||
):
|
||||
if tokenizer_path is None:
|
||||
base_path = os.path.dirname(os.path.dirname(__file__))
|
||||
tokenizer_path = os.path.join(base_path, "tokenizer_configs/kolors/tokenizer")
|
||||
super().__init__()
|
||||
self.tokenizer = ChatGLMTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
|
||||
def encode_prompt_using_ChatGLM(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
).to(device)
|
||||
output = text_encoder(
|
||||
input_ids=text_inputs['input_ids'] ,
|
||||
attention_mask=text_inputs['attention_mask'],
|
||||
position_ids=text_inputs['position_ids'],
|
||||
output_hidden_states=True
|
||||
)
|
||||
prompt_emb = output.hidden_states[-clip_skip].permute(1, 0, 2).clone()
|
||||
pooled_prompt_emb = output.hidden_states[-1][-1, :, :].clone()
|
||||
return prompt_emb, pooled_prompt_emb
|
||||
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
text_encoder: ChatGLMModel,
|
||||
prompt,
|
||||
clip_skip=2,
|
||||
positive=True,
|
||||
device="cuda"
|
||||
):
|
||||
prompt = self.process_prompt(prompt, positive=positive)
|
||||
prompt_emb, pooled_prompt_emb = self.encode_prompt_using_ChatGLM(prompt, text_encoder, self.tokenizer, 256, clip_skip, device)
|
||||
|
||||
return pooled_prompt_emb, prompt_emb
|
||||
BIN
diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model
Normal file
BIN
diffsynth/tokenizer_configs/kolors/tokenizer/tokenizer.model
Normal file
Binary file not shown.
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"name_or_path": "THUDM/chatglm3-6b-base",
|
||||
"remove_space": false,
|
||||
"do_lower_case": false,
|
||||
"tokenizer_class": "ChatGLMTokenizer",
|
||||
"auto_map": {
|
||||
"AutoTokenizer": [
|
||||
"tokenization_chatglm.ChatGLMTokenizer",
|
||||
null
|
||||
]
|
||||
}
|
||||
}
|
||||
BIN
diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt
Normal file
BIN
diffsynth/tokenizer_configs/kolors/tokenizer/vocab.txt
Normal file
Binary file not shown.
@@ -28,6 +28,16 @@ LoRA Training: [`../train/stable_diffusion_3/`](../train/stable_diffusion_3/)
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
### Example: Kolors
|
||||
|
||||
Example script: [`kolors_text_to_image.py`](./kolors_text_to_image.py)
|
||||
|
||||
LoRA Training: [`../train/kolors/`](../train/kolors/)
|
||||
|
||||
|1024*1024|2048*2048|
|
||||
|-|-|
|
||||
|||
|
||||
|
||||
### Example: Hunyuan-DiT
|
||||
|
||||
Example script: [`hunyuan_dit_text_to_image.py`](./hunyuan_dit_text_to_image.py)
|
||||
|
||||
34
examples/image_synthesis/kolors_text_to_image.py
Normal file
34
examples/image_synthesis/kolors_text_to_image.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from diffsynth import ModelManager, KolorsImagePipeline, download_models
|
||||
import torch
|
||||
|
||||
# Download models
|
||||
# https://huggingface.co/Kwai-Kolors/Kolors
|
||||
download_models(["Kolors"])
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
|
||||
file_path_list=[
|
||||
"models/kolors/Kolors/text_encoder",
|
||||
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
|
||||
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors"
|
||||
])
|
||||
pipe = KolorsImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
prompt = "一幅充满诗意美感的全身画,泛红的肤色,画中一位银色长发、蓝色眼睛、肤色红润、身穿蓝色吊带连衣裙的少女漂浮在水下,面向镜头,周围是光彩的气泡,和煦的阳光透过水面折射进水下"
|
||||
negative_prompt = "半身,苍白的肤色,蜡黄的肤色,尸体,错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,错误的手指,口红,腮红"
|
||||
|
||||
torch.manual_seed(7)
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_inference_steps=50,
|
||||
cfg_scale=4,
|
||||
)
|
||||
image.save(f"image_1024.jpg")
|
||||
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
input_image=image.resize((2048, 2048)), denoising_strength=0.4, height=2048, width=2048,
|
||||
num_inference_steps=50,
|
||||
cfg_scale=4,
|
||||
)
|
||||
image.save("image_2048.jpg")
|
||||
175
examples/train/kolors/README.md
Normal file
175
examples/train/kolors/README.md
Normal file
@@ -0,0 +1,175 @@
|
||||
# Kolors
|
||||
|
||||
Kolors is a Chinese diffusion model, which is based on ChatGLM and Stable Diffusion XL. We provide training scripts here.
|
||||
|
||||
## Download models
|
||||
|
||||
The following files will be used for constructing Kolors. You can download them from [huggingface](https://huggingface.co/Kwai-Kolors/Kolors) or [modelscope](https://modelscope.cn/models/Kwai-Kolors/Kolors).
|
||||
|
||||
```
|
||||
models/kolors/Kolors
|
||||
├── text_encoder
|
||||
│ ├── config.json
|
||||
│ ├── pytorch_model-00001-of-00007.bin
|
||||
│ ├── pytorch_model-00002-of-00007.bin
|
||||
│ ├── pytorch_model-00003-of-00007.bin
|
||||
│ ├── pytorch_model-00004-of-00007.bin
|
||||
│ ├── pytorch_model-00005-of-00007.bin
|
||||
│ ├── pytorch_model-00006-of-00007.bin
|
||||
│ ├── pytorch_model-00007-of-00007.bin
|
||||
│ └── pytorch_model.bin.index.json
|
||||
├── unet
|
||||
│ └── diffusion_pytorch_model.safetensors
|
||||
└── vae
|
||||
└── diffusion_pytorch_model.safetensors
|
||||
```
|
||||
|
||||
You can use the following code to download these files:
|
||||
|
||||
```python
|
||||
from diffsynth import download_models
|
||||
|
||||
download_models(["Kolors"])
|
||||
```
|
||||
|
||||
## Train
|
||||
|
||||
### Install training dependency
|
||||
|
||||
```
|
||||
pip install peft lightning pandas torchvision
|
||||
```
|
||||
|
||||
### Prepare your dataset
|
||||
|
||||
We provide an example dataset [here](https://modelscope.cn/datasets/buptwq/lora-stable-diffusion-finetune/files). You need to manage the training images as follows:
|
||||
|
||||
```
|
||||
data/dog/
|
||||
└── train
|
||||
├── 00.jpg
|
||||
├── 01.jpg
|
||||
├── 02.jpg
|
||||
├── 03.jpg
|
||||
├── 04.jpg
|
||||
└── metadata.csv
|
||||
```
|
||||
|
||||
`metadata.csv`:
|
||||
|
||||
```
|
||||
file_name,text
|
||||
00.jpg,一只小狗
|
||||
01.jpg,一只小狗
|
||||
02.jpg,一只小狗
|
||||
03.jpg,一只小狗
|
||||
04.jpg,一只小狗
|
||||
```
|
||||
|
||||
### Train a LoRA model
|
||||
|
||||
We provide a training script `train_kolors_lora.py`. Before you run this training script, please copy it to the root directory of this project.
|
||||
|
||||
The following settings are recommended. **We found the UNet model suffers from precision overflow issues, thus the training script doesn't support float16. 40GB VRAM is required. We are working on overcoming this pitfall.**
|
||||
|
||||
```
|
||||
CUDA_VISIBLE_DEVICES="0" python examples/train/kolors/train_kolors_lora.py \
|
||||
--pretrained_path models/kolors/Kolors \
|
||||
--dataset_path data/dog \
|
||||
--output_path ./models \
|
||||
--max_epochs 10 \
|
||||
--center_crop \
|
||||
--use_gradient_checkpointing \
|
||||
--precision 32
|
||||
```
|
||||
|
||||
Optional arguments:
|
||||
```
|
||||
-h, --help show this help message and exit
|
||||
--pretrained_path PRETRAINED_PATH
|
||||
Path to pretrained model. For example, `models/kolors/Kolors`.
|
||||
--dataset_path DATASET_PATH
|
||||
The path of the Dataset.
|
||||
--output_path OUTPUT_PATH
|
||||
Path to save the model.
|
||||
--steps_per_epoch STEPS_PER_EPOCH
|
||||
Number of steps per epoch.
|
||||
--height HEIGHT Image height.
|
||||
--width WIDTH Image width.
|
||||
--center_crop Whether to center crop the input images to the resolution. If not set, the images will be randomly cropped. The images will be resized to the resolution first before cropping.
|
||||
--random_flip Whether to randomly flip images horizontally
|
||||
--batch_size BATCH_SIZE
|
||||
Batch size (per device) for the training dataloader.
|
||||
--dataloader_num_workers DATALOADER_NUM_WORKERS
|
||||
Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
|
||||
--precision {32,16,16-mixed}
|
||||
Training precision
|
||||
--learning_rate LEARNING_RATE
|
||||
Learning rate.
|
||||
--lora_rank LORA_RANK
|
||||
The dimension of the LoRA update matrices.
|
||||
--lora_alpha LORA_ALPHA
|
||||
The weight of the LoRA update matrices.
|
||||
--use_gradient_checkpointing
|
||||
Whether to use gradient checkpointing.
|
||||
--accumulate_grad_batches ACCUMULATE_GRAD_BATCHES
|
||||
The number of batches in gradient accumulation.
|
||||
--training_strategy {auto,deepspeed_stage_1,deepspeed_stage_2,deepspeed_stage_3}
|
||||
Training strategy
|
||||
--max_epochs MAX_EPOCHS
|
||||
Number of epochs.
|
||||
```
|
||||
|
||||
### Inference with your own LoRA model
|
||||
|
||||
After training, you can use your own LoRA model to generate new images. Here are some examples.
|
||||
|
||||
```python
|
||||
from diffsynth import ModelManager, KolorsImagePipeline
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
import torch
|
||||
|
||||
|
||||
def load_lora(model, lora_rank, lora_alpha, lora_path):
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out"],
|
||||
)
|
||||
model = inject_adapter_in_model(lora_config, model)
|
||||
state_dict = torch.load(lora_path, map_location="cpu")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
return model
|
||||
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
|
||||
file_path_list=[
|
||||
"models/kolors/Kolors/text_encoder",
|
||||
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
|
||||
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors"
|
||||
])
|
||||
pipe = KolorsImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
# Generate an image with lora
|
||||
pipe.unet = load_lora(
|
||||
pipe.unet,
|
||||
lora_rank=4, lora_alpha=4.0, # The two parameters should be consistent with those in your training script.
|
||||
lora_path="path/to/your/lora/model/lightning_logs/version_x/checkpoints/epoch=x-step=xxx.ckpt"
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
image = pipe(
|
||||
prompt="一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉",
|
||||
negative_prompt="",
|
||||
cfg_scale=4,
|
||||
num_inference_steps=50, height=1024, width=1024,
|
||||
)
|
||||
image.save("image_with_lora.jpg")
|
||||
```
|
||||
|
||||
Prompt: 一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉
|
||||
|
||||
|Without LoRA|With LoRA|
|
||||
|-|-|
|
||||
|||
|
||||
293
examples/train/kolors/train_kolors_lora.py
Normal file
293
examples/train/kolors/train_kolors_lora.py
Normal file
@@ -0,0 +1,293 @@
|
||||
from diffsynth import ModelManager, KolorsImagePipeline
|
||||
from peft import LoraConfig, inject_adapter_in_model
|
||||
from torchvision import transforms
|
||||
from PIL import Image
|
||||
import lightning as pl
|
||||
import pandas as pd
|
||||
import torch, os, argparse
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "True"
|
||||
|
||||
|
||||
|
||||
class TextImageDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
|
||||
self.steps_per_epoch = steps_per_epoch
|
||||
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
|
||||
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
|
||||
self.text = metadata["text"].to_list()
|
||||
self.image_processor = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
|
||||
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
data_id = torch.randint(0, len(self.path), (1,))[0]
|
||||
data_id = (data_id + index) % len(self.path) # For fixed seed.
|
||||
text = self.text[data_id]
|
||||
image = Image.open(self.path[data_id]).convert("RGB")
|
||||
image = self.image_processor(image)
|
||||
return {"text": text, "image": image}
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return self.steps_per_epoch
|
||||
|
||||
|
||||
|
||||
class LightningModel(pl.LightningModule):
|
||||
def __init__(self, torch_dtype=torch.float16, learning_rate=1e-4, pretrained_weights=[], lora_rank=4, lora_alpha=4, use_gradient_checkpointing=True):
|
||||
super().__init__()
|
||||
|
||||
# Load models
|
||||
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
|
||||
model_manager.load_models(pretrained_weights)
|
||||
self.pipe = KolorsImagePipeline.from_model_manager(model_manager)
|
||||
|
||||
# Freeze parameters
|
||||
self.pipe.text_encoder.requires_grad_(False)
|
||||
self.pipe.unet.requires_grad_(False)
|
||||
self.pipe.vae_decoder.requires_grad_(False)
|
||||
self.pipe.vae_encoder.requires_grad_(False)
|
||||
self.pipe.text_encoder.eval()
|
||||
self.pipe.unet.train()
|
||||
self.pipe.vae_decoder.eval()
|
||||
self.pipe.vae_encoder.eval()
|
||||
|
||||
# Add LoRA to UNet
|
||||
lora_config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
init_lora_weights="gaussian",
|
||||
target_modules=["to_q", "to_k", "to_v", "to_out"],
|
||||
)
|
||||
self.pipe.unet = inject_adapter_in_model(lora_config, self.pipe.unet)
|
||||
for param in self.pipe.unet.parameters():
|
||||
# Upcast LoRA parameters into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(torch.float32)
|
||||
|
||||
# Set other parameters
|
||||
self.learning_rate = learning_rate
|
||||
self.use_gradient_checkpointing = use_gradient_checkpointing
|
||||
self.pipe.scheduler.set_timesteps(1100)
|
||||
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# Data
|
||||
text, image = batch["text"], batch["image"]
|
||||
|
||||
# Prepare input parameters
|
||||
self.pipe.device = self.device
|
||||
add_prompt_emb, prompt_emb = self.pipe.prompter.encode_prompt(
|
||||
self.pipe.text_encoder, text, clip_skip=2, device=self.device, positive=True,
|
||||
)
|
||||
height, width = image.shape[-2:]
|
||||
latents = self.pipe.vae_encoder(image.to(dtype=torch.float32, device=self.device)).to(self.pipe.torch_dtype)
|
||||
noise = torch.randn_like(latents)
|
||||
timestep = torch.randint(0, 1100, (1,), device=self.device)[0]
|
||||
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
|
||||
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
|
||||
|
||||
# Compute loss
|
||||
noise_pred = self.pipe.unet(
|
||||
noisy_latents, timestep, prompt_emb, add_time_id, add_prompt_emb,
|
||||
use_gradient_checkpointing=self.use_gradient_checkpointing
|
||||
)
|
||||
loss = torch.nn.functional.mse_loss(noise_pred, noise)
|
||||
|
||||
# Record log
|
||||
self.log("train_loss", loss, prog_bar=True)
|
||||
return loss
|
||||
|
||||
|
||||
def configure_optimizers(self):
|
||||
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.unet.parameters())
|
||||
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
|
||||
return optimizer
|
||||
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
checkpoint.clear()
|
||||
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.unet.named_parameters()))
|
||||
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
|
||||
state_dict = self.pipe.unet.state_dict()
|
||||
for name, param in state_dict.items():
|
||||
if name in trainable_param_names:
|
||||
checkpoint[name] = param
|
||||
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model. For example, `models/kolors/Kolors`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="The path of the Dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_path",
|
||||
type=str,
|
||||
default="./",
|
||||
help="Path to save the model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--steps_per_epoch",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of steps per epoch.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Image height.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="Image width.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
||||
" cropped. The images will be resized to the resolution first before cropping."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random_flip",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to randomly flip images horizontally",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Batch size (per device) for the training dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision",
|
||||
type=str,
|
||||
default="16-mixed",
|
||||
choices=["32", "16", "16-mixed"],
|
||||
help="Training precision",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="Learning rate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_rank",
|
||||
type=int,
|
||||
default=4,
|
||||
help="The dimension of the LoRA update matrices.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_alpha",
|
||||
type=float,
|
||||
default=4.0,
|
||||
help="The weight of the LoRA update matrices.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_gradient_checkpointing",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Whether to use gradient checkpointing.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--accumulate_grad_batches",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of batches in gradient accumulation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--training_strategy",
|
||||
type=str,
|
||||
default="auto",
|
||||
choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"],
|
||||
help="Training strategy",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_epochs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of epochs.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# args
|
||||
args = parse_args()
|
||||
|
||||
# dataset and data loader
|
||||
dataset = TextImageDataset(
|
||||
args.dataset_path,
|
||||
steps_per_epoch=args.steps_per_epoch * args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
center_crop=args.center_crop,
|
||||
random_flip=args.random_flip
|
||||
)
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.dataloader_num_workers
|
||||
)
|
||||
|
||||
# model
|
||||
model = LightningModel(
|
||||
pretrained_weights=[
|
||||
os.path.join(args.pretrained_path, "text_encoder"),
|
||||
os.path.join(args.pretrained_path, "unet/diffusion_pytorch_model.safetensors"),
|
||||
os.path.join(args.pretrained_path, "vae/diffusion_pytorch_model.safetensors"),
|
||||
],
|
||||
torch_dtype=torch.float32 if args.precision == "32" else torch.float16,
|
||||
learning_rate=args.learning_rate,
|
||||
lora_rank=args.lora_rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
use_gradient_checkpointing=args.use_gradient_checkpointing
|
||||
)
|
||||
|
||||
# train
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=args.max_epochs,
|
||||
accelerator="gpu",
|
||||
devices="auto",
|
||||
precision=args.precision,
|
||||
strategy=args.training_strategy,
|
||||
default_root_dir=args.output_path,
|
||||
accumulate_grad_batches=args.accumulate_grad_batches,
|
||||
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
|
||||
)
|
||||
trainer.fit(model=model, train_dataloaders=train_loader)
|
||||
@@ -153,7 +153,7 @@ image = pipe(
|
||||
image.save("image_with_lora.jpg")
|
||||
```
|
||||
|
||||
Prompt:
|
||||
Prompt: a dog is jumping, flowers around the dog, the background is mountains and clouds
|
||||
|
||||
|Without LoRA|With LoRA|
|
||||
|-|-|
|
||||
|
||||
@@ -10,3 +10,4 @@ imageio[ffmpeg]
|
||||
safetensors
|
||||
einops
|
||||
sentencepiece
|
||||
modelscope
|
||||
|
||||
Reference in New Issue
Block a user