diff --git a/diffsynth/models/lora.py b/diffsynth/models/lora.py index da8302a..7d4f52d 100644 --- a/diffsynth/models/lora.py +++ b/diffsynth/models/lora.py @@ -195,85 +195,73 @@ class FluxLoRAFromCivitai(LoRAFromCivitai): "txt.mod": "txt_mod", } - + + class GeneralLoRAFromPeft: def __init__(self): self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel] - - - def fetch_device_dtype_from_state_dict(self, state_dict): - device, torch_dtype = None, None - for name, param in state_dict.items(): - device, torch_dtype = param.device, param.dtype - break - return device, torch_dtype - - - def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}): - device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict) - if torch_dtype == torch.float8_e4m3fn: - torch_dtype = torch.float32 - state_dict_ = {} - for key in state_dict: + + + def get_name_dict(self, lora_state_dict): + lora_name_dict = {} + for key in lora_state_dict: if ".lora_B." not in key: continue - weight_up = state_dict[key].to(device=device, dtype=torch_dtype) - weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype) - if len(weight_up.shape) == 4: - weight_up = weight_up.squeeze(3).squeeze(2) - weight_down = weight_down.squeeze(3).squeeze(2) - lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) - else: - lora_weight = alpha * torch.mm(weight_up, weight_down) keys = key.split(".") if len(keys) > keys.index("lora_B") + 2: keys.pop(keys.index("lora_B") + 1) keys.pop(keys.index("lora_B")) + if keys[0] == "diffusion_model": + keys.pop(0) target_name = ".".join(keys) - if target_name.startswith("diffusion_model."): - target_name = target_name[len("diffusion_model."):] - if target_name not in target_state_dict: - return {} - state_dict_[target_name] = lora_weight.cpu() - return state_dict_ + lora_name_dict[target_name] = (key, key.replace(".lora_B.", ".lora_A.")) + return lora_name_dict + + def match(self, model: torch.nn.Module, state_dict_lora): + lora_name_dict = self.get_name_dict(state_dict_lora) + model_name_dict = {name: None for name, _ in model.named_parameters()} + matched_num = sum([i in model_name_dict for i in lora_name_dict]) + if matched_num == len(lora_name_dict): + return "", "" + else: + return None + + + def fetch_device_and_dtype(self, state_dict): + device, dtype = None, None + for name, param in state_dict.items(): + device, dtype = param.device, param.dtype + break + computation_device = device + computation_dtype = dtype + if computation_device == torch.device("cpu"): + if torch.cuda.is_available(): + computation_device = torch.device("cuda") + if computation_dtype == torch.float8_e4m3fn: + computation_dtype = torch.float32 + return device, dtype, computation_device, computation_dtype + def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""): state_dict_model = model.state_dict() - state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model) - if len(state_dict_lora) > 0: - print(f" {len(state_dict_lora)} tensors are updated.") - for name in state_dict_lora: - if state_dict_model[name].dtype == torch.float8_e4m3fn: - weight = state_dict_model[name].to(torch.float32) - lora_weight = state_dict_lora[name].to( - dtype=torch.float32, - device=state_dict_model[name].device - ) - state_dict_model[name] = (weight + lora_weight).to( - dtype=state_dict_model[name].dtype, - device=state_dict_model[name].device - ) - else: - state_dict_model[name] += state_dict_lora[name].to( - dtype=state_dict_model[name].dtype, - device=state_dict_model[name].device - ) - model.load_state_dict(state_dict_model) + device, dtype, computation_device, computation_dtype = self.fetch_device_and_dtype(state_dict_model) + lora_name_dict = self.get_name_dict(state_dict_lora) + for name in lora_name_dict: + weight_up = state_dict_lora[lora_name_dict[name][0]].to(device=computation_device, dtype=computation_dtype) + weight_down = state_dict_lora[lora_name_dict[name][1]].to(device=computation_device, dtype=computation_dtype) + if len(weight_up.shape) == 4: + weight_up = weight_up.squeeze(3).squeeze(2) + weight_down = weight_down.squeeze(3).squeeze(2) + weight_lora = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) + else: + weight_lora = alpha * torch.mm(weight_up, weight_down) + weight_model = state_dict_model[name].to(device=computation_device, dtype=computation_dtype) + weight_patched = weight_model + weight_lora + state_dict_model[name] = weight_patched.to(device=device, dtype=dtype) + print(f" {len(lora_name_dict)} tensors are updated.") + model.load_state_dict(state_dict_model) - - def match(self, model, state_dict_lora): - for model_class in self.supported_model_classes: - if not isinstance(model, model_class): - continue - state_dict_model = model.state_dict() - try: - state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model) - if len(state_dict_lora_) > 0: - return "", "" - except: - pass - return None class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai): diff --git a/diffsynth/models/model_manager.py b/diffsynth/models/model_manager.py index 775fafe..7ae3c50 100644 --- a/diffsynth/models/model_manager.py +++ b/diffsynth/models/model_manager.py @@ -376,6 +376,7 @@ class ModelManager: self.load_lora(file_path_, state_dict=state_dict, lora_alpha=lora_alpha) else: print(f"Loading LoRA models from file: {file_path}") + is_loaded = False if len(state_dict) == 0: state_dict = load_state_dict(file_path) for model_name, model, model_path in zip(self.model_name, self.model, self.model_path): @@ -385,7 +386,10 @@ class ModelManager: print(f" Adding LoRA to {model_name} ({model_path}).") lora_prefix, model_resource = match_results lora.load(model, state_dict, lora_prefix, alpha=lora_alpha, model_resource=model_resource) + is_loaded = True break + if not is_loaded: + print(f" Cannot load LoRA: {file_path}") def load_model(self, file_path, model_names=None, device=None, torch_dtype=None):