Skip to content

Commit

Permalink
symbolic codegen and exec (tinygrad#1552)
Browse files Browse the repository at this point in the history
* symbolic codegen and exec

* fix and add test

* no sketchy

* merge_dicts type

* dtypes._arg_int32
  • Loading branch information
chenyuxyz committed Aug 16, 2023
1 parent 1e1d48b commit 11dd9b1
Show file tree
Hide file tree
Showing 16 changed files with 202 additions and 44 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ jobs:
run: DEBUG=2 METAL=1 python -m pytest -n=auto test/test_ops.py
- name: Run JIT test
run: DEBUG=2 METAL=1 python -m pytest -n=auto test/test_jit.py
- name: Run symbolic shapetracker test
run: METAL=1 python -m pytest -n=auto test/test_symbolic_shapetracker.py test/test_symbolic_ops.py
- name: Check Device.DEFAULT
run: WEBGPU=1 python -c "from tinygrad.lazy import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT"
#- name: Run webgpu pytest
Expand Down
14 changes: 13 additions & 1 deletion test/test_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from tinygrad.helpers import Context, ContextVar
from tinygrad.helpers import Context, ContextVar, merge_dicts

VARIABLE = ContextVar("VARIABLE", 0)

Expand Down Expand Up @@ -106,5 +106,17 @@ def test_context_exit_reverts_updated_values(self):
...
assert D.value == 2, f"Expected D to be 2, but was {D.value}. Indicates that Context.__exit__ did not restore to the correct value."

class TestMergeDicts(unittest.TestCase):
def test_merge_dicts(self):
a = {"a": 1, "b": 2}
b = {"a": 1, "c": 3}
c = {}
d = {"a": 2, "b": 2}
assert merge_dicts([a, b]) == {"a": 1, "b": 2, "c": 3}
assert merge_dicts([a, c]) == a
assert merge_dicts([a, b, c]) == {"a": 1, "b": 2, "c": 3}
with self.assertRaises(AssertionError):
merge_dicts([a, d])

if __name__ == '__main__':
unittest.main()
113 changes: 113 additions & 0 deletions test/test_symbolic_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import unittest
from tinygrad.shape.symbolic import Variable
from tinygrad.helpers import getenv, CI
from tinygrad.tensor import Tensor, Device
import numpy as np

@unittest.skipIf(getenv("ARM64"), "ARM64 is not supported")
@unittest.skipUnless(Device.DEFAULT in ["GPU", "METAL", "CLANG"], f"{Device.DEFAULT} is not supported")
class TestSymbolicOps(unittest.TestCase):
def test_plus1(self):
def f(a): return (a+1).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(3, i)
symbolic = f(a.reshape(3, vi)).reshape(3, i).cpu().numpy()
expected = f(a).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)

def test_add(self):
def f(a, b): return (a+b).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(3, i)
b = Tensor.rand(3, i)
symbolic = f(a.reshape(3, vi), b.reshape(3, vi)).reshape(3, i).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)

def test_matmul(self):
def f(a, b): return (a@b).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(3, i)
b = Tensor.rand(i, 5)
symbolic = f(a.reshape(3, vi), b.reshape(vi, 5)).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)

def test_matmul_same_var_different_val(self):
def f(a, b): return (a@b).realize()
vi = Variable("i", 1, 10)
a = Tensor.rand(3, 4)
b = Tensor.rand(7, 5)
with self.assertRaises(AssertionError):
f(a.reshape(3, vi), b.reshape(vi, 5)).cpu().numpy()

@unittest.skipIf(Device.DEFAULT == "CLANG" and CI, "broken on CLANG CI")
def test_attention(self):
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
q = Tensor.rand(2, 1, 4, 8)
k = Tensor.rand(2, i, 4, 8)
v = Tensor.rand(2, i, 4, 8)
symbolic = f(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).cpu().numpy()
expected = f(q, k, v).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)

def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(i, 3)
b = Tensor.rand(2, 3)
symbolic = f(a.reshape(vi, 3), b).reshape(i+2, 3).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)

