Skip to content

Commit

Permalink
LazyOp out of Linearizer (tinygrad#1908)
Browse files Browse the repository at this point in the history
* loadop buffer on cpu

* works for GPU

* sort of working

* has bugs

* gpu tests pass

* fix some tests

* fix tensor cores

* fix test linearizer

* fix symbolic

* fix has_variable_shape

* non symbolic size

* disable weird test

* simple cache fix

* fix custom function

* fix kopt

* cleanups

* a bit broken on the assign

* contig check

* only buffer

* need that order

* idx

* dedup buffers

* hmm, bugfix

* fix tensor cores

* opts device
  • Loading branch information
geohot committed Sep 24, 2023
1 parent 2201b46 commit 7ff7aac
Show file tree
Hide file tree
Showing 15 changed files with 158 additions and 145 deletions.
19 changes: 8 additions & 11 deletions docs/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,13 @@ def linearize(self): pass
uops: List[UOp]

from tinygrad.tensor import Tensor
from tinygrad.helpers import prod
result = Tensor(2).realize() + Tensor(3).realize()
result.lazydata.realized = Device[Device.DEFAULT].buffer(prod(result.shape), result.dtype)

# use the real Linearizer to linearize 2+3
from tinygrad.lazy import _replace_loadops
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.codegen.kernel import LinearizerOptions
linearizer = Linearizer(result.lazydata.op, result.lazydata, LinearizerOptions())
op, _ = _replace_loadops(result.lazydata.op)
linearizer = Linearizer(op)
linearizer.linearize()

# print the uops
Expand All @@ -279,13 +278,11 @@ def linearize(self): pass
# output:
"""
0 UOps.DEFINE_GLOBAL : ptr.dtypes.float [] ('data0', dtypes.float)
1 UOps.LOOP : [] ([], 'global')
2 UOps.LOOP : [] ([], 'local')
3 UOps.CONST : dtypes.float [] 2.0
4 UOps.CONST : dtypes.float [] 3.0
5 UOps.ALU : dtypes.float [3, 4] BinaryOps.ADD
6 UOps.STORE : [5] MemOp(name='data0', idx=<0>, local=False, memory_dtype=dtypes.float, valid=<1>, invalid_value=0.0)
7 UOps.ENDLOOP : [] ([], 'global+local')
1 UOps.CONST : dtypes.float [] 2.0
2 UOps.CONST : dtypes.float [] 3.0
3 UOps.ALU : dtypes.float [1, 2] BinaryOps.ADD
4 UOps.CONST : dtypes.int [] 0
5 UOps.STORE : [0, 4, 3] None
"""

# %%
Expand Down
8 changes: 4 additions & 4 deletions test/models/test_real_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,17 @@
from examples.llama import Transformer as LLaMaTransformer, MODEL_PARAMS as LLAMA_MODEL_PARAMS
from examples.stable_diffusion import UNetModel

def kopt_search_hook(k, create_k, to_prg, baseline):
def kopt_search_hook(k, create_k, to_prg, baseline, bufs):
import nevergrad as ng
wanna_output = k.bufs[0].toCPU().copy()
wanna_output = bufs[0].toCPU().copy()
def check_opt(x):
try:
k = create_k()
k.process()
k.apply_auto_opt(x)
prg = to_prg(k)
first_tm = prg.exec(k.bufs, force_wait=True, optimizing=True)
np.testing.assert_allclose(wanna_output, k.bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
first_tm = prg.exec(bufs, force_wait=True, optimizing=True)
np.testing.assert_allclose(wanna_output, bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
return first_tm
except Exception:
return 10000_000 # 10000 seconds is infinity
Expand Down
2 changes: 1 addition & 1 deletion test/test_custom_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
__kernel void atan2_gpu(global float *c, global float *a, global float *b) {
int idx = get_global_id(0);
c[idx] = atan2(a[idx], b[idx]);
}""", global_size=[prod(ret.shape)]).build(Device[ret.device].runtime).exec([ret, a, b])
}""", global_size=[prod(ret.shape)]).build(Device[ret.device].runtime).exec([ret.realized, a.realized, b.realized])
return ret.realized

def atan2_cpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
Expand Down
39 changes: 21 additions & 18 deletions test/test_linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tinygrad.ops import Compiled, Device, MovementOps, LazyOp
from tinygrad.tensor import Tensor
from tinygrad.jit import CacheCollector
from tinygrad.lazy import _replace_loadops

class TestLinearizer(unittest.TestCase):
def test_arg_dedup(self):
Expand All @@ -30,7 +31,7 @@ def test_load_dedup(self):
r = a[:-1] + a[1:]
ast = r.lazydata.op
r = r.realize() # realize an output buffer
k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
k.process()
k.upcast()
k.linearize()
Expand All @@ -48,7 +49,7 @@ def test_upcast_cse(self):
r = a.expand([2]) + b.expand([2])
ast = r.lazydata.op
r = r.realize() # realize an output buffer
k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
k.process()
k.upcast()
k.linearize()
Expand All @@ -63,7 +64,7 @@ def test_zero_fold(self):
r = Tensor.stack([a, b])
ast = r.lazydata.op
r = r.realize() # realize an output buffer
k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
k.process()
k.upcast()
k.linearize()
Expand All @@ -79,7 +80,7 @@ def test_constant_fold(self):
r = a * b
ast = r.lazydata.op
r = r.realize() # realize an output buffer
k = Linearizer(ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
k = Linearizer(_replace_loadops(ast)[0], Device[Device.DEFAULT].linearizer_opts)
k.process()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]])
Expand All @@ -88,12 +89,14 @@ def test_constant_fold(self):
def helper_linearizer_opt(r:Tensor, opts=[]):
wanna_output = None
realized_ast = None
real_bufs = None

