mirror of
https://github.com/modelscope/DiffSynth-Studio.git
synced 2026-03-19 06:48:12 +00:00
Compare commits
1 Commits
vram-bugfi
...
qwen-image
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2abc97fc0f |
@@ -110,8 +110,47 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
self.lora_A_weights = []
|
||||
self.lora_B_weights = []
|
||||
self.lora_merger = None
|
||||
self.enable_fp8 = computation_dtype in [torch.float8_e4m3fn, torch.float8_e4m3fnuz]
|
||||
|
||||
def fp8_linear(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
device = input.device
|
||||
origin_dtype = input.dtype
|
||||
origin_shape = input.shape
|
||||
input = input.reshape(-1, origin_shape[-1])
|
||||
|
||||
x_max = torch.max(torch.abs(input), dim=-1, keepdim=True).values
|
||||
fp8_max = 448.0
|
||||
# For float8_e4m3fnuz, the maximum representable value is half of that of e4m3fn.
|
||||
# To avoid overflow and ensure numerical compatibility during FP8 computation,
|
||||
# we scale down the input by 2.0 in advance.
|
||||
# This scaling will be compensated later during the final result scaling.
|
||||
if self.computation_dtype == torch.float8_e4m3fnuz:
|
||||
fp8_max = fp8_max / 2.0
|
||||
scale_a = torch.clamp(x_max / fp8_max, min=1.0).float().to(device=device)
|
||||
scale_b = torch.ones((weight.shape[0], 1)).to(device=device)
|
||||
input = input / (scale_a + 1e-8)
|
||||
input = input.to(self.computation_dtype)
|
||||
weight = weight.to(self.computation_dtype)
|
||||
|
||||
result = torch._scaled_mm(
|
||||
input,
|
||||
weight.T,
|
||||
scale_a=scale_a,
|
||||
scale_b=scale_b.T,
|
||||
bias=bias,
|
||||
out_dtype=origin_dtype,
|
||||
)
|
||||
new_shape = origin_shape[:-1] + result.shape[-1:]
|
||||
result = result.reshape(new_shape)
|
||||
return result
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
# VRAM management
|
||||
if self.state == 2:
|
||||
weight, bias = self.weight, self.bias
|
||||
else:
|
||||
@@ -123,8 +162,14 @@ class AutoWrappedLinear(torch.nn.Linear, AutoTorchModule):
|
||||
else:
|
||||
weight = cast_to(self.weight, self.computation_dtype, self.computation_device)
|
||||
bias = None if self.bias is None else cast_to(self.bias, self.computation_dtype, self.computation_device)
|
||||
out = torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
# Linear forward
|
||||
if self.enable_fp8:
|
||||
out = self.fp8_linear(x, weight, bias)
|
||||
else:
|
||||
out = torch.nn.functional.linear(x, weight, bias)
|
||||
|
||||
# LoRA
|
||||
if len(self.lora_A_weights) == 0:
|
||||
# No LoRA
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user