support kolors! (#106)

This commit is contained in:
Zhongjie Duan
2024-07-11 21:43:45 +08:00
committed by GitHub
parent 2a4709e572
commit 9c6607f78d
20 changed files with 2510 additions and 281 deletions

View File

@@ -8,6 +8,7 @@ DiffSynth Studio is a Diffusion engine. We have restructured architectures inclu
Until now, DiffSynth Studio has supported the following models:
* [ExVideo](https://huggingface.co/ECNU-CILab/ExVideo-SVD-128f-v1)
* [Kolors](https://huggingface.co/Kwai-Kolors/Kolors)
* [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium)
* [Stable Video Diffusion](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt)
* [Hunyuan-DiT](https://github.com/Tencent/HunyuanDiT)
@@ -85,11 +86,13 @@ Generate high-resolution images, by breaking the limitation of diffusion models!
LoRA fine-tuning is supported in [`examples/train`](./examples/train/).
|Stable Diffusion|Stable Diffusion XL|
|Model|Example|
|-|-|
|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|![2048](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/584186bc-9855-4140-878e-99541f9a757f)|
|Stable Diffusion 3|Hunyuan-DiT|
|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/60b022c8-df3f-4541-95ab-bf39f2fa8bb5)|
|Stable Diffusion|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/6fc84611-8da6-4a1f-8fee-9a34eba3b4a5)|
|Stable Diffusion XL|![1024](https://github.com/Artiprocher/DiffSynth-Studio/assets/35051019/67687748-e738-438c-aee5-96096f09ac90)|
|Stable Diffusion 3|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)|
|Kolors|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/53ef6f41-da11-4701-8665-9f64392607bf)|
|Hunyuan-DiT|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/60b022c8-df3f-4541-95ab-bf39f2fa8bb5)|
### Toon Shading

View File

