mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-21 16:18:13 +00:00
support lora
This commit is contained in:
60
diffsynth/models/sd_lora.py
Normal file
60
diffsynth/models/sd_lora.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import torch
|
||||
from .sd_unet import SDUNetStateDictConverter, SDUNet
|
||||
from .sd_text_encoder import SDTextEncoderStateDictConverter, SDTextEncoder
|
||||
|
||||
|
||||
class SDLoRA:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def convert_state_dict(self, state_dict, lora_prefix="lora_unet_", alpha=1.0, device="cuda"):
|
||||
special_keys = {
|
||||
"down.blocks": "down_blocks",
|
||||
"up.blocks": "up_blocks",
|
||||
"mid.block": "mid_block",
|
||||
"proj.in": "proj_in",
|
||||
"proj.out": "proj_out",
|
||||
"transformer.blocks": "transformer_blocks",
|
||||
"to.q": "to_q",
|
||||
"to.k": "to_k",
|
||||
"to.v": "to_v",
|
||||
"to.out": "to_out",
|
||||
}
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if ".lora_up" not in key:
|
||||
continue
|
||||
if not key.startswith(lora_prefix):
|
||||
continue
|
||||
weight_up = state_dict[key].to(device="cuda", dtype=torch.float16)
|
||||
weight_down = state_dict[key.replace(".lora_up", ".lora_down")].to(device="cuda", dtype=torch.float16)
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2).to(torch.float32)
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
target_name = key.split(".")[0].replace("_", ".")[len(lora_prefix):] + ".weight"
|
||||
for special_key in special_keys:
|
||||
target_name = target_name.replace(special_key, special_keys[special_key])
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
return state_dict_
|
||||
|
||||
def add_lora_to_unet(self, unet: SDUNet, state_dict_lora, alpha=1.0, device="cuda"):
|
||||
state_dict_unet = unet.state_dict()
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_unet_", alpha=alpha, device=device)
|
||||
state_dict_lora = SDUNetStateDictConverter().from_diffusers(state_dict_lora)
|
||||
if len(state_dict_lora) > 0:
|
||||
for name in state_dict_lora:
|
||||
state_dict_unet[name] += state_dict_lora[name].to(device=device)
|
||||
unet.load_state_dict(state_dict_unet)
|
||||
|
||||
def add_lora_to_text_encoder(self, text_encoder: SDTextEncoder, state_dict_lora, alpha=1.0, device="cuda"):
|
||||
state_dict_text_encoder = text_encoder.state_dict()
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, lora_prefix="lora_te_", alpha=alpha, device=device)
|
||||
state_dict_lora = SDTextEncoderStateDictConverter().from_diffusers(state_dict_lora)
|
||||
if len(state_dict_lora) > 0:
|
||||
for name in state_dict_lora:
|
||||
state_dict_text_encoder[name] += state_dict_lora[name].to(device=device)
|
||||
text_encoder.load_state_dict(state_dict_text_encoder)
|
||||
|
||||
Reference in New Issue
Block a user