Skip to content

Commit

Permalink
simple runtime args (tinygrad#2211)
Browse files Browse the repository at this point in the history
* simple runtime args

* fix some tests

* fix abstractions and triton

* fix search
  • Loading branch information
geohot committed Nov 3, 2023
1 parent 9ea0448 commit f17bc16
Show file tree
Hide file tree
Showing 16 changed files with 37 additions and 24 deletions.
2 changes: 1 addition & 1 deletion docs/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def toCPU(self): return self._buf

# compile the program, run it, and 2+3 does indeed equal 5
program = ClangProgram("add", compile_clang(f"void add(float *a, float *b, float *c) {{ *a = *b + *c; }}"))
program(None, None, output, input_a, input_b) # NOTE: the None are for global_size and local_size
program(output, input_a, input_b)
print(output.toCPU())
assert output.toCPU()[0] == 5, "it's still 5"
np.testing.assert_allclose(output.toCPU(), numpy_a+numpy_b)
Expand Down
2 changes: 1 addition & 1 deletion extra/thneed.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def save(self, output_fn):
l.x = get_global_id(0);
out[l.y*row_pitch + l.x] = read_imagef(in, smp, l);
}
"""), argdtypes=(None, None, np.int32))(a.shape, None, a, buf, row_pitch//(4*(2 if FLOAT16 else 4)))
"""), argdtypes=(None, None, np.int32))(a, buf, row_pitch//(4*(2 if FLOAT16 else 4)), global_size=a.shape)

# multiple of 32 isn't enough
jdat['objects'].append({
Expand Down
2 changes: 1 addition & 1 deletion test/external/external_test_allocator_on_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def fromCPU(cls, x:np.ndarray, **kwargs): return cls(prod(x.shape), dtypes.from_
def toCPU(self): return np.empty(self.size, dtype=self.dtype.np)
class FakeProgram:
def __init__(self, name:str, prg:str): pass
def __call__(self, global_size, local_size, *bufs, wait=False): pass
def __call__(self, *bufs, global_size, local_size, wait=False): pass

def helper_test_correctness(gen, train):
from tinygrad.runtime.ops_gpu import CL, CLAllocator
Expand Down
2 changes: 1 addition & 1 deletion test/external/external_test_speed_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class FakeProgram:
def __init__(self, name:str, prg:str): pass
def __call__(self, global_size, local_size, *bufs, wait=False): pass
def __call__(self, *bufs, global_size, local_size, wait=False): pass

class RawFakeBuffer(RawBuffer):
@classmethod
Expand Down
3 changes: 3 additions & 0 deletions test/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ def test_reasonable_time(self):
rawbufs = [Device[Device.DEFAULT].buffer(si.out.st.size(), si.out.dtype)] + [Device[Device.DEFAULT].buffer(x.st.size(), x.dtype) for x in si.inputs]
tm = time_linearizer(Linearizer(si.ast), rawbufs, allow_test_size=False, cnt=10)
assert tm > 0 and tm != float('inf')

if __name__ == '__main__':
unittest.main()
4 changes: 3 additions & 1 deletion test/test_uops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

def _uops_to_prg(uops):
src, runtime_args = Device[Device.DEFAULT].renderer("test", uops)
return ASTRunner("test", src, [1], [1], runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime)
return ASTRunner("test", src,
[1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None,
runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime)

def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(vin), arg, len(uops)))
Expand Down
15 changes: 10 additions & 5 deletions tinygrad/features/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,18 @@ def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=Tru
if clear_l2:
# TODO: this is too small for many L2 caches
with Context(DEBUG=0): Tensor.rand(1024,1024).realize()
tms.append(prg.clprg(global_size, local_size, *rawbufs, *var_vals.values(), wait=True)*factor)
lra = prg.runtime_args.copy()
if global_size: lra['global_size'] = global_size
if local_size: lra['local_size'] = local_size
tms.append(prg.clprg(*rawbufs, *var_vals.values(), **lra, wait=True)*factor)
prg.global_size = real_global_size
except Exception:
#import traceback; traceback.print_exc()
#print("FAILED")
#print(lin.ast)
#print(lin.applied_opts)
if DEBUG >= 4:
import traceback
traceback.print_exc()
print("FAILED")
print(lin.ast)
print(lin.applied_opts)
tms = [float('inf')]
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
return min(tms)
Expand Down
7 changes: 5 additions & 2 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def try_exec(local_size):

def build(self, compiler, runtime, batch_exec=BasicBatchExecutor):
self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg)
self.clprg, self.batch_exec = runtime(self.name, self.lib, **self.runtime_args), batch_exec
self.clprg, self.batch_exec = runtime(self.name, self.lib), batch_exec
return self

def exec(self, rawbufs, var_vals:Optional[Dict[Variable, int]]=None, force_wait=False, optimizing=False) -> Optional[float]:
Expand All @@ -254,7 +254,10 @@ def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int
# TODO: this is copied from get_program
local_size = self.local_size = self.optimize_local_size(global_size, rawbufs)
global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)]
if et := self.clprg(global_size, local_size, *rawbufs, *var_vals.values(), wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et
lra = self.runtime_args.copy()
if global_size: lra['global_size'] = global_size
if local_size and 'local_size' not in lra: lra['local_size'] = local_size
if et := self.clprg(*rawbufs, *var_vals.values(), **lra, wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et
op_estimate = sym_infer(self.op_estimate, var_vals)
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(37-ansilen(self.display_name))) if self.display_name is not None else self.name:33s} arg {len(rawbufs):3d} sz {str(global_size):18s} {str(local_size):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/renderer/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,4 @@ def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, f
max_local_size = [int(x) for x in prg.split(".maxntid ")[1].split("\n")[0].split(", ")]
for i in range(len(local_size)): local_size[i] = min(local_size[i], max_local_size[i])

return prg, {"shared":compiled.metadata["shared"], "local_size_override":local_size + [1]*(3-len(local_size))}
return prg, {"shared":compiled.metadata["shared"], "local_size":local_size + [1]*(3-len(local_size))}
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_clang.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, name:str, prg:bytes):
pathlib.Path(cached_file_path.name).write_bytes(prg)
self.fxn: Any = ctypes.CDLL(str(cached_file_path.name))[name]

