mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 18:28:10 +00:00
add: LoRA Encoder
This commit is contained in:
@@ -468,80 +468,13 @@ class LoRAEmbedder(torch.nn.Module):
|
|||||||
"type": suffix,
|
"type": suffix,
|
||||||
})
|
})
|
||||||
return lora_patterns
|
return lora_patterns
|
||||||
|
|
||||||
def get_lora_param_pair(self, lora, name, dim, device, dtype):
|
|
||||||
key_A = name + ".lora_A.default.weight"
|
|
||||||
key_B = name + ".lora_B.default.weight"
|
|
||||||
if key_A in lora and key_B in lora:
|
|
||||||
return lora[key_A], lora[key_B]
|
|
||||||
|
|
||||||
if "to_qkv" in name:
|
|
||||||
base_name = name.replace("to_qkv", "")
|
|
||||||
suffixes = ["to_q", "to_k", "to_v"]
|
|
||||||
|
|
||||||
found_As = []
|
|
||||||
found_Bs = []
|
|
||||||
|
|
||||||
all_found = True
|
|
||||||
for suffix in suffixes:
|
|
||||||
sub_name = base_name + suffix
|
|
||||||
k_A = sub_name + ".lora_A.default.weight"
|
|
||||||
k_B = sub_name + ".lora_B.default.weight"
|
|
||||||
|
|
||||||
if k_A in lora and k_B in lora:
|
|
||||||
found_As.append(lora[k_A])
|
|
||||||
found_Bs.append(lora[k_B])
|
|
||||||
else:
|
|
||||||
all_found = False
|
|
||||||
break
|
|
||||||
if all_found:
|
|
||||||
pass
|
|
||||||
|
|
||||||
rank = 16
|
|
||||||
for k, v in lora.items():
|
|
||||||
if "lora_A" in k:
|
|
||||||
rank = v.shape[0]
|
|
||||||
device = v.device
|
|
||||||
dtype = v.dtype
|
|
||||||
break
|
|
||||||
|
|
||||||
lora_A = torch.zeros((rank, dim[0]), device=device, dtype=dtype)
|
|
||||||
lora_B = torch.zeros((dim[1], rank), device=device, dtype=dtype)
|
|
||||||
|
|
||||||
return lora_A, lora_B
|
|
||||||
|
|
||||||
def forward(self, lora):
|
def forward(self, lora):
|
||||||
lora_emb = []
|
lora_emb = []
|
||||||
device = None
|
|
||||||
dtype = None
|
|
||||||
for v in lora.values():
|
|
||||||
device = v.device
|
|
||||||
dtype = v.dtype
|
|
||||||
break
|
|
||||||
|
|
||||||
for lora_pattern in self.lora_patterns:
|
for lora_pattern in self.lora_patterns:
|
||||||
name, layer_type = lora_pattern["name"], lora_pattern["type"]
|
name, layer_type = lora_pattern["name"], lora_pattern["type"]
|
||||||
dim = lora_pattern["dim"]
|
lora_A = lora[name + ".lora_A.weight"]
|
||||||
|
lora_B = lora[name + ".lora_B.weight"]
|
||||||
lora_A, lora_B = self.get_lora_param_pair(lora, name, dim, device, dtype)
|
|
||||||
|
|
||||||
if "to_qkv" in name and (lora_A is None or (torch.equal(lora_A, torch.zeros_like(lora_A)))):
|
|
||||||
base_name = name.replace("to_qkv", "")
|
|
||||||
try:
|
|
||||||
q_name = base_name + "to_q"
|
|
||||||
k_name = base_name + "to_k"
|
|
||||||
v_name = base_name + "to_v"
|
|
||||||
|
|
||||||
real_A = lora[q_name + ".lora_A.default.weight"]
|
|
||||||
B_q = lora[q_name + ".lora_B.default.weight"]
|
|
||||||
B_k = lora[k_name + ".lora_B.default.weight"]
|
|
||||||
B_v = lora[v_name + ".lora_B.default.weight"]
|
|
||||||
real_B = torch.cat([B_q, B_k, B_v], dim=0)
|
|
||||||
|
|
||||||
lora_A, lora_B = real_A, real_B
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
|
lora_out = self.model_dict[name.replace(".", "___")](lora_A, lora_B)
|
||||||
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
|
lora_out = self.proj_dict[layer_type.replace(".", "___")](lora_out)
|
||||||
lora_emb.append(lora_out)
|
lora_emb.append(lora_out)
|
||||||
|
|||||||
@@ -106,38 +106,6 @@ class FluxImagePipeline(BasePipeline):
|
|||||||
def enable_lora_magic(self):
|
def enable_lora_magic(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def load_lora(self, model, lora_config, alpha=1, hotload=False):
|
|
||||||
if isinstance(lora_config, str):
|
|
||||||
path = lora_config
|
|
||||||
else:
|
|
||||||
lora_config.download_if_necessary()
|
|
||||||
path = lora_config.path
|
|
||||||
|
|
||||||
state_dict = load_state_dict(path, torch_dtype=self.torch_dtype, device="cpu")
|
|
||||||
loader = self.lora_loader(torch_dtype=self.torch_dtype, device=self.device)
|
|
||||||
state_dict = loader.convert_state_dict(state_dict)
|
|
||||||
loaded_count = 0
|
|
||||||
for key in tqdm(state_dict, desc="Applying LoRA"):
|
|
||||||
if ".lora_A." in key:
|
|
||||||
layer_name = key.split(".lora_A.")[0]
|
|
||||||
module = model
|
|
||||||
try:
|
|
||||||
parts = layer_name.split(".")
|
|
||||||
for part in parts:
|
|
||||||
if part.isdigit():
|
|
||||||
module = module[int(part)]
|
|
||||||
else:
|
|
||||||
module = getattr(module, part)
|
|
||||||
except AttributeError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
w_a = state_dict[key].to(device=module.weight.device, dtype=module.weight.dtype)
|
|
||||||
w_b_key = key.replace("lora_A", "lora_B")
|
|
||||||
if w_b_key not in state_dict: continue
|
|
||||||
w_b = state_dict[w_b_key].to(device=module.weight.device, dtype=module.weight.dtype)
|
|
||||||
delta_w = torch.mm(w_b, w_a)
|
|
||||||
module.weight.data += delta_w * alpha
|
|
||||||
loaded_count += 1
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
torch_dtype: torch.dtype = torch.bfloat16,
|
torch_dtype: torch.dtype = torch.bfloat16,
|
||||||
|
|||||||
@@ -18,12 +18,10 @@ pipe.enable_lora_magic()
|
|||||||
pipe.load_lora(
|
pipe.load_lora(
|
||||||
pipe.dit,
|
pipe.dit,
|
||||||
ModelConfig(model_id="cancel13/cxsk", origin_file_pattern="30.safetensors"),
|
ModelConfig(model_id="cancel13/cxsk", origin_file_pattern="30.safetensors"),
|
||||||
hotload=True,
|
|
||||||
)
|
)
|
||||||
pipe.load_lora(
|
pipe.load_lora(
|
||||||
pipe.dit,
|
pipe.dit,
|
||||||
ModelConfig(model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", origin_file_pattern="merged_lora.safetensors"),
|
ModelConfig(model_id="DiffSynth-Studio/ArtAug-lora-FLUX.1dev-v1", origin_file_pattern="merged_lora.safetensors"),
|
||||||
hotload=True,
|
|
||||||
)
|
)
|
||||||
image = pipe(prompt="a cat", seed=0)
|
image = pipe(prompt="a cat", seed=0)
|
||||||
image.save("image_fused.jpg")
|
image.save("image_fused.jpg")
|
||||||
|
|||||||
Reference in New Issue
Block a user