forked from langroid/langroid
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_llm.py
116 lines (103 loc) · 4.07 KB
/
test_llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import openai
import pytest
from langroid.cachedb.redis_cachedb import RedisCacheConfig
from langroid.language_models.base import LLMMessage, Role
from langroid.language_models.openai_gpt import (
OpenAIChatModel,
OpenAICompletionModel,
OpenAIGPT,
OpenAIGPTConfig,
)
from langroid.parsing.parser import Parser, ParsingConfig
from langroid.parsing.utils import generate_random_sentences
from langroid.utils.configuration import Settings, set_global
# allow streaming globally, but can be turned off by individual models
set_global(Settings(stream=True))
@pytest.mark.parametrize(
"streaming, country, capital",
[(True, "France", "Paris"), (False, "India", "Delhi")],
)
def test_openai_gpt(test_settings: Settings, streaming, country, capital):
set_global(test_settings)
cfg = OpenAIGPTConfig(
stream=streaming, # use streaming output if enabled globally
type="openai",
max_output_tokens=100,
min_output_tokens=10,
chat_model=(
OpenAIChatModel.GPT3_5_TURBO
if test_settings.gpt3_5
else OpenAIChatModel.GPT4 # or GPT4_TURBO
),
completion_model=OpenAICompletionModel.GPT3_5_TURBO_INSTRUCT,
cache_config=RedisCacheConfig(fake=False),
)
mdl = OpenAIGPT(config=cfg)
question = "What is the capital of " + country + "?"
set_global(Settings(cache=False))
# chat mode via `generate`,
# i.e. use same call as for completion, but the setting below
# actually calls `chat` under the hood
cfg.use_chat_for_completion = True
# check that "generate" works when "use_chat_for_completion" is True
response = mdl.generate(prompt=question, max_tokens=10)
assert capital in response.message
assert not response.cached
# actual chat mode
messages = [
LLMMessage(
role=Role.SYSTEM,
content="You are a serious, helpful assistant. Be very concise, not funny",
),
LLMMessage(role=Role.USER, content=question),
]
response = mdl.chat(messages=messages, max_tokens=10)
assert capital in response.message
assert not response.cached
set_global(Settings(cache=True))
# should be from cache this time
response = mdl.chat(messages=messages, max_tokens=10)
assert capital in response.message
assert response.cached
@pytest.mark.parametrize(
"mode, max_tokens",
[("completion", 100), ("chat", 100), ("completion", 1000), ("chat", 1000)],
)
def _test_context_length_error(test_settings: Settings, mode: str, max_tokens: int):
"""
Test disabled, see TODO below.
Also it takes too long since we are trying to test
that it raises the expected error when the context length is exceeded.
Args:
test_settings: from conftest.py
mode: "completion" or "chat"
max_tokens: number of tokens to generate
"""
set_global(test_settings)
set_global(Settings(cache=False))
cfg = OpenAIGPTConfig(
stream=False,
max_output_tokens=max_tokens,
chat_model=OpenAIChatModel.GPT4, # or GPT4_TURBO,
completion_model=OpenAICompletionModel.TEXT_DA_VINCI_003,
cache_config=RedisCacheConfig(fake=False),
)
parser = Parser(config=ParsingConfig())
llm = OpenAIGPT(config=cfg)
context_length = (
llm.chat_context_length() if mode == "chat" else llm.completion_context_length()
)
toks_per_sentence = int(parser.num_tokens(generate_random_sentences(1000)) / 1000)
max_sentences = int(context_length * 1.5 / toks_per_sentence)
big_message = generate_random_sentences(max_sentences + 1)
big_message_tokens = parser.num_tokens(big_message)
assert big_message_tokens + max_tokens > context_length
response = None
# TODO need to figure out what error type to expect here
with pytest.raises(openai.BadRequestError) as e:
if mode == "chat":
response = llm.chat(big_message, max_tokens=max_tokens)
else:
response = llm.generate(prompt=big_message, max_tokens=max_tokens)
assert response is None
assert "context length" in str(e.value).lower()