mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-24 10:18:12 +00:00
@@ -195,85 +195,73 @@ class FluxLoRAFromCivitai(LoRAFromCivitai):
|
|||||||
"txt.mod": "txt_mod",
|
"txt.mod": "txt_mod",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class GeneralLoRAFromPeft:
|
class GeneralLoRAFromPeft:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel]
|
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel]
|
||||||
|
|
||||||
|
|
||||||
def fetch_device_dtype_from_state_dict(self, state_dict):
|
def get_name_dict(self, lora_state_dict):
|
||||||
device, torch_dtype = None, None
|
lora_name_dict = {}
|
||||||
for name, param in state_dict.items():
|
for key in lora_state_dict:
|
||||||
device, torch_dtype = param.device, param.dtype
|
|
||||||
break
|
|
||||||
return device, torch_dtype
|
|
||||||
|
|
||||||
|
|
||||||
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
|
|
||||||
device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict)
|
|
||||||
if torch_dtype == torch.float8_e4m3fn:
|
|
||||||
torch_dtype = torch.float32
|
|
||||||
state_dict_ = {}
|
|
||||||
for key in state_dict:
|
|
||||||
if ".lora_B." not in key:
|
if ".lora_B." not in key:
|
||||||
continue
|
continue
|
||||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
|
||||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
|
||||||
if len(weight_up.shape) == 4:
|
|
||||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
|
||||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
|
||||||
else:
|
|
||||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
|
||||||
keys = key.split(".")
|
keys = key.split(".")
|
||||||
if len(keys) > keys.index("lora_B") + 2:
|
if len(keys) > keys.index("lora_B") + 2:
|
||||||
keys.pop(keys.index("lora_B") + 1)
|
keys.pop(keys.index("lora_B") + 1)
|
||||||
keys.pop(keys.index("lora_B"))
|
keys.pop(keys.index("lora_B"))
|
||||||
|
if keys[0] == "diffusion_model":
|
||||||
|
keys.pop(0)
|
||||||
target_name = ".".join(keys)
|
target_name = ".".join(keys)
|
||||||
if target_name.startswith("diffusion_model."):
|
lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A."))
|
||||||
target_name = target_name[len("diffusion_model."):]
|
return lora_name_dict
|
||||||
if target_name not in target_state_dict:
|
|
||||||
return {}
|
|
||||||
state_dict_[target_name] = lora_weight.cpu()
|
|
||||||
return state_dict_
|
|
||||||
|
|
||||||
|
|
||||||
|
def match(self, model: torch.nn.Module, state_dict_lora):
|
||||||
|
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||||
|
model_name_dict = {name: None for name, _ in model.named_parameters()}
|
||||||
|
matched_num = sum([i in model_name_dict for i in lora_name_dict])
|
||||||
|
if matched_num == len(lora_name_dict):
|
||||||
|
return "", ""
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_device_and_dtype(self, state_dict):
|
||||||
|
device, dtype = None, None
|
||||||
|
for name, param in state_dict.items():
|
||||||
|
device, dtype = param.device, param.dtype
|
||||||
|
break
|
||||||
|
computation_device = device
|
||||||
|
computation_dtype = dtype
|
||||||
|
if computation_device == torch.device("cpu"):
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
computation_device = torch.device("cuda")
|
||||||
|
if computation_dtype == torch.float8_e4m3fn:
|
||||||
|
computation_dtype = torch.float32
|
||||||
|
return device, dtype, computation_device, computation_dtype
|
||||||
|
|
||||||
|
|
||||||
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
||||||
state_dict_model = model.state_dict()
|
state_dict_model = model.state_dict()
|
||||||
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model)
|
device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model)
|
||||||
if len(state_dict_lora) > 0:
|
lora_name_dict = self.get_name_dict(state_dict_lora)
|
||||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
for name in lora_name_dict:
|
||||||
for name in state_dict_lora:
|
weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype)
|
||||||
if state_dict_model[name].dtype == torch.float8_e4m3fn:
|
weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype)
|
||||||
weight = state_dict_model[name].to(torch.float32)
|
if len(weight_up.shape) == 4:
|
||||||
lora_weight = state_dict_lora[name].to(
|
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||||
dtype=torch.float32,
|
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||||
device=state_dict_model[name].device
|
weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||||
)
|
else:
|
||||||
state_dict_model[name] = (weight + lora_weight).to(
|
weight_lora = alpha * torch.mm(weight_up, weight_down)
|
||||||
dtype=state_dict_model[name].dtype,
|
weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype)
|
||||||
device=state_dict_model[name].device
|
weight_patched = weight_model + weight_lora
|
||||||
)
|
state_dict_model[name] = weight_patched.to(device=device, dtype=dtype)
|
||||||
else:
|
print(f" {len(lora_name_dict)} tensors are updated.")
|
||||||
state_dict_model[name] += state_dict_lora[name].to(
|
model.load_state_dict(state_dict_model)
|
||||||
dtype=state_dict_model[name].dtype,
|
|
||||||
device=state_dict_model[name].device
|
|
||||||
)
|
|
||||||
model.load_state_dict(state_dict_model)
|
|
||||||
|
|
||||||
|
|
||||||
def match(self, model, state_dict_lora):
|
|
||||||
for model_class in self.supported_model_classes:
|
|
||||||
if not isinstance(model, model_class):
|
|
||||||
continue
|
|
||||||
state_dict_model = model.state_dict()
|
|
||||||
try:
|
|
||||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model)
|
|
||||||
if len(state_dict_lora_) > 0:
|
|
||||||
return "", ""
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
||||||
|
|||||||
@@ -376,6 +376,7 @@ class ModelManager:
|
|||||||
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
|
self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha)
|
||||||
else:
|
else:
|
||||||
print(f"Loading LoRA models from file: {file_path}")
|
print(f"Loading LoRA models from file: {file_path}")
|
||||||
|
is_loaded = False
|
||||||
if len(state_dict) == 0:
|
if len(state_dict) == 0:
|
||||||
state_dict = load_state_dict(file_path)
|
state_dict = load_state_dict(file_path)
|
||||||
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
for model_name, model, model_path in zip(self.model_name, self.model, self.model_path):
|
||||||
@@ -385,7 +386,10 @@ class ModelManager:
|
|||||||
print(f" Adding LoRA to {model_name} ({model_path}).")
|
print(f" Adding LoRA to {model_name} ({model_path}).")
|
||||||
lora_prefix, model_resource = match_results
|
lora_prefix, model_resource = match_results
|
||||||
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource)
|
||||||
|
is_loaded = True
|
||||||
break
|
break
|
||||||
|
if not is_loaded:
|
||||||
|
print(f" Cannot load LoRA: {file_path}")
|
||||||
|
|
||||||
|
|
||||||
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
|
def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):
|
||||||
|
|||||||
Reference in New Issue
Block a user