diff --git a/tests/predict/test_predict.py b/tests/predict/test_predict.py index c701b380a..bd07e776e 100644 --- a/tests/predict/test_predict.py +++ b/tests/predict/test_predict.py @@ -1,11 +1,14 @@ -import dspy -from dspy import Predict, Signature, TypedPredictor -from dspy.utils.dummies import DummyLM import copy import textwrap + import pydantic +import pytest import ujson +import dspy +from dspy import Predict, Signature, TypedPredictor +from dspy.utils.dummies import DummyLM + def test_initialization_with_string_signature(): signature_string = "input1, input2 -> output" @@ -209,14 +212,203 @@ class OutputOnlySignature(dspy.Signature): assert lm.get_convo(-1) == textwrap.dedent( """\ Given the fields , produce the fields `output`. - + --- - + Follow the following format. - + Output: ${output} - + --- - + Output: short answer""" ) + + +@pytest.fixture(name="SandwichIdea") +def sandwich_idea_signature(): + class SandwichIdea(dspy.Signature): + """Based on the meal and dietary requirements, suggest a sandwich idea.""" + + meal: str = dspy.InputField() + dietary_requiements: str = dspy.InputField() + bread: str = dspy.OutputField() + protein: str = dspy.OutputField() + fat: str = dspy.OutputField() + garnish: str = dspy.OutputField() + sauce: str = dspy.OutputField() + + return SandwichIdea + + +def test_extend_generation(SandwichIdea): + lm = DummyLM( + [ + " whole wheat\n\nProtein: turkey\n\nFat: avocado", + # Incomplete generation leads to tomato field being assigned as an + # empty string ("") in dsp.primitives.predict l98 the generation + # therefores continues with the next field. + " tomato \n\nSauce: mustard", + ] + ) + dspy.settings.configure(lm=lm) + + prediction = Predict(SandwichIdea)(meal="lunch", dietary_requiements="N/A") + # The logged conversation (additional newlines removed, [..] indicates the generation): + # === DummyLM === + # Based on the meal and dietary requirements, suggest a sandwich idea. + # --- + # Follow the following format. + # Meal: ${meal} + # Dietary Requiements: ${dietary_requiements} + # Bread: ${bread} + # Protein: ${protein} + # Fat: ${fat} + # Garnish: ${garnish} + # Sauce: ${sauce} + # --- + # Meal: lunch + # Dietary Requiements: N/A + # Bread: [whole wheat + # Protein: turkey + # Fat: avocado] + # === + # === DummyLM === + # Based on the meal and dietary requirements, suggest a sandwich idea. + # --- + # Follow the following format. + # Meal: ${meal} + # Dietary Requiements: ${dietary_requiements} + # Bread: ${bread} + # Protein: ${protein} + # Fat: ${fat} + # Garnish: ${garnish} + # Sauce: ${sauce} + # --- + # Meal: lunch + # Dietary Requiements: N/A + # Bread: whole wheat + # Protein: turkey + # Fat: avocado + # Garnish: [tomato + # Sauce: mustard] + # === + + assert prediction.bread == "whole wheat" + assert prediction.protein == "turkey" + assert prediction.fat == "avocado" + assert prediction.garnish == "" # This field is assigned as "" when the generation is extended + assert prediction.sauce == "tomato \n\nSauce: mustard" + + +def test_extend_generation_rolled_back_when_field_is_skipped(SandwichIdea): + lm = DummyLM( + [ + " white\n\nFat: butter\n\nGarnish: lettuce\n\nSauce: mayo", + " ham\n\nFat: butter\n\nGarnish: lettuce\n\nSauce: mayo", + ] + ) + dspy.settings.configure(lm=lm) + # The logged conversation (additional newlines removed, [..] indicates the generation): + # === DummyLM === + # Based on the meal and dietary requirements, suggest a sandwich idea. + # --- + # Follow the following format. + # Meal: ${meal} + # Dietary Requiements: ${dietary_requiements} + # Bread: ${bread} + # Protein: ${protein} + # Fat: ${fat} + # Garnish: ${garnish} + # Sauce: ${sauce} + # --- + # Meal: lunch + # Dietary Requiements: N/A + # Bread:[ white + # Fat: butter + # Garnish: lettuce + # Sauce: mayo] + # === + # === DummyLM === + # Based on the meal and dietary requirements, suggest a sandwich idea. + # --- + # Follow the following format. + # Meal: ${meal} + # Dietary Requiements: ${dietary_requiements} + # Bread: ${bread} + # Protein: ${protein} + # Fat: ${fat} + # Garnish: ${garnish} + # Sauce: ${sauce} + # --- + # Meal: lunch + # Dietary Requiements: N/A + # Bread: white Fat: butter Garnish: lettuce Sauce: mayo + # Protein:[ ham + # Fat: butter + # Garnish: lettuce + # Sauce: mayo] + # === + + predictor = Predict(SandwichIdea)(meal="lunch", dietary_requiements="N/A") + assert predictor.bread == "white\n\nFat: butter\n\nGarnish: lettuce\n\nSauce: mayo" + assert predictor.protein == "" + assert predictor.fat == "ham\n\nFat: butter" + assert predictor.garnish == "lettuce" + assert predictor.sauce == "mayo" + + +def test_extend_generation_with_empty_field(SandwichIdea): + lm = DummyLM( + [ + " white\n\nProtein: \n\nFat: butter\n\nGarnish: lettuce", + " lettuce \n\nSauce: mayo", + ] + ) + dspy.settings.configure(lm=lm) + # The logged conversation (additional newlines removed, [..] indicates the generation): + # === DummyLM === + # Based on the meal and dietary requirements, suggest a sandwich idea. + # --- + # Follow the following format. + # Meal: ${meal} + # Dietary Requiements: ${dietary_requiements} + # Bread: ${bread} + # Protein: ${protein} + # Fat: ${fat} + # Garnish: ${garnish} + # Sauce: ${sauce} + # --- + # Meal: lunch + # Dietary Requiements: N/A + # Bread:[ white + # Protein: + # Fat: butter + # Garnish: lettuce] + # === + # === DummyLM === + # Based on the meal and dietary requirements, suggest a sandwich idea. + # --- + # Follow the following format. + # Meal: ${meal} + # Dietary Requiements: ${dietary_requiements} + # Bread: ${bread} + # Protein: ${protein} + # Fat: ${fat} + # Garnish: ${garnish} + # Sauce: ${sauce} + # --- + # Meal: lunch + # Dietary Requiements: N/A + # Bread: white + # Protein: Fat: butter Garnish: lettuce + # Fat:[ lettuce + # Sauce: mayo] + # === + + predictor = Predict(SandwichIdea)(meal="lunch", dietary_requiements="N/A") + assert predictor.bread == "white" + assert predictor.protein == "Fat: butter\n\nGarnish: lettuce" + assert predictor.fat == "" + assert predictor.garnish == "lettuce" + assert predictor.sauce == "mayo"