support flux-fp8

This commit is contained in:
Artiprocher
2024-09-19 10:32:16 +08:00
parent a9fbfa108f
commit 091df1f1e7
4 changed files with 91 additions and 125 deletions

View File

@@ -415,8 +415,10 @@ class ModelManager:
break
def load_model(self, file_path, model_names=None):
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
print(f"Loading models from: {file_path}")
if device is None: device = self.device
if torch_dtype is None: torch_dtype = self.torch_dtype
if os.path.isfile(file_path):
state_dict = load_state_dict(file_path)
else:
@@ -425,7 +427,7 @@ class ModelManager:
if model_detector.match(file_path, state_dict):
model_names, models = model_detector.load(
file_path, state_dict,
device=self.device, torch_dtype=self.torch_dtype,
device=device, torch_dtype=torch_dtype,
allowed_model_names=model_names, model_manager=self
)
for model_name, model in zip(model_names, models):
@@ -438,9 +440,9 @@ class ModelManager:
print(f" We cannot detect the model type. No models are loaded.")
def load_models(self, file_path_list, model_names=None):
def load_models(self, file_path_list, model_names=None, device=None, torch_dtype=None):
for file_path in file_path_list:
self.load_model(file_path, model_names)
self.load_model(file_path, model_names, device=device, torch_dtype=torch_dtype)
def fetch_model(self, model_name, file_path=None, require_model_path=False):