Skip to content

Commit

Permalink
Assert for memory allocation failures (tinygrad#2337)
Browse files Browse the repository at this point in the history
* assert adequate memory has been freed

* cleaned up runtime error message

* improved metal buffer alloc error catching and reporting

* decreased lines and altered messages

* removed unnecessary  _get_cur_free_space() call

* improved assert message

* added allocate massive buffer test

* added test_lru_allocator_metal_max_buffer_length

* split into two asserts and removed walrus assignment from assert expression

* update assert message and use byte data type for clarity
  • Loading branch information
imaolo committed Nov 17, 2023
1 parent aa01a63 commit 0d0c74b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
10 changes: 10 additions & 0 deletions test/test_allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,5 +174,15 @@ def sp():
test()
check_gc()

def test_lru_allocator_massive_buffer(self):
with self.assertRaises(AssertionError) as context: alloc(allocator := FakeAllocator(), size := 1e13, dtypes.int8)
self.assertEqual(str(context.exception), f"out of memory - requested: {size/1e9:5.2f} GB, available: {allocator._get_cur_free_space('0')/1e9:5.2f} GB")

@unittest.skipIf(Device.DEFAULT != "METAL", "only applies to Metal")
def test_lru_allocator_metal_max_buffer_length(self):
from tinygrad.runtime.ops_metal import METAL
with self.assertRaises(AssertionError) as context: METAL.allocator._do_alloc(buf_len := (max_buf_len := METAL.device.maxBufferLength()+1), dtypes.int8, '0')
self.assertEqual(str(context.exception), f"Buffer length of {buf_len/1e9:5.2f} GB exceeds Metal's max buffer length of {max_buf_len/1e9:5.2f} GB.")

if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions tinygrad/runtime/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def ensure_has_free_space(self, space_to_free, device):
while len(self.aging_order[device]) and self._get_cur_free_space(device) < space_to_free: # When OOM removing lru buffers.
bucket, epoch = self.aging_order[device].popleft()
if self.cached_buffers[bucket] and self.cached_buffers[bucket][-1][1] == epoch: self._free_buffer(self.cached_buffers[bucket].pop()[0]) # Free cached buffer if it is still in cache.
assert (curr_free := self._get_cur_free_space(device)) > space_to_free, f"out of memory - requested: {space_to_free/1e9:5.2f} GB, available: {curr_free/1e9:5.2f} GB"

def _alloc_buffer(self, size, dtype, device, **kwargs):
self.ensure_has_free_space(size*dtype.itemsize, device)
Expand Down
7 changes: 6 additions & 1 deletion tinygrad/runtime/ops_metal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
from tinygrad.shape.symbolic import Variable, Node

class MetalAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs): return METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)
def _do_alloc(self, size, dtype, device, **kwargs):
buf_len, max_buf_len = size*dtype.itemsize, METAL.device.maxBufferLength()
assert buf_len < max_buf_len, f"Buffer length of {buf_len/1e9:5.2f} GB exceeds Metal's max buffer length of {max_buf_len/1e9:5.2f} GB."
buf = METAL.device.newBufferWithLength_options_(buf_len, Metal.MTLResourceStorageModeShared)
assert buf, f"Metal buffer allocation failed with {buf}."
return buf
def _do_free(self, buf): buf.release()
def _cached_bufkey(self, size, dtype, device): return (device, size*dtype.itemsize) # Buffers of the same length could be reused, no matter what dtype.

Expand Down

0 comments on commit 0d0c74b

Please sign in to comment.