Skip to content

Commit

Permalink
Add inductor triton to the benchmark
Browse files Browse the repository at this point in the history
Summary: As the title

Reviewed By: xuzhao9

Differential Revision: D58307371

fbshipit-source-id: be3fefe497524f01e3e7ecba2a33b4bf7200f642
  • Loading branch information
sijiac authored and facebook-github-bot committed Jun 7, 2024
1 parent 9cd6e86 commit 3a514a5
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions torchbenchmark/operators/addmm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy
import torch
import torch._inductor.config as inductor_config
import triton
from hammer.ops.triton.triton_hstu_linear import _addmm_fwd, triton_addmm

Expand Down Expand Up @@ -87,6 +88,19 @@ def triton_addmm(self, a, mat1, mat2) -> Callable:
def aten_addmm(self, a, mat1, mat2) -> Callable:
return lambda: torch.addmm(a, mat1, mat2)

@register_benchmark()
def pt2_triton_matmul(self, a, mat1, mat2) -> Callable:
torch._dynamo.reset()
with inductor_config.patch(
max_autotune=True,
max_autotune_gemm_backends="TRITON",
autotune_fallback_to_aten=False,
):
f = lambda a, mat1, mat2: torch.addmm(a, mat1, mat2)
compiled = torch.compile(f, dynamic=False)
compiled(a, mat1, mat2)
return lambda: compiled(a, mat1, mat2)

@register_metric()
def gbps(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
Expand Down

0 comments on commit 3a514a5

Please sign in to comment.