Skip to content

Commit

Permalink
fix backends for new style (tinygrad#1443)
Browse files Browse the repository at this point in the history
* fix backends for new style

* fix method cache

* fix fakeless

* llvm blacklist

* fix kernel optimizer
  • Loading branch information
geohot committed Aug 5, 2023
1 parent 67781fc commit 84c4303
Show file tree
Hide file tree
Showing 18 changed files with 136 additions and 167 deletions.
9 changes: 4 additions & 5 deletions docs/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def toCPU(self): return self._buf
# ClangProgram is the simplest runtime (in tinygrad/runtime/ops_clang.py, code 7/10)
# __init__ calls clang, and __call__ calls the function in the *.so outputted by clang
# in CLANG, global_size and local_size are ignored
from tinygrad.runtime.ops_clang import ClangProgram, ClangCodegen
from tinygrad.runtime.ops_clang import ClangProgram

# a concrete example looks like this, this adds two size 1 RawBuffer
# first we create two numpy buffers containing 2 and 3
Expand All @@ -229,7 +229,7 @@ def toCPU(self): return self._buf
output = RawMallocBuffer(1, dtypes.float32)

# compile the program, run it, and 2+3 does indeed equal 5
program = ClangProgram("add", f"{ClangCodegen.lang.kernel_prefix} void add(float *a, float *b, float *c) {{ *a = *b + *c; }}")
program = ClangProgram("add", f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}")
program(None, None, output, input_a, input_b) # NOTE: the None are for global_size and local_size
print(output.toCPU())
assert output.toCPU()[0] == 5, "it's still 5"
Expand Down Expand Up @@ -270,9 +270,8 @@ def linearize(self): pass
result.lazydata.realized = Device[Device.DEFAULT].buffer(prod(result.shape), result.dtype)

# use the real Linearizer to linearize 2+3
from tinygrad.codegen.linearizer import Linearizer
linearizer = Linearizer(result.lazydata.op, result.lazydata)
linearizer.process()
from tinygrad.codegen.linearizer import Linearizer, LinearizerOptions
linearizer = Linearizer(result.lazydata.op, result.lazydata, LinearizerOptions())
linearizer.linearize()

# print the uops
Expand Down
19 changes: 8 additions & 11 deletions test/external/external_test_speed_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@
from tinygrad.state import get_state_dict
from tinygrad.ops import Compiled

class FakeProgram:
def __init__(self, name:str, prg:str): pass
def __call__(self, global_size, local_size, *bufs, wait=False): pass

class TestLLaMASpeed(unittest.TestCase):
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "only test for compiled backends")
def test_llama_compile(self):
# TODO: with default device
old_default = Device.DEFAULT
Device.DEFAULT = "FAKE"

# use the codegen from the real device
Device['fake'].codegen = Device[old_default].codegen
print("using", Device['fake'].codegen)
backup_program = Device[Device.DEFAULT].runtime
Device[Device.DEFAULT].runtime = FakeProgram

print("testing llama python run time")
model = Transformer(**MODEL_PARAMS[1]["7B"]["args"])
Expand All @@ -26,8 +25,7 @@ def test_llama_compile(self):
print("assigned empty tensors, doing warmup")

def run_llama(st, empty_method_cache=True):
#print(f"clearing {len(Device['fake'].method_cache)} from method cache")
if empty_method_cache: Device['fake'].method_cache.clear()
if empty_method_cache: Device[Device.DEFAULT].method_cache.clear()
tms = [time.perf_counter()]
for i in range(10):
model(Tensor([[2]]), i).realize()
Expand All @@ -42,8 +40,7 @@ def run_llama(st, empty_method_cache=True):
run_llama("profile")
stop_profile(pr, sort='time', frac=0.1)

# reset device
Device.DEFAULT = old_default
Device[Device.DEFAULT].runtime = backup_program

