lora retrieval

This commit is contained in:
Artiprocher
2025-06-23 17:34:30 +08:00
parent 44da204dbd
commit 50d2c86ae5
14 changed files with 698 additions and 462 deletions

12
lora/utils.py Normal file
View File

@@ -0,0 +1,12 @@
from diffsynth import load_state_dict
import math, torch
def load_lora(file_path, device):
sd = load_state_dict(file_path, torch_dtype=torch.bfloat16, device=device)
scale = math.sqrt(sd["lora_unet_single_blocks_9_modulation_lin.alpha"] / sd["lora_unet_single_blocks_9_modulation_lin.lora_down.weight"].shape[0])
if scale != 1:
sd = {i: sd[i] * scale for i in sd}
return sd