Skip to content

Commit

Permalink
move all to compile api (tinygrad#2203)
Browse files Browse the repository at this point in the history
* move metal+clang to compile api

* all to the new style

* remove binary arg

* fix triton

* fixup tests

* fix clang

* diskcache is generic

* __wrapped__

* compile_gpu

* fix thneed

* keep the src in the ASTRunner

* lib

* move compile_gpu

* compile_gpu in device

* put compiler in astrunner

* test reverts

* triton compiler

* ugh, that too
  • Loading branch information
geohot committed Nov 2, 2023
1 parent 8932816 commit 03cf0af
Show file tree
Hide file tree
Showing 18 changed files with 126 additions and 134 deletions.
4 changes: 2 additions & 2 deletions docs/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def toCPU(self): return self._buf
# ClangProgram is the simplest runtime (in tinygrad/runtime/ops_clang.py, code 7/10)
# __init__ calls clang, and __call__ calls the function in the *.so outputted by clang
# in CLANG, global_size and local_size are ignored
from tinygrad.runtime.ops_clang import ClangProgram
from tinygrad.runtime.ops_clang import ClangProgram, compile_clang

# a concrete example looks like this, this adds two size 1 RawBuffer
# first we create two numpy buffers containing 2 and 3
Expand All @@ -229,7 +229,7 @@ def toCPU(self): return self._buf
output = RawMallocBuffer(1, dtypes.float32)

# compile the program, run it, and 2+3 does indeed equal 5
program = ClangProgram("add", f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}")
program = ClangProgram("add", compile_clang(f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}"))
program(None, None, output, input_a, input_b) # NOTE: the None are for global_size and local_size
print(output.toCPU())
assert output.toCPU()[0] == 5, "it's still 5"
Expand Down
20 changes: 5 additions & 15 deletions extra/thneed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import traceback
import numpy as np
from tinygrad.runtime.ops_gpu import CLProgram
from tinygrad.runtime.ops_gpu import CLProgram, compile_gpu
from tinygrad.helpers import DEBUG, getenv
from collections import defaultdict
import pyopencl as cl
Expand Down Expand Up @@ -104,21 +104,11 @@ def load(self, input_fn):
if 'data' in o:
self.buffers_to_save.add(buf)

# load in the programs (this isn't used)
prgs = {}
for k,v in jdat['programs'].items():
print("building", k)
try:
prgs[k] = CLProgram(k, v, rename=False)
except Exception:
print("FAILED", k)
traceback.print_exc()
exit(0)

# load binaries
prgs = {}
for o in jdat['binaries']:
nptr = ptr + o['length']
prgs[o['name']] = CLProgram(o['name'], weights[ptr:nptr], binary=True)
prgs[o['name']] = CLProgram(o['name'], weights[ptr:nptr])
ptr = nptr

# populate the cl_cache
Expand Down Expand Up @@ -208,15 +198,15 @@ def save(self, output_fn):
# zero out the buffer
cl.enqueue_copy(CL.cl_queue[0], buf, b'\x00'*buf.size, is_blocking=True)

CLProgram("from_image_strided", """
CLProgram("from_image_strided", compile_gpu("""
__kernel void from_image_strided(read_only image2d_t in, __global float4 *out, int row_pitch) {
const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
int2 l;
l.y = get_global_id(1);
l.x = get_global_id(0);
out[l.y*row_pitch + l.x] = read_imagef(in, smp, l);
}
""", argdtypes=(None, None, np.int32))(a.shape, None, a, buf, row_pitch//(4*(2 if FLOAT16 else 4)))
"""), argdtypes=(None, None, np.int32))(a.shape, None, a, buf, row_pitch//(4*(2 if FLOAT16 else 4)))

# multiple of 32 isn't enough
jdat['objects'].append({
Expand Down
2 changes: 1 addition & 1 deletion test/external/external_test_speed_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tinygrad.runtime.lib import RawBuffer

class FakeProgram:
def __init__(self, name:str, prg:str, binary:bool): pass
def __init__(self, name:str, prg:str): pass
def __call__(self, global_size, local_size, *bufs, wait=False): pass

class RawFakeBuffer(RawBuffer):
Expand Down
2 changes: 1 addition & 1 deletion test/test_custom_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
__kernel void atan2_gpu(global float *c, global float *a, global float *b) {
int idx = get_global_id(0);
c[idx] = atan2(a[idx], b[idx]);
}""", global_size=[prod(ret.shape)]).build(Device[ret.device].runtime).exec([ret.realized, a.realized, b.realized])
}""", global_size=[prod(ret.shape)]).build(Device[ret.device].compiler, Device[ret.device].runtime).exec([ret.realized, a.realized, b.realized])
return ret.realized

def atan2_cpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
Expand Down
34 changes: 16 additions & 18 deletions test/test_kernel_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,34 @@
import unittest
import secrets
import string
import tempfile
import pathlib
from tinygrad.tensor import Tensor
from tinygrad.ops import Device
from tinygrad.helpers import cache_compiled
import tinygrad.runtime.ops_clang
from tinygrad.helpers import diskcache

def generate_random_string(length=16):
alphabet = string.ascii_letters + string.digits
return ''.join(secrets.choice(alphabet) for _ in range(length))

class TestKernelCache(unittest.TestCase):
compile_call_count = 0
compile_call_count = 0

@cache_compiled
def __helper_test_compile(self, prg, output_file=pathlib.Path(tempfile.mktemp()), **kwargs):
self.compile_call_count += 1
return prg.encode()
@diskcache
def helper_test_compile(prg:str) -> bytes:
global compile_call_count
compile_call_count += 1
return prg.encode()

class TestKernelCache(unittest.TestCase):
def test_compile_cache(self):
prg1 = generate_random_string(64) + "a"
prg2 = generate_random_string(64) + "b"
cold_compile_res = self.__helper_test_compile(prg1)
warm_compile_res = self.__helper_test_compile(prg1)
cold_compile_res = helper_test_compile(prg1)
warm_compile_res = helper_test_compile(prg1)
assert cold_compile_res == warm_compile_res == prg1.encode()
assert self.compile_call_count == 1
assert compile_call_count == 1

prg2_res = self.__helper_test_compile(prg2)
prg2_res = helper_test_compile(prg2)
assert prg2_res == prg2.encode()
assert self.compile_call_count == 2
assert compile_call_count == 2

def test_kernel_cache_in_action(self):
if Device.DEFAULT not in ["CLANG"]:
Expand All @@ -42,15 +40,15 @@ def test_kernel_cache_in_action(self):
x = a + b
x.realize()

orig_compile_func = tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile
tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile = None # making it not callable
orig_compile_func = Device['CLANG'].compiler
Device['CLANG'].compiler = None # making it not callable

a1 = Tensor.rand(4,4)
b1 = Tensor.rand(4,4)
x1 = a1 + b1
x1.realize() # Same kernel should be from cache.

tinygrad.runtime.ops_clang.ClangBuffer.runtime.compile = orig_compile_func
Device['CLANG'].compiler = orig_compile_func

if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion test/test_uops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

def _uops_to_prg(uops):
src, runtime_args = Device[Device.DEFAULT].renderer("test", uops)
return ASTRunner("test", src, [1], [1], runtime_args=runtime_args).build(Device[Device.DEFAULT].runtime)
return ASTRunner("test", src, [1], [1], runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime)

def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(vin), arg, len(uops)))
Expand Down
20 changes: 9 additions & 11 deletions tinygrad/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,22 +153,12 @@ class GlobalCounters:
@staticmethod
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count = 0,0,0.0,0

# *** compiled cache decorator ***

def cache_compiled(func):
if getenv("DISABLE_COMPILER_CACHE"): return func
def wrapper(self, prg:str, *args, **kwargs) -> bytes:
table, key = f"compiler_cache_{type(self).__name__}", hashlib.sha256(prg.encode()).hexdigest()
if (ret:=diskcache_get(table, key)): return ret
return diskcache_put(table, key, func(self, prg, *args, **kwargs))
return wrapper

# *** universal database cache ***

CACHEDB = getenv("CACHEDB", "/tmp/tinygrad_cache")
CACHELEVEL = getenv("CACHELEVEL", 2)

VERSION = 5
VERSION = 6
_db_connection = None
def db_connection():
global _db_connection
Expand Down Expand Up @@ -207,3 +197,11 @@ def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
conn.commit()
cur.close()
return val

def diskcache(func):
def wrapper(*args, **kwargs) -> bytes:
table, key = f"cache_{func.__name__}", hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest()
if (ret:=diskcache_get(table, key)): return ret
return diskcache_put(table, key, func(*args, **kwargs))
setattr(wrapper, "__wrapped__", func)
return wrapper
15 changes: 8 additions & 7 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ def update_node(self, instid, jcid, prg, pargs, variables, updated_args=None): r
def exec_instance(self, instid): raise NotImplementedError("must be implemented")

class ASTRunner:
def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args or not runtime_args['binary']): print(prg)
def __init__(self, name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
if DEBUG >= 4: print(prg)
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}