def test_cat_dim1(self):
def f(a, b): return a.cat(b, dim=1).realize()
vi = Variable("i", 1, 10)
for i in range(1, 5):
a = Tensor.rand(3, i)
b = Tensor.rand(3, 2)
symbolic = f(a.reshape(3, vi), b).reshape(3, i+2).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)

def test_cat_dim0_two_vars(self):
def f(a, b): return a.cat(b, dim=0).realize()
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
a = Tensor.rand(i, 3)
b = Tensor.rand(j, 3)
symbolic = f(a.reshape(vi, 3), b.reshape(vj, 3)).reshape(i+j, 3).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)

def test_cat_dim1_two_vars(self):
def f(a, b): return a.cat(b, dim=1).realize()
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
a = Tensor.rand(3, i)
b = Tensor.rand(3, j)
symbolic = f(a.reshape(3, vi), b.reshape(3, vj)).reshape(3, i+j).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)

def test_two_vars_plus1(self):
def f(a, b): return (a@b+1).realize()
vi = Variable("i", 1, 10)
vj = Variable("j", 1, 10)
for i in range(1, 5):
for j in range(1, 5):
a = Tensor.rand(i, 3)
b = Tensor.rand(3, j)
symbolic = f(a.reshape(vi, 3), b.reshape(3, vj)).reshape(i, j).cpu().numpy()
expected = f(a, b).cpu().numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
15 changes: 8 additions & 7 deletions test/test_symbolic_shapetracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,18 @@ def test_cat_strides(self):
i = Variable("i", 1, 5)
j = Variable("j", 1, 5)
k = Variable("k", 1, 5)
t1 = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
st = t1.lazydata.st
t = Tensor.rand(3, 4).reshape(i, 4).cat(Tensor.rand(3, 4).reshape(j, 4), dim=0).cat(Tensor.rand(3, 4).reshape(k, 4), dim=0)
st = t.lazydata.st
assert st.shape == (i+j+k, 4)
assert st.real_strides() == (4, 1)
i = Variable("i", 1, 5)
j = Variable("j", 1, 5)
k = Variable("k", 1, 5)
t1 = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
st = t1.lazydata.st
t = Tensor.rand(3, 4).reshape(3, i).cat(Tensor.rand(3, 4).reshape(3, j), dim=1).cat(Tensor.rand(3, 4).reshape(3, k), dim=1)
st = t.lazydata.st
assert st.shape == (3, i+j+k)
assert st.real_strides() == (i+j+k, 1)
t = Tensor.rand(i, 3).reshape(i, 3).cat(Tensor.rand(3, 3).reshape(i, 3), dim=0).cat(Tensor.rand(3, 3), dim=0)
st = t.lazydata.st
assert st.shape == (2*i+3, 3)
assert st.real_strides() == (3, 1)

class TestSymbolicReshape(unittest.TestCase):
def test_reshape_into_symbols_simple(self):
Expand Down
1 change: 1 addition & 0 deletions test/unit/test_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def test_node_div_node(self):
assert NumNode(0) // (Variable("i", 1, 10)*128) == 0
assert NumNode(127) // (Variable("i", 1, 10)*128) == 0
assert idx0 // (i*3) == 0
assert i // i == 1

def test_node_mod_node(self):
i = Variable("i", 1, 10)
Expand Down
9 changes: 6 additions & 3 deletions tinygrad/codegen/linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, TernaryOps
from tinygrad.runtime.lib import RawConst, buf_is_kernel_arg
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape, View
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, sym_rename
VariableOrNum = Union[Variable, NumNode, Node]

# bottom ones are asm only
Expand Down Expand Up @@ -301,6 +301,9 @@ def linearize(self):
# add global buffers
for buf,name in self.arg_bufs.items():
self.uop(UOps.DEFINE_GLOBAL, None, [], (name, buf.dtype))
# add variables from symbolic shapes
for var in sorted(set(v for buf in self.ast.buffers for v in buf.st.var_vals), key=lambda k: k.key):
self.uop(UOps.DEFINE_GLOBAL, None, [], (var.expr, dtypes._arg_int32))

