Skip to content

Commit

Permalink
init allocator for compiled backends (tinygrad#1467)
Browse files Browse the repository at this point in the history
* init allocator for compiled backends

* Update ops_webgpu.py

---------

Co-authored-by: George Hotz <[email protected]>
  • Loading branch information
nimlgen and geohot committed Aug 17, 2023
1 parent a293c18 commit bd11141
Show file tree
Hide file tree
Showing 9 changed files with 343 additions and 39 deletions.
135 changes: 135 additions & 0 deletions test/external/external_test_allocator_on_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#!/usr/bin/env python
import unittest, gc
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.state import get_parameters, get_state_dict
from tinygrad.ops import GlobalCounters, LazyOp, LoadOps
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
from tinygrad.helpers import dtypes, prod
from tinygrad.lazy import Device

from examples.llama import Transformer

ALLOCATED_DEV_BUFS = 0
class FakeDeviceBuffer():
def __init__(self, sz, dt, device):
self.id = 1
self.size = sz
self.dtype = dt
self.device = device

global ALLOCATED_DEV_BUFS
ALLOCATED_DEV_BUFS += 1
class FakeAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs): return FakeDeviceBuffer(size, dtype, device)
def _do_free(self, buf):
buf.id -= 1
assert buf.id == 0, f"Free should be called once, but {buf.id}"

FAKE_GLOBAL_ALLOCATOR = None
class FakeBuffer(RawBuffer):
def __init__(self, size, dtype, device='0'):
global FAKE_GLOBAL_ALLOCATOR
super().__init__(size, dtype, allocator=FAKE_GLOBAL_ALLOCATOR, **{'device': device})
assert self._buf.size == size and self._buf.dtype == dtype and self._buf.device == device, "This allocator requires 100% match of dtype and size."
@classmethod
def fromCPU(cls, x:np.ndarray, **kwargs): return cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs)
def toCPU(self): return np.empty(self.size, dtype=self.dtype.np)
class FakeProgram:
def __init__(self, name:str, prg:str): pass
def __call__(self, global_size, local_size, *bufs, wait=False): pass

def helper_test_correctness(gen, train):
from tinygrad.runtime.ops_gpu import CL, CLAllocator
old_alloc = CL.cl_allocator
CL.cl_allocator = CLAllocator(0)
no_alloc_result = train(*gen()).numpy()
Device[Device.DEFAULT].synchronize()
CL.cl_allocator = CLAllocator(512<<30) # Test cache correctness, so cache as much as possible, 512gb
for _ in range(4):
GlobalCounters.reset()
np.testing.assert_allclose(train(*gen()).numpy(), no_alloc_result, rtol=1e-3, atol=1e-5)
Device[Device.DEFAULT].synchronize()
assert len(CL.cl_allocator.cached_buffers) != 0, "Cache must be used"
CL.cl_allocator = old_alloc

def __helper_test_alloc_count(gen, train):
was_alloc = ALLOCATED_DEV_BUFS
for _ in range(2):
train(*gen())
return ALLOCATED_DEV_BUFS - was_alloc

def helper_test_alloc_count(mm, gen, train):
global FAKE_GLOBAL_ALLOCATOR
backup_program = Device[Device.DEFAULT].runtime
backup_buffer = Device[Device.DEFAULT].buffer
Device[Device.DEFAULT].runtime = FakeProgram
Device[Device.DEFAULT].buffer = FakeBuffer
Device[Device.DEFAULT].method_cache.clear()
FAKE_GLOBAL_ALLOCATOR = FakeAllocator(16<<30)
new_allocs = __helper_test_alloc_count(gen, train)
Device[Device.DEFAULT].method_cache.clear()
FAKE_GLOBAL_ALLOCATOR = FakeAllocator(0)
old_allocs = __helper_test_alloc_count(gen, train)
print(f"{mm}: llama: old allocs count {old_allocs}, new allocs count {new_allocs}")
assert new_allocs < old_allocs, f"Hmm, doesn't cache work any more?"
Device[Device.DEFAULT].runtime = backup_program
Device[Device.DEFAULT].buffer = backup_buffer
FAKE_GLOBAL_ALLOCATOR = None

def check_gc():
if Device.DEFAULT == "GPU":
gc.collect() # Need to collect Tensors.
from extra.introspection import print_objects
assert print_objects() == 0

