This commit is contained in:
Artiprocher
2025-11-20 17:44:00 +08:00
parent eeb55a0ce6
commit 96daa30bcc
7 changed files with 72 additions and 24 deletions

View File

@@ -28,27 +28,52 @@ class ModelConfig:
if self.path is None and self.model_id is None:
raise ValueError(f"""No valid model files. Please use `ModelConfig(path="xxx")` or `ModelConfig(model_id="xxx/yyy", origin_file_pattern="zzz")`. `skip_download=True` only supports the first one.""")
def download(self):
origin_file_pattern = self.origin_file_pattern + ("*" if self.origin_file_pattern.endswith("/") else "")
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
def parse_original_file_pattern(self):
if self.origin_file_pattern is None or self.origin_file_pattern == "":
return "*"
elif self.origin_file_pattern.endswith("/"):
return self.origin_file_pattern + "*"
else:
return self.origin_file_pattern
def parse_download_resource(self):
if self.download_resource is None:
if os.environ.get('DIFFSYNTH_DOWNLOAD_RESOURCE') is not None:
self.download_resource = os.environ.get('DIFFSYNTH_DOWNLOAD_RESOURCE')
return os.environ.get('DIFFSYNTH_DOWNLOAD_RESOURCE')
else:
self.download_resource = "modelscope"
if self.download_resource.lower() == "modelscope":
return "modelscope"
else:
return self.download_resource
def parse_skip_download(self):
if self.skip_download is None:
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None:
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "true":
return True
elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD').lower() == "false":
return False
else:
return False
else:
return self.skip_download
def download(self):
origin_file_pattern = self.parse_original_file_pattern()
downloaded_files = glob.glob(origin_file_pattern, root_dir=os.path.join(self.local_model_path, self.model_id))
download_resource = self.parse_download_resource()
if download_resource.lower() == "modelscope":
snapshot_download(
self.model_id,
local_dir=os.path.join(self.local_model_path, self.model_id),
allow_file_pattern=self.origin_file_pattern,
allow_file_pattern=origin_file_pattern,
ignore_file_pattern=downloaded_files,
local_files_only=False
)
elif self.download_resource.lower() == "huggingface":
elif download_resource.lower() == "huggingface":
hf_snapshot_download(
self.model_id,
local_dir=os.path.join(self.local_model_path, self.model_id),
allow_patterns=self.origin_file_pattern,
allow_patterns=origin_file_pattern,
ignore_patterns=downloaded_files,
local_files_only=False
)
@@ -58,15 +83,8 @@ class ModelConfig:
def require_downloading(self):
if self.path is not None:
return False
if self.skip_download is None:
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') is not None:
if os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') in ["True", "true"]:
self.skip_download = True
elif os.environ.get('DIFFSYNTH_SKIP_DOWNLOAD') in ["False", "false"]:
self.skip_download = False
else:
self.skip_download = False
return not self.skip_download
skip_download = self.parse_skip_download()
return not skip_download
def reset_local_model_path(self):
if os.environ.get('DIFFSYNTH_MODEL_BASE_PATH') is not None:
@@ -79,7 +97,10 @@ class ModelConfig:
self.reset_local_model_path()
if self.require_downloading():
self.download()
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
if self.origin_file_pattern is None or self.origin_file_pattern == "":
self.path = os.path.join(self.local_model_path, self.model_id)
else:
self.path = glob.glob(os.path.join(self.local_model_path, self.model_id, self.origin_file_pattern))
if isinstance(self.path, list) and len(self.path) == 1:
self.path = self.path[0]

View File

@@ -60,9 +60,10 @@ class DiskMap:
if self.rename_dict is not None: name = self.rename_dict[name]
file_id = self.name_map[name]
param = self.files[file_id].get_tensor(name)
if self.torch_dtype is not None:
if self.torch_dtype is not None and isinstance(param, torch.Tensor):
param = param.to(self.torch_dtype)
self.num_params += param.numel()
if isinstance(param, torch.Tensor):
self.num_params += param.numel()
if self.num_params > self.buffer_size:
self.flush_files()
return param

View File

@@ -0,0 +1,20 @@
import torch
from typing import Dict, List
def merge_lora_weight(tensors_A, tensors_B):
lora_A = torch.concat(tensors_A, dim=0)
lora_B = torch.concat(tensors_B, dim=1)
return lora_A, lora_B
def merge_lora(loras: List[Dict[str, torch.Tensor]]):
lora_merged = {}
keys = [i for i in loras[0].keys() if ".lora_A." in i]
for key in keys:
tensors_A = [lora[key] for lora in loras]
tensors_B = [lora[key.replace(".lora_A.", ".lora_B.")] for lora in loras]
lora_A, lora_B = merge_lora_weight(tensors_A, tensors_B)
lora_merged[key] = lora_A
lora_merged[key.replace(".lora_A.", ".lora_B.")] = lora_B
return lora_merged