if __name__ == '__main__':
unittest.main()
28 changes: 19 additions & 9 deletions tinygrad/codegen/linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,19 @@ class UOp(NamedTuple):
arg: Any
def __repr__(self): return f"{str(self.uop):20s}: {str(self.out) if self.out is not None else '':25s} {str(self.vin):32s} {self.arg}"

class Linearizer:
supports_float4: bool = False
supports_float4_alu: bool = False
class LinearizerOptions(NamedTuple):
# TODO: make this generic with a list of supported types
supports_float4: bool = True
supports_float4_alu: bool = True
has_local: bool = True
global_max: Optional[List[int]] = None
local_max: Optional[List[int]] = None

def __init__(self, ast:LazyOp, output_buffer:LazyBuffer):
class Linearizer:
def __init__(self, ast:LazyOp, output_buffer:LazyBuffer, opts:LinearizerOptions):
# NOTE: if there's a RESHAPE, we skip it. the output shape is set from the reduce op or a latebuf
self.ast = ast.src[0] if ast.op == MovementOps.RESHAPE else ast
self.opts = opts

# get the output buffers
self.bufs = [output_buffer] + dedup(ast.buffers)
Expand All @@ -151,8 +157,6 @@ def __init__(self, ast:LazyOp, output_buffer:LazyBuffer):
# mapping the buffers to integers is required because a-b != b-a (and how would you tell a and b apart?)
self.key = (ast.map_buffers({x:(self.arg_bufs[x.realized] if x.realized in self.arg_bufs else x) for x in self.bufs}).key, tuple([x.key for x in self.bufs]))

def codegen(self): raise NotImplementedError("must be implemented")

def get_buffer_name(self, i):
if self.bufs[i].__class__ == LocalBuffer: return self.bufs[i].name
assert self.bufs[i].realized.__class__ is not RawConst # constants shouldn't be loaded with memops
Expand Down Expand Up @@ -217,7 +221,7 @@ def acc_offsets(self, i):
return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])]

def get_upcast_dim(self, i) -> List[int]:
should_upcast = self.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType))
should_upcast = self.opts.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType))
return [x for x in self.sts[i].unit_stride_axes() if should_upcast and x >= self.shape_len-self.upcasted and self.sts[i].shape[x] > 1]

def global_load(self, i:int, idxs:Sequence[VariableOrNum], const=None) -> List[Token]:
Expand Down Expand Up @@ -283,6 +287,11 @@ def global_store(self, i, idxs:List[VariableOrNum], store:List[Token], ssa) -> N

kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
def linearize(self):
self.process()

# limit dims if we need to
if self.opts.global_max and self.opts.local_max: self.limit_global_dims(3, self.opts.global_max, self.opts.local_max)

# uops
self.uops: List[UOp] = []
self.saved_exprs: Dict[LazyOp, List[Token]] = dict()
Expand Down Expand Up @@ -458,6 +467,7 @@ def ssa(name, ltype=dtypes.float) -> Token:
Linearizer.kernel_cnt[self.function_name] += 1
suffix = f"{'n'+str(Linearizer.kernel_cnt[self.function_name]-1)}" if Linearizer.kernel_cnt[self.function_name] > 1 else ""
self.function_name, self.display_name = self.function_name+suffix, self.display_name+colored(suffix, 'BLACK')
return self

