support stepvideo

This commit is contained in:
Artiprocher
2025-02-17 17:32:25 +08:00
parent 7434ec8fcd
commit 3681adc5ac
12 changed files with 2866 additions and 8 deletions

View File

@@ -158,7 +158,7 @@ class ModelDetectorFromSingleFile:
def match(self, file_path="", state_dict={}):
if os.path.isdir(file_path):
if isinstance(file_path, str) and os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
@@ -200,7 +200,7 @@ class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
def match(self, file_path="", state_dict={}):
if os.path.isdir(file_path):
if isinstance(file_path, str) and os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
@@ -243,7 +243,7 @@ class ModelDetectorFromHuggingfaceFolder:
def match(self, file_path="", state_dict={}):
if os.path.isfile(file_path):
if not isinstance(file_path, str) or os.path.isfile(file_path):
return False
file_list = os.listdir(file_path)
if "config.json" not in file_list:
@@ -284,7 +284,7 @@ class ModelDetectorFromPatchedSingleFile:
def match(self, file_path="", state_dict={}):
if os.path.isdir(file_path):
if not isinstance(file_path, str) or os.path.isdir(file_path):
return False
if len(state_dict) == 0:
state_dict = load_state_dict(file_path)
@@ -390,7 +390,11 @@ class ModelManager:
print(f"Loading models from: {file_path}")
if device is None: device = self.device
if torch_dtype is None: torch_dtype = self.torch_dtype
if os.path.isfile(file_path):
if isinstance(file_path, list):
state_dict = {}
for path in file_path:
state_dict.update(load_state_dict(path))
elif os.path.isfile(file_path):
state_dict = load_state_dict(file_path)
else:
state_dict = None