Skip to content

Commit

Permalink
Col-major for mat2
Browse files Browse the repository at this point in the history
Summary: as the title

Reviewed By: xuzhao9

Differential Revision: D58097785

fbshipit-source-id: 056e78bf04388416302ea0e6835d353fb3ff469e
  • Loading branch information
sijiac authored and facebook-github-bot committed Jun 7, 2024
1 parent 1a986c5 commit 9cd6e86
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions torchbenchmark/operators/addmm/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ def parse_args(args: List[str]) -> argparse.Namespace:
parser.add_argument("--k", type=int)
parser.add_argument("--n", type=int)
parser.add_argument("--input", type=str)
parser.add_argument("--col-major", type=bool, default=False)
args = parser.parse_args(args)
return args
3 changes: 3 additions & 0 deletions torchbenchmark/operators/addmm/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(self, mode: str, device: str, extra_args: Optional[List[str]] = Non
self.shapes = [(addmm_args.m, addmm_args.k, addmm_args.n)]
else:
self.shapes = BUILDIN_SHAPES
self.col_major = addmm_args.col_major

@register_benchmark()
def triton_addmm(self, a, mat1, mat2) -> Callable:
Expand Down Expand Up @@ -140,6 +141,8 @@ def get_input_iter(self) -> Generator:
mat2 = torch.randn(
(k, n), device=self.device, dtype=self.dtype
).requires_grad_(False)
if self.col_major:
mat2 = mat2.T.contiguous().T
yield a, mat1, mat2

def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
Expand Down

0 comments on commit 9cd6e86

Please sign in to comment.