Files
DiffSynth-Studio/lora/utils.py
2025-06-23 17:34:30 +08:00

13 lines
402 B
Python

from diffsynth import load_state_dict
import math, torch
def load_lora(file_path, device):
sd = load_state_dict(file_path, torch_dtype=torch.bfloat16, device=device)
scale = math.sqrt(sd["lora_unet_single_blocks_9_modulation_lin.alpha"] / sd["lora_unet_single_blocks_9_modulation_lin.lora_down.weight"].shape[0])
if scale != 1:
sd = {i: sd[i] * scale for i in sd}
return sd