mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:32:27 +00:00
Fixes high RAM usage Wan 2.1
Fixes high RAM usage Wan 2.1
This commit is contained in:
@@ -200,81 +200,139 @@ class GeneralLoRAFromPeft:
|
||||
def __init__(self):
|
||||
self.supported_model_classes = [SDUNet, SDXLUNet, SD3DiT, HunyuanDiT, FluxDiT, CogDiT, WanModel]
|
||||
|
||||
def fetch_device_dtype_from_state_dict(self, target_param):
|
||||
"""Get device and dtype from a parameter"""
|
||||
return target_param.device, target_param.dtype
|
||||
|
||||
def fetch_device_dtype_from_state_dict(self, state_dict):
|
||||
device, torch_dtype = None, None
|
||||
for name, param in state_dict.items():
|
||||
device, torch_dtype = param.device, param.dtype
|
||||
break
|
||||
return device, torch_dtype
|
||||
|
||||
def _get_target_name(self, key):
|
||||
"""Extract target parameter name from LoRA key"""
|
||||
keys = key.split(".")
|
||||
if len(keys) > keys.index("lora_B") + 2:
|
||||
keys.pop(keys.index("lora_B") + 1)
|
||||
keys.pop(keys.index("lora_B"))
|
||||
target_name = ".".join(keys)
|
||||
if target_name.startswith("diffusion_model."):
|
||||
target_name = target_name[len("diffusion_model."):]
|
||||
return target_name
|
||||
|
||||
def convert_state_dict(self, state_dict, alpha=1.0, target_state_dict={}):
|
||||
device, torch_dtype = self.fetch_device_dtype_from_state_dict(target_state_dict)
|
||||
"""Original method kept for compatibility with match method"""
|
||||
device, torch_dtype = None, None
|
||||
for name, param in target_state_dict.items():
|
||||
device, torch_dtype = param.device, param.dtype
|
||||
break
|
||||
|
||||
if torch_dtype == torch.float8_e4m3fn:
|
||||
torch_dtype = torch.float32
|
||||
|
||||
state_dict_ = {}
|
||||
for key in state_dict:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
|
||||
weight_up = state_dict[key].to(device=device, dtype=torch_dtype)
|
||||
weight_down = state_dict[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=torch_dtype)
|
||||
|
||||
if len(weight_up.shape) == 4:
|
||||
weight_up = weight_up.squeeze(3).squeeze(2)
|
||||
weight_down = weight_down.squeeze(3).squeeze(2)
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_up, weight_down)
|
||||
keys = key.split(".")
|
||||
if len(keys) > keys.index("lora_B") + 2:
|
||||
keys.pop(keys.index("lora_B") + 1)
|
||||
keys.pop(keys.index("lora_B"))
|
||||
target_name = ".".join(keys)
|
||||
if target_name.startswith("diffusion_model."):
|
||||
target_name = target_name[len("diffusion_model."):]
|
||||
|
||||
target_name = self._get_target_name(key)
|
||||
|
||||
if target_name not in target_state_dict:
|
||||
return {}
|
||||
|
||||
state_dict_[target_name] = lora_weight.cpu()
|
||||
|
||||
return state_dict_
|
||||
|
||||
|
||||
def load(self, model, state_dict_lora, lora_prefix="", alpha=1.0, model_resource=""):
|
||||
state_dict_model = model.state_dict()
|
||||
state_dict_lora = self.convert_state_dict(state_dict_lora, alpha=alpha, target_state_dict=state_dict_model)
|
||||
if len(state_dict_lora) > 0:
|
||||
print(f" {len(state_dict_lora)} tensors are updated.")
|
||||
for name in state_dict_lora:
|
||||
if state_dict_model[name].dtype == torch.float8_e4m3fn:
|
||||
weight = state_dict_model[name].to(torch.float32)
|
||||
lora_weight = state_dict_lora[name].to(
|
||||
dtype=torch.float32,
|
||||
device=state_dict_model[name].device
|
||||
)
|
||||
state_dict_model[name] = (weight + lora_weight).to(
|
||||
dtype=state_dict_model[name].dtype,
|
||||
device=state_dict_model[name].device
|
||||
)
|
||||
else:
|
||||
state_dict_model[name] += state_dict_lora[name].to(
|
||||
dtype=state_dict_model[name].dtype,
|
||||
device=state_dict_model[name].device
|
||||
)
|
||||
model.load_state_dict(state_dict_model)
|
||||
|
||||
"""Apply LoRA weights directly to model parameters without loading entire state dict"""
|
||||
# Create parameter name mapping for faster lookup
|
||||
param_dict = {}
|
||||
for name, param in model.named_parameters():
|
||||
param_dict[name] = param
|
||||
|
||||
# Process each LoRA parameter pair
|
||||
modified_count = 0
|
||||
for key in state_dict_lora:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
|
||||
# Get target parameter name and make sure the parameter exists
|
||||
target_name = self._get_target_name(key)
|
||||
if target_name not in param_dict:
|
||||
continue
|
||||
|
||||
# Get the target parameter
|
||||
param = param_dict[target_name]
|
||||
|
||||
# Calculate LoRA weight update
|
||||
device, dtype = param.device, param.dtype
|
||||
dtype_for_calc = torch.float32 if dtype == torch.float8_e4m3fn else dtype
|
||||
|
||||
# Process weights and calculate LoRA update
|
||||
weight_b = state_dict_lora[key].to(device=device, dtype=dtype_for_calc)
|
||||
weight_a = state_dict_lora[key.replace(".lora_B.", ".lora_A.")].to(device=device, dtype=dtype_for_calc)
|
||||
|
||||
if len(weight_b.shape) == 4:
|
||||
weight_b = weight_b.squeeze(3).squeeze(2)
|
||||
weight_a = weight_a.squeeze(3).squeeze(2)
|
||||
lora_weight = alpha * torch.mm(weight_b, weight_a).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
lora_weight = alpha * torch.mm(weight_b, weight_a)
|
||||
|
||||
# Apply update to parameter
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
param_float = param.to(torch.float32)
|
||||
param.data = (param_float + lora_weight).to(dtype)
|
||||
del param_float
|
||||
else:
|
||||
param.data += lora_weight
|
||||
|
||||
# Clean up temporary tensors
|
||||
del weight_a, weight_b, lora_weight
|
||||
modified_count += 1
|
||||
|
||||
print(f" {modified_count} tensors are updated.")
|
||||
|
||||
def match(self, model, state_dict_lora):
|
||||
"""Check if LoRA parameters match model parameters without loading full state dict"""
|
||||
for model_class in self.supported_model_classes:
|
||||
if not isinstance(model, model_class):
|
||||
continue
|
||||
state_dict_model = model.state_dict()
|
||||
try:
|
||||
state_dict_lora_ = self.convert_state_dict(state_dict_lora, alpha=1.0, target_state_dict=state_dict_model)
|
||||
if len(state_dict_lora_) > 0:
|
||||
|
||||
# Create set of parameter names
|
||||
param_names = set()
|
||||
for name, _ in model.named_parameters():
|
||||
param_names.add(name)
|
||||
|
||||
# Check if a sample of LoRA keys map to model parameters
|
||||
matched_count = 0
|
||||
checked_count = 0
|
||||
|
||||
for key in state_dict_lora:
|
||||
if ".lora_B." not in key:
|
||||
continue
|
||||
|
||||
target_name = self._get_target_name(key)
|
||||
if target_name in param_names:
|
||||
matched_count += 1
|
||||
|
||||
checked_count += 1
|
||||
if matched_count >= 5: # Found enough matches
|
||||
return "", ""
|
||||
except:
|
||||
pass
|
||||
if checked_count >= 50 and matched_count == 0: # Checked enough without matches
|
||||
break
|
||||
|
||||
if matched_count > 0:
|
||||
return "", ""
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
||||
class HunyuanVideoLoRAFromCivitai(LoRAFromCivitai):
|
||||
def __init__(self):
|
||||
|
||||
Reference in New Issue
Block a user