def optimize_local_size(self, global_size:List[int], rawbufs:List[RawBuffer]) -> List[int]:
Expand All @@ -211,8 +211,9 @@ def try_exec(local_size):
return float('inf')
return min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])[1]

def build(self, runtime, batch_exec=BasicBatchExecutor):
self.clprg, self.batch_exec = runtime(self.name, self.prg, **self.runtime_args), batch_exec
def build(self, compiler, runtime, batch_exec=BasicBatchExecutor):
self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg)
self.clprg, self.batch_exec = runtime(self.name, self.lib, **self.runtime_args), batch_exec
return self

def exec(self, rawbufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False, optimizing=False) -> Optional[float]:
Expand Down Expand Up @@ -243,16 +244,16 @@ def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int
return et

class Compiled:
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, runtime, synchronize=lambda: None, batch_exec=BasicBatchExecutor):
self.buffer, self.linearizer_opts, self.renderer, self.runtime, self.synchronize, self.batch_exec = buffer, linearizer_opts, renderer, runtime, synchronize, batch_exec
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None, batch_exec=BasicBatchExecutor):
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize, self.batch_exec = buffer, linearizer_opts, renderer, compiler, runtime, synchronize, batch_exec
self.method_cache: Dict[LazyOp, ASTRunner] = {}

