Skip to content

Commit

Permalink
Sparse matmul op should pass tests again
Browse files Browse the repository at this point in the history
  • Loading branch information
norabelrose committed Apr 25, 2021
1 parent 99999fc commit a2e4d80
Showing 1 changed file with 37 additions and 30 deletions.
67 changes: 37 additions & 30 deletions python/triton/ops/blocksparse/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,15 @@ def __init__(self, layout, block, mode, trans_a=False, trans_b=False):
self.dense_inner_size = layout.shape[sparse_inner] * block
# Expected shape for sparse inputs
self.sparse_shape = (layout.shape[0], layout.sum().item(), block, block)
self.a_trailing_dims = 3 if mode == 'dsd' else 2
self.b_trailing_dims = 2 if mode == 'dsd' else 3
self.c_trailing_dims = 2
else:
# How many of the dims of A, B, and C should be considered non-batch dims? Block sparse tensors
# have shape [(...,) blocks, block height, block width]
self.a_trailing_dims = 2
self.b_trailing_dims = 2
self.c_trailing_dims = 3

def __call__(self, a, b):
c_lut, c_num_locks, c_width, c_packs,\
Expand All @@ -646,34 +655,14 @@ def __call__(self, a, b):
# and potential illegal memory accesses
a, b, leading_dims = self._validate_inputs(a, b)

# Either pad shapes with leading singleton dimensions or flatten extra leading dimensions as needed
a = matmul._normalize_leading_dims(a)
b = matmul._normalize_leading_dims(b)

# execute
c = _matmul.apply(
a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width,
c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs
)
# This removes any leading singleton dimensions we may have added to the tensor that weren't in the input;
# also, if we flattened any extra leading dimensions (i.e. the inputs were more than 4D) this unflattens them
return c.view(*leading_dims, *c.shape[-2:])

# Kernel assumes that all tensors are 4 dimensional
@staticmethod
def _normalize_leading_dims(x):
extra_dims = x.ndim - 4

# Flatten any extra leading dimensions into the batch dimension
if extra_dims > 0:
x = x.flatten(0, extra_dims)

# Add extra leading singleton dimensions if needed
elif extra_dims < 0:
singletons = [1] * -extra_dims
x = x.view(*singletons, *x.shape)

return x
return c.view(*leading_dims, *c.shape[-self.c_trailing_dims:])

def _validate_inputs(self, a, b):
if a.device != b.device:
Expand Down Expand Up @@ -707,20 +696,38 @@ def _validate_inputs(self, a, b):
raise ValueError(f"Expected tensor {dense_name} to have size {self.dense_inner_size} at dim "
f"{self.dense_inner_dim % dense.ndim}, got {dense_inner}.")

if not broadcastable_shapes(sparse.shape, self.sparse_shape):
if sparse.shape != self.sparse_shape:
raise ValueError(f"Expected tensor of shape {self.sparse_shape} for argument {sparse_name}, "
f"got {sparse.shape}")

leading_dims_a = a.shape[:-2] # Batch, attention head, etc. dims
leading_dims_b = b.shape[:-2] # Batch, attention head, etc. dims
if not broadcastable_shapes(leading_dims_a, leading_dims_b):
leading_dims_a = a.shape[:-self.a_trailing_dims] # Batch, attention head, etc. dims
leading_dims_b = b.shape[:-self.b_trailing_dims] # Batch, attention head, etc. dims
try:
leading_dims_c = torch.broadcast_shapes(leading_dims_a, leading_dims_b)
except RuntimeError:
raise ValueError(f"Tensors A and B should be broacastable along leading (non-matrix) dimensions; "
f"got {leading_dims_a} for A and {leading_dims_b} for B.")
else:
a = a.expand(*leading_dims_c, *a.shape[-self.a_trailing_dims:])
b = b.expand(*leading_dims_c, *b.shape[-self.b_trailing_dims:])

def normalize_leading_dims(x):
extra_dims = x.ndim - 4

return a, b, leading_dims_a
# Flatten any extra leading dimensions into the batch dimension; this
# will involve copying if we just broadcasted A and B
if extra_dims > 0:
x = x.flatten(0, extra_dims)

# Add extra leading singleton dimensions if needed
elif extra_dims < 0:
singletons = [1] * -extra_dims
x = x.view(*singletons, *x.shape)

return x

# Either pad shapes with leading singleton dimensions or flatten extra leading dimensions as needed
a = normalize_leading_dims(a)
b = normalize_leading_dims(b)

# Standard PyTorch broadcasting semantics- see <https://pytorch.org/docs/stable/notes/broadcasting.html>
def broadcastable_shapes(shape1, shape2):
dim_iter = zip(reversed(shape1), reversed(shape2))
return all(a_dim == b_dim or a_dim == 1 or b_dim == 1 for a_dim, b_dim in dim_iter)
return a, b, leading_dims_c

0 comments on commit a2e4d80

Please sign in to comment.