Skip to content

Commit

Permalink
Fix up mm benchmark option handling
Browse files Browse the repository at this point in the history
Summary:
D56609656 inadvertently made it such that when we ran the gemm benchmark
without m/n/k arguments, it would default to using `m == n == k == 8`, instead
of the values in `BUILDIN_SHAPES`. This diff fixes that. I also removed
`Operator.tbargs` because its superclass already has `self.tb_args`, which I
think is kind of confusing.

Along the way, I also noticed that the addmm benchmark was not handling the
m/n/k arguments correctly either, so I fixed that too.

Reviewed By: xuzhao9

Differential Revision: D56842980

fbshipit-source-id: bf65c1a31d9db35a228593eae70e72d6d216c194
  • Loading branch information
int3 authored and facebook-github-bot committed May 1, 2024
1 parent 254c5a7 commit 2ec8267
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 16 deletions.
9 changes: 5 additions & 4 deletions torchbenchmark/operators/addmm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "speedup", "accuracy"]
DEFAULT_PRECISION = "bf16"

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None):
def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
if not self.extra_args:
self.shapes = BUILDIN_SHAPES
addmm_args = parse_args(self.extra_args)
if addmm_args.m and addmm_args.n and addmm_args.k:
self.shapes = [(addmm_args.m, addmm_args.k, addmm_args.n)]
else:
self.shapes = [(self.tb_args.m, self.tbargs.k, self.tbargs.n)]
self.shapes = BUILDIN_SHAPES

@register_benchmark()
def triton_addmm(self, a, mat1, mat2) -> Callable:
Expand Down
10 changes: 5 additions & 5 deletions torchbenchmark/operators/gemm/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@

def parse_args(args: List[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="TorchBench Gemm operator Benchmark")
parser.add_argument("--m", default=8, type=int)
parser.add_argument("--k", default=8, type=int)
parser.add_argument("--n", default=8, type=int)
parser.add_argument("--bias", default=None, type=int)
parser.add_argument("--input", default=None, type=str)
parser.add_argument("--m", type=int)
parser.add_argument("--k", type=int)
parser.add_argument("--n", type=int)
parser.add_argument("--bias", type=int)
parser.add_argument("--input", type=str)
parser.add_argument("--splitk", action="store_true", default=False)
args = parser.parse_args(args)
return args
Expand Down
14 changes: 7 additions & 7 deletions torchbenchmark/operators/gemm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,16 @@ class Operator(BenchmarkOperator):
DEFAULT_METRICS = ["latency", "speedup", "accuracy"]
DEFAULT_PRECISION = "fp16"

def __init__(self, mode: str, device: str, extra_args: Optional[List[str]]=None):
def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = None):
super().__init__(mode=mode, device=device, extra_args=extra_args)
self.tbargs = parse_args(self.extra_args)
if self.tbargs.input:
self.shapes = read_shapes_from_csv(self.tbargs.input)
elif self.tbargs.splitk:
gemm_args = parse_args(self.extra_args)
if gemm_args.input:
self.shapes = read_shapes_from_csv(gemm_args.input)
elif gemm_args.splitk:
self.shapes = SPLIT_K_SHAPES
elif self.tbargs.m and self.tbargs.k and self.tbargs.n:
elif gemm_args.m and gemm_args.k and gemm_args.n:
self.shapes = [
(self.tbargs.m, self.tbargs.k, self.tbargs.n, self.tbargs.bias)
(gemm_args.m, gemm_args.k, gemm_args.n, gemm_args.bias)
]
else:
self.shapes = BUILDIN_SHAPES
Expand Down

0 comments on commit 2ec8267

Please sign in to comment.