Skip to content

Commit

Permalink
add langchain llm math chain (openai#631)
Browse files Browse the repository at this point in the history
Add LLMMath chain as a completion function

Is a bit of misnomer, because it is basically and LLM plus PythonREPL,
so can solve more than math problems

`oaieval langchain/chains/llm_math bigrams --max_samples 20` gives me an
accuracy of 100%

Example on a single input:

<img width="1043" alt="Screenshot 2023-04-10 at 10 48 02 PM"
src="https://user-images.githubusercontent.com/11986836/231067743-f34b7458-0a68-4d84-9489-8a4f61756b95.png">
  • Loading branch information
hwchase17 committed Apr 11, 2023
1 parent b4e5142 commit 2675cbc
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
30 changes: 30 additions & 0 deletions evals/completion_fns/langchain_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import importlib
from typing import Optional

from langchain import OpenAI, LLMMathChain

from evals.prompt.base import CompletionPrompt
from evals.record import record_sampling


class LangChainCompletionResult:
def __init__(self, response) -> None:
self.response = response

def get_completions(self) -> list[str]:
return [self.response.strip()]


class LangChainMathChainCompletionFn:
def __init__(self, **kwargs) -> None:
llm = OpenAI(temperature=0)
self.llm_math = LLMMathChain(llm=llm)

def __call__(self, prompt, **kwargs) -> LangChainCompletionResult:

prompt = CompletionPrompt(prompt).to_formatted_prompt()
response = self.llm_math.run(prompt)
# The LangChain response comes with `Answer: ` ahead of this, let's strip it out
response = response.strip("Answer:").strip()
record_sampling(prompt=prompt, sampled=response)
return LangChainCompletionResult(response)
4 changes: 3 additions & 1 deletion evals/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def make_completion_fn(self, name: str) -> CompletionFn:
spec = self.get_completion_fn(name)
if spec is None:
raise ValueError(f"Could not find CompletionFn in the registry with ID {name}")

if spec.args is None:
spec.args = {}

spec.args["registry"] = self
instance = make_object(spec.cls)(**spec.args or {})
assert isinstance(instance, CompletionFn), f"{name} must be a CompletionFn"
Expand Down
2 changes: 2 additions & 0 deletions evals/registry/completion_fns/langchain_chains.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
langchain/chains/llm_math:
class: evals.completion_fns.langchain_math:LangChainMathChainCompletionFn

0 comments on commit 2675cbc

Please sign in to comment.