support FLUX

This commit is contained in:
Artiprocher
2024-08-16 20:04:10 +08:00
parent 1116e6dbc7
commit 99e11112a7
20 changed files with 230033 additions and 48 deletions

View File

@@ -39,6 +39,10 @@ from .sdxl_ipadapter import SDXLIpAdapter, IpAdapterXLCLIPImageEmbedder
from .hunyuan_dit_text_encoder import HunyuanDiTCLIPTextEncoder, HunyuanDiTT5TextEncoder
from .hunyuan_dit import HunyuanDiT
from .flux_dit import FluxDiT
from .flux_text_encoder import FluxTextEncoder1, FluxTextEncoder2
from .flux_vae import FluxVAEEncoder, FluxVAEDecoder
from ..configs.model_config import model_loader_configs, huggingface_model_loader_configs, patch_model_loader_configs
@@ -83,10 +87,10 @@ def search_parameter(param, state_dict):
for name, param_ in state_dict.items():
if param.numel() == param_.numel():
if param.shape == param_.shape:
if torch.dist(param, param_) < 1e-6:
if torch.dist(param, param_) < 1e-3:
return name
else:
if torch.dist(param.flatten(), param_.flatten()) < 1e-6:
if torch.dist(param.flatten(), param_.flatten()) < 1e-3:
return name
return None
@@ -340,8 +344,8 @@ class ModelDetectorFromHuggingfaceFolder:
self.add_model_metadata(*metadata)
def add_model_metadata(self, architecture, huggingface_lib, model_name):
self.architecture_dict[architecture] = (huggingface_lib, model_name)
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
def match(self, file_path="", state_dict={}):
@@ -362,7 +366,9 @@ class ModelDetectorFromHuggingfaceFolder:
config = json.load(f)
loaded_model_names, loaded_models = [], []
for architecture in config["architectures"]:
huggingface_lib, model_name = self.architecture_dict[architecture]
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
if redirected_architecture is not None:
architecture = redirected_architecture
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
loaded_model_names += loaded_model_names_