Skip to content

Commit

Permalink
Log best_config for all triton kernels
Browse files Browse the repository at this point in the history
Reviewed By: chenyang78

Differential Revision: D57196022

fbshipit-source-id: 16dcb60da7160df224a916c6f84692b6c84948a3
  • Loading branch information
int3 authored and facebook-github-bot committed May 10, 2024
1 parent 8a8c1fc commit 3f8bb6a
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions torchbenchmark/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,20 @@ def best_config(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
if "triton_tutorial_matmul" in str(fn_name):
bconfig = triton_matmul_kernel.best_config
kwargs = deepcopy(bconfig.kwargs)
kwargs["num_stages"] = bconfig.num_stages
kwargs["num_warps"] = bconfig.num_warps
dumped_str = json.dumps(kwargs)
return dumped_str
return ""
kernel = triton_matmul_kernel
elif "triton_ops_matmul" in str(fn_name):
kernel = triton.ops._matmul.kernel
elif "hstu_triton_matmul" in str(fn_name):
import hammer
kernel = hammer.ops.triton.triton_matmul._epilogue_mm
else:
return ""

bconfig = kernel.best_config
kwargs = deepcopy(bconfig.kwargs)
kwargs["num_stages"] = bconfig.num_stages
kwargs["num_warps"] = bconfig.num_warps
return json.dumps(kwargs)

@register_metric()
def tflops(
Expand Down

0 comments on commit 3f8bb6a

Please sign in to comment.