_OT = TypeVar("_OT")
def uop(self, uop:UOps, out:_OT, vin:List[Token], arg:Any=None) -> _OT:
Expand All @@ -482,9 +492,9 @@ def ast_parse(self, x, acc, loaded_buffers, ssa, do_reduce=False) -> List[Token]
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
if x.op in ops:
ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), ops[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.supports_float4_alu)]
ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), ops[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.opts.supports_float4_alu)]
else:
ret = [(idx, self.uop(UOps.ALU, ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op not in {BinaryOps.CMPEQ, TernaryOps.WHERE})]
ret = [(idx, self.uop(UOps.ALU, ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.opts.supports_float4_alu and x.op not in {BinaryOps.CMPEQ, TernaryOps.WHERE})]
ordered_ret: List[Optional[Token]] = [None]*len(values[0])
# scatter
for i,j in ret:
Expand Down
22 changes: 12 additions & 10 deletions tinygrad/codegen/optimizer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import Callable
import itertools, time
import itertools
from tinygrad.helpers import DEBUG, prod, getenv, ImageDType
from tinygrad.ops import ReduceOps, BinaryOps, LazyOp
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.lazy import LazyBuffer

# auto opt is disabled
import time
from typing import Callable
def apply_opt(k, x):
for axis, amt, typ in x:
if axis is None or amt == 1: continue
Expand All @@ -22,14 +24,14 @@ def apply_opt(k, x):

UPCASTS = [1,2,3,4,5,6,7,8]
LOCALS = [1,2,3,4,5,6,7,8,16,24,32]
def kernel_optimize_search(k:Linearizer, create_k:Callable[[], Linearizer], runtime, baseline):
def kernel_optimize_search(k:Linearizer, create_k:Callable[[], Linearizer], to_prg, baseline):
import nevergrad as ng
def opt(x):
try:
k = create_k()
k.process()
apply_opt(k, x)
prg = k.codegen().build(runtime)
prg = to_prg(k)
first_tm = prg.exec(k.bufs, force_wait=True, optimizing=True)
if baseline*5 < first_tm*1000: return first_tm*1000 # very slow
tm = min([first_tm]+[prg.exec(k.bufs, force_wait=True, optimizing=True) for _ in range(2)])*1000
Expand Down Expand Up @@ -58,7 +60,7 @@ def opt(x):

# optimization
global_db = None
def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime):
def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], to_prg):
global global_db

k.process()
Expand All @@ -75,9 +77,9 @@ def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], runtime):
def get_baseline():
k = create_k()
hand_coded_optimizations(k)
prg = k.codegen().build(runtime)
prg = to_prg(k)
return min([prg.exec(k.bufs, force_wait=True, optimizing=True) for _ in range(5)])*1000
choice = kernel_optimize_search(k, create_k, runtime, get_baseline())
choice = kernel_optimize_search(k, create_k, to_prg, get_baseline())
if global_db is not None:
global_db[skey] = choice
global_db.sync()
Expand Down Expand Up @@ -108,7 +110,7 @@ def hand_coded_optimizations(k:Linearizer):
tensor_cores_allowed = getenv("TC", 1) != 0 and (getenv("TC", 1) == 2 or (k.bufs[0].device == "METAL" and getenv("CI", "") != "true"))
if tensor_cores_allowed and k.reduceop and k.reduceop.op == ReduceOps.SUM and \
isinstance(k.reduceop.src[0], LazyOp) and k.reduceop.src[0].op == BinaryOps.MUL and \
isinstance(k.reduceop.src[0].src[0], LazyBuffer) and isinstance(k.reduceop.src[0].src[1], LazyBuffer) and hasattr(k, 'lang') and len(k.lang.lid):
isinstance(k.reduceop.src[0].src[0], LazyBuffer) and isinstance(k.reduceop.src[0].src[1], LazyBuffer) and k.opts.has_local:
buf0 = k.bufs.index(k.reduceop.src[0].src[0])
buf1 = k.bufs.index(k.reduceop.src[0].src[1])
buf0_strides = k.sts[buf0].real_strides()
Expand Down Expand Up @@ -160,7 +162,7 @@ def hand_coded_optimizations(k:Linearizer):
# early exit
return

if hasattr(k, 'lang') and len(k.lang.smem_prefix):
if k.opts.has_local:
# are we grouping? (requires local shape support)
if not k.float4_axis(0) and k.first_reduce <= 2 and k.first_reduce + 1 <= k.shape_len and prod(k.sts[0].shape[:k.first_reduce]) <= 2048:
# TODO: use 1024 if it's allowed in a smarter way
Expand Down Expand Up @@ -237,7 +239,7 @@ def hand_coded_optimizations(k:Linearizer):

# **** local groups ****

