accelerate load model

This commit is contained in:
tc2000731
2024-10-18 15:29:50 +08:00
parent 7d7d72dcfe
commit dfbf43e463
3 changed files with 58 additions and 6 deletions

View File

@@ -2,6 +2,7 @@ import torch
from .sd3_dit import TimestepEmbeddings, AdaLayerNorm
from einops import rearrange
from .tiler import TileWorker
from .utils import init_weights_on_device
@@ -466,10 +467,11 @@ class FluxDiT(torch.nn.Module):
def replace_layer(model):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
new_layer = quantized_layer.Linear(module.in_features,module.out_features)
new_layer.weight.data = module.weight.data
with init_weights_on_device():
new_layer = quantized_layer.Linear(module.in_features,module.out_features)
new_layer.weight = module.weight
if module.bias is not None:
new_layer.bias.data = module.bias.data
new_layer.bias = module.bias
# del module
setattr(model, name, new_layer)
elif isinstance(module, RMSNorm):