# HACK to get real ast.
real_dev_exec_ast = Device[Device.DEFAULT].exec_ast
def fake_exec_ast(ast, output=None, **kwargs):
nonlocal realized_ast
x = real_dev_exec_ast(ast, output, **kwargs)
def fake_exec_ast(ast, output=None, inputs=None, **kwargs):
nonlocal realized_ast, real_bufs
x = real_dev_exec_ast(ast, output, inputs, **kwargs)
real_bufs = [output.realized] + inputs
if not(ast.op in MovementOps and ast.src[0].__class__ is not LazyOp and ast.src[0].realized): realized_ast = ast # get last executed
return x
Device[Device.DEFAULT].exec_ast = fake_exec_ast
Expand All @@ -106,26 +109,26 @@ def check_opt(x, create_k, to_prg):
k.process()
k.apply_auto_opt(x)
prg = to_prg(k)
k.bufs[0].realized = k.bufs[0].realized.fromCPU(np.zeros(k.bufs[0].shape, dtype=k.bufs[0].dtype.np)) # Zero to check that all values are filled
prg.exec(k.bufs, force_wait=True)
np.testing.assert_allclose(wanna_output, k.bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
real_bufs[0] = real_bufs[0].fromCPU(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np)) # Zero to check that all values are filled
prg.exec(real_bufs, force_wait=True)
np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4)

