Skip to content

Commit

Permalink
JIT support in Interpreted (tinygrad#2314)
Browse files Browse the repository at this point in the history
* factor that out

* jit is supported everywhere

* fix some tests

* there's no jit supported device, the jit is everywhere

* fix test uops
  • Loading branch information
geohot committed Nov 15, 2023
1 parent 9a20bc0 commit 70a65c2
Show file tree
Hide file tree
Showing 16 changed files with 164 additions and 141 deletions.
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
repos:
- repo: local
hooks:
- id: mypy
name: mypy
entry: mypy tinygrad/ extra/helpers.py
language: system
always_run: true
pass_filenames: false
- id: ruff
name: ruff
entry: ruff .
Expand All @@ -19,12 +25,6 @@ repos:
language: system
always_run: true
pass_filenames: false
- id: mypy
name: mypy
entry: mypy tinygrad/ extra/helpers.py
language: system
always_run: true
pass_filenames: false
- id: tests
name: subset of TORCH tests
entry: env PYTHONPATH="." TORCH=1 python3 -m pytest -n=4 test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py
Expand Down
4 changes: 2 additions & 2 deletions examples/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from tinygrad.nn import Embedding, Linear
from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters
from tinygrad.helpers import GlobalCounters
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
from tinygrad.jit import TinyJit
from tinygrad.shape.symbolic import Variable

MAX_CONTEXT = 1024
JIT = getenv("JIT", 0 if CI else int(Device.DEFAULT in JIT_SUPPORTED_DEVICE))
JIT = getenv("JIT", 0 if CI else 1)

# https://github.com/facebookresearch/llama/blob/1076b9c51c77ad06e9d7ba8a4c6df775741732bd/llama/model.py#L47
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
Expand Down
3 changes: 1 addition & 2 deletions test/external/external_test_jit_on_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import unittest
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
from tinygrad.jit import TinyJit
from tinygrad.helpers import dtypes, CI
from tinygrad.ops import Device
from test.helpers import derandomize_model
Expand All @@ -14,7 +14,6 @@ def helper_test_jitted_correctness(gen, train, train_jit):
for _ in range(5): jit = train_jit(*gen()).numpy()
np.testing.assert_allclose(nojit, jit, rtol=1e-3, atol=1e-5)

@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE, "needs JIT")
class TestJittedModels(unittest.TestCase):
def test_jitted_tiny_llama(self):
old_type = Tensor.default_type
Expand Down
8 changes: 4 additions & 4 deletions test/models/test_real_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from tinygrad.tensor import Tensor
from tinygrad.nn import optim
from tinygrad.nn.state import get_parameters
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
from tinygrad.jit import TinyJit
from tinygrad.ops import Device, GlobalCounters
from tinygrad.helpers import CI, dtypes, getenv, prod
from test.helpers import derandomize_model
Expand Down Expand Up @@ -47,7 +47,7 @@ def test_stable_diffusion(self):
def test(t, t2): return model(t, 801, t2).realize()
helper_test("test_sd", lambda: (Tensor.randn(1, 4, 64, 64),Tensor.randn(1, 77, 768)), test, 18.0, 967)

@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and Device.DEFAULT not in ["LLVM"], "needs JIT, too long on CI LLVM")
@unittest.skipUnless((Device.DEFAULT not in ["LLVM", "CPU"] or not CI), "needs JIT, too long on CI LLVM")
def test_llama(self):
Tensor.default_type = dtypes.float16

Expand All @@ -59,7 +59,7 @@ def test(t): return model(t, 0).realize()
# NOTE: only test one pass, not testing the dynamic shape autoregressive part
helper_test("test_llama", lambda: (Tensor([[1,]]),), test, 0.22 if CI else 13.5, 137 if CI else 521, all_jitted=True)

@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and (Device.DEFAULT not in ["LLVM"] or not CI), "needs JIT, too long on CI LLVM")
@unittest.skipUnless((Device.DEFAULT not in ["LLVM", "CPU"] or not CI), "needs JIT, too long on CI LLVM")
def test_gpt2(self):
Tensor.default_type = dtypes.float16

Expand All @@ -70,7 +70,7 @@ def test_gpt2(self):
def test(t): return model(t, 0).realize()
helper_test("test_gpt2", lambda: (Tensor([[1,]]),), test, 0.21 if CI else 0.9, 140 if CI else 396, all_jitted=True)

@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and (Device.DEFAULT not in ["LLVM", "CLANG"] or not CI), "needs JIT, too long on CI LLVM and CLANG")
@unittest.skipUnless((Device.DEFAULT not in ["LLVM", "CLANG", "CPU"] or not CI), "needs JIT, too long on CI LLVM and CLANG")
def test_train_cifar(self):
# TODO: with default device
#old_default = Device.DEFAULT
Expand Down
5 changes: 3 additions & 2 deletions test/test_custom_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# *** first, we implement the atan2 op at the lowest level ***
# `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers
from tinygrad.lazy import LazyBuffer, create_lazybuffer
from tinygrad.ops import ASTRunner, Device
from tinygrad.ops import CompiledASTRunner, Device
from tinygrad.shape.shapetracker import ShapeTracker
import pytest

