mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-22 16:50:47 +00:00
support flux-fp8
This commit is contained in:
@@ -404,6 +404,77 @@ class FluxDiT(torch.nn.Module):
|
||||
hidden_states = self.unpatchify(hidden_states, height, width)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def quantize(self):
|
||||
def cast_to(weight, dtype=None, device=None, copy=False):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight)
|
||||
return r
|
||||
|
||||
def cast_weight(s, input=None, dtype=None, device=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
return weight
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
bias = None
|
||||
weight = cast_to(s.weight, dtype, device)
|
||||
bias = cast_to(s.bias, bias_dtype, device)
|
||||
return weight, bias
|
||||
|
||||
class quantized_layer:
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def forward(self,input,**kwargs):
|
||||
weight,bias= cast_bias_weight(self.module,input)
|
||||
return torch.nn.functional.linear(input,weight,bias)
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, module):
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
def forward(self,hidden_states,**kwargs):
|
||||
weight= cast_weight(self.module,hidden_states)
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps)
|
||||
hidden_states = hidden_states.to(input_dtype) * weight
|
||||
return hidden_states
|
||||
|
||||
def replace_layer(model):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
new_layer = quantized_layer.Linear(module)
|
||||
setattr(model, name, new_layer)
|
||||
elif isinstance(module, RMSNorm):
|
||||
new_layer = quantized_layer.RMSNorm(module)
|
||||
setattr(model, name, new_layer)
|
||||
else:
|
||||
replace_layer(module)
|
||||
|
||||
replace_layer(self)
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user