Skip to content

Commit

Permalink
Drop support for more than 4D tensors since it turns out to be a pain…
Browse files Browse the repository at this point in the history
… to implement
  • Loading branch information
norabelrose committed Apr 25, 2021
1 parent 410f66d commit 85f970b
Showing 1 changed file with 16 additions and 24 deletions.
40 changes: 16 additions & 24 deletions python/triton/ops/blocksparse/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,8 +660,7 @@ def __call__(self, a, b):
a, b, self.trans_a, self.trans_b, False, self.mode, self.spdims, self.block, c_lut, c_num_locks, c_width,
c_packs, da_lut, da_num_locks, da_width, da_packs, db_lut, db_num_locks, db_width, db_packs
)
# This removes any leading singleton dimensions we may have added to the tensor that weren't in the input;
# also, if we flattened any extra leading dimensions (i.e. the inputs were more than 4D) this unflattens them
# This removes any leading singleton dimensions we may have added to the tensor that weren't in the input
return c.view(*leading_dims, *c.shape[-self.c_trailing_dims:])

def _validate_inputs(self, a, b):
Expand Down Expand Up @@ -700,6 +699,21 @@ def _validate_inputs(self, a, b):
raise ValueError(f"Expected tensor of shape {self.sparse_shape} for argument {sparse_name}, "
f"got {sparse.shape}")

def add_extra_dims(x):
# Add extra leading singleton dimensions if needed
dims_needed = 4 - x.ndim
if dims_needed > 0:
singletons = [1] * dims_needed
x = x.view(*singletons, *x.shape)
elif dims_needed < 0:
raise ValueError("Tensors with more than 4 dimensions are not currently supported")

return x

# Pad shapes with leading singleton dimensions
a = add_extra_dims(a)
b = add_extra_dims(b)

leading_dims_a = a.shape[:-self.a_trailing_dims] # Batch, attention head, etc. dims
leading_dims_b = b.shape[:-self.b_trailing_dims] # Batch, attention head, etc. dims
if leading_dims_a != leading_dims_b:
Expand All @@ -708,31 +722,9 @@ def _validate_inputs(self, a, b):
except RuntimeError:
raise ValueError(f"Tensors A and B should be broacastable along leading (non-matrix) dimensions; "
f"got {leading_dims_a} for A and {leading_dims_b} for B.")
else:
a = a.expand(*leading_dims_c, *a.shape[-self.a_trailing_dims:])
b = b.expand(*leading_dims_c, *b.shape[-self.b_trailing_dims:])
else:
leading_dims_c = leading_dims_a

def normalize_leading_dims(x):
extra_dims = x.ndim - 4

# Flatten any extra leading dimensions into the batch dimension; this
# will involve copying if we just broadcasted A and B
if extra_dims > 0:
x = x.flatten(0, extra_dims)

# Add extra leading singleton dimensions if needed
elif extra_dims < 0:
singletons = [1] * -extra_dims
x = x.view(*singletons, *x.shape)

return x

# Either pad shapes with leading singleton dimensions or flatten extra leading dimensions as needed
a = normalize_leading_dims(a)
b = normalize_leading_dims(b)

return a, b, leading_dims_c

# Copied from the PyTorch 1.8 implementation so that we can support older PyTorch versions
Expand Down

0 comments on commit 85f970b

Please sign in to comment.