mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
13 lines
402 B
Python
13 lines
402 B
Python
from diffsynth import load_state_dict
|
|
import math, torch
|
|
|
|
|
|
def load_lora(file_path, device):
|
|
sd = load_state_dict(file_path, torch_dtype=torch.bfloat16, device=device)
|
|
scale = math.sqrt(sd["lora_unet_single_blocks_9_modulation_lin.alpha"] / sd["lora_unet_single_blocks_9_modulation_lin.lora_down.weight"].shape[0])
|
|
if scale != 1:
|
|
sd = {i: sd[i] * scale for i in sd}
|
|
return sd
|
|
|
|
|