forked from openai/evals
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cot.py
66 lines (53 loc) · 2.62 KB
/
cot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
Extending Completion Functions with Chain-of-Thought
"""
from evals.api import CompletionFn, CompletionResult
from evals.prompt.base import ChatCompletionPrompt
from evals.record import record_sampling
from evals.registry import Registry
DEFAULT_COT_TEMPLATE = "\nBefore answering, reason in a step-by-step manner as to get the right answer, then conclude with the answer."
DEFAULT_EXTRACT_ANSWER_TEMPLATE = (
"\nGiven the above reasoning, the answer in the format requested by the question is:"
)
class ChainOfThoughtCompletionResult(CompletionResult):
def __init__(self, response) -> None:
self.response = response
def get_completions(self) -> list[str]:
return [self.response.strip()]
class ChainOfThoughtCompletionFn(CompletionFn):
def __init__(
self,
cot_template: str = DEFAULT_COT_TEMPLATE,
extract_answer_template: str = DEFAULT_EXTRACT_ANSWER_TEMPLATE,
cot_completion_fn: str = None,
extract_completion_fn: str = None,
registry: Registry = None,
registry_path: str = None,
**kwargs
) -> None:
registry = Registry() if not registry else registry
if registry_path:
registry.add_registry_paths(registry_path)
if extract_completion_fn is None:
extract_completion_fn = cot_completion_fn
# This model will use chain of thought to answer the question
self.cot_template = cot_template
self.cot_completion_fn_instance = registry.make_completion_fn(cot_completion_fn)
# This model will extract the answer from the chain of thought
self.extract_answer_template = extract_answer_template
self.extract_completion_fn_instance = registry.make_completion_fn(extract_completion_fn)
def __call__(self, prompt, **kwargs) -> ChainOfThoughtCompletionResult:
# Ensure it is in string format
prompt = ChatCompletionPrompt(prompt).to_formatted_prompt()
cot_prompt = prompt + [{"role": "assistant", "content": self.cot_template}]
answer = self.cot_completion_fn_instance(prompt=cot_prompt, **kwargs).get_completions()[0]
record_sampling(prompt=cot_prompt, sampled=answer)
extraction_prompt = cot_prompt + [
{"role": "assistant", "content": answer},
{"role": "assistant", "content": self.extract_answer_template},
]
extracted_answer = self.extract_completion_fn_instance(
prompt=extraction_prompt, **kwargs
).get_completions()[0]
record_sampling(prompt=extraction_prompt, sampled=extracted_answer)
return ChainOfThoughtCompletionResult(extracted_answer)