Skip to content

Commit

Permalink
metal indirect command buffers (tinygrad#2285)
Browse files Browse the repository at this point in the history
* metal indirect command buffers

* sub 1ms gpt

* metal batch exec is good

* remove whitespace

* input_replace

* fix ci

* useResources

* very simple cacheallocator

* update_stats

* fix CI

* minor

* remove that from jit
  • Loading branch information
geohot committed Nov 14, 2023
1 parent d86ea18 commit b1f7f29
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 36 deletions.
4 changes: 2 additions & 2 deletions extra/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def run(*x):

# hack to put the inputs back
for (j,i),idx in run.input_replace.items():
realized_input = args[idx[0]].lazydata.realized
realized_input = args[idx].lazydata.realized
run.jit_cache[j].rawbufs[i] = realized_input
special_names[id(realized_input)] = f'input{idx[0]}'
special_names[id(realized_input)] = f'input{idx}'

# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
for i, output in enumerate(the_output):
Expand Down
2 changes: 1 addition & 1 deletion openpilot/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def compile(dat, output_fn):

# pull out inputs and put them in the jit cache
input_rawbuffers = {k:inputs[k].lazydata.realized for k in inputs.keys()}
for (j,i),(idx,_,_) in model_exec.input_replace.items(): model_exec.jit_cache[j].rawbufs[i] = input_rawbuffers[idx]
for (j,i),idx in model_exec.input_replace.items(): model_exec.jit_cache[j].rawbufs[i] = input_rawbuffers[idx]

# transform to CL.CACHE
used_ops = 0
Expand Down
1 change: 0 additions & 1 deletion test/external/dist/test_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def run():

# reset jit
allreduce_jit.cnt = 0
allreduce_jit.input_replace = {}

# test uneven chunk sizes
for _ in range(3):
Expand Down
82 changes: 60 additions & 22 deletions tinygrad/jit.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
import functools, itertools
from tinygrad.helpers import DEBUG, DType, merge_dicts
from tinygrad.helpers import DEBUG, DType, merge_dicts, GlobalCounters, getenv, colored
from tinygrad.ops import RawBuffer, Device, ASTRunner
from tinygrad.tensor import Tensor
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable
from tinygrad.shape.symbolic import Variable, NumNode, sym_infer
from dataclasses import dataclass
from weakref import ref, WeakKeyDictionary

Expand All @@ -16,13 +16,54 @@ class JitItem:
prg: ASTRunner
rawbufs: List[Optional[RawBuffer]]

class BatchExecutor:
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
self.jit_cache: List[JitItem] = jit_cache
self.input_replace: Dict[Tuple[int, int], Union[int, str]] = {}
self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0)
for j,ji in enumerate(jit_cache):
if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored
self.op_estimate += ji.prg.op_estimate
self.mem_estimate += ji.prg.mem_estimate
for i,a in enumerate(ji.rawbufs):
if a in [v for v in input_rawbuffers.values()]:
self.input_replace[(j,i)] = [k for k,v in input_rawbuffers.items() if v == a][0]
assert set(self.input_replace.values()) == set(input_rawbuffers.keys()), "some input tensors not found"
self.clear_jit_inputs()

def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), {v:var_vals[v] for v in getattr(ji.prg,"vars",[])}, jit=True)
self.clear_jit_inputs()

def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]):
# TODO: this is mostly copied from ASTRunner
op_estimate = sym_infer(self.op_estimate, var_vals)
mem_estimate = sym_infer(self.mem_estimate, var_vals)
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'CYAN')} kernels:{len(self.jit_cache):4d} inputs:{len(self.input_replace):3d} {' '.join([f'{k.expr}={v}' for k,v in var_vals.items()])[:50]:50s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += len(self.jit_cache)
GlobalCounters.global_ops += sym_infer(self.op_estimate, var_vals)
GlobalCounters.global_mem += sym_infer(self.mem_estimate, var_vals)
if et is not None: GlobalCounters.time_sum_s += et

def clear_jit_inputs(self):
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None

class TinyJit:
def __init__(self, fxn:Callable):
self.fxn: Callable = fxn
self.jit_fxn: Optional[BatchExecutor] = None
self.cnt: int = 0
self.jit_cache: List[JitItem] = []
self.ret: Any = None
self.input_replace: Dict[Tuple[int, int], Tuple[Union[int, str], ShapeTracker, DType]] = {} # (kernel_number, buffer_number) -> (input_name, expected_shapetracker, expected_type)
self.expected_vals: Optional[Tuple[Variable, ...]] = None
self.expected_sts_dtype: Optional[Tuple[Tuple[ShapeTracker, DType], ...]] = None

@property
def jit_cache(self) -> List[JitItem]: return self.jit_fxn.jit_cache if self.jit_fxn else []
@property
def input_replace(self) -> Dict[Tuple[int, int], Union[int, str]]: return self.jit_fxn.input_replace if self.jit_fxn else {}

# add support for instance methods
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
Expand All @@ -32,39 +73,36 @@ def __call__(self, *args, **kwargs) -> Any:

# all inputs are realized
input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
expected_sts_dtype = tuple([(v.lazydata.st.unbind(), v.dtype) for v in input_tensors.values()])

# get rawbuffers
input_rawbuffers: Dict[Union[int, str], Tuple[RawBuffer, ShapeTracker]] = {k:(cast(RawBuffer, v.lazydata.realized), v.lazydata.st) for k,v in input_tensors.items()}
input_rawbuffers: Dict[Union[int, str], RawBuffer] = {k:cast(RawBuffer, v.lazydata.realized) for k,v in input_tensors.items()}
assert len(input_rawbuffers) != 0, "no inputs to JIT"
assert len(set(input_rawbuffers.values())) == len(input_rawbuffers), "duplicate inputs to JIT"

# get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global
var_vals: Dict[Variable, int] = merge_dicts([arg.lazydata.st.var_vals for arg in input_tensors.values()] + [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))])
expected_vals = tuple(var_vals.keys())