# for speed
def derandomize(x):
if isinstance(x, LazyOp):
if x.op == LoadOps.RAND: x.op = LoadOps.EMPTY
x.src = [derandomize(s) for s in x.src]
else:
x.op = derandomize(x.op)
return x

def derandomize_model(model):
for p in get_parameters(model):
p.lazydata = derandomize(p.lazydata)
p.realize()

class TestAllocators(unittest.TestCase):
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
def test_lru_allocator_tiny_llama(self):
old_type = Tensor.default_type
Tensor.default_type = dtypes.float16

args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
def __test():
model = Transformer(**args_tiny)
derandomize_model(model)
def test(t): return model(t, 0).realize()
helper_test_correctness(lambda: (Tensor([[1,]]),), test)
__test()
Tensor.default_type = old_type
check_gc()

@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
def test_lru_allocator_tiny_llama_alloc_counts(self):
args_tiny = {"dim": 1024, "multiple_of": 256, "n_heads": 8, "n_layers": 8, "norm_eps": 1e-05, "vocab_size": 1000}
def test_alloc_count(t):
model = Transformer(**args_tiny)
for v in get_state_dict(model).values(): v.assign(Tensor.empty(*v.shape, dtype=v.dtype))
return model(t, 0).realize()
helper_test_alloc_count("llama", lambda: (Tensor([[2,]]),), test_alloc_count)
check_gc()

@unittest.skip("huge for CI")
def test_stable_diffusion(self):
from examples.stable_diffusion import UNetModel
model = UNetModel()
derandomize_model(model)
def test(t, t2): return model(t, 801, t2).realize()
helper_test_correctness(lambda: (Tensor.randn(1, 4, 16, 16),Tensor.randn(1, 77, 768)), test)

if __name__ == "__main__":
unittest.main()
106 changes: 106 additions & 0 deletions test/test_allocators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#!/usr/bin/env python
import unittest
import numpy as np
from weakref import ref
from tinygrad.ops import GlobalCounters
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
from tinygrad.helpers import dtypes, prod
from tinygrad.lazy import Device

def check_gc():
if Device.DEFAULT == "GPU":
from extra.introspection import print_objects
assert print_objects() == 0

class FakeDeviceBuffer():
def __init__(self, sz, dt, device):
self.id = 1
self.size = sz
self.dtype = dt
self.device = device
def __del__(self):
assert self.id == 0, "Should called _do_free() before"

class FakeAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs): return FakeDeviceBuffer(size, dtype, device)
def _do_free(self, buf):
buf.id -= 1
assert buf.id == 0, f"Free should be called once, but {buf.id}"

FAKE_GLOBAL_ALLOCATOR = None
class FakeBuffer(RawBuffer):
def __init__(self, size, dtype, device='0'):
global FAKE_GLOBAL_ALLOCATOR
super().__init__(size, dtype, allocator=FAKE_GLOBAL_ALLOCATOR, **{'device': device})
assert self._buf.size == size and self._buf.dtype == dtype and self._buf.device == device, "This allocator requires 100% match of dtype and size."
@classmethod
def fromCPU(cls, x:np.ndarray, **kwargs): return cls(prod(x.shape), dtypes.from_np(x.dtype), **kwargs)
def toCPU(self): return np.empty(self.size, dtype=self.dtype.np)

def alloc(allocator, size, dtype, **kwargs):
global FAKE_GLOBAL_ALLOCATOR
FAKE_GLOBAL_ALLOCATOR = allocator
buf = FakeBuffer(size, dtype, **kwargs)
assert buf.dtype == dtype and buf.size == size
FAKE_GLOBAL_ALLOCATOR = None
return buf

def alloc_free_trace(allocator, size, dtype, **kwargs):
buf = alloc(allocator, size, dtype, **kwargs)
return ref(buf._buf)

def cmp_trace_and_buf(buf, trace_ref): return trace_ref and trace_ref() == buf._buf

class TestAllocators(unittest.TestCase):
def test_lru_allocator_reusage(self):
def test():
lru_allocator = FakeAllocator(2048)
traced_buf = alloc_free_trace(lru_allocator, 16, dtypes.float32)
assert GlobalCounters.mem_cached == 16*dtypes.float32.itemsize, "Buffer should be cached"
for _ in range(32):
def __test():
buf = alloc(lru_allocator, 16, dtypes.float32)
assert cmp_trace_and_buf(buf, traced_buf), "Buffer should be reused"
__test()

