Skip to content

Commit

Permalink
Scale GEMM inputs for better accuracy results
Browse files Browse the repository at this point in the history
Summary:
This provides more numerically stable inputs for GEMMs. The +1
eliminates very small values that could result in denormals, and the
scale (which should be set to K in an M*N*K GEMM) means that the
elements of the output tensor are close to integer values. bwasti came
up with this.

Reviewed By: bertmaher

Differential Revision: D57084742

fbshipit-source-id: 13ea8b72f0a502db0e37bd5a92fc12fd3e7688d6
  • Loading branch information
int3 authored and facebook-github-bot committed May 10, 2024
1 parent 06eb98c commit 19b0bf6
Showing 1 changed file with 20 additions and 21 deletions.
41 changes: 20 additions & 21 deletions torchbenchmark/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,20 +49,12 @@
(1408, 1408, 1408, None),
(1536, 1536, 1536, None),
(1664, 1664, 1664, None),
# FIXME: triton_matmul failed with accuracy check for fb16 inputs on A100:
# Mismatched elements: 882 / 3211264 (0.0%)
# Greatest absolute difference: 0.03125 at index (169, 218) (up to 0.01 allowed)
# Greatest relative difference: 35.21875 at index (1169, 1720) (up to 0.01 allowed)
# (1792, 1792, 1792, None),
(1792, 1792, 1792, None),
(1920, 1920, 1920, None),
(2048, 2048, 2048, None),
(2176, 2176, 2176, None),
(2304, 2304, 2304, None),
# FIXME: triton_matmul failed with accuracy check for fb16 inputs on A100:
# Mismatched elements: 2479 / 5914624 (0.0%)
# Greatest absolute difference: 0.03173828125 at index (171, 1067) (up to 0.01 allowed)
# Greatest relative difference: 95.875 at index (2423, 2312) (up to 0.01 allowed)
# (2432, 2432, 2432, None),
(2432, 2432, 2432, None),
(2560, 2560, 2560, None),
(2688, 2688, 2688, None),
(2816, 2816, 2816, None),
Expand All @@ -72,11 +64,7 @@
(3328, 3328, 3328, None),
(3456, 3456, 3456, None),
(3584, 3584, 3584, None),
# FIXME: triton_matmul failed with accuracy check for fb16 inputs on A100:
# Mismatched elements: 619 / 13778944 (0.0%)
# Greatest absolute difference: 0.06005859375 at index (622, 69) (up to 0.02 allowed)
# Greatest relative difference: 20.546875 at index (3609, 685) (up to 0.02 allowed)
# (3712, 3712, 3712, None),
(3712, 3712, 3712, None),
(3840, 3840, 3840, None),
(3968, 3968, 3968, None),
(4096, 4096, 4096, None),
Expand Down Expand Up @@ -193,15 +181,26 @@ def tflops(
flops = m * k * 2 * n
return [flops / x / 1e12 * 1e3 for x in metrics.latency]

@staticmethod
def _scaled_randn(*args, scale: float, **kwargs) -> torch.Tensor:
"""
This provides more numerically stable inputs for GEMMs. The +1
eliminates very small values that could result in denormals, and the
scale (which should be set to K in an M*N*K GEMM) reduces the size of
the absolute error.
In particular, for a given element in the output tensor, the cumulative
error is eps * 2 * K, where eps is the smallest precision representable
in the dtype. By scaling the element by K, we avoid the error growing
with the size of the tensor.
"""
return (torch.randn(*args, **kwargs) + 1) / scale

def get_input_iter(self) -> Generator:
for shape in self.shapes:
m, k, n, bias = shape
a = torch.randn(
(m, k), device=self.device, dtype=self.dtype
).requires_grad_(False)
w = torch.randn(
(k, n), device=self.device, dtype=self.dtype
).requires_grad_(False)
a = self._scaled_randn((m, k), scale=k, device=self.device, dtype=self.dtype)
w = self._scaled_randn((k, n), scale=k, device=self.device, dtype=self.dtype)
if not bias == None:
bias = torch.randn(
(bias), device=self.device, dtype=self.dtype
Expand Down

0 comments on commit 19b0bf6

Please sign in to comment.