if self.cnt >= 2:
# check validity and assign the inputs
for (j,i),(input_name, expected_st, expected_type) in self.input_replace.items():
assert input_rawbuffers[input_name][0].dtype == expected_type, f"type mismatch in JIT, {input_rawbuffers[input_name][0].dtype} != {expected_type}"
assert input_rawbuffers[input_name][1].unbind() == expected_st, f"ShapeTracker mismatch in JIT, {input_rawbuffers[input_name][1].unbind()} != {expected_st}"
self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name][0]
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), {v:var_vals[v] for v in getattr(ji.prg,"vars",[])}, jit=True)
assert self.expected_vals == expected_vals, "mismatch of var_vals"
assert self.expected_sts_dtype == expected_sts_dtype, "mismatch of sts"
assert self.jit_fxn, "didn't get jitted?"
self.jit_fxn(input_rawbuffers, var_vals, DEBUG>=2)
elif self.cnt == 1:
self.expected_vals, self.expected_sts_dtype = expected_vals, expected_sts_dtype

CacheCollector.start(var_vals)
self.ret = self.fxn(*args, **kwargs)
self.jit_cache = CacheCollector.finish()
assert len(self.jit_cache) != 0, "didn't JIT anything!"
if DEBUG >= 1: print(f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs")
# get the inputs for replacement
for j,ji in enumerate(self.jit_cache):
for i,a in enumerate(ji.rawbufs):
if a in [v[0] for v in input_rawbuffers.values()]:
self.input_replace[(j,i)] = [(k, v[1].unbind(), v[0].dtype) for k,v in input_rawbuffers.items() if v[0] == a][0]
assert set([x[0] for x in self.input_replace.values()]) == set(input_rawbuffers.keys()), "some input tensors not found"
jit_cache = CacheCollector.finish()
assert len(jit_cache) != 0, "didn't JIT anything!"
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_rawbuffers)} inputs")

alt_batch_exec = Device[Device.DEFAULT].batch_executor
self.jit_fxn = (BatchExecutor if alt_batch_exec is None or getenv("JIT") == 2 else alt_batch_exec)(jit_cache, input_rawbuffers, var_vals)
elif self.cnt == 0:
self.ret = self.fxn(*args, **kwargs)

# clear the inputs
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
self.cnt += 1
return self.ret

Expand Down
5 changes: 3 additions & 2 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class Interpreted:
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_underlying=None):
self.buffer, self.fxn_for_op, self.from_underlying = buffer, fxn_for_op, from_underlying
self.synchronize = lambda: None
self.batch_executor = None
self.codegen = None
self.method_cache: Dict[LazyOp, Callable] = {}

Expand Down Expand Up @@ -232,8 +233,8 @@ def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int
return et

class Compiled:
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None):
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize = buffer, linearizer_opts, renderer, compiler, runtime, synchronize
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None, batch_executor=None):
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize, self.batch_executor = buffer, linearizer_opts, renderer, compiler, runtime, synchronize, batch_executor
self.method_cache: Dict[LazyOp, ASTRunner] = {}

def to_program(self, k):
Expand Down
6 changes: 3 additions & 3 deletions tinygrad/runtime/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def fromCPU(cls, x:np.ndarray, **kwargs):
class RawBufferMapped(RawBufferCopyIn):
def _buffer(self) -> memoryview: raise NotImplementedError("must be implemented")
# NOTE: this metadata prevents the backing buffer from being freed. hack can be removed with PEP688
def buffer_view(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self}), count=self.size)
def buffer_view(self) -> np.ndarray: return np.frombuffer(self._buffer(), dtype=np.dtype(self.dtype.np, metadata={"backing": self}), count=self.size) # type: ignore
def toCPU(self) -> np.ndarray: return self.buffer_view().copy() # Need a copy, since jit will write to the same buffer.
def _copyin(self, x:np.ndarray) -> None: np.copyto(self.buffer_view(), x.reshape(-1))

Expand Down Expand Up @@ -83,8 +83,8 @@ def ensure_has_free_space(self, space_to_free, device):

