diff --git a/diffsynth/vram_management/layers.py b/diffsynth/vram_management/layers.py index 0ebb054..fdd39c7 100644 --- a/diffsynth/vram_management/layers.py +++ b/diffsynth/vram_management/layers.py @@ -110,8 +110,47 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): self.lora_A_weights = [] self.lora_B_weights = [] self.lora_merger = None + self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz] + + def fp8_linear( + self, + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + device = input.device + origin_dtype = input.dtype + origin_shape = input.shape + input = input.reshape(-1, origin_shape[-1]) + + x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values + fp8_max = 448.0 + # For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn. + # To avoid overflow and ensure numerical compatibility during FP8 computation, + # we scale down the input by 2.0 in advance. + # This scaling will be compensated later during the final result scaling. + if self.computation_dtype == torch.float8_e4m3fnuz: + fp8_max = fp8_max / 2.0 + scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device) + scale_b = torch.ones((weight.shape[0], 1)).to(device=device) + input = input / (scale_a + 1e-8) + input = input.to(self.computation_dtype) + weight = weight.to(self.computation_dtype) + + result = torch._scaled_mm( + input, + weight.T, + scale_a=scale_a, + scale_b=scale_b.T, + bias=bias, + out_dtype=origin_dtype, + ) + new_shape = origin_shape[:-1] + result.shape[-1:] + result = result.reshape(new_shape) + return result def forward(self, x, *args, **kwargs): + # VRAM management if self.state == 2: weight, bias = self.weight, self.bias else: @@ -123,8 +162,14 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule): else: weight = cast_to(self.weight, self.computation_dtype, self.computation_device) bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device) - out = torch.nn.functional.linear(x, weight, bias) + # Linear forward + if self.enable_fp8: + out = self.fp8_linear(x, weight, bias) + else: + out = torch.nn.functional.linear(x, weight, bias) + + # LoRA if len(self.lora_A_weights) == 0: # No LoRA return out