improve lora loading efficiency

This commit is contained in:
Artiprocher
2025-03-14 15:15:37 +08:00
parent 3acd9c73be
commit 1fd3d67379

View File

@@ -195,86 +195,6 @@ class FluxLoRAFromCivitai(LoRAFromCivitai):
"txt.mod": "txt_mod",
}
class GeneralLoRAFromPeft_:
def __init__(self):
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel]
def fetch_device_dtype_from_state_dict(self, state_dict):
device, torch_dtype = None, None
for name, param in state_dict.items():
device, torch_dtype = param.device, param.dtype
break
return device, torch_dtype
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict)
if torch_dtype == torch.float8_e4m3fn:
torch_dtype = torch.float32
state_dict_ = {}
for key in state_dict:
if ".lora_B." not in key:
continue
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze(3).squeeze(2)
weight_down = weight_down.squeeze(3).squeeze(2)
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
lora_weight = alpha * torch.mm(weight_up, weight_down)
keys = key.split(".")
if len(keys) > keys.index("lora_B") + 2:
keys.pop(keys.index("lora_B") + 1)
keys.pop(keys.index("lora_B"))
target_name = ".".join(keys)
if target_name.startswith("diffusion_model."):
target_name = target_name[len("diffusion_model."):]
if target_name not in target_state_dict:
return {}
state_dict_[target_name] = lora_weight.cpu()
return state_dict_
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
state_dict_model = model.state_dict()
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model)
if len(state_dict_lora) > 0:
print(f" {len(state_dict_lora)} tensors are updated.")
for name in state_dict_lora:
if state_dict_model[name].dtype == torch.float8_e4m3fn:
weight = state_dict_model[name].to(torch.float32)
lora_weight = state_dict_lora[name].to(
dtype=torch.float32,
device=state_dict_model[name].device
)
state_dict_model[name] = (weight + lora_weight).to(
dtype=state_dict_model[name].dtype,
device=state_dict_model[name].device
)
else:
state_dict_model[name] += state_dict_lora[name].to(
dtype=state_dict_model[name].dtype,
device=state_dict_model[name].device
)
model.load_state_dict(state_dict_model)
def match(self, model, state_dict_lora):
for model_class in self.supported_model_classes:
if not isinstance(model, model_class):
continue
state_dict_model = model.state_dict()
try:
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model)
if len(state_dict_lora_) > 0:
return "", ""
except:
pass
return None
class GeneralLoRAFromPeft: