mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-18 22:08:13 +00:00
add: LoRA Encoder
This commit is contained in:
@@ -468,80 +468,13 @@ class LoRAEmbedder(torch.nn.Module):
|
||||
"type": suffix,
|
||||
})
|
||||
return lora_patterns
|
||||
|
||||
def get_lora_param_pair(self, lora, name, dim, device, dtype):
|
||||
key_A = name + ".lora_A.default.weight"
|
||||
key_B = name + ".lora_B.default.weight"
|
||||
if key_A in lora and key_B in lora:
|
||||
return lora[key_A], lora[key_B]
|
||||
|
||||
if "to_qkv" in name:
|
||||
base_name = name.replace("to_qkv", "")
|
||||
suffixes = ["to_q", "to_k", "to_v"]
|
||||
|
||||
found_As = []
|
||||
found_Bs = []
|
||||
|
||||
all_found = True
|
||||
for suffix in suffixes:
|
||||
sub_name = base_name + suffix
|
||||
k_A = sub_name + ".lora_A.default.weight"
|
||||
k_B = sub_name + ".lora_B.default.weight"
|
||||
|
||||
if k_A in lora and k_B in lora:
|
||||
found_As.append(lora[k_A])
|
||||
found_Bs.append(lora[k_B])
|
||||
else:
|
||||
all_found = False
|
||||
break
|
||||
if all_found:
|
||||
pass
|
||||
|
||||
rank = 16
|
||||
for k, v in lora.items():
|
||||
if "lora_A" in k:
|
||||
rank = v.shape[0]
|
||||
device = v.device
|
||||
dtype = v.dtype
|
||||
break
|
||||
|
||||
lora_A = torch.zeros((rank, dim[0]), device=device, dtype=dtype)
|
||||
lora_B = torch.zeros((dim[1], rank), device=device, dtype=dtype)
|
||||
|
||||
return lora_A, lora_B
|
||||
|
||||
def forward(self, lora):
|
||||
lora_emb = []
|
||||
device = None
|
||||
dtype = None
|
||||
for v in lora.values():
|
||||
device = v.device
|
||||
dtype = v.dtype
|
||||
break
|
||||
|
||||
for lora_pattern in self.lora_patterns:
|
||||
name, layer_type = lora_pattern["name"], lora_pattern["type"]
|
||||
dim = lora_pattern["dim"]
|
||||
|
||||
lora_A, lora_B = self.get_lora_param_pair(lora, name, dim, device, dtype)
|
||||
|
||||
if "to_qkv" in name and (lora_A is None or (torch.equal(lora_A, torch.zeros_like(lora_A)))):
|
||||
base_name = name.replace("to_qkv", "")
|
||||
try:
|
||||
q_name = base_name + "to_q"
|
||||
k_name = base_name + "to_k"
|
||||
v_name = base_name + "to_v"
|
||||
|
||||
real_A = lora[q_name + ".lora_A.default.weight"]
|
||||
B_q = lora[q_name + ".lora_B.default.weight"]
|
||||
B_k = lora[k_name + ".lora_B.default.weight"]
|
||||
B_v = lora[v_name + ".lora_B.default.weight"]
|
||||
real_B = torch.cat([B_q, B_k, B_v], dim=0)
|
||||
|
||||
lora_A, lora_B = real_A, real_B
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
lora_A = lora[name + ".lora_A.weight"]
|
||||
lora_B = lora[name + ".lora_B.weight"]
|
||||
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
|
||||
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
|
||||
lora_emb.append(lora_out)
|
||||
|
||||
@@ -106,38 +106,6 @@ class FluxImagePipeline(BasePipeline):
|
||||
def enable_lora_magic(self):
|
||||
pass
|
||||
|
||||
def load_lora(self, model, lora_config, alpha=1, hotload=False):
|
||||
if isinstance(lora_config, str):
|
||||
path = lora_config
|
||||
else:
|
||||
lora_config.download_if_necessary()
|
||||
path = lora_config.path
|
||||
|
||||
state_dict = load_state_dict(path, torch_dtype=self.torch_dtype, device="cpu")
|
||||
loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
|
||||
state_dict = loader.convert_state_dict(state_dict)
|
||||
loaded_count = 0
|
||||
for key in tqdm(state_dict, desc="Applying LoRA"):
|
||||
if ".lora_A." in key:
|
||||
layer_name = key.split(".lora_A.")[0]
|
||||
module = model
|
||||
try:
|
||||
parts = layer_name.split(".")
|
||||
for part in parts:
|
||||
if part.isdigit():
|
||||
module = module[int(part)]
|
||||
else:
|
||||
module = getattr(module, part)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype)
|
||||
w_b_key = key.replace("lora_A", "lora_B")
|
||||
if w_b_key not in state_dict: continue
|
||||
w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype)
|
||||
delta_w = torch.mm(w_b, w_a)
|
||||
module.weight.data += delta_w * alpha
|
||||
loaded_count += 1
|
||||
@staticmethod
|
||||
def from_pretrained(
|
||||
torch_dtype: torch.dtype = torch.bfloat16,
|
||||
|
||||
@@ -18,12 +18,10 @@ pipe.enable_lora_magic()
|
||||
pipe.load_lora(
|
||||
pipe.dit,
|
||||
ModelConfig(model_id="cancel13/cxsk", origin_file_pattern="30.safetensors"),
|
||||
hotload=True,
|
||||
)
|
||||
pipe.load_lora(
|
||||
pipe.dit,
|
||||
ModelConfig(model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", origin_file_pattern="merged_lora.safetensors"),
|
||||
hotload=True,
|
||||
)
|
||||
image = pipe(prompt="a cat", seed=0)
|
||||
image.save("image_fused.jpg")
|
||||
|
||||
Reference in New Issue
Block a user