usedbuf = alloc(lru_allocator, 16, dtypes.float32)
for _ in range(32):
def __test():
buf = alloc(lru_allocator, 16, dtypes.float32)
assert usedbuf != buf, "Nobody should get used buffer"
__test()
assert GlobalCounters.mem_used == 16*dtypes.float32.itemsize, "Only usedbuf is still allocated."
test()
check_gc()

def test_lru_allocator_cache_free(self):
def test():
lru_allocator = FakeAllocator(128)
refs = []
for _ in range(32):
refs.append(alloc_free_trace(lru_allocator, 16, dtypes.float32))
for sz in range(32):
alloc_free_trace(lru_allocator, sz, dtypes.float32)
assert GlobalCounters.mem_used + GlobalCounters.mem_cached <= 128, "Should not allocate on device more than allowed (128)"
for r in refs: assert r() is None, "All refs should be dead, since buffers were cleared from cache"
test()
check_gc()

def test_lru_allocator_multidevice(self):
def test():
lru_allocator = FakeAllocator(256)
refs=[]
for i in range(8):
refs.append(alloc_free_trace(lru_allocator, 16, dtypes.float32, device=str(i)))
for i in range(64):
def __test():
dev = str(i % 8)
buf = alloc(lru_allocator, 16, dtypes.float32, device=dev)
assert cmp_trace_and_buf(buf, refs[i%8]), "Buffer should be reused"
__test()
for r in refs: assert r() is not None, "All refs should be cached"
test()
check_gc()

if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions tinygrad/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class GlobalCounters:
time_sum_s: ClassVar[float] = 0.0
kernel_count: ClassVar[int] = 0
mem_used: ClassVar[int] = 0 # NOTE: this is not reset
mem_cached: ClassVar[int] = 0 # NOTE: this is not reset
cache: ClassVar[Optional[List[Tuple[Callable, Any]]]] = None
@staticmethod
def reset(): GlobalCounters.global_ops, GlobalCounters.global_mem, GlobalCounters.time_sum_s, GlobalCounters.kernel_count, GlobalCounters.cache = 0,0,0.0,0,None
Expand Down
51 changes: 47 additions & 4 deletions tinygrad/runtime/lib.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import ctypes
import numpy as np
from typing import TypeVar, Type, Any
from tinygrad.helpers import DType, dtypes, prod, GlobalCounters
from collections import defaultdict, deque
from typing import TypeVar, Type, Any, Dict, Deque, Tuple
from tinygrad.helpers import DType, dtypes, prod, GlobalCounters, ImageDType

_T = TypeVar("_T")
class RawBuffer: # pylint: disable=abstract-method
def __init__(self, size:int, dtype:DType, buf:Any=None):
def __init__(self, size:int, dtype:DType, buf:Any=None, allocator:Any=None, **kwargs):
self.size: int = size
self.dtype: DType = dtype
self._buf = buf
self._buf = buf if buf is not None else (allocator.alloc(size, dtype, **kwargs) if allocator else None) # If buf is provided, use it. Otherwise try to allocate from the allocator.
self._memsz: int = size*dtype.itemsize
self._allocator = allocator
GlobalCounters.mem_used += self._memsz
def __del__(self): # NOTE: if it fails on init (bad dtype), it won't have a _memsz
if hasattr(self, '_memsz'): GlobalCounters.mem_used -= self._memsz
if hasattr(self, '_allocator') and self._allocator: self._allocator.free(self._buf)
def __repr__(self): return f"buffer<{self.size}, {self.dtype}>"
@property
def key(self): return (self.size, self.dtype.key)
Expand Down Expand Up @@ -66,3 +69,43 @@ def key(self): return (str(self._buf), self.dtype.key)

def buf_is_kernel_arg(x) -> bool:
return x.realized is not None and x.realized.__class__ is not RawConst

