Skip to content

Commit

Permalink
Make the JIT simple (no batch exec, no cache collector) (tinygrad#2215)
Browse files Browse the repository at this point in the history
* remove batch exec

* simple cachecollector

* remove cache collector test

* less lr
  • Loading branch information
geohot committed Nov 6, 2023
1 parent 719a97b commit baeb77a
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 454 deletions.
4 changes: 2 additions & 2 deletions test/models/test_end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def compare_tiny_torch(model, model_torch, X, Y):
if not CI: print(f"initting {k} from torch")
model_state_dict[k].assign(Tensor(v.detach().numpy())).realize()

optimizer = optim.SGD(get_parameters(model), lr=0.01)
optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.01)
optimizer = optim.SGD(get_parameters(model), lr=0.001)
optimizer_torch = torch.optim.SGD(model_torch.parameters(), lr=0.001)

Xt = torch.Tensor(X.numpy())
np.testing.assert_allclose(X.numpy(), Xt.detach().numpy())
Expand Down
210 changes: 0 additions & 210 deletions test/test_cache_collector.py

This file was deleted.

59 changes: 0 additions & 59 deletions test/test_exec.py

This file was deleted.

55 changes: 9 additions & 46 deletions tinygrad/jit.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional, Set
from weakref import ref
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
from collections import defaultdict
import functools, itertools
from tinygrad.helpers import DEBUG, DType, merge_dicts, ImageDType, getenv
from tinygrad.ops import RawBuffer, Device, BasicBatchExecutor, ASTRunner
from tinygrad.helpers import DEBUG, DType, merge_dicts
from tinygrad.ops import RawBuffer, Device
from tinygrad.tensor import Tensor
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable
Expand Down Expand Up @@ -42,7 +41,7 @@ def __call__(self, *args, **kwargs) -> Any:
for k in self.jit_cache[j][2].keys():
try: self.jit_cache[j][2][k] = var_vals[k]
except KeyError: pass
self.batch_executor.exec(self.jit_cache, self.updatable_entries)
for prg, pargs, variables in self.jit_cache: prg(pargs, variables, jit=True)
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
elif self.cnt == 1:
CacheCollector.start()
Expand All @@ -60,57 +59,21 @@ def __call__(self, *args, **kwargs) -> Any:
for i in range(len(cache[2])): self.updatable_entries[j_].append(len(cache[1])+i)
#if prg.local_size is None: prg.local_size = prg.optimize_local_size(args, preserve_output=True) # the JIT can optimize local
assert set([x[0] for x in self.input_replace.values()]) == set(input_rawbuffers.keys()), "some input tensors not found"
self.batch_executor = self.jit_cache[0][0].batch_exec(self.jit_cache) if hasattr(self.jit_cache[0][0], 'batch_exec') else BasicBatchExecutor(self.jit_cache)
for (j,i) in self.input_replace.keys(): self.jit_cache[j][1][i] = None
elif self.cnt == 0:
self.ret = self.fxn(*args, **kwargs)
self.cnt += 1
return self.ret

class _CacheCollector:
class _Placeholder:
def __init__(self, buf): self.size, self.dtype, self._device, self.ref, self.buftype = buf.size, buf.dtype, getattr(buf, '_device', None), ref(buf), type(buf)
def alloc_rawbuf(self): return self.buftype(self.size, self.dtype, **({'device':self._device} if self._device is not None else dict()))

def __init__(self):
self.cache: Optional[List[Tuple[Callable, List[Any], Dict[Any,Any]]]] = None
self.placeholders: Dict[ref[RawBuffer], _CacheCollector._Placeholder] = {} # Output rawbufs are replaced with placeholders to allow freeing of the real buffer while collecting cache.
self.circular_signatures: Set[Any] = set()
def start(self): self.cache, self.placeholders, self.circular_signatures = [], {}, set()
def __init__(self): self.cache: Optional[List[Tuple[Callable, List[Any], Dict[Any,Any]]]] = None
def start(self): self.cache = []
def add(self, prg, rawbufs, var_vals):
if self.cache is None: return
# Substitute output buffers with placeholders to find the most optimal reusage.
if ref(rawbufs[0]) not in self.placeholders: self.placeholders[ref(rawbufs[0])] = _CacheCollector._Placeholder(rawbufs[0])
cached_rawbufs = [self.placeholders.get(ref(buf), buf) if isinstance(prg, ASTRunner) and isinstance(buf, RawBuffer) and ref(buf) not in self.circular_signatures else buf for buf in rawbufs]
self.cache.append((prg, cached_rawbufs, var_vals))
self.cache.append((prg, rawbufs, var_vals))
def finish(self):
if self.cache is None: return []

rawbuf_pool: List[Tuple[RawBuffer, List[Tuple[int, int]]]] = []
buf_usage_bounds: Dict[_CacheCollector._Placeholder, Tuple[int, int]] = {}
buf_map: Dict[_CacheCollector._Placeholder, RawBuffer] = {}

for j,(_,bufs,_) in enumerate(self.cache):
for buf in bufs:
if buf.__class__ is not _CacheCollector._Placeholder: continue
if buf.ref() is not None: buf_map[buf] = buf.ref() # rawbufs that are referenced are not replaced but are used as is.
else: buf_usage_bounds[buf] = buf_usage_bounds.get(buf, (j, j))[0], j

# The query list contains a query for every placeholder that should be replaced with the actual rawbuffer. Queries are served from the largest to the smallest.
# For each query, find any rawbuffer that is free within the query timeframe or allocate a new one.
query_list = sorted([(buf.size*buf.dtype.itemsize, buf_usage_bounds[buf][0], buf_usage_bounds[buf][1], buf) for buf in buf_usage_bounds.keys()], key=lambda x: x[0], reverse=True)
for _, start, end, buf in query_list:
pool_idx = next((i for i,(with_buf, usages) in enumerate(rawbuf_pool) if self._can_substitute(buf, with_buf) and self._no_intersect(start,end,usages)), -1)
if pool_idx == -1: rawbuf_pool.append((buf.alloc_rawbuf(), []))
buf_map[buf] = rawbuf_pool[pool_idx][0]
rawbuf_pool[pool_idx][1].append((start, end))

cache_result = [(p, [buf_map.get(buf, buf) for buf in cached_bufs], var_vals) for p, cached_bufs, var_vals in self.cache]
ret = self.cache
self.cache = None
return cache_result
def _no_intersect(self, start:int, end:int, usages:List[Tuple[int, int]]): return all(en < start or end < st for st, en in usages)
def _can_substitute(self, buf, with_buf):
if getenv("NO_BUFFER_REUSE"): return False
return buf._device==with_buf._device and (buf.size*buf.dtype.itemsize<=with_buf.size*with_buf.dtype.itemsize if not isinstance(buf.dtype, ImageDType) and not isinstance(with_buf.dtype, ImageDType) else buf.size==with_buf.size and buf.dtype==with_buf.dtype and buf.dtype.shape==with_buf.dtype.shape)
def _mark_output_buffer(self, output_buffer): self.circular_signatures.add(ref(output_buffer))
return ret
CacheCollector = _CacheCollector()
Loading

0 comments on commit baeb77a

Please sign in to comment.