Skip to content

Commit

Permalink
move BatchExecutor (tinygrad#2297)
Browse files Browse the repository at this point in the history
* move BatchExecutor

* refactor to get_optimized_program

* that changed
  • Loading branch information
geohot committed Nov 14, 2023
1 parent 0cbf6c1 commit 8916028
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 92 deletions.
52 changes: 5 additions & 47 deletions tinygrad/jit.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,15 @@
from __future__ import annotations
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
import functools, itertools
from tinygrad.helpers import DEBUG, DType, merge_dicts, GlobalCounters, getenv, colored
from tinygrad.ops import RawBuffer, Device, ASTRunner
from tinygrad.helpers import DEBUG, DType, merge_dicts
from tinygrad.ops import RawBuffer, Device, ASTRunner, BatchExecutor, JitItem
from tinygrad.tensor import Tensor
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, sym_infer
from dataclasses import dataclass
from tinygrad.shape.symbolic import Variable
from weakref import ref, WeakKeyDictionary

JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU", "LLVM"]

@dataclass(frozen=True)
class JitItem:
prg: ASTRunner
rawbufs: List[Optional[RawBuffer]]

class BatchExecutor:
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
self.jit_cache: List[JitItem] = jit_cache
self.input_replace: Dict[Tuple[int, int], Union[int, str]] = {}
self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0)
for j,ji in enumerate(jit_cache):
if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored
self.op_estimate += ji.prg.op_estimate
self.mem_estimate += ji.prg.mem_estimate
for i,a in enumerate(ji.rawbufs):
if a in [v for v in input_rawbuffers.values()]:
self.input_replace[(j,i)] = [k for k,v in input_rawbuffers.items() if v == a][0]
assert set(self.input_replace.values()) == set(input_rawbuffers.keys()), "some input tensors not found"
self.clear_jit_inputs()

def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), {v:var_vals[v] for v in getattr(ji.prg,"vars",[])}, jit=True)
self.clear_jit_inputs()

def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]):
# TODO: this is mostly copied from ASTRunner
op_estimate = sym_infer(self.op_estimate, var_vals)
mem_estimate = sym_infer(self.mem_estimate, var_vals)
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'CYAN')} kernels:{len(self.jit_cache):4d} inputs:{len(self.input_replace):3d} {' '.join([f'{k.expr}={v}' for k,v in var_vals.items()])[:50]:50s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += len(self.jit_cache)
GlobalCounters.global_ops += sym_infer(self.op_estimate, var_vals)
GlobalCounters.global_mem += sym_infer(self.mem_estimate, var_vals)
if et is not None: GlobalCounters.time_sum_s += et

def clear_jit_inputs(self):
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None

class TinyJit:
def __init__(self, fxn:Callable):
self.fxn: Callable = fxn
Expand Down Expand Up @@ -98,8 +57,7 @@ def __call__(self, *args, **kwargs) -> Any:
assert len(jit_cache) != 0, "didn't JIT anything!"
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_rawbuffers)} inputs")

alt_batch_exec = Device[Device.DEFAULT].batch_executor
self.jit_fxn = (BatchExecutor if alt_batch_exec is None or getenv("JIT") == 2 else alt_batch_exec)(jit_cache, input_rawbuffers, var_vals)
self.jit_fxn = Device[Device.DEFAULT].batch_executor(jit_cache, input_rawbuffers, var_vals)
elif self.cnt == 0:
self.ret = self.fxn(*args, **kwargs)

Expand Down Expand Up @@ -129,7 +87,7 @@ def start(self, var_vals:Optional[Dict[Variable, int]]=None):
def add(self, prg, rawbufs, var_vals):
if self.cache is None: return
for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}"
self.placeholders[rawbufs[0]] = PlaceHolder(rawbufs[0])
self.placeholders[rawbufs[0]] = PlaceHolder(rawbufs[0]) # NOTE: this is making an assumption that 0 is special
self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, RawBuffer) else x for x in rawbufs]))

