add: LoRA Encoder

This commit is contained in:
yjy415
2025-11-18 21:29:35 +08:00
parent 3f9e9cad9d
commit 2d23c897c2
3 changed files with 2 additions and 103 deletions

View File

@@ -106,38 +106,6 @@ class FluxImagePipeline(BasePipeline):
def enable_lora_magic(self):
pass
def load_lora(self, model, lora_config, alpha=1, hotload=False):
if isinstance(lora_config, str):
path = lora_config
else:
lora_config.download_if_necessary()
path = lora_config.path
state_dict = load_state_dict(path, torch_dtype=self.torch_dtype, device="cpu")
loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
state_dict = loader.convert_state_dict(state_dict)
loaded_count = 0
for key in tqdm(state_dict, desc="Applying LoRA"):
if ".lora_A." in key:
layer_name = key.split(".lora_A.")[0]
module = model
try:
parts = layer_name.split(".")
for part in parts:
if part.isdigit():
module = module[int(part)]
else:
module = getattr(module, part)
except AttributeError:
continue
w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype)
w_b_key = key.replace("lora_A", "lora_B")
if w_b_key not in state_dict: continue
w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype)
delta_w = torch.mm(w_b, w_a)
module.weight.data += delta_w * alpha
loaded_count += 1
@staticmethod
def from_pretrained(
torch_dtype: torch.dtype = torch.bfloat16,