def __call__(self, unused_global_size, unused_local_size, *args, wait=False):
def __call__(self, *args, wait=False):
if wait: st = time.perf_counter()
self.fxn(*[x._buf if isinstance(x, RawMallocBuffer) else x for x in args])
if wait: return time.perf_counter()-st
Expand Down
8 changes: 4 additions & 4 deletions tinygrad/runtime/ops_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def exec_instance(self, instid): self.graphs[instid][0].launch()
def compile_cuda(prg) -> bytes: return cuda_compile(prg, target="ptx", no_extern_c=True, options=['-Wno-deprecated-gpu-targets'])

class CUDAProgram:
def __init__(self, name:str, _prg:bytes, shared=0, local_size_override=None):
def __init__(self, name:str, _prg:bytes):
prg = _prg.decode('utf-8')
if DEBUG >= 5: print(pretty_ptx(prg))
if DEBUG >= 6:
Expand All @@ -103,13 +103,13 @@ def __init__(self, name:str, _prg:bytes, shared=0, local_size_override=None):
print(subprocess.check_output(['nvdisasm', fn]).decode('utf-8'))
except Exception as e: print("failed to generate SASS", str(e))
# TODO: name is wrong, so we get it from the ptx using hacks
self.prg, self.shared, self.local_size_override = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0]), shared, local_size_override
self.prg = cuda.module_from_buffer(prg.encode('utf-8')).get_function(prg.split(".visible .entry ")[1].split("(")[0])

def __call__(self, global_size, local_size, *args, wait=False):
def __call__(self, *args, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], shared:int=0, wait=False):
if wait:
start, end = cuda.Event(), cuda.Event()
start.record()
self.prg(*[x._buf if isinstance(x, RawCUDABuffer) else np.int32(x) if (isinstance(x, int) and not getenv("CUDACPU")) else x for x in args], block=tuple(local_size if self.local_size_override is None else self.local_size_override), grid=tuple(global_size), shared=self.shared)
self.prg(*[x._buf if isinstance(x, RawCUDABuffer) else np.int32(x) if (isinstance(x, int) and not getenv("CUDACPU")) else x for x in args], block=tuple(local_size), grid=tuple(global_size), shared=shared)
if wait:
end.record()
end.synchronize()
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/runtime/ops_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pathlib
import numpy as np
import pyopencl as cl # type: ignore
from typing import Optional, List
from typing import Optional, List, Tuple
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport, diskcache
from tinygrad.ops import Compiled
from tinygrad.renderer.opencl import OpenCLRenderer
Expand Down Expand Up @@ -90,7 +90,7 @@ def set_argdtypes(self, argdtypes): self.argdtypes, _ = argdtypes, [clprg.set_sc
@staticmethod
def max_work_group_size(): return CL.cl_ctxs[0].devices[0].max_work_group_size

def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]:
def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Optional[Tuple[int,int,int]]=None, wait=False) -> Optional[float]:
if not hasattr(self, 'argdtypes'): self.set_argdtypes(tuple(None if x.__class__ is CLBuffer else np.int32 for x in bufs))
cl_bufs, wait_for = [], []
for x in bufs:
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(self, name:str, prg:bytes):
self.modules.append(hip.hipModuleLoadData(prg))
self.prgs.append(hip.hipModuleGetFunction(self.modules[-1], name))

def __call__(self, global_size, local_size, *args, wait=False):
def __call__(self, *args, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], wait=False):
hip.hipSetDevice(args[0]._device)
if wait:
start, end = hip.hipEventCreate(), hip.hipEventCreate()
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, name:str, lib:bytes):
LLVM().engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib))
self.fxn = LLVM.engine.get_function_address(name)

def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False):
def __call__(self, *bufs, wait=False):
cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_void_p for _ in bufs])(self.fxn)
if wait: st = time.perf_counter()
cfunc(*[x._buf if not isinstance(x, int) else x for x in bufs])
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self, name:str, lib:bytes):
os.system(f"cd {pathlib.Path(__file__).parents[2]}/disassemblers/applegpu && python3 compiler_explorer.py {shader.name}")
self.pipeline_state = unwrap(METAL.device.newComputePipelineStateWithFunction_error_(self.fxn, None))

def __call__(self, global_size, local_size, *bufs, wait=False):
def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int], wait=False):
assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}"
command_buffer = METAL.mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/runtime/ops_webgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

class WebGPUProgram:
def __init__(self, name: str, prg: str): self.name,self.prg = name,wgpu_device.create_shader_module(code=prg)
def __call__(self, global_size, local_size, *bufs, wait=False):
def __call__(self, *bufs, global_size, local_size, wait=False):
assert len(bufs) <= 8, "WEBGPU only supports 8 buffers"
binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.storage}} for i in range(len(bufs))]
bindings = [{"binding": i, "resource": {"buffer": x._buf, "offset": 0, "size": x._buf.size}} for i, x in enumerate(bufs)]
Expand Down

0 comments on commit f17bc16

Please sign in to comment.