mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
support kolors! (#106)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
import torch, os
|
||||
import torch, os, json
|
||||
from safetensors import safe_open
|
||||
from typing_extensions import Literal, TypeAlias
|
||||
from typing import List
|
||||
@@ -36,6 +36,7 @@ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
|
||||
|
||||
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
|
||||
from .hunyuan_dit import HunyuanDiT
|
||||
from .kolors_text_encoder import ChatGLMModel
|
||||
|
||||
|
||||
preset_models_on_huggingface = {
|
||||
@@ -159,6 +160,20 @@ preset_models_on_modelscope = {
|
||||
("AI-ModelScope/IP-Adapter", "sdxl_models/image_encoder/model.safetensors", "models/IpAdapter/stable_diffusion_xl/image_encoder"),
|
||||
("AI-ModelScope/IP-Adapter", "sdxl_models/ip-adapter_sdxl.bin", "models/IpAdapter/stable_diffusion_xl"),
|
||||
],
|
||||
# Kolors
|
||||
"Kolors": [
|
||||
("Kwai-Kolors/Kolors", "text_encoder/config.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model.bin.index.json", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00001-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00002-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00003-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00004-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00005-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00006-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "text_encoder/pytorch_model-00007-of-00007.bin", "models/kolors/Kolors/text_encoder"),
|
||||
("Kwai-Kolors/Kolors", "unet/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/unet"),
|
||||
("Kwai-Kolors/Kolors", "vae/diffusion_pytorch_model.safetensors", "models/kolors/Kolors/vae"),
|
||||
],
|
||||
}
|
||||
Preset_model_id: TypeAlias = Literal[
|
||||
"HunyuanDiT",
|
||||
@@ -184,7 +199,8 @@ Preset_model_id: TypeAlias = Literal[
|
||||
"IP-Adapter-SD",
|
||||
"IP-Adapter-SDXL",
|
||||
"StableDiffusion3",
|
||||
"StableDiffusion3_without_T5"
|
||||
"StableDiffusion3_without_T5",
|
||||
"Kolors",
|
||||
]
|
||||
Preset_model_website: TypeAlias = Literal[
|
||||
"HuggingFace",
|
||||
@@ -272,8 +288,7 @@ class ModelManager:
|
||||
|
||||
def is_controlnet(self, state_dict):
|
||||
param_name = "control_model.time_embed.0.weight"
|
||||
param_name_2 = "mid_block.resnets.1.time_emb_proj.weight" # For controlnets in diffusers format
|
||||
return param_name in state_dict or param_name_2 in state_dict
|
||||
return param_name in state_dict
|
||||
|
||||
def is_animatediff(self, state_dict):
|
||||
param_name = "mid_block.motion_modules.0.temporal_transformer.proj_out.weight"
|
||||
@@ -343,6 +358,21 @@ class ModelManager:
|
||||
param_name = "encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight"
|
||||
return param_name in state_dict
|
||||
|
||||
def is_kolors_text_encoder(self, file_path):
|
||||
file_list = os.listdir(file_path)
|
||||
if "config.json" in file_list:
|
||||
try:
|
||||
with open(os.path.join(file_path, "config.json"), "r") as f:
|
||||
config = json.load(f)
|
||||
if config.get("model_type") == "chatglm":
|
||||
return True
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
|
||||
def is_kolors_unet(self, state_dict):
|
||||
return "up_blocks.2.resnets.2.time_emb_proj.weight" in state_dict and "encoder_hid_proj.weight" in state_dict
|
||||
|
||||
def load_stable_video_diffusion(self, state_dict, components=None, file_path="", add_positional_conv=None):
|
||||
component_dict = {
|
||||
"image_encoder": SVDImageEncoder,
|
||||
@@ -532,13 +562,13 @@ class ModelManager:
|
||||
component = "vae_encoder"
|
||||
model = SDXLVAEEncoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
model.to(torch.float32).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
component = "vae_decoder"
|
||||
model = SDXLVAEDecoder()
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
model.to(torch.float32).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
@@ -592,6 +622,21 @@ class ModelManager:
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_kolors_text_encoder(self, state_dict=None, file_path=""):
|
||||
component = "kolors_text_encoder"
|
||||
model = ChatGLMModel.from_pretrained(file_path, torch_dtype=self.torch_dtype)
|
||||
model = model.to(dtype=self.torch_dtype, device=self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def load_kolors_unet(self, state_dict, file_path=""):
|
||||
component = "kolors_unet"
|
||||
model = SDXLUNet(is_kolors=True)
|
||||
model.load_state_dict(model.state_dict_converter().from_diffusers(state_dict))
|
||||
model.to(self.torch_dtype).to(self.device)
|
||||
self.model[component] = model
|
||||
self.model_path[component] = file_path
|
||||
|
||||
def search_for_embeddings(self, state_dict):
|
||||
embeddings = []
|
||||
for k in state_dict:
|
||||
@@ -607,7 +652,11 @@ class ModelManager:
|
||||
|
||||
# Load every textual inversion file
|
||||
for file_name in os.listdir(folder):
|
||||
if file_name.endswith(".txt"):
|
||||
if os.path.isdir(os.path.join(folder, file_name)) or \
|
||||
not (file_name.endswith(".bin") or \
|
||||
file_name.endswith(".safetensors") or \
|
||||
file_name.endswith(".pth") or \
|
||||
file_name.endswith(".pt")):
|
||||
continue
|
||||
keyword = os.path.splitext(file_name)[0]
|
||||
state_dict = load_state_dict(os.path.join(folder, file_name))
|
||||
@@ -620,6 +669,10 @@ class ModelManager:
|
||||
break
|
||||
|
||||
def load_model(self, file_path, components=None, lora_alphas=[]):
|
||||
if os.path.isdir(file_path):
|
||||
if self.is_kolors_text_encoder(file_path):
|
||||
self.load_kolors_text_encoder(file_path=file_path)
|
||||
return
|
||||
state_dict = load_state_dict(file_path, torch_dtype=self.torch_dtype)
|
||||
if self.is_stable_video_diffusion(state_dict):
|
||||
self.load_stable_video_diffusion(state_dict, file_path=file_path)
|
||||
@@ -663,6 +716,8 @@ class ModelManager:
|
||||
self.load_stable_diffusion_3(state_dict, components=components, file_path=file_path)
|
||||
elif self.is_stable_diffusion_3_t5(state_dict):
|
||||
self.load_stable_diffusion_3_t5(state_dict, file_path=file_path)
|
||||
elif self.is_kolors_unet(state_dict):
|
||||
self.load_kolors_unet(state_dict, file_path=file_path)
|
||||
|
||||
def load_models(self, file_path_list, lora_alphas=[]):
|
||||
for file_path in file_path_list:
|
||||
|
||||
@@ -1,251 +1,6 @@
|
||||
from huggingface_hub import hf_hub_download
|
||||
from http.cookiejar import CookieJar
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, List, Union
|
||||
import copy, uuid, requests, io, platform, pickle, os, urllib
|
||||
from requests.adapters import Retry
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def _get_sep(path):
|
||||
if isinstance(path, bytes):
|
||||
return b'/'
|
||||
else:
|
||||
return '/'
|
||||
|
||||
|
||||
def expanduser(path):
|
||||
"""Expand ~ and ~user constructions. If user or $HOME is unknown,
|
||||
do nothing."""
|
||||
path = os.fspath(path)
|
||||
if isinstance(path, bytes):
|
||||
tilde = b'~'
|
||||
else:
|
||||
tilde = '~'
|
||||
if not path.startswith(tilde):
|
||||
return path
|
||||
sep = _get_sep(path)
|
||||
i = path.find(sep, 1)
|
||||
if i < 0:
|
||||
i = len(path)
|
||||
if i == 1:
|
||||
if 'HOME' not in os.environ:
|
||||
import pwd
|
||||
try:
|
||||
userhome = pwd.getpwuid(os.getuid()).pw_dir
|
||||
except KeyError:
|
||||
# bpo-10496: if the current user identifier doesn't exist in the
|
||||
# password database, return the path unchanged
|
||||
return path
|
||||
else:
|
||||
userhome = os.environ['HOME']
|
||||
else:
|
||||
import pwd
|
||||
name = path[1:i]
|
||||
if isinstance(name, bytes):
|
||||
name = str(name, 'ASCII')
|
||||
try:
|
||||
pwent = pwd.getpwnam(name)
|
||||
except KeyError:
|
||||
# bpo-10496: if the user name from the path doesn't exist in the
|
||||
# password database, return the path unchanged
|
||||
return path
|
||||
userhome = pwent.pw_dir
|
||||
if isinstance(path, bytes):
|
||||
userhome = os.fsencode(userhome)
|
||||
root = b'/'
|
||||
else:
|
||||
root = '/'
|
||||
userhome = userhome.rstrip(root)
|
||||
return (userhome + path[i:]) or root
|
||||
|
||||
|
||||
|
||||
class ModelScopeConfig:
|
||||
DEFAULT_CREDENTIALS_PATH = Path.home().joinpath('.modelscope', 'credentials')
|
||||
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
|
||||
COOKIES_FILE_NAME = 'cookies'
|
||||
GIT_TOKEN_FILE_NAME = 'git_token'
|
||||
USER_INFO_FILE_NAME = 'user'
|
||||
USER_SESSION_ID_FILE_NAME = 'session'
|
||||
|
||||
@staticmethod
|
||||
def make_sure_credential_path_exist():
|
||||
os.makedirs(ModelScopeConfig.path_credential, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
def get_user_session_id():
|
||||
session_path = os.path.join(ModelScopeConfig.path_credential,
|
||||
ModelScopeConfig.USER_SESSION_ID_FILE_NAME)
|
||||
session_id = ''
|
||||
if os.path.exists(session_path):
|
||||
with open(session_path, 'rb') as f:
|
||||
session_id = str(f.readline().strip(), encoding='utf-8')
|
||||
return session_id
|
||||
if session_id == '' or len(session_id) != 32:
|
||||
session_id = str(uuid.uuid4().hex)
|
||||
ModelScopeConfig.make_sure_credential_path_exist()
|
||||
with open(session_path, 'w+') as wf:
|
||||
wf.write(session_id)
|
||||
|
||||
return session_id
|
||||
|
||||
@staticmethod
|
||||
def get_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str:
|
||||
"""Formats a user-agent string with basic info about a request.
|
||||
|
||||
Args:
|
||||
user_agent (`str`, `dict`, *optional*):
|
||||
The user agent info in the form of a dictionary or a single string.
|
||||
|
||||
Returns:
|
||||
The formatted user-agent string.
|
||||
"""
|
||||
|
||||
# include some more telemetrics when executing in dedicated
|
||||
# cloud containers
|
||||
MODELSCOPE_CLOUD_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT'
|
||||
MODELSCOPE_CLOUD_USERNAME = 'MODELSCOPE_USERNAME'
|
||||
env = 'custom'
|
||||
if MODELSCOPE_CLOUD_ENVIRONMENT in os.environ:
|
||||
env = os.environ[MODELSCOPE_CLOUD_ENVIRONMENT]
|
||||
user_name = 'unknown'
|
||||
if MODELSCOPE_CLOUD_USERNAME in os.environ:
|
||||
user_name = os.environ[MODELSCOPE_CLOUD_USERNAME]
|
||||
|
||||
ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s; user/%s' % (
|
||||
"1.15.0",
|
||||
platform.python_version(),
|
||||
ModelScopeConfig.get_user_session_id(),
|
||||
platform.platform(),
|
||||
platform.processor(),
|
||||
env,
|
||||
user_name,
|
||||
)
|
||||
if isinstance(user_agent, dict):
|
||||
ua += '; ' + '; '.join(f'{k}/{v}' for k, v in user_agent.items())
|
||||
elif isinstance(user_agent, str):
|
||||
ua += '; ' + user_agent
|
||||
return ua
|
||||
|
||||
@staticmethod
|
||||
def get_cookies():
|
||||
cookies_path = os.path.join(ModelScopeConfig.path_credential,
|
||||
ModelScopeConfig.COOKIES_FILE_NAME)
|
||||
if os.path.exists(cookies_path):
|
||||
with open(cookies_path, 'rb') as f:
|
||||
cookies = pickle.load(f)
|
||||
return cookies
|
||||
return None
|
||||
|
||||
|
||||
|
||||
def modelscope_http_get_model_file(
|
||||
url: str,
|
||||
local_dir: str,
|
||||
file_name: str,
|
||||
file_size: int,
|
||||
cookies: CookieJar,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Download remote file, will retry 5 times before giving up on errors.
|
||||
|
||||
Args:
|
||||
url(str):
|
||||
actual download url of the file
|
||||
local_dir(str):
|
||||
local directory where the downloaded file stores
|
||||
file_name(str):
|
||||
name of the file stored in `local_dir`
|
||||
file_size(int):
|
||||
The file size.
|
||||
cookies(CookieJar):
|
||||
cookies used to authentication the user, which is used for downloading private repos
|
||||
headers(Dict[str, str], optional):
|
||||
http headers to carry necessary info when requesting the remote file
|
||||
|
||||
Raises:
|
||||
FileDownloadError: File download failed.
|
||||
|
||||
"""
|
||||
get_headers = {} if headers is None else copy.deepcopy(headers)
|
||||
get_headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
||||
temp_file_path = os.path.join(local_dir, file_name)
|
||||
# retry sleep 0.5s, 1s, 2s, 4s
|
||||
retry = Retry(
|
||||
total=5,
|
||||
backoff_factor=1,
|
||||
allowed_methods=['GET'])
|
||||
while True:
|
||||
try:
|
||||
progress = tqdm(
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
total=file_size,
|
||||
initial=0,
|
||||
desc='Downloading',
|
||||
)
|
||||
partial_length = 0
|
||||
if os.path.exists(
|
||||
temp_file_path): # download partial, continue download
|
||||
with open(temp_file_path, 'rb') as f:
|
||||
partial_length = f.seek(0, io.SEEK_END)
|
||||
progress.update(partial_length)
|
||||
if partial_length > file_size:
|
||||
break
|
||||
get_headers['Range'] = 'bytes=%s-%s' % (partial_length,
|
||||
file_size - 1)
|
||||
with open(temp_file_path, 'ab') as f:
|
||||
r = requests.get(
|
||||
url,
|
||||
stream=True,
|
||||
headers=get_headers,
|
||||
cookies=cookies,
|
||||
timeout=60)
|
||||
r.raise_for_status()
|
||||
for chunk in r.iter_content(
|
||||
chunk_size=1024 * 1024 * 1):
|
||||
if chunk: # filter out keep-alive new chunks
|
||||
progress.update(len(chunk))
|
||||
f.write(chunk)
|
||||
progress.close()
|
||||
break
|
||||
except (Exception) as e: # no matter what happen, we will retry.
|
||||
retry = retry.increment('GET', url, error=e)
|
||||
retry.sleep()
|
||||
|
||||
|
||||
def get_endpoint():
|
||||
MODELSCOPE_URL_SCHEME = 'https://'
|
||||
DEFAULT_MODELSCOPE_DOMAIN = 'www.modelscope.cn'
|
||||
modelscope_domain = os.getenv('MODELSCOPE_DOMAIN',
|
||||
DEFAULT_MODELSCOPE_DOMAIN)
|
||||
return MODELSCOPE_URL_SCHEME + modelscope_domain
|
||||
|
||||
|
||||
def get_file_download_url(model_id: str, file_path: str, revision: str):
|
||||
"""Format file download url according to `model_id`, `revision` and `file_path`.
|
||||
e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`,
|
||||
the resulted download url is: https://modelscope.cn/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md
|
||||
|
||||
Args:
|
||||
model_id (str): The model_id.
|
||||
file_path (str): File path
|
||||
revision (str): File revision.
|
||||
|
||||
Returns:
|
||||
str: The file url.
|
||||
"""
|
||||
file_path = urllib.parse.quote_plus(file_path)
|
||||
revision = urllib.parse.quote_plus(revision)
|
||||
download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}'
|
||||
return download_url_template.format(
|
||||
endpoint=get_endpoint(),
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
file_path=file_path,
|
||||
)
|
||||
from modelscope import snapshot_download
|
||||
import os, shutil
|
||||
|
||||
|
||||
def download_from_modelscope(model_id, origin_file_path, local_dir):
|
||||
@@ -255,17 +10,12 @@ def download_from_modelscope(model_id, origin_file_path, local_dir):
|
||||
return
|
||||
else:
|
||||
print(f"Start downloading {os.path.join(local_dir, os.path.basename(origin_file_path))}")
|
||||
headers = {'user-agent': ModelScopeConfig.get_user_agent(user_agent=None)}
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
url = get_file_download_url(model_id=model_id, file_path=origin_file_path, revision="master")
|
||||
modelscope_http_get_model_file(
|
||||
url,
|
||||
local_dir,
|
||||
os.path.basename(origin_file_path),
|
||||
file_size=0,
|
||||
headers=headers,
|
||||
cookies=cookies
|
||||
)
|
||||
snapshot_download(model_id, allow_file_pattern=origin_file_path, local_dir=local_dir)
|
||||
downloaded_file_path = os.path.join(local_dir, origin_file_path)
|
||||
target_file_path = os.path.join(local_dir, os.path.split(origin_file_path)[-1])
|
||||
if downloaded_file_path != target_file_path:
|
||||
shutil.move(downloaded_file_path, target_file_path)
|
||||
shutil.rmtree(os.path.join(local_dir, origin_file_path.split("/")[0]))
|
||||
|
||||
|
||||
def download_from_huggingface(model_id, origin_file_path, local_dir):
|
||||
|
||||
1363
diffsynth/models/kolors_text_encoder.py
Normal file
1363
diffsynth/models/kolors_text_encoder.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,7 +3,7 @@ from .sd_unet import Timesteps, ResnetBlock, AttentionBlock, PushBlock, PopBlock
|
||||
|
||||
|
||||
class SDXLUNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, is_kolors=False):
|
||||
super().__init__()
|
||||
self.time_proj = Timesteps(320)
|
||||
self.time_embedding = torch.nn.Sequential(
|
||||
@@ -13,11 +13,12 @@ class SDXLUNet(torch.nn.Module):
|
||||
)
|
||||
self.add_time_proj = Timesteps(256)
|
||||
self.add_time_embedding = torch.nn.Sequential(
|
||||
torch.nn.Linear(2816, 1280),
|
||||
torch.nn.Linear(5632 if is_kolors else 2816, 1280),
|
||||
torch.nn.SiLU(),
|
||||
torch.nn.Linear(1280, 1280)
|
||||
)
|
||||
self.conv_in = torch.nn.Conv2d(4, 320, kernel_size=3, padding=1)
|
||||
self.text_intermediate_proj = torch.nn.Linear(4096, 2048) if is_kolors else None
|
||||
|
||||
self.blocks = torch.nn.ModuleList([
|
||||
# DownBlock2D
|
||||
@@ -85,7 +86,9 @@ class SDXLUNet(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
sample, timestep, encoder_hidden_states, add_time_id, add_text_embeds,
|
||||
tiled=False, tile_size=64, tile_stride=8, **kwargs
|
||||
tiled=False, tile_size=64, tile_stride=8,
|
||||
use_gradient_checkpointing=False,
|
||||
**kwargs
|
||||
):
|
||||
# 1. time
|
||||
t_emb = self.time_proj(timestep[None]).to(sample.dtype)
|
||||
@@ -102,15 +105,26 @@ class SDXLUNet(torch.nn.Module):
|
||||
# 2. pre-process
|
||||
height, width = sample.shape[2], sample.shape[3]
|
||||
hidden_states = self.conv_in(sample)
|
||||
text_emb = encoder_hidden_states
|
||||
text_emb = encoder_hidden_states if self.text_intermediate_proj is None else self.text_intermediate_proj(encoder_hidden_states)
|
||||
res_stack = [hidden_states]
|
||||
|
||||
# 3. blocks
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
return custom_forward
|
||||
for i, block in enumerate(self.blocks):
|
||||
hidden_states, time_emb, text_emb, res_stack = block(
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
if self.training and use_gradient_checkpointing and not (isinstance(block, PushBlock) or isinstance(block, PopBlock)):
|
||||
hidden_states, time_emb, text_emb, res_stack = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states, time_emb, text_emb, res_stack = block(
|
||||
hidden_states, time_emb, text_emb, res_stack,
|
||||
tiled=tiled, tile_size=tile_size, tile_stride=tile_stride
|
||||
)
|
||||
|
||||
# 4. output
|
||||
hidden_states = self.conv_norm_out(hidden_states)
|
||||
@@ -148,6 +162,8 @@ class SDXLUNetStateDictConverter:
|
||||
names = name.split(".")
|
||||
if names[0] in ["conv_in", "conv_norm_out", "conv_out"]:
|
||||
pass
|
||||
elif names[0] in ["encoder_hid_proj"]:
|
||||
names[0] = "text_intermediate_proj"
|
||||
elif names[0] in ["time_embedding", "add_embedding"]:
|
||||
if names[0] == "add_embedding":
|
||||
names[0] = "add_time_embedding"
|
||||
|
||||
Reference in New Issue
Block a user