def finish(self) -> List[JitItem]:
Expand Down
124 changes: 83 additions & 41 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations
import importlib, inspect, functools, pathlib, re
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, cast
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT
from tinygrad.runtime.lib import RawBuffer
from tinygrad.shape.symbolic import Variable, sym_infer
from tinygrad.shape.symbolic import Variable, sym_infer, NumNode
from dataclasses import dataclass

# these are the llops your accelerator must implement, along with toCpu
Expand Down Expand Up @@ -106,13 +106,55 @@ def DEFAULT(self) -> str:
return "CPU"
Device = _Device()

# **************** batch executor ****************

@dataclass(frozen=True)
class JitItem:
prg: ASTRunner
rawbufs: List[Optional[RawBuffer]]

class BatchExecutor:
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
self.jit_cache: List[JitItem] = jit_cache
self.input_replace: Dict[Tuple[int, int], Union[int, str]] = {}
self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0)
for j,ji in enumerate(jit_cache):
if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored
self.op_estimate += ji.prg.op_estimate
self.mem_estimate += ji.prg.mem_estimate
for i,a in enumerate(ji.rawbufs):
if a in [v for v in input_rawbuffers.values()]:
self.input_replace[(j,i)] = [k for k,v in input_rawbuffers.items() if v == a][0]
assert set(self.input_replace.values()) == set(input_rawbuffers.keys()), "some input tensors not found"
self.clear_jit_inputs()

def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), {v:var_vals[v] for v in getattr(ji.prg,"vars",[])}, jit=True)
self.clear_jit_inputs()

def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]):
# TODO: this is mostly copied from ASTRunner
op_estimate = sym_infer(self.op_estimate, var_vals)
mem_estimate = sym_infer(self.mem_estimate, var_vals)
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'CYAN')} kernels:{len(self.jit_cache):4d} inputs:{len(self.input_replace):3d} {' '.join([f'{k.expr}={v}' for k,v in var_vals.items()])[:50]:50s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += len(self.jit_cache)
GlobalCounters.global_ops += sym_infer(self.op_estimate, var_vals)
GlobalCounters.global_mem += sym_infer(self.mem_estimate, var_vals)
if et is not None: GlobalCounters.time_sum_s += et

def clear_jit_inputs(self):
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None

# **************** for Interpreted Buffers ****************

class Interpreted:
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_underlying=None):
self.buffer, self.fxn_for_op, self.from_underlying = buffer, fxn_for_op, from_underlying
self.synchronize = lambda: None
self.batch_executor = None
self.batch_executor = BatchExecutor
self.codegen = None
self.method_cache: Dict[LazyOp, Callable] = {}

Expand Down Expand Up @@ -233,17 +275,50 @@ def __call__(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int
return et

class Compiled:
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None, batch_executor=None):
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize, self.batch_executor = buffer, linearizer_opts, renderer, compiler, runtime, synchronize, batch_executor
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None, batch_executor=BatchExecutor):
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize = buffer, linearizer_opts, renderer, compiler, runtime, synchronize
self.batch_executor = BatchExecutor if getenv("JIT") == 2 else batch_executor
self.method_cache: Dict[LazyOp, ASTRunner] = {}

def to_program(self, k):
def to_program(self, k) -> ASTRunner:
k.linearize()
src, runtime_args = self.renderer(k.function_name, k.uops)
return ASTRunner(k.function_name, src, k.global_size, k.local_size,
op_estimate=k.info.flops, mem_estimate=k.info.mem_estimate,
display_name=k.display_name, runtime_args=runtime_args).build(self.compiler, self.runtime)

