Skip to content

Commit

Permalink
Updates for Solvers (openai#1461)
Browse files Browse the repository at this point in the history
We provide an update to our Solvers infrastructure
- Add a new README to onboard users wanting to work with solvers (beta)
- Creating a separate folder for registration: `evals/registry/solvers`
- Refactoring previous solver code to support reusability: NestedSolvers
allow you to chain multiple solvers
- New solvers: FewShotSolver, SelfConsistencySolver,
OpenAIAssistantsSolver
- A defaults.yaml for commonly reusable solvers
- Change abstract method for Solver action from `__call__` to `_solver`
so that task state is immutable

---------

Co-authored-by: johny-b <[email protected]>
Co-authored-by: ojaffe <[email protected]>
Co-authored-by: Andrei Alexandru <[email protected]>
  • Loading branch information
4 people committed Jan 29, 2024
1 parent 3040d6f commit 4a105ae
Show file tree
Hide file tree
Showing 32 changed files with 2,596 additions and 695 deletions.
65 changes: 45 additions & 20 deletions evals/elsuite/bluff/strategy_solver.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import copy
import re
from importlib import import_module
from typing import Optional

from evals.elsuite.bluff.bluff.cards import get_bluff_move
from evals.solvers.solver import Solver, SolverResult
from evals.solvers.utils import PersistentMemoryCache
from evals.task_state import Message, TaskState


Expand All @@ -24,7 +26,23 @@ def __init__(
self.max_attempts = max_attempts
self.rethink_strategy_after = rethink_strategy_after

def __call__(self, task_state: TaskState):
# interaction_length=1 to store reasoning step in private memory
self.interaction_cache = PersistentMemoryCache(interaction_length=1)

def _generate_response(self, task_state: TaskState):
"""
Calls base solver. Modifies taks state to remove all non-reasoning messages
from assistant
"""
task_state = copy.deepcopy(task_state)
task_state.messages = [
msg
for msg in task_state.messages
if msg.role != "assistant" or msg.content.startswith("{") or len(msg.content) > 20
]
return self.base_solver(task_state).output

def _solve(self, task_state: TaskState):
"""
This solver does three things that should help the model play better:
1. Adds a strategy guide as the first message (just after the task description)
Expand All @@ -35,25 +53,12 @@ def __call__(self, task_state: TaskState):
# GENERAL NOTE.
# This function is pretty ugly. I'm not sure how to implement this better. We decided this is good enough.

# Remove assistant messages added by the main solver (i.e. non-JSON).
# We need len(msg.content) > 20 because we don't want to remove "rething startegy".
task_state.messages = [
msg
for msg in task_state.messages
if msg.role != "assistant" or msg.content.startswith("{") or len(msg.content) > 20
]
# Before the first move in a game - strategy guide goes first
strategy_msg = Message("system", strategy)
task_state.messages.insert(0, strategy_msg)
task_state.messages = self.interaction_cache.load_private_interaction(task_state)

game = task_state.current_state

if len(game.rounds) == 1 and len(game.rounds[0].moves) < 2:
# Before the first move in a game - strategy guide goes first
strategy_msg = Message("system", strategy)

# This if is important - we might have already tried
# to bid, but gave an invalid bid, so still we have no moves
if strategy_msg not in task_state.messages:
task_state.messages.insert(0, strategy_msg)

if (
self.rethink_strategy_after is not None
and len(game.rounds) == 1 + self.rethink_strategy_after
Expand All @@ -67,15 +72,32 @@ def __call__(self, task_state: TaskState):
if strategy_update_msg not in task_state.messages:
last_system_message = task_state.messages.pop()
task_state.messages.append(strategy_update_msg)
response = self.base_solver(task_state).output
response = self._generate_response(task_state)
task_state.messages.append(Message("assistant", response))
task_state.messages.append(last_system_message)

# Manually update interaction cache, since we re-order messages
last_interaction = self.interaction_cache.last_interaction
last_interaction_messages = last_interaction.messages[:-1] + [
Message("system", strategy_update_msg),
Message("assistant", response),
Message("system", last_system_message),
]
last_interaction_private_ids = last_interaction.private_messages_ids + [
len(task_state.messages) - 3,
len(task_state.messages) - 2,
]

self.interaction_cache.last_interaction.messages = last_interaction_messages
self.interaction_cache.last_interaction.private_messages_ids = (
last_interaction_private_ids
)

# If this move_str is preserved, the game engine will have to deal with that
# (and it has some way of solving this problem)
move_str = "[INVALID MOVE]"
for _ in range(self.max_attempts):
response = self.base_solver(task_state).output
response = self._generate_response(task_state)
try:
move_str = self._parse_response(response)
# This will raise ValueError if this is not a valid move
Expand All @@ -85,6 +107,9 @@ def __call__(self, task_state: TaskState):
pass

task_state.messages.append(Message("assistant", response))
task_state.messages.append(Message("assistant", move_str))
self.interaction_cache.save_private_interaction(task_state)

return SolverResult(move_str)

@property
Expand Down
5 changes: 3 additions & 2 deletions evals/elsuite/make_me_pay/solvers/lm_con_artist_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
EXTRACT_MESSAGE,
TIME_REMINDER_MSG,
)
from evals.solvers.nested.cot_solver import CoTSolver
from evals.solvers.solver import Solver, SolverResult
from evals.task_state import Message, TaskState

Expand All @@ -28,7 +29,7 @@ def __init__(

# Add more args to base solver if we're using CoT
base_solver_args_extra = {}
if class_name == "OpenAIChatCompletionCoTSolver":
if issubclass(cls, CoTSolver):
base_solver_args_extra = {
"cot_template": COT_MESSAGE,
"extract_template": EXTRACT_MESSAGE,
Expand All @@ -49,7 +50,7 @@ def __init__(
def name(self):
return "Scaffolded-LM-Solver"

def __call__(self, task_state: TaskState, **kwargs) -> SolverResult:
def _solve(self, task_state: TaskState, **kwargs) -> SolverResult:
# Optional additional message for better LM capabilities. Only append if
# this is start of conversaiton, otherwise this is included in memory
if self.lm_system_prompt:
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/sandbagging/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _construct_prompt(self, task_state: TaskState) -> Sequence[Dict]:

return prompt

def __call__(self, task_state: TaskState, **kwargs) -> (Sequence[Dict], SolverResult):
def _solve(self, task_state: TaskState, **kwargs) -> (Sequence[Dict], SolverResult):
prompt = self._construct_prompt(task_state)
result = self._predict_answer(prompt, **kwargs)

Expand Down
6 changes: 3 additions & 3 deletions evals/elsuite/self_prompting/solvers/baselines.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(
This solver simply returns an empty string as the prompt.
"""

def __call__(
def _solve(
self,
task_state: TaskState,
**kwargs,
Expand All @@ -32,7 +32,7 @@ def __init__(
This solver simply returns the original instruction as the prompt.
"""

def __call__(
def _solve(
self,
task_state: TaskState,
**kwargs,
Expand All @@ -54,7 +54,7 @@ def __init__(
This solver concatenates the given input-output examples as few-shot demonstrations.
"""

def __call__(
def _solve(
self,
task_state: TaskState,
**kwargs,
Expand Down
2 changes: 1 addition & 1 deletion evals/elsuite/self_prompting/solvers/custom_cot_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
self.extract_completion_fn = OpenAIChatCompletionFn(**self.completion_fn_options)
self.extract_template = extract_template

def __call__(
def _solve(
self,
task_state: TaskState,
**kwargs,
Expand Down
9 changes: 7 additions & 2 deletions evals/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def n_ctx_from_model_name(model_name: str) -> Optional[int]:
"gpt-4": 8192,
"gpt-4-32k": 32768,
"gpt-4-base": 8192,
"gpt-4-1106-preview": 128_000,
}

# first, look for an exact match
Expand Down Expand Up @@ -135,7 +136,7 @@ def make_completion_fn(
# No match, so try to find a completion-fn-id in the registry
spec = self.get_completion_fn(name)
if spec is None:
raise ValueError(f"Could not find CompletionFn in the registry with ID {name}")
raise ValueError(f"Could not find CompletionFn/Solver in the registry with ID {name}")
if spec.args is None:
spec.args = {}
spec.args.update(kwargs)
Expand Down Expand Up @@ -195,7 +196,7 @@ def get_modelgraded_spec(self, name: str, **kwargs: dict) -> Optional[ModelGrade
)

def get_completion_fn(self, name: str) -> Optional[CompletionFnSpec]:
return self._dereference(name, self._completion_fns, "completion_fn", CompletionFnSpec)
return self._dereference(name, self._completion_fns | self._solvers, "completion_fn", CompletionFnSpec)

def get_eval(self, name: str) -> Optional[EvalSpec]:
return self._dereference(name, self._evals, "eval", EvalSpec)
Expand Down Expand Up @@ -303,6 +304,10 @@ def _load_registry(self, registry_paths: Sequence[Path], resource_type: str) ->
def _completion_fns(self) -> RawRegistry:
return self._load_registry(self._registry_paths, "completion_fns")

@functools.cached_property
def _solvers(self) -> RawRegistry:
return self._load_registry(self._registry_paths, "solvers")

@functools.cached_property
def _eval_sets(self) -> RawRegistry:
return self._load_registry(self._registry_paths, "eval_sets")
Expand Down
77 changes: 0 additions & 77 deletions evals/registry/completion_fns/bluff.yaml

This file was deleted.

92 changes: 0 additions & 92 deletions evals/registry/completion_fns/make-me-pay.yaml

This file was deleted.

Loading

0 comments on commit 4a105ae

Please sign in to comment.