# add a local buffer for multistage reduce
if len(self.group_for_reduce):
Expand All @@ -317,7 +320,7 @@ def linearize(self):
if DEBUG >= 3: self.printbufs()

# kernel name (before late upcast)
self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) for x in self.full_shape])
self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) if isinstance(x, int) else sym_rename(x) for x in self.full_shape])
self.display_name = ("r_" if self.reduceop else "E_") + colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])

# parse AST
Expand Down Expand Up @@ -548,7 +551,7 @@ def colors(self) -> List[str]:
assert len(colors) == self.shape_len, "colors size mismatch"
return colors

def colored_shape(self) -> str: return ' '.join(colored(f"{s:4d}", color) for s,color in zip(self.full_shape, self.colors()))
def colored_shape(self) -> str: return ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) else s for s in self.full_shape], self.colors()))
def printbufs(self, prefix=""):
for i in range(len(self.sts)):
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i].realized is not None else str(self.bufs[i]):47s}", self.sts[i].views)
Expand Down
8 changes: 4 additions & 4 deletions tinygrad/codegen/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def hand_coded_optimizations(k:Linearizer):
# early exit
return

if k.opts.has_local:
if k.opts.has_local and all(isinstance(s, int) for s in k.sts[0].shape[:k.first_reduce]):
# are we grouping? (requires local shape support)
if not k.float4_axis(0) and k.first_reduce <= 2 and k.first_reduce + 1 <= k.shape_len and prod(k.sts[0].shape[:k.first_reduce]) <= 2048:
# TODO: use 1024 if it's allowed in a smarter way
Expand Down Expand Up @@ -204,8 +204,8 @@ def hand_coded_optimizations(k:Linearizer):
while prod(k.sts[0].shape[:k.first_reduce]) >= 1024:
xb_choices = []
for axis, upcast_amount in itertools.product(range(k.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
# if we haven't upcasted it, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
if axis not in upcasted_axis and k.full_shape[axis]%upcast_amount == 0 and any(k.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in k.upcasted_axis(buf_index)) for buf_index in range(len(k.sts))):
# if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
if axis not in upcasted_axis and isinstance(k.full_shape[axis], int) and k.full_shape[axis]%upcast_amount == 0 and any(k.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in k.upcasted_axis(buf_index)) for buf_index in range(len(k.sts))):
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in k.sts), sum(st.views[-1].strides[axis] for st in k.sts), axis, upcast_amount))
if len(xb_choices):
xb_choices = sorted(xb_choices)
Expand All @@ -219,7 +219,7 @@ def hand_coded_optimizations(k:Linearizer):

# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
if k.first_reduce < (k.shape_len-k.upcasted) and (len(list(k.shape_offsets(k.full_buf_index))) <= 4 or not any(r for _,_,r in k.upcasted_axis(k.full_buf_index))):
if (s:=k.full_unupcasted_shape[-1]) <= 32:
if (s:=k.full_unupcasted_shape[-1]) <= 32 and isinstance(s, int): # NOTE: cannot loop unroll symbolic axis
k.upcast()
# if it's small, upcast a second reduce dimension too
if k.first_reduce < (k.shape_len-k.upcasted) and s <= 3 and k.full_unupcasted_shape[-1] <= 3: k.upcast()
Expand Down
7 changes: 6 additions & 1 deletion tinygrad/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from weakref import KeyedRef, ref
from _weakref import _remove_dead_weakref # type: ignore
import numpy as np
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Callable, Any, Iterable
from math import prod # noqa: F401 # pylint:disable=unused-import