def to_program(self, k):
k.linearize()
src, runtime_args = self.renderer(k.function_name, k.uops)
return ASTRunner(k.function_name, src, k.global_size, k.local_size,
op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
display_name=k.display_name, runtime_args=runtime_args).build(self.runtime, self.batch_exec)
display_name=k.display_name, runtime_args=runtime_args).build(self.compiler, self.runtime, self.batch_exec)

def exec_ast(self, ast:LazyOp, output, inputs, var_vals, **kwargs):
# check if we can reuse the output buffer
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/renderer/cstyle.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,4 +209,4 @@ def ssa(u, prefix="t"):
else:
raise RuntimeError(f"failed to render {uop}")

return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel), {"binary":False}
return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel), {}
2 changes: 1 addition & 1 deletion tinygrad/renderer/llvmir.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,4 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Dict]:
lvars[u] = code_for_op[args](bb[-1], *[lvars[x] for x in vin])

bb[-1].ret_void()
return str(module), {"binary":False}
return str(module), {}
5 changes: 3 additions & 2 deletions tinygrad/renderer/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,13 @@ def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, f
for x in local_size: acc_local_size *= next_power_of_2(x)
local_size = [acc_local_size] + [1] * (len(local_size) - 1)

if DEBUG >=4: print(prg)
if DEBUG >= 4: print(prg)
getlines = linecache.getlines
linecache.getlines = lambda filename, module_globals=None: prg.splitlines(keepends=True) if "<triton>" == filename else getlines(filename, module_globals)
exec(compile(prg, "<triton>", "exec"), globals()) # pylint: disable=W0122\
compiled = triton_compile(globals()[function_name], signature=",".join(signatures), device_type="cuda", debug=False, cc=(35 if getenv("CUDACPU", 0) else None))
prg = remove_single_scalar_curly_braces(compiled.asm["ptx"].split(".file")[0].split(".visible .func")[0])
max_local_size = [int(x) for x in prg.split(".maxntid ")[1].split("\n")[0].split(", ")]
for i in range(len(local_size)): local_size[i] = min(local_size[i], max_local_size[i])
return prg, {"binary":True, "shared":compiled.metadata["shared"], "local_size_override":local_size + [1]*(3-len(local_size))}

return prg, {"shared":compiled.metadata["shared"], "local_size_override":local_size + [1]*(3-len(local_size))}
24 changes: 11 additions & 13 deletions tinygrad/runtime/ops_clang.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time, ctypes, subprocess, platform, functools, pathlib, tempfile
from typing import Any
from tinygrad.ops import Compiled
from tinygrad.helpers import cache_compiled
from tinygrad.helpers import diskcache
from tinygrad.runtime.lib import RawMallocBuffer
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
Expand All @@ -13,26 +13,24 @@

CLANG_PROGRAM_HEADER = '#include <math.h>\n#define max(x,y) ((x>y)?x:y)\n#define int64 long\n#define half __fp16\n#define uchar unsigned char\n#include <stdbool.h>\n'

