Skip to content

Commit

Permalink
fix the variable arg order (tinygrad#2382)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz committed Nov 21, 2023
1 parent c5f429a commit 9eeba96
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 6 deletions.
16 changes: 15 additions & 1 deletion test/test_symbolic_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def f(a, b): return a.cat(b, dim=1).realize()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1

def test_two_vars_plus1(self):
def test_two_vars_plus1_ij(self):
def f(a, b): return (a@b+1).realize()
jf = TinyJit(f)
for i in range(1, 5):
Expand All @@ -136,6 +136,20 @@ def f(a, b): return (a@b+1).realize()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1

def test_two_vars_plus1_ji(self):
def f(a, b): return (a@b+1).realize()
jf = TinyJit(f)
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
a = Tensor.rand(j, 3)
b = Tensor.rand(3, i)
symbolic = jf(a.reshape(vj, 3), b.reshape(3, vi)).reshape(j, i).numpy()
expected = f(a, b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 1

def test_jit_symbolic_shape_mismatch(self):
@TinyJit
def add(a, b): return (a+b).realize()
Expand Down
15 changes: 14 additions & 1 deletion test/test_symbolic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def f(a, b): return a.cat(b, dim=1).realize()
expected = f(a, b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)

def test_two_vars_plus1(self):
def test_two_vars_plus1_ij(self):
def f(a, b): return (a@b+1).realize()
for i in range(1, 5):
for j in range(1, 5):
Expand All @@ -110,6 +110,19 @@ def f(a, b): return (a@b+1).realize()
expected = f(a, b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)

def test_two_vars_plus1_ji(self):
# reverse the order of variables
def f(a, b): return (a@b+1).realize()
for i in range(1, 5):
for j in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
vj = Variable("j", 1, 10).bind(j)
a = Tensor.rand(j, 3)
b = Tensor.rand(3, i)
symbolic = f(a.reshape(vj, 3), b.reshape(3, vi)).reshape(j, i).numpy()
expected = f(a, b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)

def test_shrink(self):
for i in range(1, 5):
vi = Variable("i", 1, 10).bind(i)
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/codegen/linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def linearize(self):
if isinstance(buf, MemBuffer):
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype))
# add var vals
for var in vars_from_ast(self.ast):
for var in sorted(vars_from_ast(self.ast)):
assert var.expr is not None
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32))
# define local buffers
Expand Down
5 changes: 3 additions & 2 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def __init__(self, ast:LazyOp, fxn:Callable):
super().__init__(ast)

def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> float:
var_vals = {k:var_vals[k] for k in sorted(self.vars)} if var_vals is not None else {}
st = time.perf_counter()
ret: RawBuffer = self.fxn(rawbufs[1:], var_vals)
et = time.perf_counter() - st
Expand Down Expand Up @@ -286,8 +287,8 @@ def launch_dims(self, var_vals):
return global_size, local_size

def __call__(self, rawbufs:List[Optional[RawBuffer]], var_vals:Optional[Dict[Variable, int]]=None, jit=False, force_wait=False) -> Optional[float]:
if var_vals is None: var_vals = {}
var_vals = {k:var_vals[k] for k in self.vars} # filter the var_vals
# filter the var_vals
var_vals = {k:var_vals[k] for k in sorted(self.vars)} if var_vals is not None else {}
global_size, local_size = self.launch_dims(var_vals)
if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type]
# TODO: this is copied from get_program
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, s
icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i)
if i == 0: write_resources.append(b._buf)
else: read_resources.append(b._buf)
var_vals_keys = list(var_vals.keys())
var_vals_keys = sorted(var_vals.keys())
for i,v in enumerate(prg.vars):
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf._buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i)
global_size, local_size = prg.launch_dims(var_vals)
Expand Down

0 comments on commit 9eeba96

Please sign in to comment.