Skip to content

Commit

Permalink
clean up to_shape_strides (tinygrad#2402)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz committed Nov 23, 2023
1 parent e4026dc commit 64aa2f4
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions tinygrad/shape/shapetracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> Tuple[Tuple[int, int], ...]:
assert len(shape) == len(strides)
ret = [(shape[0], strides[0])] if shape else []
for i in range(1, len(shape)):
if ret[-1][1] == shape[i]*strides[i] or ret[-1][0] == 1:
ret[-1] = (ret[-1][0] * shape[i], strides[i])
elif shape[i] == 1:
continue
else:
ret.append((shape[i], strides[i]))
for s,st in zip(shape[1:], strides[1:]):
ps,pst = ret[-1]
if pst == s*st or ps == 1: ret[-1] = (ps*s, st)
elif s != 1: ret.append((s, st))
return tuple(ret)

def expr_node_mask(view:View, idx:Node, valid:Optional[Node]=None) -> Node:
Expand Down

0 comments on commit 64aa2f4

Please sign in to comment.