# Get baseline, which is not optimized at all.
k = Linearizer(realized_ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
k = Linearizer(realized_ast, Device[Device.DEFAULT].linearizer_opts)
k.process()
prg = Device[Device.DEFAULT].to_program(k)
prg.exec(k.bufs, force_wait=True)
wanna_output = k.bufs[0].toCPU().copy()
prg.exec(real_bufs, force_wait=True)
wanna_output = real_bufs[0].toCPU().copy()

# Check correctness of handcoded optimiztions.
k = Linearizer(realized_ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts)
k = Linearizer(realized_ast, Device[Device.DEFAULT].linearizer_opts)
k.hand_coded_optimizations()
prg = Device[Device.DEFAULT].to_program(k)
k.bufs[0].realized = k.bufs[0].realized.fromCPU(np.zeros(k.bufs[0].shape, dtype=k.bufs[0].dtype.np)) # Zero to check that all values are filled
prg.exec(k.bufs, force_wait=True)
np.testing.assert_allclose(wanna_output, k.bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
real_bufs[0] = real_bufs[0].fromCPU(np.zeros((real_bufs[0].size, ), dtype=real_bufs[0].dtype.np)) # Zero to check that all values are filled
prg.exec(real_bufs, force_wait=True)
np.testing.assert_allclose(wanna_output, real_bufs[0].toCPU(), atol=1e-4, rtol=1e-4)
for x in opts: # Check custom transformations if any.
check_opt(x, lambda: Linearizer(realized_ast, r.lazydata, Device[Device.DEFAULT].linearizer_opts), Device[Device.DEFAULT].to_program)
check_opt(x, lambda: Linearizer(realized_ast, Device[Device.DEFAULT].linearizer_opts), Device[Device.DEFAULT].to_program)

class TestLinearizerOpts(unittest.TestCase):
def test_local_and_grouped_reduce(self):
Expand Down
15 changes: 5 additions & 10 deletions test/unit/test_flopcounter.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
#!/usr/bin/env python
import unittest
from typing import NamedTuple, Tuple
from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info
from tinygrad.helpers import DType, dtypes

class TestBuffer(NamedTuple):
__test__ = False # To prevent pytest from collecting this as a test
shape: Tuple[int, ...]
dtype: DType
from tinygrad.ops import LazyOp, BinaryOps, ReduceOps, get_lazyop_info, LoadOps, MemBuffer
from tinygrad.shape.view import View
from tinygrad.helpers import dtypes

class TestFlopCounter(unittest.TestCase):
def setUp(self):
self.buf0 = TestBuffer(shape=(4,), dtype=dtypes.float32)
self.buf1 = TestBuffer(shape=(4,), dtype=dtypes.float32)
self.buf0 = LazyOp(LoadOps.BUFFER, (), MemBuffer(1, dtypes.float32, (View.create((4,)),)))
self.buf1 = LazyOp(LoadOps.BUFFER, (), MemBuffer(2, dtypes.float32, (View.create((4,)),)))

def test_flops_add(self):
op0 = LazyOp(BinaryOps.ADD, (self.buf0,self.buf1,), None)
Expand Down
42 changes: 16 additions & 26 deletions tinygrad/codegen/kernel.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import NamedTuple, Optional, List, Tuple, cast, Dict
import itertools
from tinygrad.ops import LazyOp, MovementOps, FlopCounter, get_lazyop_info, ReduceOps
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType
from tinygrad.runtime.lib import buf_is_kernel_arg
from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, ReduceOps, LoadOps, MemBuffer
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import sint
from tinygrad.shape.view import strides_for_shape
from tinygrad.shape.view import strides_for_shape, View

class LocalBuffer(NamedTuple):
name: str
Expand All @@ -16,6 +14,7 @@ class LocalBuffer(NamedTuple):
def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"

class LinearizerOptions(NamedTuple):
device: str = ""
# TODO: make this generic with a list of supported types
supports_float4: bool = True
supports_float4_alu: bool = True
Expand All @@ -26,41 +25,32 @@ class LinearizerOptions(NamedTuple):
local_max: Optional[List[int]] = None

class Kernel:
def __init__(self, ast:LazyOp, output_buffer:LazyBuffer, opts:Optional[LinearizerOptions]=None):
# NOTE: if there's a RESHAPE, we skip it. the output shape is set from the reduce op or a latebuf
self.ast = ast.src[0] if ast.op == MovementOps.RESHAPE else ast
def __init__(self, ast:LazyOp, opts:Optional[LinearizerOptions]=None, var_vals=None):
self.opts = opts if opts else LinearizerOptions()

# get the output buffers
self.bufs = [output_buffer] + dedup(ast.buffers)
self.arg_bufs = {x:f"data{i}" for i,x in enumerate(dedup([x.realized for x in self.bufs if buf_is_kernel_arg(x)]))}

# key for lookup in cache (can change, str might not be right)
# bufs are needed because kernels like f(x) = x + x and f(x, y) = x + y have the same str(ast), but are different kernels.
# mapping the buffers to integers is required because a-b != b-a (and how would you tell a and b apart?)
self.key = (ast.map_buffers({x:self.arg_bufs.get(x.realized,x) for x in self.bufs}).key, tuple([x.key for x in self.bufs]))
self.ast = ast
self.var_vals = var_vals
self.key = (ast, tuple(var_vals.keys())) if var_vals else ast

def process(self) -> None:
if hasattr(self, "sts"): return # already processed

# fetch lazyop info
self.info: FlopCounter = get_lazyop_info(cast(LazyOp, self.ast))
self.mem_estimate: int = sum(x.dtype.itemsize*x.size for x in self.arg_bufs.keys())

# there's only allowed to be one reduceop
reduceops = [x for x in self.ast.get_lazyops() if x.op in ReduceOps]
assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast"
self.reduceop = reduceops[0] if reduceops else None

# get earlybufs, before the one reduce op
self.earlybufs = dedup(self.reduceop.buffers) if self.reduceop else []

# create new shapetrackers inside this kernel, we will permute them
self.sts: List[ShapeTracker] = [x.st.copy() for x in self.bufs]
self.bufs = [MemBuffer(0, self.info.dtype, (View.create(self.info.shape),))] + dedup([x.arg for x in self.ast.get_lazyops() if x.op in LoadOps])
self.sts: List[ShapeTracker] = [ShapeTracker(x.views[-1].shape, views=list(x.views)) for x in self.bufs]
for st in self.sts: st.simplify()

# make the output buffer shape correct in here
self.sts[0].reshape(self.info.shape)
self.mem_estimate: int = sum(x.dtype.itemsize*x.views[-1].size() for x in self.bufs)

# get earlybufs, before the one reduce op
self.earlybufs = [x.arg for x in self.reduceop.get_lazyops() if x.op in LoadOps] if self.reduceop else []
self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0

# parameters
Expand All @@ -77,7 +67,7 @@ def process(self) -> None:

def has_variable_shape(self) -> bool:
for b in self.bufs:
if any(not isinstance(x, int) for x in b.st.shape): return True
if not all_int(b.views[-1].shape): return True
return False

def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()]
Expand Down Expand Up @@ -147,6 +137,6 @@ def colors(self) -> List[str]:
def colored_shape(self) -> str: return ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) else s for s in self.full_shape], self.colors()))
def printbufs(self, prefix=""):
for i,st in enumerate(self.sts):
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i].realized is not None else str(self.bufs[i]):47s}", st.views)
print(prefix, f"{i:3d} {str(self.bufs[i]):47s}", st.views)
print(self.colored_shape())

