From 3fdba19e02531cf7049b6cd30f553d4f818d3a4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Furkan=20G=C3=B6z=C3=BCkara?= Date: Wed, 12 Mar 2025 15:49:57 +0300 Subject: [PATCH] Fixes high RAM usage Wan 2.1 Fixes high RAM usage Wan 2.1 --- diffsynth/models/lora.py | 148 +++++++++++++++++++++++++++------------ 1 file changed, 103 insertions(+), 45 deletions(-) diff --git a/diffsynth/models/lora.py b/diffsynth/models/lora.py index da8302a..2315e96 100644 --- a/diffsynth/models/lora.py +++ b/diffsynth/models/lora.py @@ -200,81 +200,139 @@ class GeneralLoRAFromPeft: def __init__(self): self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel] + def fetch_device_dtype_from_state_dict(self, target_param): + """Get device and dtype from a parameter""" + return target_param.device, target_param.dtype - def fetch_device_dtype_from_state_dict(self, state_dict): - device, torch_dtype = None, None - for name, param in state_dict.items(): - device, torch_dtype = param.device, param.dtype - break - return device, torch_dtype - + def _get_target_name(self, key): + """Extract target parameter name from LoRA key""" + keys = key.split(".") + if len(keys) > keys.index("lora_B") + 2: + keys.pop(keys.index("lora_B") + 1) + keys.pop(keys.index("lora_B")) + target_name = ".".join(keys) + if target_name.startswith("diffusion_model."): + target_name = target_name[len("diffusion_model."):] + return target_name def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}): - device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict) + """Original method kept for compatibility with match method""" + device, torch_dtype = None, None + for name, param in target_state_dict.items(): + device, torch_dtype = param.device, param.dtype + break + if torch_dtype == torch.float8_e4m3fn: torch_dtype = torch.float32 + state_dict_ = {} for key in state_dict: if ".lora_B." not in key: continue + weight_up = state_dict[key].to(device=device, dtype=torch_dtype) weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype) + if len(weight_up.shape) == 4: weight_up = weight_up.squeeze(3).squeeze(2) weight_down = weight_down.squeeze(3).squeeze(2) lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) else: lora_weight = alpha * torch.mm(weight_up, weight_down) - keys = key.split(".") - if len(keys) > keys.index("lora_B") + 2: - keys.pop(keys.index("lora_B") + 1) - keys.pop(keys.index("lora_B")) - target_name = ".".join(keys) - if target_name.startswith("diffusion_model."): - target_name = target_name[len("diffusion_model."):] + + target_name = self._get_target_name(key) + if target_name not in target_state_dict: return {} + state_dict_[target_name] = lora_weight.cpu() + return state_dict_ - def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""): - state_dict_model = model.state_dict() - state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model) - if len(state_dict_lora) > 0: - print(f" {len(state_dict_lora)} tensors are updated.") - for name in state_dict_lora: - if state_dict_model[name].dtype == torch.float8_e4m3fn: - weight = state_dict_model[name].to(torch.float32) - lora_weight = state_dict_lora[name].to( - dtype=torch.float32, - device=state_dict_model[name].device - ) - state_dict_model[name] = (weight + lora_weight).to( - dtype=state_dict_model[name].dtype, - device=state_dict_model[name].device - ) - else: - state_dict_model[name] += state_dict_lora[name].to( - dtype=state_dict_model[name].dtype, - device=state_dict_model[name].device - ) - model.load_state_dict(state_dict_model) - + """Apply LoRA weights directly to model parameters without loading entire state dict""" + # Create parameter name mapping for faster lookup + param_dict = {} + for name, param in model.named_parameters(): + param_dict[name] = param + + # Process each LoRA parameter pair + modified_count = 0 + for key in state_dict_lora: + if ".lora_B." not in key: + continue + + # Get target parameter name and make sure the parameter exists + target_name = self._get_target_name(key) + if target_name not in param_dict: + continue + + # Get the target parameter + param = param_dict[target_name] + + # Calculate LoRA weight update + device, dtype = param.device, param.dtype + dtype_for_calc = torch.float32 if dtype == torch.float8_e4m3fn else dtype + + # Process weights and calculate LoRA update + weight_b = state_dict_lora[key].to(device=device, dtype=dtype_for_calc) + weight_a = state_dict_lora[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=dtype_for_calc) + + if len(weight_b.shape) == 4: + weight_b = weight_b.squeeze(3).squeeze(2) + weight_a = weight_a.squeeze(3).squeeze(2) + lora_weight = alpha * torch.mm(weight_b, weight_a).unsqueeze(2).unsqueeze(3) + else: + lora_weight = alpha * torch.mm(weight_b, weight_a) + + # Apply update to parameter + if dtype == torch.float8_e4m3fn: + param_float = param.to(torch.float32) + param.data = (param_float + lora_weight).to(dtype) + del param_float + else: + param.data += lora_weight + + # Clean up temporary tensors + del weight_a, weight_b, lora_weight + modified_count += 1 + + print(f" {modified_count} tensors are updated.") def match(self, model, state_dict_lora): + """Check if LoRA parameters match model parameters without loading full state dict""" for model_class in self.supported_model_classes: if not isinstance(model, model_class): continue - state_dict_model = model.state_dict() - try: - state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model) - if len(state_dict_lora_) > 0: + + # Create set of parameter names + param_names = set() + for name, _ in model.named_parameters(): + param_names.add(name) + + # Check if a sample of LoRA keys map to model parameters + matched_count = 0 + checked_count = 0 + + for key in state_dict_lora: + if ".lora_B." not in key: + continue + + target_name = self._get_target_name(key) + if target_name in param_names: + matched_count += 1 + + checked_count += 1 + if matched_count >= 5: # Found enough matches return "", "" - except: - pass + if checked_count >= 50 and matched_count == 0: # Checked enough without matches + break + + if matched_count > 0: + return "", "" + return None - + class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai): def __init__(self):