mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:23:43 +00:00
46 lines
1.9 KiB
Python
46 lines
1.9 KiB
Python
import torch
|
|
|
|
|
|
|
|
class GeneralLoRALoader:
|
|
def __init__(self, device="cpu", torch_dtype=torch.float32):
|
|
self.device = device
|
|
self.torch_dtype = torch_dtype
|
|
|
|
|
|
def get_name_dict(self, lora_state_dict):
|
|
lora_name_dict = {}
|
|
for key in lora_state_dict:
|
|
if ".lora_B." not in key:
|
|
continue
|
|
keys = key.split(".")
|
|
if len(keys) > keys.index("lora_B") + 2:
|
|
keys.pop(keys.index("lora_B") + 1)
|
|
keys.pop(keys.index("lora_B"))
|
|
if keys[0] == "diffusion_model":
|
|
keys.pop(0)
|
|
keys.pop(-1)
|
|
target_name = ".".join(keys)
|
|
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
|
return lora_name_dict
|
|
|
|
|
|
def load(self, model: torch.nn.Module, state_dict_lora, alpha=1.0):
|
|
updated_num = 0
|
|
lora_name_dict = self.get_name_dict(state_dict_lora)
|
|
for name, module in model.named_modules():
|
|
if name in lora_name_dict:
|
|
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=self.device, dtype=self.torch_dtype)
|
|
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=self.device, dtype=self.torch_dtype)
|
|
if len(weight_up.shape) == 4:
|
|
weight_up = weight_up.squeeze(3).squeeze(2)
|
|
weight_down = weight_down.squeeze(3).squeeze(2)
|
|
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
|
else:
|
|
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
|
state_dict = module.state_dict()
|
|
state_dict["weight"] = state_dict["weight"].to(device=self.device, dtype=self.torch_dtype) + weight_lora
|
|
module.load_state_dict(state_dict)
|
|
updated_num += 1
|
|
print(f"{updated_num} tensors are updated by LoRA.")
|