35 changes: 14 additions & 21 deletions tinygrad/codegen/linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from enum import Enum, auto

from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, prod, PtrDType, all_same
from tinygrad.ops import LazyOp, UnaryOps
from tinygrad.ops import LazyOp, UnaryOps, LoadOps, ConstBuffer, MemBuffer
from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps
from tinygrad.runtime.lib import RawConst
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, sym_rename
from tinygrad.codegen.optimizer import OptimizedKernel
Expand Down Expand Up @@ -128,11 +127,6 @@ def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)]

class Linearizer(OptimizedKernel):
def get_buffer_name(self, i):
if self.bufs[i].__class__ == LocalBuffer: return self.bufs[i].name
assert self.bufs[i].realized.__class__ is not RawConst # constants shouldn't be loaded with memops
return self.arg_bufs[self.bufs[i].realized]

def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype=dtypes.int32):
render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx))
return self.uop(UOps.ALU, dtype, (a, render_b), op)
Expand All @@ -147,7 +141,7 @@ def const(self, b:Union[int,float], dtype=dtypes.int32) -> UOp: return self.uop(
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) }

def global_load(self, i:int, idxs:Sequence[VariableOrNum], acc=None) -> List[UOp]:
const = self.bufs[i].realized._buf if isinstance(self.bufs[i].realized, RawConst) else acc
const = self.bufs[i].val if isinstance(self.bufs[i], ConstBuffer) else acc

