Skip to content

Commit

Permalink
fix test_real_world llama (tinygrad#2335)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyuxyz committed Nov 17, 2023
1 parent 3b9dd33 commit 3971259
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions test/models/test_real_world.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tinygrad.nn.state import get_parameters
from tinygrad.jit import TinyJit
from tinygrad.ops import Device, GlobalCounters
from tinygrad.helpers import CI, dtypes, getenv, prod
from tinygrad.helpers import CI, dtypes
from test.helpers import derandomize_model

from examples.gpt2 import Transformer as GPT2Transformer, MODEL_PARAMS as GPT2_MODEL_PARAMS
Expand Down Expand Up @@ -56,8 +56,8 @@ def test_llama(self):
derandomize_model(model)
@TinyJit
def test(t): return model(t, 0).realize()
# NOTE: only test one pass, not testing the dynamic shape autoregressive part
helper_test("test_llama", lambda: (Tensor([[1,]]),), test, 0.22 if CI else 13.5, 137 if CI else 521, all_jitted=True)
# TODO: test first token vs rest properly, also memory test is broken with CacheCollector
helper_test("test_llama", lambda: (Tensor([[1,2,3,4]]),), test, 0.22 if CI else 13.5, 181 if CI else 685, all_jitted=True)

@unittest.skipUnless((Device.DEFAULT not in ["LLVM", "CPU"] or not CI), "needs JIT, too long on CI LLVM")
def test_gpt2(self):
Expand Down

0 comments on commit 3971259

Please sign in to comment.