From 2675cbce23a4478592a17942487e6de83dc94478 Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Mon, 10 Apr 2023 23:03:17 -0700 Subject: [PATCH] add langchain llm math chain (#631) Add LLMMath chain as a completion function Is a bit of misnomer, because it is basically and LLM plus PythonREPL, so can solve more than math problems `oaieval langchain/chains/llm_math bigrams --max_samples 20` gives me an accuracy of 100% Example on a single input: Screenshot 2023-04-10 at 10 48 02 PM --- evals/completion_fns/langchain_math.py | 30 +++++++++++++++++++ evals/registry.py | 4 ++- .../completion_fns/langchain_chains.yaml | 2 ++ 3 files changed, 35 insertions(+), 1 deletion(-) create mode 100644 evals/completion_fns/langchain_math.py create mode 100644 evals/registry/completion_fns/langchain_chains.yaml diff --git a/evals/completion_fns/langchain_math.py b/evals/completion_fns/langchain_math.py new file mode 100644 index 0000000000..7b06b3a490 --- /dev/null +++ b/evals/completion_fns/langchain_math.py @@ -0,0 +1,30 @@ +import importlib +from typing import Optional + +from langchain import OpenAI, LLMMathChain + +from evals.prompt.base import CompletionPrompt +from evals.record import record_sampling + + +class LangChainCompletionResult: + def __init__(self, response) -> None: + self.response = response + + def get_completions(self) -> list[str]: + return [self.response.strip()] + + +class LangChainMathChainCompletionFn: + def __init__(self, **kwargs) -> None: + llm = OpenAI(temperature=0) + self.llm_math = LLMMathChain(llm=llm) + + def __call__(self, prompt, **kwargs) -> LangChainCompletionResult: + + prompt = CompletionPrompt(prompt).to_formatted_prompt() + response = self.llm_math.run(prompt) + # The LangChain response comes with `Answer: ` ahead of this, let's strip it out + response = response.strip("Answer:").strip() + record_sampling(prompt=prompt, sampled=response) + return LangChainCompletionResult(response) diff --git a/evals/registry.py b/evals/registry.py index 9e91ccbc7b..601b518000 100644 --- a/evals/registry.py +++ b/evals/registry.py @@ -104,7 +104,9 @@ def make_completion_fn(self, name: str) -> CompletionFn: spec = self.get_completion_fn(name) if spec is None: raise ValueError(f"Could not find CompletionFn in the registry with ID {name}") - + if spec.args is None: + spec.args = {} + spec.args["registry"] = self instance = make_object(spec.cls)(**spec.args or {}) assert isinstance(instance, CompletionFn), f"{name} must be a CompletionFn" diff --git a/evals/registry/completion_fns/langchain_chains.yaml b/evals/registry/completion_fns/langchain_chains.yaml new file mode 100644 index 0000000000..7a24650074 --- /dev/null +++ b/evals/registry/completion_fns/langchain_chains.yaml @@ -0,0 +1,2 @@ +langchain/chains/llm_math: + class: evals.completion_fns.langchain_math:LangChainMathChainCompletionFn