Expand All @@ -20,7 +20,7 @@ def atan2_gpu(ret:LazyBuffer, a:LazyBuffer, b:LazyBuffer):
assert a.device == "GPU" and b.device == "GPU", "gpu function requires GPUBuffers"
assert a.dtype == b.dtype and a.dtype == dtypes.float32, "gpu function only supports float32"
ret.realized = Device[ret.device].buffer(prod(ret.shape), ret.dtype)
ASTRunner("atan2_gpu", """
CompiledASTRunner(None, "atan2_gpu", """
__kernel void atan2_gpu(global float *c, global float *a, global float *b) {
int idx = get_global_id(0);
c[idx] = atan2(a[idx], b[idx]);
Expand Down Expand Up @@ -89,6 +89,7 @@ def test_atan2_backward(self):
np.testing.assert_allclose(a.grad.numpy(), ta.grad.numpy(), atol=1e-5)
np.testing.assert_allclose(b.grad.numpy(), tb.grad.numpy(), atol=1e-5)

@unittest.skipIf(Device.DEFAULT in ["CPU"], "atan2_cpu not jittable")
def test_atan2_jit(self):
# custom ops even work in the JIT!
from tinygrad.jit import TinyJit
Expand Down
4 changes: 2 additions & 2 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
import unittest
import numpy as np
from tinygrad.tensor import Tensor, Device
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
from tinygrad.jit import TinyJit
import pytest

pytestmark = pytest.mark.webgpu

# NOTE: METAL fails, might be platform and optimization options dependent.
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and Device.DEFAULT not in ["METAL", "WEBGPU"], f"no JIT on {Device.DEFAULT}")
@unittest.skipUnless(Device.DEFAULT not in ["METAL", "WEBGPU"], f"no JIT on {Device.DEFAULT}")
class TestJit(unittest.TestCase):
def test_simple_jit(self):
@TinyJit
Expand Down
1 change: 0 additions & 1 deletion test/test_symbolic_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy as np

@unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported")
@unittest.skipUnless(Device.DEFAULT in ["GPU", "METAL", "CLANG", "CUDA", "LLVM"], f"{Device.DEFAULT} is not supported")
class TestSymbolicJit(unittest.TestCase):
def test_plus1(self):
def f(a): return (a+1).realize()
Expand Down
1 change: 0 additions & 1 deletion test/test_symbolic_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import unittest
from tinygrad.jit import JIT_SUPPORTED_DEVICE
from tinygrad.shape.symbolic import Variable
from tinygrad.helpers import getenv
from tinygrad.tensor import Tensor, Device
Expand Down
8 changes: 4 additions & 4 deletions test/test_uops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
import numpy as np
from tinygrad.helpers import dtypes, getenv, DType, PtrDType
from tinygrad.tensor import Device
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ASTRunner, Compiled
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, CompiledASTRunner, Compiled
from tinygrad.codegen.linearizer import UOps, UOp

def _uops_to_prg(uops):
src, runtime_args = Device[Device.DEFAULT].renderer("test", uops)
return ASTRunner("test", src,
[1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None,
runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime)
return CompiledASTRunner(None, "test", src,
[1] if Device[Device.DEFAULT].linearizer_opts.has_local else None, [1] if Device[Device.DEFAULT].linearizer_opts.has_local else None,
runtime_args=runtime_args).build(Device[Device.DEFAULT].compiler, Device[Device.DEFAULT].runtime)

def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], vin:Tuple[UOp, ...], arg:Any=None) -> UOp:
uops.append(UOp(uop, dtype, tuple(vin), arg))
Expand Down
4 changes: 0 additions & 4 deletions tinygrad/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from tinygrad.shape.symbolic import Variable
from weakref import ref, WeakKeyDictionary

JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU", "LLVM"]

class TinyJit:
def __init__(self, fxn:Callable):
self.fxn: Callable = fxn
Expand All @@ -28,8 +26,6 @@ def input_replace(self) -> Dict[Tuple[int, int], Union[int, str]]: return self.j
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)

def __call__(self, *args, **kwargs) -> Any:
if Device.DEFAULT.split(":")[0] not in JIT_SUPPORTED_DEVICE: return self.fxn(*args, **kwargs) # only jit on supported device

# all inputs are realized
input_tensors: Dict[Union[int, str], Tensor] = {cast(Union[int, str], k):v.realize() for k,v in itertools.chain(enumerate(args), kwargs.items()) if v.__class__ is Tensor}
expected_sts_dtype = tuple([(v.lazydata.st.unbind(), v.dtype) for v in input_tensors.values()])
Expand Down
Loading

0 comments on commit 70a65c2

Please sign in to comment.