expanded_nodes = [idx.expand() for idx in idxs]
_idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
Expand Down Expand Up @@ -176,7 +170,7 @@ def global_load(self, i:int, idxs:Sequence[VariableOrNum], acc=None) -> List[UOp
idx, valid = g_idx.substitute(substitute), g_valid.substitute(substitute)
localtype = dtypes.float32
this_const, idx, valid = (invalid_value, Variable.num(0), Variable.num(1)) if valid.max == 0 else (const, idx, valid)
key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else self.get_buffer_name(i)}{idx.render()}{valid.render()}"
key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (self.bufs[i].idx if isinstance(self.bufs[i], MemBuffer) else self.bufs[i].name)}{idx.render()}{valid.render()}"
if key not in self.load_cache:
if acc is not None:
assert valid.min == 1
Expand Down Expand Up @@ -253,15 +247,13 @@ def linearize(self):
self.loop_uops: Dict[str, UOp] = {}

# add global buffers
arg_bufs = {}
for buf,name in self.arg_bufs.items():
arg_bufs[buf] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (name, buf.dtype))
for i,b in enumerate(self.bufs):
if b.realized in arg_bufs: self.buf_uops[i] = arg_bufs[b.realized]
# add variables from symbolic shapes
for var in sorted(set(v for buf in self.ast.buffers for v in buf.var_vals), key=lambda k: k.key):
assert var.expr is not None
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32))
for i,buf in enumerate(self.bufs):
if isinstance(buf, MemBuffer):
self.buf_uops[i] = self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype) if not isinstance(buf.dtype, ImageDType) else buf.dtype, (), (f"data{buf.idx}", buf.dtype))
if self.var_vals:
for var in sorted(set(self.var_vals), key=lambda k: k.key):
assert var.expr is not None
self.loop_uops[var.expr] = self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32))
# define local buffers
for lb in self.local_alias.values():
self.buf_uops[self.bufs.index(lb)] = self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size()))
Expand Down Expand Up @@ -368,7 +360,7 @@ def end_loop(xx:List[Variable]):

# copy in any global buffers
if self.use_tensor_cores:
if self.bufs[0].device == "METAL":
if self.opts.device == "METAL":
if 2 * len(acc) == len(locals_to_store[0][2]) * len(locals_to_store[1][2]):
i = 0
for y0,y1 in zip(locals_to_store[1][2][::2], locals_to_store[1][2][1::2]):
Expand All @@ -380,7 +372,7 @@ def end_loop(xx:List[Variable]):
for i in range(0, len(acc), 2):
for y0,y1,x0,x1 in zip(locals_to_store[1][2][:k], locals_to_store[1][2][k:], locals_to_store[0][2][k*i:], locals_to_store[0][2][k*i+k:]):
self.uop(UOps.WMMA, None, (x0, x1, y0, y1, acc[i], acc[i+1]), "METAL")
elif self.bufs[0].device == "HIP":
elif self.opts.device == "HIP":
i = 0
for y in range(0, len(locals_to_store[1][2]), 0x10):
for x in range(0, len(locals_to_store[0][2]), 0x10):
Expand Down Expand Up @@ -491,7 +483,8 @@ def uop(self, uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None
return self.uops[-1]

def ast_parse(self, x, acc, loaded_buffers, do_reduce=False) -> List[UOp]:
if x.__class__ is not LazyOp: return loaded_buffers[x]
if x.__class__ is not LazyOp: return loaded_buffers[x] # for LOCAL_BUFFER
if x.op in [LoadOps.BUFFER, LoadOps.CONST]: return loaded_buffers[x.arg]
if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, loaded_buffers) # cast isn't an ALU op
if x.op in ReduceOps and not do_reduce: return acc
# MULACC fusion. TODO: this is copied from Interpreted
Expand Down
Loading

0 comments on commit 7ff7aac

Please sign in to comment.