mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-20 07:18:14 +00:00
update
This commit is contained in:
20
diffsynth/utils/lora/merge.py
Normal file
20
diffsynth/utils/lora/merge.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import torch
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def merge_lora_weight(tensors_A, tensors_B):
|
||||
lora_A = torch.concat(tensors_A, dim=0)
|
||||
lora_B = torch.concat(tensors_B, dim=1)
|
||||
return lora_A, lora_B
|
||||
|
||||
|
||||
def merge_lora(loras: List[Dict[str, torch.Tensor]]):
|
||||
lora_merged = {}
|
||||
keys = [i for i in loras[0].keys() if ".lora_A." in i]
|
||||
for key in keys:
|
||||
tensors_A = [lora[key] for lora in loras]
|
||||
tensors_B = [lora[key.replace(".lora_A.", ".lora_B.")] for lora in loras]
|
||||
lora_A, lora_B = merge_lora_weight(tensors_A, tensors_B)
|
||||
lora_merged[key] = lora_A
|
||||
lora_merged[key.replace(".lora_A.", ".lora_B.")] = lora_B
|
||||
return lora_merged
|
||||
Reference in New Issue
Block a user