Skip to content

Commit

Permalink
that had bugs, force an order (tinygrad#2411)
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed Nov 23, 2023
1 parent 65f4e69 commit 193be14
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 14 deletions.
4 changes: 2 additions & 2 deletions test/test_symbolic_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def f(a, b):
symbolic = jf(a.reshape(3, vi), b.reshape(vi, 5)).numpy()
expected = f(a, b).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 2
assert len(jf.jit_cache) == 2 or getattr(Device[Device.DEFAULT], "graph", None)

def test_attention(self):
def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).realize()
Expand All @@ -68,7 +68,7 @@ def f(q, k, v): return Tensor.scaled_dot_product_attention(q.transpose(1, 2), k.
symbolic = jf(q, k.reshape(2, vi, 4, 8), v.reshape(2, vi, 4, 8)).reshape(2, 4, 1, 8).numpy()
expected = f(q, k, v).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)
assert len(jf.jit_cache) == 6
assert len(jf.jit_cache) == 6 or getattr(Device[Device.DEFAULT], "graph", None)

def test_cat_dim0(self):
def f(a, b): return a.cat(b, dim=0).realize()
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/codegen/linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def linearize(self):
if isinstance(buf, MemBuffer):
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype))
# add var vals
for var in sorted(vars_from_ast(self.ast)):
for var in vars_from_ast(self.ast):
assert var.expr is not None
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32))
# define local buffers
Expand Down
3 changes: 2 additions & 1 deletion tinygrad/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)

def vars_from_ast(ast:LazyOp) -> Set[Variable]: return set.union(*[x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set())
# NOTE: this is the canonical order
def vars_from_ast(ast:LazyOp) -> List[Variable]: return sorted(set.union(*[x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set()), key=lambda x: str(x.expr))

lazycache: WeakValueDictionary = WeakValueDictionary()
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, base:Optional[LazyBuffer]=None):
Expand Down
14 changes: 5 additions & 9 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
import importlib, inspect, functools, pathlib, time, re
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, Set
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT, dedup, all_int
from tinygrad.runtime.lib import RawBuffer
from tinygrad.shape.symbolic import Variable, sym_infer, sint
Expand Down Expand Up @@ -184,10 +184,8 @@ def __init__(self, buffer: Type[RawBuffer], fxn_for_op:Dict[Op, Callable]):

def exec_ast(self, ast:LazyOp, output:LazyBuffer, inputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], **kwargs):
if ast not in self.method_cache: self.method_cache[ast] = get_interpreted_fxn(self.fxn_for_op, ast)
rawbufs = [output.realized if output.realized is not None else output.output_buffer] + [x.realized for x in inputs]
if rawbufs[0] is None: rawbufs[0] = self.buffer.__new__(self.buffer)
self.method_cache[ast].exec(rawbufs, var_vals)
output.realized = rawbufs[0]
output.realized = output.output_buffer if output.output_buffer is not None else self.buffer.__new__(self.buffer)
self.method_cache[ast].exec([output.realized] + [x.realized for x in inputs], var_vals)

def get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> InterpretedASTRunner:
if DEBUG >= 3:
Expand Down Expand Up @@ -236,7 +234,7 @@ def __init__(self, ast:Optional[LazyOp], name:str, prg:str, global_size:Optional
if DEBUG >= 4: print(prg)
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = \
name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
self.vars: Set[Variable] = set()
self.vars: List[Variable] = []
if ast:
info = get_lazyop_info(ast)
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
Expand All @@ -255,8 +253,6 @@ def launch_dims(self, var_vals):
return global_size, local_size

def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
# filter the var_vals
var_vals = {k:var_vals[k] for k in sorted(self.vars)}
global_size, local_size = self.launch_dims(var_vals)
if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type]
# TODO: this is copied from get_program
Expand All @@ -266,7 +262,7 @@ def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=F
lra = self.runtime_args.copy()
if global_size: lra['global_size'] = global_size
if local_size and 'local_size' not in lra: lra['local_size'] = local_size
et = self.clprg(*rawbufs, *var_vals.values(), **lra, wait=wait or DEBUG>=2)
et = self.clprg(*rawbufs, *[var_vals[k] for k in self.vars], **lra, wait=wait or DEBUG>=2)
update_stats(self.display_name if self.display_name is not None else self.name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra)
return et

Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[RawBuffer],
icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i)
if i == 0: write_resources.append(b._buf)
else: read_resources.append(b._buf)
var_vals_keys = sorted(var_vals.keys())
var_vals_keys = list(var_vals.keys())
for i,v in enumerate(prg.vars):
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf._buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i)
global_size, local_size = prg.launch_dims(var_vals)
Expand Down

0 comments on commit 193be14

Please sign in to comment.