Skip to content

Commit

Permalink
Add triton.ops.matmul
Browse files Browse the repository at this point in the history
Summary:
This is the more "official" version in the repo, versus the tutorial
one that's checked in.  It's worth having both, though (since if the tutorial
one regresses, it should be fixed).

Reviewed By: chenyang78, sijiac

Differential Revision: D56158787

fbshipit-source-id: f187085f95d8a35f6d558a043d183a1912375023
  • Loading branch information
bertmaher authored and facebook-github-bot committed Apr 19, 2024
1 parent 7795b06 commit 1d39d82
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions torchbenchmark/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy
import torch
import triton
import triton.ops

from torchbenchmark.util.triton_op import (
BenchmarkOperator,
Expand Down Expand Up @@ -97,12 +98,18 @@ def __init__(self, mode: str, device: str, extra_args: List[str] = []):
self.DEFAULT_NUM_BATCH = len(self.shapes)

@register_benchmark()
def triton_matmul(self, a, b, bias) -> Callable:
def triton_tutorial_matmul(self, a, b, bias) -> Callable:
if not bias == None:
return lambda: triton_matmul(a, b) + bias
else:
return lambda: triton_matmul(a, b)

@register_benchmark(enabled=torch.version.cuda is not None)
def triton_ops_matmul(self, a, b, bias) -> Callable:
if bias is None:
return lambda: triton.ops.matmul(a, b)
return lambda: triton.ops.matmul(a, b, bias)

@register_benchmark(baseline=True)
def aten_matmul(self, a, b, bias) -> Callable:
if not bias == None:
Expand Down Expand Up @@ -194,15 +201,17 @@ def plot(self):
line_arg="provider", # argument name whose value corresponds to a different line in the plot
line_vals=[
"aten_matmul",
"triton_matmul",
"triton_tutorial_matmul",
"triton_ops_matmul",
"hstu_triton_matmul",
], # possible values for `line_arg``
line_names=[
"ATen GEMM",
"Triton GEMM",
"Triton Tutorial GEMM",
"triton.ops.matmul",
"HSTU Triton GEMM",
], # label name for the lines
styles=[("blue", "-"), ("green", "-"), ("red", "-")], # line styles
styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("yellow", "-")], # line styles
ylabel="tflops", # label name for the y-axis
plot_name="gemm-performance", # name for the plot. Used also as a file name for saving the plot.
args={}, # values for function arguments not in `x_names` and `y_name`
Expand Down

0 comments on commit 1d39d82

Please sign in to comment.