-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
extend generation logic tests #1172
extend generation logic tests #1172
Conversation
tests/predict/test_predict.py
Outdated
def test_extend_generation(SandwichIdea): | ||
lm = DummyLM( | ||
[ | ||
" whole wheat\n\nProtein: turkey\n\nFat: avocado", | ||
" tomato\n\nSauce: mustard", | ||
] | ||
) | ||
dspy.settings.configure(lm=lm) | ||
|
||
prediction = Predict(SandwichIdea)(meal="lunch", dietary_requiements="N/A") | ||
assert prediction.bread == "whole wheat" | ||
assert prediction.protein == "turkey" | ||
assert prediction.fat == "avocado" | ||
assert prediction.garnish == "tomato" | ||
assert prediction.sauce == "mustard" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@okhat - I'm not sure about the other tests I've added, but I was definitely expecting this test to pass. However:
SandwichIdea = SandwichIdea(meal, dietary_requiements -> bread, protein, fat, garnish, sauce
instructions='Based on the meal and ...notation=str required=True json_schema_extra={'__dspy_field_type': 'output', 'prefix': 'Sauce:', 'desc': '${sauce}'})
)
def test_extend_generation(SandwichIdea):
lm = DummyLM(
[
" whole wheat\n\nProtein: turkey\n\nFat: avocado",
" tomato\n\nSauce: mustard",
]
)
dspy.settings.configure(lm=lm)
prediction = Predict(SandwichIdea)(meal="lunch", dietary_requiements="N/A")
assert prediction.bread == "whole wheat"
assert prediction.protein == "turkey"
assert prediction.fat == "avocado"
> assert prediction.garnish == "tomato"
E AssertionError: assert '' == 'tomato'
E - tomato
I think when ""
is added as the result of the field
dspy/dsp/primitives/predict.py
Line 98 in 34725d0
completion[field_names[last_field_idx]] = "" |
None
dspy/dsp/templates/template_v2.py
Line 156 in 34725d0
if self.fields[idx].input_variable not in example or example[self.fields[idx].input_variable] is None: |
So the field is being filled with an empty string - and the model generated value is not being included in the "correct place" in the output, but the program continues without raising the recursion error.
The upshot is the prediction ends up offset by one field, and the following fields are not parsed correctly. ie:
prediction.garnish # ""
prediction.sauce # " tomato\n\nSauce: mustard"
The most atomised fix would be to delete
dspy/dsp/primitives/predict.py
Line 98 in 34725d0
completion[field_names[last_field_idx]] = "" |
But I think that might cause quite a few issues with existing examples which use the extend generation logic - the "quiet fail" which currently occurs is replaced by a recursion depth exception as the model continues to not generate the field.
Depending on the program, it might be the case that further down the line the deserialisation of a prompt + completion will unpick the offset caused by
dspy/dsp/primitives/predict.py
Line 98 in 34725d0
completion[field_names[last_field_idx]] = "" |
Hey @mikeedjones, thanks so much for the deep dive! The changes do need to be reverted for now, because they break more fundamental things than the regressions you mentioned, although these are very important too. In the longer run, I like the direction of this PR overall. Let's think of what the right long-term behavior is for parsing. The original DSPy behavior, which strikes a very good compromise IMO but needs to be better documented, is: when you ask for n=1 completion, you'll always get it back. If you request n>1 completions, you get at least one. No guarantees. If you need guaranteed n > 1 behavior, create multiple modules with |
Hi @okhat - I think what I show above is that the reverted logic isn't currently working as I expected? The original logic fills the missed field with an empty string and the completion continues from there - so the model would be prompted to start from Given that #920 was merged - think it makes sense to revert and add some tests so someone else can't come and inadvertently break something further down the line! |
tests/predict/test_predict.py
Outdated
lm = DummyLM( | ||
[ | ||
" whole wheat\n\nProtein: turkey\n\nFat: avocado", | ||
" tomato\n\nSauce: mustard", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@okhat So this generation would actually look like
" mustard"
Because the last None
field in the signature would be sauce
?
Updated the tests and added a comment to each including the logged lm calls. This branch has been pointing at https://github.com/stanfordnlp/dspy/tree/mipro_v2 and the tests are passing for the reverted logic. I think the dummy LM generations are representative - but I'm not sure the behaviour demonstrated by the tests is desired? |
Added tests to better define the behaviour of the extend generation logic in
dsp/primitives/predict.py
.They currently don't pass with either version of the extend generation logic! I'm not sure what the intended behaviour should be - can @XenonMolecule, @okhat @arnavsinghvi11 please explain what these tests should look like?
Cheers!