Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Tests for extend generation logic #1184

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
208 changes: 200 additions & 8 deletions tests/predict/test_predict.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import dspy
from dspy import Predict, Signature, TypedPredictor
from dspy.utils.dummies import DummyLM
import copy
import textwrap

import pydantic
import pytest
import ujson

import dspy
from dspy import Predict, Signature, TypedPredictor
from dspy.utils.dummies import DummyLM


def test_initialization_with_string_signature():
signature_string = "input1, input2 -> output"
Expand Down Expand Up @@ -209,14 +212,203 @@ class OutputOnlySignature(dspy.Signature):
assert lm.get_convo(-1) == textwrap.dedent(
"""\
Given the fields , produce the fields `output`.

---

Follow the following format.

Output: ${output}

---

Output: short answer"""
)


@pytest.fixture(name="SandwichIdea")
def sandwich_idea_signature():
class SandwichIdea(dspy.Signature):
"""Based on the meal and dietary requirements, suggest a sandwich idea."""

meal: str = dspy.InputField()
dietary_requiements: str = dspy.InputField()
bread: str = dspy.OutputField()
protein: str = dspy.OutputField()
fat: str = dspy.OutputField()
garnish: str = dspy.OutputField()
sauce: str = dspy.OutputField()

return SandwichIdea


def test_extend_generation(SandwichIdea):
lm = DummyLM(
[
" whole wheat\n\nProtein: turkey\n\nFat: avocado",
# Incomplete generation leads to tomato field being assigned as an
# empty string ("") in dsp.primitives.predict l98 the generation
# therefores continues with the next field.
" tomato \n\nSauce: mustard",
]
)
dspy.settings.configure(lm=lm)

prediction = Predict(SandwichIdea)(meal="lunch", dietary_requiements="N/A")
# The logged conversation (additional newlines removed, [..] indicates the generation):
# === DummyLM ===
# Based on the meal and dietary requirements, suggest a sandwich idea.
# ---
# Follow the following format.
# Meal: ${meal}
# Dietary Requiements: ${dietary_requiements}
# Bread: ${bread}
# Protein: ${protein}
# Fat: ${fat}
# Garnish: ${garnish}
# Sauce: ${sauce}
# ---
# Meal: lunch
# Dietary Requiements: N/A
# Bread: [whole wheat
# Protein: turkey
# Fat: avocado]
# ===
# === DummyLM ===
# Based on the meal and dietary requirements, suggest a sandwich idea.
# ---
# Follow the following format.
# Meal: ${meal}
# Dietary Requiements: ${dietary_requiements}
# Bread: ${bread}
# Protein: ${protein}
# Fat: ${fat}
# Garnish: ${garnish}
# Sauce: ${sauce}
# ---
# Meal: lunch
# Dietary Requiements: N/A
# Bread: whole wheat
# Protein: turkey
# Fat: avocado
# Garnish: [tomato
# Sauce: mustard]
# ===

assert prediction.bread == "whole wheat"
assert prediction.protein == "turkey"
assert prediction.fat == "avocado"
assert prediction.garnish == "" # This field is assigned as "" when the generation is extended
assert prediction.sauce == "tomato \n\nSauce: mustard"


def test_extend_generation_rolled_back_when_field_is_skipped(SandwichIdea):
lm = DummyLM(
[
" white\n\nFat: butter\n\nGarnish: lettuce\n\nSauce: mayo",
" ham\n\nFat: butter\n\nGarnish: lettuce\n\nSauce: mayo",
]
)
dspy.settings.configure(lm=lm)
# The logged conversation (additional newlines removed, [..] indicates the generation):
# === DummyLM ===
# Based on the meal and dietary requirements, suggest a sandwich idea.
# ---
# Follow the following format.
# Meal: ${meal}
# Dietary Requiements: ${dietary_requiements}
# Bread: ${bread}
# Protein: ${protein}
# Fat: ${fat}
# Garnish: ${garnish}
# Sauce: ${sauce}
# ---
# Meal: lunch
# Dietary Requiements: N/A
# Bread:[ white
# Fat: butter
# Garnish: lettuce
# Sauce: mayo]
# ===
# === DummyLM ===
# Based on the meal and dietary requirements, suggest a sandwich idea.
# ---
# Follow the following format.
# Meal: ${meal}
# Dietary Requiements: ${dietary_requiements}
# Bread: ${bread}
# Protein: ${protein}
# Fat: ${fat}
# Garnish: ${garnish}
# Sauce: ${sauce}
# ---
# Meal: lunch
# Dietary Requiements: N/A
# Bread: white Fat: butter Garnish: lettuce Sauce: mayo
# Protein:[ ham
# Fat: butter
# Garnish: lettuce
# Sauce: mayo]
# ===

predictor = Predict(SandwichIdea)(meal="lunch", dietary_requiements="N/A")
assert predictor.bread == "white\n\nFat: butter\n\nGarnish: lettuce\n\nSauce: mayo"
assert predictor.protein == ""
assert predictor.fat == "ham\n\nFat: butter"
assert predictor.garnish == "lettuce"
assert predictor.sauce == "mayo"


def test_extend_generation_with_empty_field(SandwichIdea):
lm = DummyLM(
[
" white\n\nProtein: \n\nFat: butter\n\nGarnish: lettuce",
" lettuce \n\nSauce: mayo",
]
)
dspy.settings.configure(lm=lm)
# The logged conversation (additional newlines removed, [..] indicates the generation):
# === DummyLM ===
# Based on the meal and dietary requirements, suggest a sandwich idea.
# ---
# Follow the following format.
# Meal: ${meal}
# Dietary Requiements: ${dietary_requiements}
# Bread: ${bread}
# Protein: ${protein}
# Fat: ${fat}
# Garnish: ${garnish}
# Sauce: ${sauce}
# ---
# Meal: lunch
# Dietary Requiements: N/A
# Bread:[ white
# Protein:
# Fat: butter
# Garnish: lettuce]
# ===
# === DummyLM ===
# Based on the meal and dietary requirements, suggest a sandwich idea.
# ---
# Follow the following format.
# Meal: ${meal}
# Dietary Requiements: ${dietary_requiements}
# Bread: ${bread}
# Protein: ${protein}
# Fat: ${fat}
# Garnish: ${garnish}
# Sauce: ${sauce}
# ---
# Meal: lunch
# Dietary Requiements: N/A
# Bread: white
# Protein: Fat: butter Garnish: lettuce
# Fat:[ lettuce
# Sauce: mayo]
# ===

predictor = Predict(SandwichIdea)(meal="lunch", dietary_requiements="N/A")
assert predictor.bread == "white"
assert predictor.protein == "Fat: butter\n\nGarnish: lettuce"
assert predictor.fat == ""
assert predictor.garnish == "lettuce"
assert predictor.sauce == "mayo"
Loading