def get_optimized_program(self, ast:LazyOp, rawbuffers:List[RawBuffer]) -> ASTRunner:
if DEBUG >= 3:
from tinygrad.graph import print_tree
print_tree(ast)
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.lazy import vars_from_ast
k = Linearizer(ast, self.linearizer_opts)
assert k.info.dtype == rawbuffers[0].dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {rawbuffers[0].dtype}"
if not NOOPT:
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
if BEAM >= 1 and not vars_from_ast(ast):
lins = [(("tc" if used_tensor_cores else "hc"), k)]
# allocate a scratch buffer if output buffer is also input
test_rawbuffers = [type(rawbuffers[0])(rawbuffers[0].size, rawbuffers[0].dtype), *rawbuffers[1:]] if rawbuffers[0] in rawbuffers[1:] else rawbuffers
kb = Linearizer(ast, self.linearizer_opts)
kb.required_optimizations()
from tinygrad.features.search import beam_search, time_linearizer
lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))))
if used_tensor_cores:
lins.append(("hc", Linearizer(ast, self.linearizer_opts)))
lins[-1][1].hand_coded_optimizations()
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, disable_cache=True, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
k = timed[0][1]
else:
k.required_optimizations()
prg = self.to_program(k)
# extract real vars used in ast
prg.vars = vars_from_ast(ast)
assert all(v._val is None for v in prg.vars), f"ast contains bound Variable {prg.vars}"
return prg

def exec_ast(self, ast:LazyOp, output, inputs, var_vals, **kwargs):
# check if we can reuse the output buffer
# if it's aliased, don't use it
Expand All @@ -264,44 +339,11 @@ def exec_ast(self, ast:LazyOp, output, inputs, var_vals, **kwargs):
# all the rawbuffers
rawbuffers = [output.realized] + [x.realized for x in inputs]

# compilation time
def get_program():
if DEBUG >= 3:
from tinygrad.graph import print_tree
print_tree(ast)
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.lazy import vars_from_ast
k = Linearizer(ast, self.linearizer_opts)
assert k.info.dtype == output.dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {output.dtype}"
if not NOOPT:
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
if BEAM >= 1 and not vars_from_ast(ast):
lins = [(("tc" if used_tensor_cores else "hc"), k)]
# allocate a scratch buffer if output buffer is also input
test_rawbuffers = [type(rawbuffers[0])(rawbuffers[0].size, rawbuffers[0].dtype), *rawbuffers[1:]] if rawbuffers[0] in rawbuffers[1:] else rawbuffers
kb = Linearizer(ast, self.linearizer_opts)
kb.required_optimizations()
from tinygrad.features.search import beam_search, time_linearizer
lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))))
if used_tensor_cores:
lins.append(("hc", Linearizer(ast, self.linearizer_opts)))
lins[-1][1].hand_coded_optimizations()
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, disable_cache=True, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
k = timed[0][1]
else:
k.required_optimizations()
prg = self.to_program(k)
# extract real vars used in ast
prg.vars = vars_from_ast(ast)
assert all(v._val is None for v in prg.vars), f"ast contains bound Variable {prg.vars}"
return prg

if getenv("ENABLE_METHOD_CACHE", 1):
if ast not in self.method_cache: self.method_cache[ast] = get_program()
if ast not in self.method_cache: self.method_cache[ast] = self.get_optimized_program(ast, rawbuffers)
prg = self.method_cache[ast]
else:
prg = get_program()
prg = self.get_optimized_program(ast, rawbuffers)

if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)

Expand Down
7 changes: 3 additions & 4 deletions tinygrad/runtime/ops_metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from typing import List, Any, Tuple, Dict, Union, Set
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup, CI
from tinygrad.ops import Compiled
from tinygrad.ops import Compiled, BatchExecutor, JitItem
from tinygrad.renderer.metal import MetalRenderer
from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator
from tinygrad.shape.symbolic import Variable, Node

class MetalAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs): return METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)
Expand Down Expand Up @@ -80,8 +81,6 @@ def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Tuple[int,i
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
METAL.mtl_buffers_in_flight.append(command_buffer)

from tinygrad.jit import BatchExecutor, JitItem
from tinygrad.shape.symbolic import Variable, Node
class MetalBatchExecutor(BatchExecutor):
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
Expand Down Expand Up @@ -151,4 +150,4 @@ def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals:
super().update_stats(var_vals, et)
return et

MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, batch_executor=MetalBatchExecutor if not CI else None)
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, batch_executor=MetalBatchExecutor if not CI else BatchExecutor)

0 comments on commit 8916028

Please sign in to comment.