if hasattr(k, 'lang') and len(k.lang.lid):
if k.opts.has_local:
for axis in range(k.first_reduce - k.local_dims - 1, -1, -1):
local_size = prod(k.full_shape[k.first_reduce-k.local_dims:k.first_reduce])
if k.full_shape[axis] == 1: continue
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def _realize_rand(buffer: LazyBuffer) -> None:
buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=buffer.shape, dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False), **buffer._device_extra_args()) # type: ignore

def _realize_const(buffer: LazyBuffer) -> None:
if hasattr(Device[buffer.device].codegen, 'supports_constant_folding'):
if isinstance(Device[buffer.device], Compiled) and buffer.device not in ["LLVM"]: # consts are broken in LLVM in NaN/inf
buffer.realized = RawConst(1, buffer.dtype, float(buffer.op.arg))
else:
buffer.realized = Device[buffer.device].buffer.fromCPU(np.array(buffer.op.arg, dtype=buffer.dtype.np), **buffer._device_extra_args())
Expand Down
34 changes: 20 additions & 14 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,16 @@ def __call__(self, rawbufs:List[RawBuffer], jit=False, force_wait=False) -> Opti
return et

class Compiled:
def __init__(self, buffer: Type[RawBuffer], codegen, runtime, synchronize=lambda: None):
self.buffer, self.codegen, self.runtime, self.synchronize = buffer, codegen, runtime, synchronize
self.method_cache: Dict[str, ASTRunner] = {}
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, runtime, synchronize=lambda: None):
self.buffer, self.linearizer_opts, self.renderer, self.runtime, self.synchronize = buffer, linearizer_opts, renderer, runtime, synchronize
self.method_cache: Dict[Any, ASTRunner] = {}

def to_program(self, k):
k.linearize()
src, global_size, local_size = self.renderer(k.function_name, k.uops)
return ASTRunner(k.function_name, src, global_size, local_size,
op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
display_name=k.display_name).build(self.runtime)

def exec_ast(self, ast:LazyOp, output, **kwargs):
# all movementops do nothing in a Compiled buffer!
Expand All @@ -175,22 +182,21 @@ def exec_ast(self, ast:LazyOp, output, **kwargs):
if not output.realized:
output.realized = self.buffer(prod(output.shape), output.dtype, **kwargs)

from tinygrad.codegen.linearizer import Linearizer
k = Linearizer(ast, output, self.linearizer_opts)

# compilation time
k = self.codegen(ast, output)
def get_program():
from tinygrad.codegen.optimizer import kernel_optimize, hand_coded_optimizations
if getenv("KOPT"): kernel_optimize(k, lambda: Linearizer(ast, output, self.linearizer_opts), self.to_program)
elif not getenv("NOOPT"): hand_coded_optimizations(k)
return self.to_program(k)

# this is the default now
if hasattr(k, 'key') and getenv("ENABLE_METHOD_CACHE", 1):
from tinygrad.codegen.optimizer import kernel_optimize, hand_coded_optimizations
if k.key not in self.method_cache:
if getenv("KOPT"):
kernel_optimize(k, lambda: self.codegen(ast, output), self.runtime)
elif not getenv("NOOPT"):
hand_coded_optimizations(k)
self.method_cache[k.key] = k.codegen().build(self.runtime)
elif DEBUG >= 5: print(f"method cache hit : {k.key}")
if k.key not in self.method_cache: self.method_cache[k.key] = get_program()
prg = self.method_cache[k.key]
else:
prg = k.codegen().build(self.runtime)
prg = get_program()

if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)

Expand Down
22 changes: 4 additions & 18 deletions tinygrad/codegen/cstyle.py → tinygrad/renderer/cstyle.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, ClassVar, List, Optional, NamedTuple, Tuple, Union
from typing import Dict, List, Optional, NamedTuple, Tuple, Union
import math
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, MemOp, ConstOp
from tinygrad.ops import ASTRunner, UnaryOps, BinaryOps, TernaryOps
from tinygrad.codegen.linearizer import UOps, UOp, MemOp, ConstOp
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.helpers import ImageDType, dtypes, getenv, prod, DType
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable

Expand Down Expand Up @@ -110,7 +110,7 @@ def add_gl_dimension(prefix: str, args, i:int, var, local_size:List[int], xid:Li
local_size.append(var.max+1)
return "{" if isinstance(var, NumNode) else f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"

def uops_to_cstyle(function_name:str, uops:List[UOp], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tuple[str, List[int], List[int]]:
global_size: List[int] = []
local_size: List[int] = []
kernel,prekernel = [],[]
Expand Down Expand Up @@ -183,17 +183,3 @@ def kk(s): kernel.append(" "*depth+s)
raise RuntimeError(f"failed to render {uop}")

return lang.render_kernel(function_name, kernel, bufs, global_size, local_size, prekernel)

class CStyleCodegen(Linearizer):
lang: ClassVar[CStyleLanguage] = CStyleLanguage()
supports_constant_folding: bool = True
supports_float4: bool = True
supports_float4_alu: bool = True

def codegen(self):
self.process()
if self.lang.global_max: self.limit_global_dims(len(self.lang.gid), self.lang.global_max, self.lang.local_max) # NOTE: this is optional now
self.linearize()

return ASTRunner(self.function_name, *uops_to_cstyle(self.function_name, self.uops, self.lang),
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name)
15 changes: 4 additions & 11 deletions tinygrad/codegen/llvmir.py → tinygrad/renderer/llvmir.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Final, Dict, Callable, Any, List, Optional, Tuple
import functools
from llvmlite import ir # type: ignore
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, Token, MemOp, ConstOp
from tinygrad.codegen.linearizer import UOps, UOp, Token, MemOp, ConstOp
from tinygrad.helpers import dtypes
from tinygrad.ops import Op, ASTRunner, UnaryOps, BinaryOps, TernaryOps
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps

from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
def int_const(x): return ir.Constant(ir.IntType(64), x)
Expand Down Expand Up @@ -32,7 +32,7 @@ def int_const(x): return ir.Constant(ir.IntType(64), x)
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), y, z, flags=('fast',)),
}

def uops_to_llvm_ir(uops:List[UOp]) -> Tuple[str, Optional[List[int]], Optional[List[int]]]:
def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[List[int]], Optional[List[int]]]:
# all llvm stuff goes into a module
module = ir.Module(name=__file__)

Expand All @@ -43,7 +43,7 @@ def uops_to_llvm_ir(uops:List[UOp]) -> Tuple[str, Optional[List[int]], Optional[
# create llvm function
dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
func_dtypes = [dtype_to_llvm_dtype[dtype] for dtype in buf_to_dtype.values()]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name='exec')
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name=function_name)

# force llvmlite to allow us to add function attribute then add the attribute
func.attributes._known = func.attributes._known.union(frozenset(['"no-nans-fp-math"="true"']))
Expand Down Expand Up @@ -132,10 +132,3 @@ def uops_to_llvm_ir(uops:List[UOp]) -> Tuple[str, Optional[List[int]], Optional[

bb[-1].ret_void()
return str(module), None, None

class LLVMIRCodegen(Linearizer):
def codegen(self):
self.process()
# no optimize, this doesn't support local
self.linearize()
return ASTRunner('exec', *uops_to_llvm_ir(self.uops), op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=self.display_name)
4 changes: 2 additions & 2 deletions tinygrad/codegen/wgsl.py → tinygrad/renderer/wgsl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from tinygrad.codegen.cstyle import render_cl
from tinygrad.renderer.cstyle import render_cl
from tinygrad.helpers import dtypes, DType
from tinygrad.codegen.cstyle import CStyleLanguage
from tinygrad.renderer.cstyle import CStyleLanguage
from typing import List, Union
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
import math
Expand Down
Loading

0 comments on commit 84c4303

Please sign in to comment.