@@ -1,4 +1,4 @@
import torch, os
import torch, os, json
from safetensors import safe_open
from typing_extensions import Literal, TypeAlias
from typing import List
@@ -36,6 +36,7 @@ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from .hunyuan_dit import HunyuanDiT
from .kolors_text_encoder import ChatGLMModel
preset_models_on_huggingface = {
@@ -159,6 +160,20 @@ preset_models_on_modelscope = {
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
],
# Kolors
"Kolors": [
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
],
}
Preset_model_id: TypeAlias = Literal[
"HunyuanDiT",
@@ -184,7 +199,8 @@ Preset_model_id: TypeAlias = Literal[
"IP-Adapter-SD",
"IP-Adapter-SDXL",
"StableDiffusion3",
"StableDiffusion3_without_T5"
"StableDiffusion3_without_T5",
"Kolors",
]
Preset_model_website: TypeAlias = Literal[
"HuggingFace",
@@ -272,8 +288,7 @@ class ModelManager:
def is_controlnet(self, state_dict):
param_name = "control_model.time_embed.0.weight"
param_name_2 = "mid_block.resnets.1.time_emb_proj.weight" # For controlnets in diffusers format
return param_name in state_dict or param_name_2 in state_dict
return param_name in state_dict
def is_animatediff(self, state_dict):
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
@@ -343,6 +358,21 @@ class ModelManager:
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
return param_name in state_dict
def is_kolors_text_encoder(self, file_path):
file_list = os.listdir(file_path)
if "config.json" in file_list:
try:
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
if config.get("model_type") == "chatglm":
return True
except:
pass
return False
def is_kolors_unet(self, state_dict):
return "up_blocks.2.resnets.2.time_emb_proj.weight" in state_dict and "encoder_hid_proj.weight" in state_dict
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
component_dict = {
"image_encoder": SVDImageEncoder,
@@ -532,13 +562,13 @@ class ModelManager:
component = "vae_encoder"
model = SDXLVAEEncoder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
model.to(torch.float32).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
component = "vae_decoder"
model = SDXLVAEDecoder()
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
model.to(torch.float32).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
@@ -592,6 +622,21 @@ class ModelManager:
self.model[component] = model
self.model_path[component] = file_path
def load_kolors_text_encoder(self, state_dict=None, file_path=""):
component = "kolors_text_encoder"
model = ChatGLMModel.from_pretrained(file_path, torch_dtype=self.torch_dtype)
model = model.to(dtype=self.torch_dtype, device=self.device)
self.model[component] = model
self.model_path[component] = file_path
def load_kolors_unet(self, state_dict, file_path=""):
component = "kolors_unet"
model = SDXLUNet(is_kolors=True)
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
model.to(self.torch_dtype).to(self.device)
self.model[component] = model
self.model_path[component] = file_path
def search_for_embeddings(self, state_dict):
embeddings = []
for k in state_dict:
@@ -607,7 +652,11 @@ class ModelManager:
# Load every textual inversion file
for file_name in os.listdir(folder):
if file_name.endswith(".txt"):
if os.path.isdir(os.path.join(folder, file_name)) or \
not (file_name.endswith(".bin") or \
file_name.endswith(".safetensors") or \
file_name.endswith(".pth") or \
file_name.endswith(".pt")):
continue
keyword = os.path.splitext(file_name)[0]
state_dict = load_state_dict(os.path.join(folder, file_name))
@@ -620,6 +669,10 @@ class ModelManager:
break
def load_model(self, file_path, components=None, lora_alphas=[]):
if os.path.isdir(file_path):
if self.is_kolors_text_encoder(file_path):
self.load_kolors_text_encoder(file_path=file_path)
return
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
if self.is_stable_video_diffusion(state_dict):
self.load_stable_video_diffusion(state_dict, file_path=file_path)
@@ -663,6 +716,8 @@ class ModelManager:
self.load_stable_diffusion_3(state_dict, components=components, file_path=file_path)
elif self.is_stable_diffusion_3_t5(state_dict):
self.load_stable_diffusion_3_t5(state_dict, file_path=file_path)
elif self.is_kolors_unet(state_dict):
self.load_kolors_unet(state_dict, file_path=file_path)
def load_models(self, file_path_list, lora_alphas=[]):
for file_path in file_path_list:

View File

@@ -1,251 +1,6 @@
from huggingface_hub import hf_hub_download
from http.cookiejar import CookieJar
from pathlib import Path
from typing import Dict, Optional, List, Union
import copy, uuid, requests, io, platform, pickle, os, urllib
from requests.adapters import Retry
from tqdm import tqdm
def _get_sep(path):
if isinstance(path, bytes):
return b'/'
else:
return '/'
def expanduser(path):
"""Expand ~ and ~user constructions. If user or $HOME is unknown,
do nothing."""
path = os.fspath(path)
if isinstance(path, bytes):
tilde = b'~'
else:
tilde = '~'
if not path.startswith(tilde):
return path
sep = _get_sep(path)
i = path.find(sep, 1)
if i < 0:
i = len(path)
if i == 1:
if 'HOME' not in os.environ:
import pwd
try:
userhome = pwd.getpwuid(os.getuid()).pw_dir
except KeyError:
# bpo-10496: if the current user identifier doesn't exist in the
# password database, return the path unchanged
return path
else:
userhome = os.environ['HOME']
else:
import pwd
name = path[1:i]
if isinstance(name, bytes):
name = str(name, 'ASCII')
try:
pwent = pwd.getpwnam(name)
except KeyError:
# bpo-10496: if the user name from the path doesn't exist in the
# password database, return the path unchanged
return path
userhome = pwent.pw_dir
if isinstance(path, bytes):
userhome = os.fsencode(userhome)
root = b'/'
else:
root = '/'
userhome = userhome.rstrip(root)
return (userhome + path[i:]) or root
class ModelScopeConfig:
DEFAULT_CREDENTIALS_PATH = Path.home().joinpath('.modelscope', 'credentials')
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
COOKIES_FILE_NAME = 'cookies'
GIT_TOKEN_FILE_NAME = 'git_token'
USER_INFO_FILE_NAME = 'user'
USER_SESSION_ID_FILE_NAME = 'session'
@staticmethod
def make_sure_credential_path_exist():
os.makedirs(ModelScopeConfig.path_credential, exist_ok=True)
@staticmethod
def get_user_session_id():
session_path = os.path.join(ModelScopeConfig.path_credential,
ModelScopeConfig.USER_SESSION_ID_FILE_NAME)
session_id = ''
if os.path.exists(session_path):
with open(session_path, 'rb') as f:
session_id = str(f.readline().strip(), encoding='utf-8')
return session_id
if session_id == '' or len(session_id) != 32:
session_id = str(uuid.uuid4().hex)
ModelScopeConfig.make_sure_credential_path_exist()
with open(session_path, 'w+') as wf:
wf.write(session_id)
return session_id
@staticmethod
def get_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str:
"""Formats a user-agent string with basic info about a request.
Args:
user_agent (`str`, `dict`, *optional*):
The user agent info in the form of a dictionary or a single string.
Returns:
The formatted user-agent string.
"""
# include some more telemetrics when executing in dedicated
# cloud containers
MODELSCOPE_CLOUD_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT'
MODELSCOPE_CLOUD_USERNAME = 'MODELSCOPE_USERNAME'
env = 'custom'
if MODELSCOPE_CLOUD_ENVIRONMENT in os.environ:
env = os.environ[MODELSCOPE_CLOUD_ENVIRONMENT]
user_name = 'unknown'
if MODELSCOPE_CLOUD_USERNAME in os.environ:
user_name = os.environ[MODELSCOPE_CLOUD_USERNAME]
ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s; user/%s' % (
"1.15.0",
platform.python_version(),
ModelScopeConfig.get_user_session_id(),
platform.platform(),
platform.processor(),
env,
user_name,
)
if isinstance(user_agent, dict):
ua += '; ' + '; '.join(f'{k}/{v}' for k, v in user_agent.items())
elif isinstance(user_agent, str):
ua += '; ' + user_agent
return ua
@staticmethod
def get_cookies():
cookies_path = os.path.join(ModelScopeConfig.path_credential,
ModelScopeConfig.COOKIES_FILE_NAME)
if os.path.exists(cookies_path):
with open(cookies_path, 'rb') as f:
cookies = pickle.load(f)
return cookies
return None
def modelscope_http_get_model_file(
url: str,
local_dir: str,
file_name: str,
file_size: int,
cookies: CookieJar,
headers: Optional[Dict[str, str]] = None,
):
"""Download remote file, will retry 5 times before giving up on errors.
Args:
url(str):
actual download url of the file
local_dir(str):
local directory where the downloaded file stores
file_name(str):
name of the file stored in `local_dir`
file_size(int):
The file size.
cookies(CookieJar):
cookies used to authentication the user, which is used for downloading private repos
headers(Dict[str, str], optional):
http headers to carry necessary info when requesting the remote file
Raises:
FileDownloadError: File download failed.
"""
get_headers = {} if headers is None else copy.deepcopy(headers)
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
temp_file_path = os.path.join(local_dir, file_name)
# retry sleep 0.5s, 1s, 2s, 4s
retry = Retry(
total=5,
backoff_factor=1,
allowed_methods=['GET'])
while True:
try:
progress = tqdm(
unit='B',
unit_scale=True,
unit_divisor=1024,
total=file_size,
initial=0,
desc='Downloading',
)
partial_length = 0
if os.path.exists(
temp_file_path): # download partial, continue download
with open(temp_file_path, 'rb') as f:
partial_length = f.seek(0, io.SEEK_END)
progress.update(partial_length)
if partial_length > file_size:
break
get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
file_size - 1)
with open(temp_file_path, 'ab') as f:
r = requests.get(
url,
stream=True,
headers=get_headers,
cookies=cookies,
timeout=60)
r.raise_for_status()
for chunk in r.iter_content(
chunk_size=1024 * 1024 * 1):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
f.write(chunk)
progress.close()
break
except (Exception) as e: # no matter what happen, we will retry.
retry = retry.increment('GET', url, error=e)
retry.sleep()
def get_endpoint():
MODELSCOPE_URL_SCHEME = 'https://'
DEFAULT_MODELSCOPE_DOMAIN = 'www.modelscope.cn'
modelscope_domain = os.getenv('MODELSCOPE_DOMAIN',
DEFAULT_MODELSCOPE_DOMAIN)
return MODELSCOPE_URL_SCHEME + modelscope_domain
def get_file_download_url(model_id: str, file_path: str, revision: str):
"""Format file download url according to `model_id`, `revision` and `file_path`.
e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`,
the resulted download url is: https://modelscope.cn/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md
Args:
model_id (str): The model_id.
file_path (str): File path
revision (str): File revision.
Returns:
str: The file url.
"""
file_path = urllib.parse.quote_plus(file_path)
revision = urllib.parse.quote_plus(revision)
download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}'
return download_url_template.format(
endpoint=get_endpoint(),
model_id=model_id,
revision=revision,
file_path=file_path,
)
from modelscope import snapshot_download
import os, shutil
def download_from_modelscope(model_id, origin_file_path, local_dir):
@@ -255,17 +10,12 @@ def download_from_modelscope(model_id, origin_file_path, local_dir):
return
else:
print(f"Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
headers = {'user-agent': ModelScopeConfig.get_user_agent(user_agent=None)}
cookies = ModelScopeConfig.get_cookies()
url = get_file_download_url(model_id=model_id, file_path=origin_file_path, revision="master")
modelscope_http_get_model_file(
url,
local_dir,
os.path.basename(origin_file_path),
file_size=0,
headers=headers,
cookies=cookies
)
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
downloaded_file_path = os.path.join(local_dir, origin_file_path)
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
if downloaded_file_path != target_file_path:
shutil.move(downloaded_file_path, target_file_path)
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
def download_from_huggingface(model_id, origin_file_path, local_dir):

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,7 @@ from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock
class SDXLUNet(torch.nn.Module):
def __init__(self):
def __init__(self, is_kolors=False):
super().__init__()
self.time_proj = Timesteps(320)
self.time_embedding = torch.nn.Sequential(
@@ -13,11 +13,12 @@ class SDXLUNet(torch.nn.Module):
)
self.add_time_proj = Timesteps(256)
self.add_time_embedding = torch.nn.Sequential(
torch.nn.Linear(2816, 1280),
torch.nn.Linear(5632 if is_kolors else 2816, 1280),
torch.nn.SiLU(),
torch.nn.Linear(1280, 1280)
)
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
self.text_intermediate_proj = torch.nn.Linear(4096, 2048) if is_kolors else None
self.blocks = torch.nn.ModuleList([
# DownBlock2D
@@ -85,7 +86,9 @@ class SDXLUNet(torch.nn.Module):
def forward(
self,
sample, timestep, encoder_hidden_states, add_time_id, add_text_embeds,
tiled=False, tile_size=64, tile_stride=8, **kwargs
tiled=False, tile_size=64, tile_stride=8,
use_gradient_checkpointing=False,
**kwargs
):
# 1. time
t_emb = self.time_proj(timestep[None]).to(sample.dtype)
@@ -102,15 +105,26 @@ class SDXLUNet(torch.nn.Module):
# 2. pre-process
height, width = sample.shape[2], sample.shape[3]
hidden_states = self.conv_in(sample)
text_emb = encoder_hidden_states
text_emb = encoder_hidden_states if self.text_intermediate_proj is None else self.text_intermediate_proj(encoder_hidden_states)
res_stack = [hidden_states]
# 3. blocks
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
for i, block in enumerate(self.blocks):
hidden_states, time_emb, text_emb, res_stack = block(
hidden_states, time_emb, text_emb, res_stack,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
if self.training and use_gradient_checkpointing and not (isinstance(block, PushBlock) or isinstance(block, PopBlock)):
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, time_emb, text_emb, res_stack,
use_reentrant=False,
)
else:
hidden_states, time_emb, text_emb, res_stack = block(
hidden_states, time_emb, text_emb, res_stack,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
)
# 4. output
hidden_states = self.conv_norm_out(hidden_states)
@@ -148,6 +162,8 @@ class SDXLUNetStateDictConverter:
names = name.split(".")
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
pass
elif names[0] in ["encoder_hid_proj"]:
names[0] = "text_intermediate_proj"
elif names[0] in ["time_embedding", "add_embedding"]:
if names[0] == "add_embedding":
names[0] = "add_time_embedding"

View File

@@ -5,3 +5,4 @@ from .stable_diffusion_xl_video import SDXLVideoPipeline
from .stable_video_diffusion import SVDVideoPipeline
from .hunyuan_dit import HunyuanDiTImagePipeline
from .stable_diffusion_3 import SD3ImagePipeline
from .kwai_kolors import KolorsImagePipeline

View File

@@ -147,7 +147,7 @@ def lets_dance_xl(
# 3. pre-process
height, width = sample.shape[2], sample.shape[3]
hidden_states = unet.conv_in(sample)
text_emb = encoder_hidden_states
text_emb = encoder_hidden_states if unet.text_intermediate_proj is None else unet.text_intermediate_proj(encoder_hidden_states)
res_stack = [hidden_states]
# 4. blocks

View File

@@ -216,7 +216,7 @@ class HunyuanDiTImagePipeline(torch.nn.Module):
# Prepare latent tensors
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
if input_image is not None:
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
image = self.preprocess_image(input_image).to(device=self.device, dtype=torch.float32)
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
@@ -293,6 +293,6 @@ class HunyuanDiTImagePipeline(torch.nn.Module):
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
image = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return image

View File

@@ -0,0 +1,168 @@
from ..models import ModelManager, SDXLUNet, SDXLVAEDecoder, SDXLVAEEncoder, SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from ..models.kolors_text_encoder import ChatGLMModel
from ..prompts import KolorsPrompter
from ..schedulers import EnhancedDDIMScheduler
from .dancer import lets_dance_xl
import torch
from tqdm import tqdm
from PIL import Image
import numpy as np
class KolorsImagePipeline(torch.nn.Module):
def __init__(self, device="cuda", torch_dtype=torch.float16):
super().__init__()
self.scheduler = EnhancedDDIMScheduler(beta_end=0.014, num_train_timesteps=1100)
self.prompter = KolorsPrompter()
self.device = device
self.torch_dtype = torch_dtype
# models
self.text_encoder: ChatGLMModel = None
self.unet: SDXLUNet = None
self.vae_decoder: SDXLVAEDecoder = None
self.vae_encoder: SDXLVAEEncoder = None
self.ipadapter_image_encoder: IpAdapterXLCLIPImageEmbedder = None
self.ipadapter: SDXLIpAdapter = None
def fetch_main_models(self, model_manager: ModelManager):
self.text_encoder = model_manager.kolors_text_encoder
self.unet = model_manager.kolors_unet
self.vae_decoder = model_manager.vae_decoder
self.vae_encoder = model_manager.vae_encoder
def fetch_ipadapter(self, model_manager: ModelManager):
if "ipadapter_xl" in model_manager.model:
self.ipadapter = model_manager.ipadapter_xl
if "ipadapter_xl_image_encoder" in model_manager.model:
self.ipadapter_image_encoder = model_manager.ipadapter_xl_image_encoder
def fetch_prompter(self, model_manager: ModelManager):
self.prompter.load_from_model_manager(model_manager)
@staticmethod
def from_model_manager(model_manager: ModelManager):
pipe = KolorsImagePipeline(
device=model_manager.device,
torch_dtype=model_manager.torch_dtype,
)
pipe.fetch_main_models(model_manager)
pipe.fetch_prompter(model_manager)
pipe.fetch_ipadapter(model_manager)
return pipe
def preprocess_image(self, image):
image = torch.Tensor(np.array(image, dtype=np.float32) * (2 / 255) - 1).permute(2, 0, 1).unsqueeze(0)
return image
def decode_image(self, latent, tiled=False, tile_size=64, tile_stride=32):
image = self.vae_decoder(latent.to(self.device), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
image = image.cpu().permute(1, 2, 0).numpy()
image = Image.fromarray(((image / 2 + 0.5).clip(0, 1) * 255).astype("uint8"))
return image
@torch.no_grad()
def __call__(
self,
prompt,
negative_prompt="",
cfg_scale=7.5,
clip_skip=2,
input_image=None,
ipadapter_images=None,
ipadapter_scale=1.0,
ipadapter_use_instant_style=False,
denoising_strength=1.0,
height=1024,
width=1024,
num_inference_steps=20,
tiled=False,
tile_size=64,
tile_stride=32,
progress_bar_cmd=tqdm,
progress_bar_st=None,
):
# Prepare scheduler
self.scheduler.set_timesteps(num_inference_steps, denoising_strength)
# Prepare latent tensors
if input_image is not None:
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.torch_dtype)
latents = self.vae_encoder(image.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(self.torch_dtype)
noise = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
latents = self.scheduler.add_noise(latents, noise, timestep=self.scheduler.timesteps[0])
else:
latents = torch.randn((1, 4, height//8, width//8), device=self.device, dtype=self.torch_dtype)
# Encode prompts
add_prompt_emb_posi, prompt_emb_posi = self.prompter.encode_prompt(
self.text_encoder,
prompt,
clip_skip=clip_skip,
device=self.device,
positive=True,
)
if cfg_scale != 1.0:
add_prompt_emb_nega, prompt_emb_nega = self.prompter.encode_prompt(
self.text_encoder,
negative_prompt,
clip_skip=clip_skip,
device=self.device,
positive=False,
)
# Prepare positional id
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
# IP-Adapter
if ipadapter_images is not None:
if ipadapter_use_instant_style:
self.ipadapter.set_less_adapter()
else:
self.ipadapter.set_full_adapter()
ipadapter_image_encoding = self.ipadapter_image_encoder(ipadapter_images)
ipadapter_kwargs_list_posi = self.ipadapter(ipadapter_image_encoding, scale=ipadapter_scale)
ipadapter_kwargs_list_nega = self.ipadapter(torch.zeros_like(ipadapter_image_encoding))
else:
ipadapter_kwargs_list_posi, ipadapter_kwargs_list_nega = {}, {}
# Denoise
for progress_id, timestep in enumerate(progress_bar_cmd(self.scheduler.timesteps)):
timestep = torch.IntTensor((timestep,))[0].to(self.device)
# Classifier-free guidance
noise_pred_posi = lets_dance_xl(
self.unet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_posi,
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_posi,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
ipadapter_kwargs_list=ipadapter_kwargs_list_posi,
)
if cfg_scale != 1.0:
noise_pred_nega = lets_dance_xl(
self.unet,
sample=latents, timestep=timestep, encoder_hidden_states=prompt_emb_nega,
add_time_id=add_time_id, add_text_embeds=add_prompt_emb_nega,
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride,
ipadapter_kwargs_list=ipadapter_kwargs_list_nega,
)
noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega)
else:
noise_pred = noise_pred_posi
latents = self.scheduler.step(noise_pred, timestep, latents)
if progress_bar_st is not None:
progress_bar_st.progress(progress_id / len(self.scheduler.timesteps))
# Decode image
image = self.decode_image(latents.to(torch.float32), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return image

View File

@@ -2,3 +2,4 @@ from .sd_prompter import SDPrompter
from .sdxl_prompter import SDXLPrompter
from .sd3_prompter import SD3Prompter
from .hunyuan_dit_prompter import HunyuanDiTPrompter
from .kolors_prompter import KolorsPrompter

View File

@@ -0,0 +1,347 @@
from .utils import Prompter
import json, os, re
from typing import List, Optional, Union, Dict
from sentencepiece import SentencePieceProcessor
from transformers import PreTrainedTokenizer
from transformers.utils import PaddingStrategy
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
from ..models.kolors_text_encoder import ChatGLMModel
class SPTokenizer:
def __init__(self, model_path: str):
# reload tokenizer
assert os.path.isfile(model_path), model_path
self.sp_model = SentencePieceProcessor(model_file=model_path)
# BOS / EOS token IDs
self.n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.unk_id()
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
self.special_tokens = {}
self.index_special_tokens = {}
for token in special_tokens:
self.special_tokens[token] = self.n_words
self.index_special_tokens[self.n_words] = token
self.n_words += 1
self.role_special_token_expression = "|".join([re.escape(token) for token in role_special_tokens])
def tokenize(self, s: str, encode_special_tokens=False):
if encode_special_tokens:
last_index = 0
t = []
for match in re.finditer(self.role_special_token_expression, s):
if last_index < match.start():
t.extend(self.sp_model.EncodeAsPieces(s[last_index:match.start()]))
t.append(s[match.start():match.end()])
last_index = match.end()
if last_index < len(s):
t.extend(self.sp_model.EncodeAsPieces(s[last_index:]))
return t
else:
return self.sp_model.EncodeAsPieces(s)
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
assert type(s) is str
t = self.sp_model.encode(s)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t
def decode(self, t: List[int]) -> str:
text, buffer = "", []
for token in t:
if token in self.index_special_tokens:
if buffer:
text += self.sp_model.decode(buffer)
buffer = []
text += self.index_special_tokens[token]
else:
buffer.append(token)
if buffer:
text += self.sp_model.decode(buffer)
return text
def decode_tokens(self, tokens: List[str]) -> str:
text = self.sp_model.DecodePieces(tokens)
return text
def convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
if token in self.special_tokens:
return self.special_tokens[token]
return self.sp_model.PieceToId(token)
def convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
if index in self.index_special_tokens:
return self.index_special_tokens[index]
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0:
return ""
return self.sp_model.IdToPiece(index)
class ChatGLMTokenizer(PreTrainedTokenizer):
vocab_files_names = {"vocab_file": "tokenizer.model"}
model_input_names = ["input_ids", "attention_mask", "position_ids"]
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
**kwargs):
self.name = "GLMTokenizer"
self.vocab_file = vocab_file
self.tokenizer = SPTokenizer(vocab_file)
self.special_tokens = {
"<bos>": self.tokenizer.bos_id,
"<eos>": self.tokenizer.eos_id,
"<pad>": self.tokenizer.pad_id
}
self.encode_special_tokens = encode_special_tokens
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
encode_special_tokens=encode_special_tokens,
**kwargs)
def get_command(self, token):
if token in self.special_tokens:
return self.special_tokens[token]
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
return self.tokenizer.special_tokens[token]
@property
def unk_token(self) -> str:
return "<unk>"
@property
def pad_token(self) -> str:
return "<unk>"
@property
def pad_token_id(self):
return self.get_command("<pad>")
@property
def eos_token(self) -> str:
return "</s>"
@property
def eos_token_id(self):
return self.get_command("<eos>")
@property
def vocab_size(self):
return self.tokenizer.n_words
def get_vocab(self):
""" Returns vocab as a dict """
vocab = {self._convert_id_to_token(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text, **kwargs):
return self.tokenizer.tokenize(text, encode_special_tokens=self.encode_special_tokens)
def _convert_token_to_id(self, token):
""" Converts a token (str) in an id using the vocab. """
return self.tokenizer.convert_token_to_id(token)
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
return self.tokenizer.convert_id_to_token(index)
def convert_tokens_to_string(self, tokens: List[str]) -> str:
return self.tokenizer.decode_tokens(tokens)
def save_vocabulary(self, save_directory, filename_prefix=None):
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
filename_prefix (`str`, *optional*):
An optional prefix to add to the named of the saved files.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if os.path.isdir(save_directory):
vocab_file = os.path.join(
save_directory, self.vocab_files_names["vocab_file"]
)
else:
vocab_file = save_directory
with open(self.vocab_file, 'rb') as fin:
proto_str = fin.read()
with open(vocab_file, "wb") as writer:
writer.write(proto_str)
return (vocab_file,)
def get_prefix_tokens(self):
prefix_tokens = [self.get_command("[gMASK]"), self.get_command("sop")]
return prefix_tokens
def build_single_message(self, role, metadata, message):
assert role in ["system", "user", "assistant", "observation"], role
role_tokens = [self.get_command(f"<|{role}|>")] + self.tokenizer.encode(f"{metadata}\n")
message_tokens = self.tokenizer.encode(message)
tokens = role_tokens + message_tokens
return tokens
def build_chat_input(self, query, history=None, role="user"):
if history is None:
history = []
input_ids = []
for item in history:
content = item["content"]
if item["role"] == "system" and "tools" in item:
content = content + "\n" + json.dumps(item["tools"], indent=4, ensure_ascii=False)
input_ids.extend(self.build_single_message(item["role"], item.get("metadata", ""), content))
input_ids.extend(self.build_single_message(role, "", query))
input_ids.extend([self.get_command("<|assistant|>")])
return self.batch_encode_plus([input_ids], return_tensors="pt", is_split_into_words=True)
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
adding special tokens. A BERT sequence has the following format:
- single sequence: `[CLS] X [SEP]`
- pair of sequences: `[CLS] A [SEP] B [SEP]`
Args:
token_ids_0 (`List[int]`):
List of IDs to which the special tokens will be added.
token_ids_1 (`List[int]`, *optional*):
Optional second list of IDs for sequence pairs.
Returns:
`List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
"""
prefix_tokens = self.get_prefix_tokens()
token_ids_0 = prefix_tokens + token_ids_0
if token_ids_1 is not None:
token_ids_0 = token_ids_0 + token_ids_1 + [self.get_command("<eos>")]
return token_ids_0
def _pad(
self,
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
max_length: Optional[int] = None,
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
Args:
encoded_inputs:
Dictionary of tokenized inputs (`List[int]`) or batch of tokenized inputs (`List[List[int]]`).
max_length: maximum length of the returned list and optionally padding length (see below).
Will truncate by taking into account the special tokens.
padding_strategy: PaddingStrategy to use for padding.
- PaddingStrategy.LONGEST Pad to the longest sequence in the batch
- PaddingStrategy.MAX_LENGTH: Pad to the max length (default)
- PaddingStrategy.DO_NOT_PAD: Do not pad
The tokenizer padding sides are defined in self.padding_side:
- 'left': pads on the left of the sequences
- 'right': pads on the right of the sequences
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
`>= 7.5` (Volta).
return_attention_mask:
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
"""
# Load from model defaults
assert self.padding_side == "left"
required_input = encoded_inputs[self.model_input_names[0]]
seq_length = len(required_input)
if padding_strategy == PaddingStrategy.LONGEST:
max_length = len(required_input)
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
# Initialize attention mask if not present.
if "attention_mask" not in encoded_inputs:
encoded_inputs["attention_mask"] = [1] * seq_length
if "position_ids" not in encoded_inputs:
encoded_inputs["position_ids"] = list(range(seq_length))
if needs_to_be_padded:
difference = max_length - len(required_input)
if "attention_mask" in encoded_inputs:
encoded_inputs["attention_mask"] = [0] * difference + encoded_inputs["attention_mask"]
if "position_ids" in encoded_inputs:
encoded_inputs["position_ids"] = [0] * difference + encoded_inputs["position_ids"]
encoded_inputs[self.model_input_names[0]] = [self.pad_token_id] * difference + required_input
return encoded_inputs
class KolorsPrompter(Prompter):
def __init__(
self,
tokenizer_path=None
):
if tokenizer_path is None:
base_path = os.path.dirname(os.path.dirname(__file__))
tokenizer_path = os.path.join(base_path, "tokenizer_configs/kolors/tokenizer")
super().__init__()
self.tokenizer = ChatGLMTokenizer.from_pretrained(tokenizer_path)
def encode_prompt_using_ChatGLM(self, prompt, text_encoder, tokenizer, max_length, clip_skip, device):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
).to(device)
output = text_encoder(
input_ids=text_inputs['input_ids'] ,
attention_mask=text_inputs['attention_mask'],
position_ids=text_inputs['position_ids'],
output_hidden_states=True
)
prompt_emb = output.hidden_states[-clip_skip].permute(1, 0, 2).clone()
pooled_prompt_emb = output.hidden_states[-1][-1, :, :].clone()
return prompt_emb, pooled_prompt_emb
def encode_prompt(
self,
text_encoder: ChatGLMModel,
prompt,
clip_skip=2,
positive=True,
device="cuda"
):
prompt = self.process_prompt(prompt, positive=positive)
prompt_emb, pooled_prompt_emb = self.encode_prompt_using_ChatGLM(prompt, text_encoder, self.tokenizer, 256, clip_skip, device)
return pooled_prompt_emb, prompt_emb

View File

@@ -0,0 +1,12 @@
{
"name_or_path": "THUDM/chatglm3-6b-base",
"remove_space": false,
"do_lower_case": false,
"tokenizer_class": "ChatGLMTokenizer",
"auto_map": {
"AutoTokenizer": [
"tokenization_chatglm.ChatGLMTokenizer",
null
]
}
}

Binary file not shown.

View File

@@ -28,6 +28,16 @@ LoRA Training: [`../train/stable_diffusion_3/`](../train/stable_diffusion_3/)
|-|-|
|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/4df346db-6f91-420a-b4c1-26e205376098)|![image_2048](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/1386c802-e580-4101-939d-f1596802df9d)|
### Example: Kolors
Example script: [`kolors_text_to_image.py`](./kolors_text_to_image.py)
LoRA Training: [`../train/kolors/`](../train/kolors/)
|1024*1024|2048*2048|
|-|-|
|![image_1024](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/53ef6f41-da11-4701-8665-9f64392607bf)|![image_2048](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/66bb7a75-fe31-44e5-90eb-d3140ee4686d)|
### Example: Hunyuan-DiT
Example script: [`hunyuan_dit_text_to_image.py`](./hunyuan_dit_text_to_image.py)

View File

@@ -0,0 +1,34 @@
from diffsynth import ModelManager, KolorsImagePipeline, download_models
import torch
# Download models
# https://huggingface.co/Kwai-Kolors/Kolors
download_models(["Kolors"])
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
file_path_list=[
"models/kolors/Kolors/text_encoder",
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors"
])
pipe = KolorsImagePipeline.from_model_manager(model_manager)
prompt = "一幅充满诗意美感的全身画,泛红的肤色,画中一位银色长发、蓝色眼睛、肤色红润、身穿蓝色吊带连衣裙的少女漂浮在水下,面向镜头,周围是光彩的气泡,和煦的阳光透过水面折射进水下"
negative_prompt = "半身,苍白的肤色,蜡黄的肤色,尸体,错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,错误的手指,口红,腮红"
torch.manual_seed(7)
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=50,
cfg_scale=4,
)
image.save(f"image_1024.jpg")
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
input_image=image.resize((2048, 2048)), denoising_strength=0.4, height=2048, width=2048,
num_inference_steps=50,
cfg_scale=4,
)
image.save("image_2048.jpg")

View File

@@ -0,0 +1,175 @@
# Kolors
Kolors is a Chinese diffusion model, which is based on ChatGLM and Stable Diffusion XL. We provide training scripts here.
## Download models
The following files will be used for constructing Kolors. You can download them from [huggingface](https://huggingface.co/Kwai-Kolors/Kolors) or [modelscope](https://modelscope.cn/models/Kwai-Kolors/Kolors).
```
models/kolors/Kolors
├── text_encoder
│ ├── config.json
│ ├── pytorch_model-00001-of-00007.bin
│ ├── pytorch_model-00002-of-00007.bin
│ ├── pytorch_model-00003-of-00007.bin
│ ├── pytorch_model-00004-of-00007.bin
│ ├── pytorch_model-00005-of-00007.bin
│ ├── pytorch_model-00006-of-00007.bin
│ ├── pytorch_model-00007-of-00007.bin
│ └── pytorch_model.bin.index.json
├── unet
│ └── diffusion_pytorch_model.safetensors
└── vae
└── diffusion_pytorch_model.safetensors
```
You can use the following code to download these files:
```python
from diffsynth import download_models
download_models(["Kolors"])
```
## Train
### Install training dependency
```
pip install peft lightning pandas torchvision
```
### Prepare your dataset
We provide an example dataset [here](https://modelscope.cn/datasets/buptwq/lora-stable-diffusion-finetune/files). You need to manage the training images as follows:
```
data/dog/
└── train
├── 00.jpg
├── 01.jpg
├── 02.jpg
├── 03.jpg
├── 04.jpg
└── metadata.csv
```
`metadata.csv`:
```
file_name,text
00.jpg,一只小狗
01.jpg,一只小狗
02.jpg,一只小狗
03.jpg,一只小狗
04.jpg,一只小狗
```
### Train a LoRA model
We provide a training script `train_kolors_lora.py`. Before you run this training script, please copy it to the root directory of this project.
The following settings are recommended. **We found the UNet model suffers from precision overflow issues, thus the training script doesn't support float16. 40GB VRAM is required. We are working on overcoming this pitfall.**
```
CUDA_VISIBLE_DEVICES="0" python examples/train/kolors/train_kolors_lora.py \
--pretrained_path models/kolors/Kolors \
--dataset_path data/dog \
--output_path ./models \
--max_epochs 10 \
--center_crop \
--use_gradient_checkpointing \
--precision 32
```
Optional arguments:
```
-h, --help show this help message and exit
--pretrained_path PRETRAINED_PATH
Path to pretrained model. For example, `models/kolors/Kolors`.
--dataset_path DATASET_PATH
The path of the Dataset.
--output_path OUTPUT_PATH
Path to save the model.
--steps_per_epoch STEPS_PER_EPOCH
Number of steps per epoch.
--height HEIGHT Image height.
--width WIDTH Image width.
--center_crop Whether to center crop the input images to the resolution. If not set, the images will be randomly cropped. The images will be resized to the resolution first before cropping.
--random_flip Whether to randomly flip images horizontally
--batch_size BATCH_SIZE
Batch size (per device) for the training dataloader.
--dataloader_num_workers DATALOADER_NUM_WORKERS
Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.
--precision {32,16,16-mixed}
Training precision
--learning_rate LEARNING_RATE
Learning rate.
--lora_rank LORA_RANK
The dimension of the LoRA update matrices.
--lora_alpha LORA_ALPHA
The weight of the LoRA update matrices.
--use_gradient_checkpointing
Whether to use gradient checkpointing.
--accumulate_grad_batches ACCUMULATE_GRAD_BATCHES
The number of batches in gradient accumulation.
--training_strategy {auto,deepspeed_stage_1,deepspeed_stage_2,deepspeed_stage_3}
Training strategy
--max_epochs MAX_EPOCHS
Number of epochs.
```
### Inference with your own LoRA model
After training, you can use your own LoRA model to generate new images. Here are some examples.
```python
from diffsynth import ModelManager, KolorsImagePipeline
from peft import LoraConfig, inject_adapter_in_model
import torch
def load_lora(model, lora_rank, lora_alpha, lora_path):
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
init_lora_weights="gaussian",
target_modules=["to_q", "to_k", "to_v", "to_out"],
)
model = inject_adapter_in_model(lora_config, model)
state_dict = torch.load(lora_path, map_location="cpu")
model.load_state_dict(state_dict, strict=False)
return model
# Load models
model_manager = ModelManager(torch_dtype=torch.float16, device="cuda",
file_path_list=[
"models/kolors/Kolors/text_encoder",
"models/kolors/Kolors/unet/diffusion_pytorch_model.safetensors",
"models/kolors/Kolors/vae/diffusion_pytorch_model.safetensors"
])
pipe = KolorsImagePipeline.from_model_manager(model_manager)
# Generate an image with lora
pipe.unet = load_lora(
pipe.unet,
lora_rank=4, lora_alpha=4.0, # The two parameters should be consistent with those in your training script.
lora_path="path/to/your/lora/model/lightning_logs/version_x/checkpoints/epoch=x-step=xxx.ckpt"
)
torch.manual_seed(0)
image = pipe(
prompt="一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉",
negative_prompt="",
cfg_scale=4,
num_inference_steps=50, height=1024, width=1024,
)
image.save("image_with_lora.jpg")
```
Prompt: 一只小狗蹦蹦跳跳,周围是姹紫嫣红的鲜花,远处是山脉
|Without LoRA|With LoRA|
|-|-|
|![image_without_lora](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/9d79ed7a-e8cf-4d98-800a-f182809db318)|![image_with_lora](https://github.com/modelscope/DiffSynth-Studio/assets/35051019/02f62323-6ee5-4788-97a1-549732dbe4f0)|

View File

@@ -0,0 +1,293 @@
from diffsynth import ModelManager, KolorsImagePipeline
from peft import LoraConfig, inject_adapter_in_model
from torchvision import transforms
from PIL import Image
import lightning as pl
import pandas as pd
import torch, os, argparse
os.environ["TOKENIZERS_PARALLELISM"] = "True"
class TextImageDataset(torch.utils.data.Dataset):
def __init__(self, dataset_path, steps_per_epoch=10000, height=1024, width=1024, center_crop=True, random_flip=False):
self.steps_per_epoch = steps_per_epoch
metadata = pd.read_csv(os.path.join(dataset_path, "train/metadata.csv"))
self.path = [os.path.join(dataset_path, "train", file_name) for file_name in metadata["file_name"]]
self.text = metadata["text"].to_list()
self.image_processor = transforms.Compose(
[
transforms.Resize(max(height, width), interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop((height, width)) if center_crop else transforms.RandomCrop((height, width)),
transforms.RandomHorizontalFlip() if random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __getitem__(self, index):
data_id = torch.randint(0, len(self.path), (1,))[0]
data_id = (data_id + index) % len(self.path) # For fixed seed.
text = self.text[data_id]
image = Image.open(self.path[data_id]).convert("RGB")
image = self.image_processor(image)
return {"text": text, "image": image}
def __len__(self):
return self.steps_per_epoch
class LightningModel(pl.LightningModule):
def __init__(self, torch_dtype=torch.float16, learning_rate=1e-4, pretrained_weights=[], lora_rank=4, lora_alpha=4, use_gradient_checkpointing=True):
super().__init__()
# Load models
model_manager = ModelManager(torch_dtype=torch_dtype, device=self.device)
model_manager.load_models(pretrained_weights)
self.pipe = KolorsImagePipeline.from_model_manager(model_manager)
# Freeze parameters
self.pipe.text_encoder.requires_grad_(False)
self.pipe.unet.requires_grad_(False)
self.pipe.vae_decoder.requires_grad_(False)
self.pipe.vae_encoder.requires_grad_(False)
self.pipe.text_encoder.eval()
self.pipe.unet.train()
self.pipe.vae_decoder.eval()
self.pipe.vae_encoder.eval()
# Add LoRA to UNet
lora_config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
init_lora_weights="gaussian",
target_modules=["to_q", "to_k", "to_v", "to_out"],
)
self.pipe.unet = inject_adapter_in_model(lora_config, self.pipe.unet)
for param in self.pipe.unet.parameters():
# Upcast LoRA parameters into fp32
if param.requires_grad:
param.data = param.to(torch.float32)
# Set other parameters
self.learning_rate = learning_rate
self.use_gradient_checkpointing = use_gradient_checkpointing
self.pipe.scheduler.set_timesteps(1100)
def training_step(self, batch, batch_idx):
# Data
text, image = batch["text"], batch["image"]
# Prepare input parameters
self.pipe.device = self.device
add_prompt_emb, prompt_emb = self.pipe.prompter.encode_prompt(
self.pipe.text_encoder, text, clip_skip=2, device=self.device, positive=True,
)
height, width = image.shape[-2:]
latents = self.pipe.vae_encoder(image.to(dtype=torch.float32, device=self.device)).to(self.pipe.torch_dtype)
noise = torch.randn_like(latents)
timestep = torch.randint(0, 1100, (1,), device=self.device)[0]
add_time_id = torch.tensor([height, width, 0, 0, height, width], device=self.device)
noisy_latents = self.pipe.scheduler.add_noise(latents, noise, timestep)
# Compute loss
noise_pred = self.pipe.unet(
noisy_latents, timestep, prompt_emb, add_time_id, add_prompt_emb,
use_gradient_checkpointing=self.use_gradient_checkpointing
)
loss = torch.nn.functional.mse_loss(noise_pred, noise)
# Record log
self.log("train_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
trainable_modules = filter(lambda p: p.requires_grad, self.pipe.unet.parameters())
optimizer = torch.optim.AdamW(trainable_modules, lr=self.learning_rate)
return optimizer
def on_save_checkpoint(self, checkpoint):
checkpoint.clear()
trainable_param_names = list(filter(lambda named_param: named_param[1].requires_grad, self.pipe.unet.named_parameters()))
trainable_param_names = set([named_param[0] for named_param in trainable_param_names])
state_dict = self.pipe.unet.state_dict()
for name, param in state_dict.items():
if name in trainable_param_names:
checkpoint[name] = param
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_path",
type=str,
default=None,
required=True,
help="Path to pretrained model. For example, `models/kolors/Kolors`.",
)
parser.add_argument(
"--dataset_path",
type=str,
default=None,
required=True,
help="The path of the Dataset.",
)
parser.add_argument(
"--output_path",
type=str,
default="./",
help="Path to save the model.",
)
parser.add_argument(
"--steps_per_epoch",
type=int,
default=500,
help="Number of steps per epoch.",
)
parser.add_argument(
"--height",
type=int,
default=1024,
help="Image height.",
)
parser.add_argument(
"--width",
type=int,
default=1024,
help="Image width.",
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
default=False,
action="store_true",
help="Whether to randomly flip images horizontally",
)
parser.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--precision",
type=str,
default="16-mixed",
choices=["32", "16", "16-mixed"],
help="Training precision",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Learning rate.",
)
parser.add_argument(
"--lora_rank",
type=int,
default=4,
help="The dimension of the LoRA update matrices.",
)
parser.add_argument(
"--lora_alpha",
type=float,
default=4.0,
help="The weight of the LoRA update matrices.",
)
parser.add_argument(
"--use_gradient_checkpointing",
default=False,
action="store_true",
help="Whether to use gradient checkpointing.",
)
parser.add_argument(
"--accumulate_grad_batches",
type=int,
default=1,
help="The number of batches in gradient accumulation.",
)
parser.add_argument(
"--training_strategy",
type=str,
default="auto",
choices=["auto", "deepspeed_stage_1", "deepspeed_stage_2", "deepspeed_stage_3"],
help="Training strategy",
)
parser.add_argument(
"--max_epochs",
type=int,
default=1,
help="Number of epochs.",
)
args = parser.parse_args()
return args
if __name__ == '__main__':
# args
args = parse_args()
# dataset and data loader
dataset = TextImageDataset(
args.dataset_path,
steps_per_epoch=args.steps_per_epoch * args.batch_size,
height=args.height,
width=args.width,
center_crop=args.center_crop,
random_flip=args.random_flip
)
train_loader = torch.utils.data.DataLoader(
dataset,
shuffle=True,
batch_size=args.batch_size,
num_workers=args.dataloader_num_workers
)
# model
model = LightningModel(
pretrained_weights=[
os.path.join(args.pretrained_path, "text_encoder"),
os.path.join(args.pretrained_path, "unet/diffusion_pytorch_model.safetensors"),
os.path.join(args.pretrained_path, "vae/diffusion_pytorch_model.safetensors"),
],
torch_dtype=torch.float32 if args.precision == "32" else torch.float16,
learning_rate=args.learning_rate,
lora_rank=args.lora_rank,
lora_alpha=args.lora_alpha,
use_gradient_checkpointing=args.use_gradient_checkpointing
)
# train
trainer = pl.Trainer(
max_epochs=args.max_epochs,
accelerator="gpu",
devices="auto",
precision=args.precision,
strategy=args.training_strategy,
default_root_dir=args.output_path,
accumulate_grad_batches=args.accumulate_grad_batches,
callbacks=[pl.pytorch.callbacks.ModelCheckpoint(save_top_k=-1)]
)
trainer.fit(model=model, train_dataloaders=train_loader)

View File

@@ -153,7 +153,7 @@ image = pipe(
image.save("image_with_lora.jpg")
```
Prompt:
Prompt: a dog is jumping, flowers around the dog, the background is mountains and clouds
|Without LoRA|With LoRA|
|-|-|

View File

@@ -10,3 +10,4 @@ imageio[ffmpeg]
safetensors
einops
sentencepiece
modelscope