mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
support i2L
This commit is contained in:
@@ -1 +1,3 @@
|
||||
from .general import GeneralLoRALoader
|
||||
from .merge import merge_lora
|
||||
from .reset_rank import reset_lora_rank
|
||||
20
diffsynth/utils/lora/reset_rank.py
Normal file
20
diffsynth/utils/lora/reset_rank.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import torch
|
||||
|
||||
def decomposite(tensor_A, tensor_B, rank):
|
||||
dtype, device = tensor_A.dtype, tensor_A.device
|
||||
weight = tensor_B @ tensor_A
|
||||
U, S, V = torch.pca_lowrank(weight.float(), q=rank)
|
||||
tensor_A = (V.T).to(dtype=dtype, device=device).contiguous()
|
||||
tensor_B = (U @ torch.diag(S)).to(dtype=dtype, device=device).contiguous()
|
||||
return tensor_A, tensor_B
|
||||
|
||||
def reset_lora_rank(lora, rank):
|
||||
lora_merged = {}
|
||||
keys = [i for i in lora.keys() if ".lora_A." in i]
|
||||
for key in keys:
|
||||
tensor_A = lora[key]
|
||||
tensor_B = lora[key.replace(".lora_A.", ".lora_B.")]
|
||||
tensor_A, tensor_B = decomposite(tensor_A, tensor_B, rank)
|
||||
lora_merged[key] = tensor_A
|
||||
lora_merged[key.replace(".lora_A.", ".lora_B.")] = tensor_B
|
||||
return lora_merged
|
||||
Reference in New Issue
Block a user