class ClangProgram:
def __init__(self, name:str, prg:str, binary=False):
self.prg: bytes = prg if binary else self.compile(CLANG_PROGRAM_HEADER+prg)
@diskcache
def compile_clang(prg:str, header:str=CLANG_PROGRAM_HEADER) -> bytes:
# TODO: remove file write. sadly clang doesn't like the use of /dev/stdout here
with tempfile.NamedTemporaryFile(delete=True) as output_file:
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+str(output_file.name)).split(), input=(header+prg).encode('utf-8'))
return pathlib.Path(output_file.name).read_bytes()

class ClangProgram:
def __init__(self, name:str, prg:bytes):
# write to disk so we can load it
with tempfile.NamedTemporaryFile(delete=True) as cached_file_path:
pathlib.Path(cached_file_path.name).write_bytes(self.prg)
pathlib.Path(cached_file_path.name).write_bytes(prg)
self.fxn: Any = ctypes.CDLL(str(cached_file_path.name))[name]

@cache_compiled
def compile(self, prg) -> bytes:
# TODO: sadly clang doesn't like the use of /dev/stdout here
with tempfile.NamedTemporaryFile(delete=True) as output_file:
subprocess.check_output(args=('clang -shared -O2 -Wall -Werror -x c '+args['cflags']+' - -o '+str(output_file.name)).split(), input=prg.encode('utf-8'))
return pathlib.Path(output_file.name).read_bytes()

def __call__(self, unused_global_size, unused_local_size, *args, wait=False):
if wait: st = time.perf_counter()
self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args])
if wait: return time.perf_counter()-st

renderer = functools.partial(uops_to_cstyle, CStyleLanguage(buffer_suffix=" restrict", arg_int_prefix="const int"))
ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, ClangProgram)
ClangBuffer = Compiled(RawMallocBuffer, LinearizerOptions(supports_float4=False, has_local=False), renderer, compile_clang, ClangProgram)
20 changes: 10 additions & 10 deletions tinygrad/runtime/ops_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional, List, Any, Tuple
import numpy as np
from pycuda.compiler import compile as cuda_compile # type: ignore
from tinygrad.helpers import DEBUG, getenv, colored, cache_compiled
from tinygrad.helpers import DEBUG, getenv, colored, diskcache
from tinygrad.ops import Compiled, GraphBatchExecutor, ASTRunner
from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator
from tinygrad.codegen.kernel import LinearizerOptions
Expand Down Expand Up @@ -88,9 +88,12 @@ def update_node(self, instid, jcid, prg, pargs, variables, updated_args=None):

def exec_instance(self, instid): self.graphs[instid][0].launch()

@diskcache
def compile_cuda(prg) -> bytes: return cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets'])

class CUDAProgram:
def __init__(self, name:str, prg:str, binary=False, shared=0, local_size_override=None):
if not binary: prg = self.compile(prg).decode('utf-8')
def __init__(self, name:str, _prg:bytes, shared=0, local_size_override=None):
prg = _prg.decode('utf-8')
if DEBUG >= 5: print(pretty_ptx(prg))
if DEBUG >= 6:
try:
Expand All @@ -102,10 +105,6 @@ def __init__(self, name:str, prg:str, binary=False, shared=0, local_size_overrid
# TODO: name is wrong, so we get it from the ptx using hacks
self.prg, self.shared, self.local_size_override = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0]), shared, local_size_override

@cache_compiled
def compile(self, prg) -> bytes:
return cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets'])

def __call__(self, global_size, local_size, *args, wait=False):
if wait:
start, end = cuda.Event(), cuda.Event()
Expand All @@ -118,7 +117,8 @@ def __call__(self, global_size, local_size, *args, wait=False):

if getenv("TRITON") == 1:
from tinygrad.renderer.triton import uops_to_triton
TritonRenderer = uops_to_triton
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024], has_shared=False), TritonRenderer, CUDAProgram, cuda.Context.synchronize)
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024], has_shared=False),
uops_to_triton, lambda x: x.encode('utf-8'), CUDAProgram, cuda.Context.synchronize)
else:
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False if getenv("PTX") else True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), CUDARenderer, CUDAProgram, cuda.Context.synchronize, CUDAGraph)
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False if getenv("PTX") else True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]),
CUDARenderer, compile_cuda, CUDAProgram, cuda.Context.synchronize, CUDAGraph)
Loading

0 comments on commit 03cf0af

Please sign in to comment.