ShapeType = Tuple[int, ...]
Expand All @@ -22,6 +22,10 @@ def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (
def flatten(l:Iterator): return [item for sublist in l for item in sublist]
def mnum(i) -> str: return str(i) if i >= 0 else f"m{-i}"
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
def merge_dicts(ds:Iterable[Dict]) -> Dict:
kvs = set([(k,v) for d in ds for k,v in d.items()])
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
return {k:v for k,v in kvs}

@functools.lru_cache(maxsize=None)
def getenv(key, default=0): return type(default)(os.getenv(key, default))
Expand Down Expand Up @@ -115,6 +119,7 @@ def fields() -> Dict[str, DType]: return DTYPES_DICT
_half4: Final[DType] = DType(0, 2*4, "half4", None, 4)
_float2: Final[DType] = DType(4, 4*2, "float2", None, 2)
_float4: Final[DType] = DType(4, 4*4, "float4", None, 4)
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)

# HACK: staticmethods are not callable in 3.8 so we have to compare the class
DTYPES_DICT = {k: v for k, v in dtypes.__dict__.items() if not k.startswith('__') and not callable(v) and not v.__class__ == staticmethod}
Expand Down
35 changes: 21 additions & 14 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import functools, time
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, dedup
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, dedup, merge_dicts
from tinygrad.shape.shapetracker import MovementOps
from tinygrad.shape.symbolic import Variable, sym_infer
from tinygrad.runtime.lib import RawBuffer, RawConst, buf_is_kernel_arg
if TYPE_CHECKING:
from tinygrad.lazy import LazyBuffer
Expand Down Expand Up @@ -131,20 +132,24 @@ def build(self, runtime):
self.clprg = runtime(self.name, self.prg, **self.runtime_args)
return self

def exec(self, bufs, force_wait=False, optimizing=False) -> Optional[float]:
def exec(self, bufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False, optimizing=False) -> Optional[float]:
rawbufs = dedup([x.realized for x in bufs if buf_is_kernel_arg(x)])
if GlobalCounters.cache is not None and not optimizing: GlobalCounters.cache.append((self, rawbufs))
return self(rawbufs, force_wait=force_wait)

def __call__(self, rawbufs:List[RawBuffer], jit=False, force_wait=False) -> Optional[float]:
if et := self.clprg((self.global_size + [1]*(3-len(self.global_size))) if self.global_size is not None else None,
(self.local_size + [1]*(3-len(self.local_size))) if self.local_size is not None else None,
*rawbufs, wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et
return self(rawbufs, var_vals, force_wait=force_wait)

def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
if var_vals is None: var_vals = {}
global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size
local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else self.local_size
if et := self.clprg((global_size + [1]*(3-len(global_size))) if global_size is not None else None,
(local_size + [1]*(3-len(local_size))) if local_size is not None else None,
*rawbufs, *var_vals.values(), wait=force_wait or DEBUG>=1): GlobalCounters.time_sum_s += et
op_estimate = sym_infer(self.op_estimate, var_vals)
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(33-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {self.mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(33-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(global_size):18s} {str(local_size):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {self.mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += 1
GlobalCounters.global_ops += self.op_estimate
GlobalCounters.global_ops += op_estimate
GlobalCounters.global_mem += self.mem_estimate
if getenv("EARLY_STOPPING") and GlobalCounters.kernel_count == getenv("EARLY_STOPPING"): exit(0)
return et
Expand Down Expand Up @@ -178,9 +183,11 @@ def exec_ast(self, ast:LazyOp, output, **kwargs):
output.realized = None
break

# we don't have an output buffer, we have to create it
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
if not output.realized:
output.realized = self.buffer(prod(output.shape), output.dtype, **kwargs)
output.realized = self.buffer(prod((s if isinstance(s, int) else s.max for s in output.shape)), output.dtype, **kwargs)
# update the output var_vals from src
output.st.var_vals = dict(sorted(merge_dicts([buf.st.var_vals for buf in ast.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))

from tinygrad.codegen.linearizer import Linearizer
k = Linearizer(ast, output, self.linearizer_opts)
Expand All @@ -200,5 +207,5 @@ def get_program():

if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)

prg.exec(k.bufs)
prg.exec(k.bufs, var_vals=output.st.var_vals)
return output.realized
Loading

0 comments on commit 11dd9b1

Please sign in to comment.