mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
support customized lora forward
This commit is contained in:
@@ -70,6 +70,52 @@ class AutoWrappedLinear(torch.nn.Linear):
|
||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||
return torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
class AutoLoRALinear(torch.nn.Linear):
|
||||
def __init__(self, name='', in_features=1, out_features=2, bias = True, device=None, dtype=None):
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
self.name = name
|
||||
|
||||
def forward(self, x, lora_state_dicts=[], lora_alpahs=[1.0,1.0], **kwargs):
|
||||
out = torch.nn.functional.linear(x, self.weight, self.bias)
|
||||
lora_a_name = f'{self.name}.lora_A.weight'
|
||||
lora_b_name = f'{self.name}.lora_B.weight'
|
||||
|
||||
for i, lora_state_dict in enumerate(lora_state_dicts):
|
||||
if lora_state_dict is None:
|
||||
break
|
||||
if lora_a_name in lora_state_dict and lora_b_name in lora_state_dict:
|
||||
lora_A = lora_state_dict[lora_a_name].to(dtype=self.weight.dtype,device=self.weight.device)
|
||||
lora_B = lora_state_dict[lora_b_name].to(dtype=self.weight.dtype,device=self.weight.device)
|
||||
out_lora = x @ lora_A.T @ lora_B.T
|
||||
out = out + out_lora * lora_alpahs[i]
|
||||
return out
|
||||
|
||||
def enable_auto_lora(model:torch.nn.Module, module_map: dict, name_prefix=''):
|
||||
targets = list(module_map.keys())
|
||||
for name, module in model.named_children():
|
||||
if name_prefix != '':
|
||||
full_name = name_prefix + '.' + name
|
||||
else:
|
||||
full_name = name
|
||||
if isinstance(module,targets[1]):
|
||||
# print(full_name)
|
||||
# print(module)
|
||||
# ToDo: replace the linear to the AutoLoRALinear
|
||||
new_module = AutoLoRALinear(
|
||||
name=full_name,
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=module.bias is not None,
|
||||
device=module.weight.device,
|
||||
dtype=module.weight.dtype)
|
||||
new_module.weight.data.copy_(module.weight.data)
|
||||
new_module.bias.data.copy_(module.bias.data)
|
||||
setattr(model, name, new_module)
|
||||
elif isinstance(module, targets[0]):
|
||||
pass
|
||||
else:
|
||||
enable_auto_lora(module, module_map, full_name)
|
||||
|
||||
|
||||
def enable_vram_management_recursively(model: torch.nn.Module, module_map: dict, module_config: dict, max_num_param=None, overflow_module_config: dict = None, total_num_param=0):
|
||||
for name, module in model.named_children():
|
||||
|
||||
Reference in New Issue
Block a user