Skip to content

Commit

Permalink
Add autocast support for sparse matmul op; throw error if dtypes diff…
Browse files Browse the repository at this point in the history
…er and autocast is off
  • Loading branch information
norabelrose committed Apr 24, 2021
1 parent c64dddb commit 99999fc
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions python/triton/ops/blocksparse/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,8 @@ def __init__(self, layout, block, mode, trans_a=False, trans_b=False):
if layout_dim == 2:
layout = layout.unsqueeze(0)

self.layout = layout.long() # Above code assumes the layout tensor is an integral type
layout = layout.long() # Above code assumes the layout tensor is an integral type
self.layout = layout
self.spdims = layout.shape

if not mode == 'sdd':
Expand All @@ -641,9 +642,9 @@ def __call__(self, a, b):
da_lut, da_num_locks, da_width, da_packs,\
db_lut, db_num_locks, db_width, db_packs = self.make_lut(a.dtype, a.device)

# If we don't check for invalid shapes & devices here, they will lead to undefined behavior
# If we don't check for invalid shapes, devices, & dtypes here, they will lead to undefined behavior
# and potential illegal memory accesses
leading_dims = self._validate_inputs(a, b)
a, b, leading_dims = self._validate_inputs(a, b)

# Either pad shapes with leading singleton dimensions or flatten extra leading dimensions as needed
a = matmul._normalize_leading_dims(a)
Expand Down Expand Up @@ -675,13 +676,18 @@ def _normalize_leading_dims(x):
return x

def _validate_inputs(self, a, b):
device_a, device_b = a.device, b.device
if device_a != device_b:
raise ValueError(f"Inputs must be on the same device; got {device_a} for tensor A "
f"and {device_b} for tensor B")
if device_a.type != 'cuda':
if a.device != b.device:
raise ValueError(f"Inputs must be on the same device; got {a.device} for tensor A "
f"and {b.device} for tensor B")
if not a.is_cuda:
raise ValueError("Only GPU devices are supported for now")

# When autocast is enabled, torch.matmul autocasts to float16, so we do the same here
if torch.is_autocast_enabled():
a, b = a.half(), b.half()
elif a.dtype != b.dtype:
raise ValueError(f"Inputs must be the same dtype; got {a.dtype} for A and {b.dtype} for B")

mode, trans_a, trans_b = self.mode, self.trans_a, self.trans_b
if mode == 'sdd':
# Both inputs are dense and the output is sparse
Expand Down Expand Up @@ -711,7 +717,7 @@ def _validate_inputs(self, a, b):
raise ValueError(f"Tensors A and B should be broacastable along leading (non-matrix) dimensions; "
f"got {leading_dims_a} for A and {leading_dims_b} for B.")

return leading_dims_a
return a, b, leading_dims_a


# Standard PyTorch broadcasting semantics- see <https://pytorch.org/docs/stable/notes/broadcasting.html>
Expand Down

0 comments on commit 99999fc

Please sign in to comment.