Skip to content

Commit

Permalink
Add gpt-3.5-turbo-16k support to ctx len getter (openai#1388)
Browse files Browse the repository at this point in the history
**What:** Adds support for `gpt-3.5-turbo-16k` to
`n_ctx_from_model_name`.
**Why:** Currently `n_ctx_from_model_name` returns 4096 for
`gpt-3.5-turbo-16k`.

Co-authored-by: Ian McKenzie <[email protected]>
  • Loading branch information
danesherbs and ianmckenzie-oai committed Jan 3, 2024
1 parent 1dd2ea2 commit bbe26f8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
10 changes: 7 additions & 3 deletions evals/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@


def n_ctx_from_model_name(model_name: str) -> Optional[int]:
"""Returns n_ctx for a given API model name. Model list last updated 2023-06-16."""
"""Returns n_ctx for a given API model name. Model list last updated 2023-10-24."""
# note that for most models, the max tokens is n_ctx + 1
PREFIX_AND_N_CTX: list[tuple[str, int]] = [
("gpt-3.5-turbo-16k-", 16384),
("gpt-3.5-turbo-", 4096),
("gpt-4-32k-", 32768),
("gpt-4-", 8192),
Expand All @@ -55,6 +56,7 @@ def n_ctx_from_model_name(model_name: str) -> Optional[int]:
"text-davinci-002": 4096,
"text-davinci-003": 4096,
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-16k": 16384,
"gpt-4": 8192,
"gpt-4-32k": 32768,
"gpt-4-base": 8192,
Expand All @@ -77,13 +79,15 @@ def is_chat_model(model_name: str) -> bool:
if model_name in {"gpt-4-base"}:
return False

CHAT_MODEL_NAMES = {"gpt-3.5-turbo", "gpt-4", "gpt-4-32k"}
CHAT_MODEL_NAMES = {"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4", "gpt-4-32k"}

if model_name in CHAT_MODEL_NAMES:
return True

for model_prefix in {"gpt-3.5-turbo-", "gpt-4-", "gpt-4-32k-"}:
for model_prefix in {"gpt-3.5-turbo-", "gpt-4-"}:
if model_name.startswith(model_prefix):
return True

return False


Expand Down
13 changes: 8 additions & 5 deletions evals/registry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,25 @@


def test_n_ctx_from_model_name():
assert n_ctx_from_model_name("gpt-3.5-turbo") == 4096
assert n_ctx_from_model_name("gpt-3.5-turbo-0613") == 4096
assert n_ctx_from_model_name("gpt-3.5-turbo-16k") == 16384
assert n_ctx_from_model_name("gpt-3.5-turbo-16k-0613") == 16384
assert n_ctx_from_model_name("gpt-4") == 8192
assert n_ctx_from_model_name("gpt-4-0314") == 8192
assert n_ctx_from_model_name("gpt-4-0613") == 8192
assert n_ctx_from_model_name("gpt-4-32k") == 32768
assert n_ctx_from_model_name("gpt-4-32k-0314") == 32768
assert n_ctx_from_model_name("gpt-4-32k-0613") == 32768


def test_is_chat_model():
assert is_chat_model("gpt-3.5-turbo")
assert is_chat_model("gpt-3.5-turbo-0314")
assert is_chat_model("gpt-3.5-turbo-0613")
assert is_chat_model("gpt-3.5-turbo-16k")
assert is_chat_model("gpt-3.5-turbo-16k-0613")
assert is_chat_model("gpt-4")
assert is_chat_model("gpt-4-0314")
assert is_chat_model("gpt-4-0613")
assert is_chat_model("gpt-4-32k")
assert is_chat_model("gpt-4-32k-0314")
assert is_chat_model("gpt-4-32k-0613")
assert not is_chat_model("text-davinci-003")
assert not is_chat_model("gpt4-base")
assert not is_chat_model("code-davinci-002")

0 comments on commit bbe26f8

Please sign in to comment.