add: LoRA Encoder

This commit is contained in:
yjy415
2025-11-18 21:29:35 +08:00
parent 3f9e9cad9d
commit 2d23c897c2
3 changed files with 2 additions and 103 deletions

View File

@@ -468,80 +468,13 @@ class LoRAEmbedder(torch.nn.Module):
"type": suffix,
})
return lora_patterns
def get_lora_param_pair(self, lora, name, dim, device, dtype):
key_A = name + ".lora_A.default.weight"
key_B = name + ".lora_B.default.weight"
if key_A in lora and key_B in lora:
return lora[key_A], lora[key_B]
if "to_qkv" in name:
base_name = name.replace("to_qkv", "")
suffixes = ["to_q", "to_k", "to_v"]
found_As = []
found_Bs = []
all_found = True
for suffix in suffixes:
sub_name = base_name + suffix
k_A = sub_name + ".lora_A.default.weight"
k_B = sub_name + ".lora_B.default.weight"
if k_A in lora and k_B in lora:
found_As.append(lora[k_A])
found_Bs.append(lora[k_B])
else:
all_found = False
break
if all_found:
pass
rank = 16
for k, v in lora.items():
if "lora_A" in k:
rank = v.shape[0]
device = v.device
dtype = v.dtype
break
lora_A = torch.zeros((rank, dim[0]), device=device, dtype=dtype)
lora_B = torch.zeros((dim[1], rank), device=device, dtype=dtype)
return lora_A, lora_B
def forward(self, lora):
lora_emb = []
device = None
dtype = None
for v in lora.values():
device = v.device
dtype = v.dtype
break
for lora_pattern in self.lora_patterns:
name, layer_type = lora_pattern["name"], lora_pattern["type"]
dim = lora_pattern["dim"]
lora_A, lora_B = self.get_lora_param_pair(lora, name, dim, device, dtype)
if "to_qkv" in name and (lora_A is None or (torch.equal(lora_A, torch.zeros_like(lora_A)))):
base_name = name.replace("to_qkv", "")
try:
q_name = base_name + "to_q"
k_name = base_name + "to_k"
v_name = base_name + "to_v"
real_A = lora[q_name + ".lora_A.default.weight"]
B_q = lora[q_name + ".lora_B.default.weight"]
B_k = lora[k_name + ".lora_B.default.weight"]
B_v = lora[v_name + ".lora_B.default.weight"]
real_B = torch.cat([B_q, B_k, B_v], dim=0)
lora_A, lora_B = real_A, real_B
except KeyError:
pass
lora_A = lora[name + ".lora_A.weight"]
lora_B = lora[name + ".lora_B.weight"]
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
lora_emb.append(lora_out)

View File

@@ -106,38 +106,6 @@ class FluxImagePipeline(BasePipeline):
def enable_lora_magic(self):
pass
def load_lora(self, model, lora_config, alpha=1, hotload=False):
if isinstance(lora_config, str):
path = lora_config
else:
lora_config.download_if_necessary()
path = lora_config.path
state_dict = load_state_dict(path, torch_dtype=self.torch_dtype, device="cpu")
loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
state_dict = loader.convert_state_dict(state_dict)
loaded_count = 0
for key in tqdm(state_dict, desc="Applying LoRA"):
if ".lora_A." in key:
layer_name = key.split(".lora_A.")[0]
module = model
try:
parts = layer_name.split(".")
for part in parts:
if part.isdigit():
module = module[int(part)]
else:
module = getattr(module, part)
except AttributeError:
continue
w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype)
w_b_key = key.replace("lora_A", "lora_B")
if w_b_key not in state_dict: continue
w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype)
delta_w = torch.mm(w_b, w_a)
module.weight.data += delta_w * alpha
loaded_count += 1
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,

View File

@@ -18,12 +18,10 @@ pipe.enable_lora_magic()
pipe.load_lora(
pipe.dit,
ModelConfig(model_id="cancel13/cxsk", origin_file_pattern="30.safetensors"),
hotload=True,
)
pipe.load_lora(
pipe.dit,
ModelConfig(model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", origin_file_pattern="merged_lora.safetensors"),
hotload=True,
)
image = pipe(prompt="a cat", seed=0)
image.save("image_fused.jpg")