class LRUAllocator:
def __init__(self, dev_memsz=(4<<30)):
self.epoch = 0
self.free_space: Dict[Any, int] = defaultdict(lambda: dev_memsz)
self.buffer_info: Dict[Any, Tuple[int, DType, str]] = dict()
self.cached_buffers: Dict[Tuple[int, ...], Deque[Tuple[Any, int]]] = defaultdict(deque) # Cached buffer storage, splitted by type and size, newest first.
self.aging_order: Dict[Any, Deque[Tuple[Tuple[int, ...], int]]] = defaultdict(deque) # Keys of cached_buffers, ordered from oldest to newest updates.
def __del__(self):
for v in self.cached_buffers.values():
for buf, _ in v: self._free_buffer(buf)
def _cache_reuse_buffer(self, rawbufs: Deque[Tuple[Any, int]]): # The newest cached buffer is reused.
GlobalCounters.mem_cached -= self._underlying_buf_memsz(rawbufs[0][0])
return rawbufs.popleft()[0]
def _alloc_buffer(self, size, dtype, device, **kwargs):
self.free_space[device] -= size*dtype.itemsize
while len(self.aging_order[device]) and self.free_space[device] < 0: # When OOM removing lru buffers.
bucket, epoch = self.aging_order[device].popleft()
if self.cached_buffers[bucket] and self.cached_buffers[bucket][-1][1] == epoch: self._free_buffer(self.cached_buffers[bucket].pop()[0]) # Free cached buffer if it is still in cache.
newbuf = self._do_alloc(size, dtype, device, **kwargs)
self.buffer_info[newbuf] = (size, dtype, device)
return newbuf
def _free_buffer(self, buf_to_free):
self.free_space[self.buffer_info[buf_to_free][2]] += self._underlying_buf_memsz(buf_to_free)
GlobalCounters.mem_cached -= self._underlying_buf_memsz(buf_to_free)
self.buffer_info.pop(buf_to_free)
self._do_free(buf_to_free)
def alloc(self, size, dtype, device='0', **kwargs):
rawbufs = self.cached_buffers.get(self._cached_bufkey(size, dtype, device), None)
return self._cache_reuse_buffer(rawbufs) if rawbufs else self._alloc_buffer(size, dtype, device, **kwargs)
def free(self, buf): # free() just caches buffer. It might be freed later when OOM during allocation.
self.epoch += 1
size, dtype, device = self.buffer_info[buf]
self.cached_buffers[self._cached_bufkey(size, dtype, device)].appendleft((buf, self.epoch))
self.aging_order[device].append((self._cached_bufkey(size, dtype, device), self.epoch))
GlobalCounters.mem_cached += self._underlying_buf_memsz(buf)
def _underlying_buf_memsz(self, buf): return self.buffer_info[buf][0] * self.buffer_info[buf][1].itemsize
def _cached_bufkey(self, size, dtype, device) -> Tuple[int, ...]: return (device, size, dtype, dtype.shape) if isinstance(dtype, ImageDType) else (device, size, dtype) # Provides a key for reusing device buffers with identical keys.
def _do_alloc(self, size, dtype, device, **kwargs): raise NotImplementedError("must be implemented")
def _do_free(self, buf): pass
8 changes: 6 additions & 2 deletions tinygrad/runtime/ops_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pycuda.compiler import compile as cuda_compile # type: ignore
from tinygrad.helpers import DEBUG, getenv, colored, fromimport
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer
from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator
from tinygrad.codegen.linearizer import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage

Expand Down Expand Up @@ -47,8 +47,12 @@ class device:
else:
import pycuda.autoprimaryctx # type: ignore # pylint: disable=unused-import # noqa: F401
import pycuda.driver as cuda # type: ignore
class CUDAAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs): return cuda.mem_alloc(size * dtype.itemsize) # type: ignore
def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.
CUDAAlloc = CUDAAllocator(pycuda.driver.Context.get_device().total_memory())
class RawCUDABuffer(RawBufferCopyInOut): # type: ignore
def __init__(self, size, dtype): super().__init__(size, dtype, cuda.mem_alloc(size * dtype.itemsize)) # type: ignore
def __init__(self, size, dtype): super().__init__(size, dtype, allocator=CUDAAlloc)
def _copyin(self, x:np.ndarray, stream:Optional[cuda.Stream]=None): cuda.memcpy_htod_async(self._buf, x.ravel(), stream) # type: ignore
def _copyout(self, x:np.ndarray): cuda.memcpy_dtoh(x, self._buf) # type: ignore

Expand Down
Loading

0 comments on commit bd11141

Please sign in to comment.