Skip to content

Commit

Permalink
HSTU addmm shapes
Browse files Browse the repository at this point in the history
Reviewed By: xuzhao9, chenyang78

Differential Revision: D58097783

fbshipit-source-id: 3affeb633bc3d6206c22c9f9f4a6a77568f2f7cc
  • Loading branch information
sijiac authored and facebook-github-bot committed Jun 7, 2024
1 parent 9f2ab74 commit 1a986c5
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 16 deletions.
10 changes: 5 additions & 5 deletions torchbenchmark/operators/addmm/data_io.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import argparse
import os
import csv
import os
from typing import List

from torchbenchmark import REPO_PATH


def parse_args(args: List[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="TorchBench Addmm operator Benchmark")
parser.add_argument("--m", default=8, type=int)
parser.add_argument("--k", default=8, type=int)
parser.add_argument("--n", default=8, type=int)
parser.add_argument("--input", default=None, type=str)
parser.add_argument("--m", type=int)
parser.add_argument("--k", type=int)
parser.add_argument("--n", type=int)
parser.add_argument("--input", type=str)
args = parser.parse_args(args)
return args
78 changes: 67 additions & 11 deletions torchbenchmark/operators/addmm/operator.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,73 @@
import csv
import os
import statistics
from typing import Any, Callable, Generator, List, Optional
from typing import Any, Callable, Generator, List, Optional, Tuple

import numpy
import torch
import triton
from hammer.ops.triton.triton_hstu_linear import triton_addmm
from hammer.ops.triton.triton_hstu_linear import _addmm_fwd, triton_addmm

from torchbenchmark.util.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
dump_autotuner_best_config,
register_benchmark,
register_metric,
register_x_val,
)

from .data_io import parse_args


BUILDIN_SHAPES = [(M * 128, 512, N) for N in [1536, 512] for M in range(4, 17)]
BUILDIN_SHAPES = [
(20120, 1536, 512),
(34579, 1536, 512),
(34839, 1536, 512),
(35561, 1536, 512),
(35916, 1536, 512),
(19735, 1536, 512),
(34533, 1536, 512),
(35791, 1536, 512),
(35844, 1536, 512),
(20116, 1536, 512),
(33887, 1536, 512),
(20203, 1536, 512),
(33961, 1536, 512),
(19747, 1536, 512),
(34181, 1536, 512),
(35541, 1536, 512),
(36032, 1536, 512),
(15168, 1536, 512),
(35249, 1536, 512),
(33894, 1536, 512),
(20067, 1536, 512),
(27456, 1536, 512),
(19410, 1536, 512),
(35884, 1536, 512),
(35917, 1536, 512),
(19632, 1536, 512),
(35656, 1536, 512),
(35405, 1536, 512),
(35503, 1536, 512),
(35504, 1536, 512),
(35605, 1536, 512),
(34238, 1536, 512),
(33660, 1536, 512),
(35410, 1536, 512),
(20211, 1536, 512),
(34308, 1536, 512),
(34516, 1536, 512),
(20224, 1536, 512),
(35678, 1536, 512),
(35380, 1536, 512),
(35901, 1536, 512),
(20068, 1536, 512),
]


class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "speedup", "accuracy"]
DEFAULT_METRICS = ["tflops"]
DEFAULT_PRECISION = "bf16"

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
Expand All @@ -41,12 +86,6 @@ def triton_addmm(self, a, mat1, mat2) -> Callable:
def aten_addmm(self, a, mat1, mat2) -> Callable:
return lambda: torch.addmm(a, mat1, mat2)

def get_x_val(self, example_inputs) -> float:
_, mat1, mat2 = example_inputs
m, k = mat1.size()
k, n = mat2.size()
return f"{m}-{k}-{n}"

@register_metric()
def gbps(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
Expand All @@ -65,13 +104,30 @@ def gbps(
@register_metric()
def tflops(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> float:
) -> List[float]:
_, mat1, mat2 = example_inputs
m, k = mat1.size()
k, n = mat2.size()
flops = m * k * 2 * n
return [flops / x / 1e12 * 1e3 for x in metrics.latency]

@register_x_val(label="(M, N, K)")
def get_x_val(self, example_inputs) -> Tuple[int, int, int]:
# x-value: computation intensity
a, mat1, mat2 = example_inputs
m, k = mat1.size()
k, n = mat2.size()
return (m, n, k)

@register_metric(skip_baseline=True)
def best_config(
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
) -> str:
if "triton_addmm" in str(fn_name):
return dump_autotuner_best_config(_addmm_fwd)
else:
return ""

def get_input_iter(self) -> Generator:
for shape in self.shapes:
m, k, n = shape
Expand Down

0 comments on commit 1a986c5

Please sign in to comment.