def _alloc_buffer(self, size, dtype, device, **kwargs):
self.ensure_has_free_space(size*dtype.itemsize, device)
while True:
try:
while True:
try:
newbuf = self._do_alloc(max(1, size), dtype, device, **kwargs)
break
except Exception:
Expand Down
81 changes: 76 additions & 5 deletions tinygrad/runtime/ops_metal.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch
import os, subprocess, pathlib, ctypes, tempfile
import Metal, Cocoa, libdispatch
from typing import List, Any, Tuple
from typing import List, Any, Tuple, Dict, Union, Set
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup, CI
from tinygrad.ops import Compiled
from tinygrad.renderer.metal import MetalRenderer
from tinygrad.runtime.lib import RawBufferMapped, LRUAllocator
from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator

class MetalAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs): return METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)
Expand Down Expand Up @@ -44,7 +44,7 @@ def compile_metal(prg, use_xcode=bool(getenv("METAL_XCODE"))) -> bytes:
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
options = Metal.MTLCompileOptions.alloc().init()
options = Metal.MTLCompileOptions.new()
library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None))
# TODO: avoid file write here?
with tempfile.NamedTemporaryFile(delete=True) as output_file:
Expand Down Expand Up @@ -80,4 +80,75 @@ def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Tuple[int,i
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
METAL.mtl_buffers_in_flight.append(command_buffer)

MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize)
from tinygrad.jit import BatchExecutor, JitItem
from tinygrad.shape.symbolic import Variable, Node
class MetalBatchExecutor(BatchExecutor):
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)

# create metal batch exec
icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new()
icb_descriptor.setCommandTypes_(Metal.MTLIndirectCommandType(Metal.MTLIndirectCommandTypeConcurrentDispatch))
icb_descriptor.setInheritBuffers_(False)
icb_descriptor.setInheritPipelineState_(False)
icb_descriptor.setMaxKernelBufferBindCount_(31)
self.icb = METAL.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache), Metal.MTLResourceOptions(0))
assert self.icb is not None, "create indirect command buffer failed, does your system support this?"

self.int_buf = RawMetalBuffer(len(var_vals), dtypes.int32)
self.input_has_variable_dims: Set[int] = set()
read_resources, write_resources = [], []
for j,ji in enumerate(self.jit_cache):
descriptor = Metal.MTLComputePipelineDescriptor.new()
descriptor.setComputeFunction_(ji.prg.clprg.fxn)
descriptor.setSupportIndirectCommandBuffers_(True)
pipeline_state = unwrap(METAL.device.newComputePipelineStateWithDescriptor_options_reflection_error_(descriptor, Metal.MTLPipelineOption(0), None, None))
icb_command = self.icb.indirectComputeCommandAtIndex_(j)
icb_command.setComputePipelineState_(pipeline_state)
for i,b in enumerate(ji.rawbufs):
if b is not None:
icb_command.setKernelBuffer_offset_atIndex_(b._buf, 0, i)
if i == 0: write_resources.append(b._buf)
else: read_resources.append(b._buf)
var_vals_keys = list(var_vals.keys())
for i,v in enumerate(getattr(ji.prg,"vars",[])):
icb_command.setKernelBuffer_offset_atIndex_(self.int_buf._buf, var_vals_keys.index(v)*4, len(ji.rawbufs)+i)
global_size, local_size = ji.prg.launch_dims(var_vals)
assert ji.prg.global_size and ji.prg.local_size, "need global and local size to JIT"
if any(isinstance(x, Node) for x in ji.prg.global_size) or any(isinstance(x, Node) for x in ji.prg.local_size):
self.input_has_variable_dims.add(j)
else:
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
icb_command.setBarrier()
self.read_resources, self.write_resources = dedup(read_resources), dedup(write_resources)
self.command_buffer: Any = None
self.int_buf_view = self.int_buf.buffer_view() # TODO: this is metal syncing when it doesn't need to

def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
# NOTE: you at least can't update the ints if this is running
if self.command_buffer is not None and self.command_buffer in METAL.mtl_buffers_in_flight: self.command_buffer.waitUntilCompleted()
all_read_resources = self.read_resources + [x._buf for x in input_rawbuffers.values()]
for (j,i),input_name in self.input_replace.items():
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_name]._buf, 0, i)
for j in self.input_has_variable_dims:
global_size, local_size = self.jit_cache[j].prg.launch_dims(var_vals)
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
self.int_buf_view[:] = list(var_vals.values())
command_buffer = METAL.mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.executeCommandsInBuffer_withRange_(self.icb, Metal.MTLIndirectCommandBufferExecutionRangeMake(0,len(self.jit_cache)))
encoder.useResources_count_usage_(all_read_resources, len(all_read_resources), Metal.MTLResourceUsageRead)
encoder.useResources_count_usage_(self.write_resources, len(self.write_resources), Metal.MTLResourceUsageWrite)
encoder.endEncoding()
command_buffer.commit()
self.command_buffer = command_buffer
if wait:
command_buffer.waitUntilCompleted()
et = command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
else:
METAL.mtl_buffers_in_flight.append(command_buffer)
et = None
super().update_stats(var_vals, et)
return et

MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, batch_executor=MetalBatchExecutor if not CI else None)

0 comments on commit b1f7f29

Please sign in to comment.