Merge branch 'main' into omost

This commit is contained in:
Zhongjie Duan
2024-09-04 12:52:03 +08:00
committed by GitHub
24 changed files with 2711 additions and 113 deletions

View File

@@ -10,7 +10,7 @@ from .sd_text_encoder import SDTextEncoder
from .sd_unet import SDUNet
from .sd_vae_encoder import SDVAEEncoder
from .sd_vae_decoder import SDVAEDecoder
from .lora import SDLoRAFromCivitai, SDXLLoRAFromCivitai, GeneralLoRAFromPeft
from .lora import get_lora_loaders
from .sdxl_text_encoder import SDXLTextEncoder, SDXLTextEncoder2
from .sdxl_unet import SDXLUNet
@@ -43,93 +43,17 @@ from .flux_dit import FluxDiT
from .flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
from .cog_vae import CogVAEEncoder, CogVAEDecoder
from .cog_dit import CogDiT
from ..extensions.RIFE import IFNet
from ..extensions.ESRGAN import RRDBNet
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
from .utils import load_state_dict
def load_state_dict(file_path, torch_dtype=None):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu")
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):
state_dict[i] = state_dict[i].to(torch_dtype)
return state_dict
def search_for_embeddings(state_dict):
embeddings = []
for k in state_dict:
if isinstance(state_dict[k], torch.Tensor):
embeddings.append(state_dict[k])
elif isinstance(state_dict[k], dict):
embeddings += search_for_embeddings(state_dict[k])
return embeddings
def search_parameter(param, state_dict):
for name, param_ in state_dict.items():
if param.numel() == param_.numel():
if param.shape == param_.shape:
if torch.dist(param, param_) < 1e-3:
return name
else:
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
return name
return None
def build_rename_dict(source_state_dict, target_state_dict, split_qkv=False):
matched_keys = set()
with torch.no_grad():
for name in source_state_dict:
rename = search_parameter(source_state_dict[name], target_state_dict)
if rename is not None:
print(f'"{name}": "{rename}",')
matched_keys.add(rename)
elif split_qkv and len(source_state_dict[name].shape)>=1 and source_state_dict[name].shape[0]%3==0:
length = source_state_dict[name].shape[0] // 3
rename = []
for i in range(3):
rename.append(search_parameter(source_state_dict[name][i*length: i*length+length], target_state_dict))
if None not in rename:
print(f'"{name}": {rename},')
for rename_ in rename:
matched_keys.add(rename_)
for name in target_state_dict:
if name not in matched_keys:
print("Cannot find", name, target_state_dict[name].shape)
def search_for_files(folder, extensions):
files = []
if os.path.isdir(folder):
for file in sorted(os.listdir(folder)):
files += search_for_files(os.path.join(folder, file), extensions)
elif os.path.isfile(folder):
for extension in extensions:
if folder.endswith(extension):
files.append(folder)
break
return files
def convert_state_dict_keys_to_single_str(state_dict, with_shape=True):
keys = []
for key, value in state_dict.items():
@@ -359,7 +283,7 @@ class ModelDetectorFromHuggingfaceFolder:
return False
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
if "architectures" not in config:
if "architectures" not in config and "_class_name" not in config:
return False
return True
@@ -368,7 +292,8 @@ class ModelDetectorFromHuggingfaceFolder:
with open(os.path.join(file_path, "config.json"), "r") as f:
config = json.load(f)
loaded_model_names, loaded_models = [], []
for architecture in config["architectures"]:
architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
for architecture in architectures:
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
if redirected_architecture is not None:
architecture = redirected_architecture
@@ -481,7 +406,7 @@ class ModelManager:
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
for lora in [SDLoRAFromCivitai(), SDXLLoRAFromCivitai(), GeneralLoRAFromPeft()]:
for lora in get_lora_loaders():
match_results = lora.match(model, state_dict)
if match_results is not None:
print(f" Adding LoRA to {model_name} ({model_path}).")