lora merger

This commit is contained in:
Artiprocher
2025-04-21 15:48:25 +08:00
parent 04260801a2
commit 44da204dbd
7 changed files with 516 additions and 30 deletions

View File

@@ -62,25 +62,26 @@ def load_state_dict_from_folder(file_path, torch_dtype=None):
return state_dict
def load_state_dict(file_path, torch_dtype=None):
def load_state_dict(file_path, torch_dtype=None, device="cpu"):
if file_path.endswith(".safetensors"):
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype)
return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype, device=device)
else:
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype)
return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype, device=device)
def load_state_dict_from_safetensors(file_path, torch_dtype=None):
def load_state_dict_from_safetensors(file_path, torch_dtype=None, device="cpu"):
state_dict = {}
with safe_open(file_path, framework="pt", device="cpu") as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
if torch_dtype is not None:
state_dict[k] = state_dict[k].to(torch_dtype)
state_dict[k] = state_dict[k].to(device)
return state_dict
def load_state_dict_from_bin(file_path, torch_dtype=None):
state_dict = torch.load(file_path, map_location="cpu", weights_only=True)
def load_state_dict_from_bin(file_path, torch_dtype=None, device="cpu"):
state_dict = torch.load(file_path, map_location=device, weights_only=True)
if torch_dtype is not None:
for i in state_dict:
if isinstance(state_dict[i], torch.Tensor):