Flux fp8 lora training (#221)

* flux fp8 lora training

---------

Co-authored-by: tc2000731 <tc2000731@163.com>
This commit is contained in:
Zhongjie Duan
2024-09-24 11:12:32 +08:00
committed by GitHub
parent 7f899dcfca
commit d91c603875
2 changed files with 29 additions and 9 deletions

View File

@@ -441,13 +441,12 @@ class FluxDiT(torch.nn.Module):
return weight, bias
class quantized_layer:
class Linear(torch.nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
class Linear(torch.nn.Linear):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self,input,**kwargs):
weight,bias= cast_bias_weight(self.module,input)
weight,bias= cast_bias_weight(self,input)
return torch.nn.functional.linear(input,weight,bias)
class RMSNorm(torch.nn.Module):
@@ -466,7 +465,11 @@ class FluxDiT(torch.nn.Module):
def replace_layer(model):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
new_layer = quantized_layer.Linear(module)
new_layer = quantized_layer.Linear(module.in_features,module.out_features)
new_layer.weight.data = module.weight.data
if module.bias is not None:
new_layer.bias.data = module.bias.data
# del module
setattr(model, name, new_layer)
elif isinstance(module, RMSNorm):
new_layer = quantized_layer.RMSNorm(module)