Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add top_k to PromptNode #4159

Merged
merged 8 commits into from
Feb 20, 2023
Merged

feat: add top_k to PromptNode #4159

merged 8 commits into from
Feb 20, 2023

Conversation

tstadel
Copy link
Member

@tstadel tstadel commented Feb 14, 2023

Related Issues

Proposed Changes:

  • add top_k parameter to PromptNode
  • push down top_k to invocation layers and map them to their invocation interface

How did you test it?

  • added tests

Notes for the reviewer

  • I don't know why test/nodes/test_audio.py::TestTextToSpeech::test_text_to_speech_compress_audio fails. The code didn't touch any of TestToSpeech. Also other PR CIs are failing because of this. Maybe a dependency update?...

Checklist

  • I have read the contributors guidelines and the code of conduct
  • I have updated the related issue with new insights and changes
  • I added tests that demonstrate the correct behavior of the change
  • I've used one of the conventional commit types for my PR title: fix:, feat:, build:, chore:, ci:, docs:, style:, refactor:, perf:, test:.
  • I documented my code
  • I ran pre-commit hooks and fixed any issue

@@ -979,4 +988,4 @@ def run_batch(
def _prepare_model_kwargs(self):
# these are the parameters from PromptNode level
# that are passed to the prompt model invocation layer
return {"stop_words": self.stop_words}
return {"stop_words": self.stop_words, "top_k": self.top_k}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will be pushed down to PromptModel.invoke and InvocationLayer.invoke

Comment on lines 466 to 467
if "top_k" in kwargs:
kwargs["n"] = kwargs.pop("top_k")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handling in OpenAIInvocationLayer

Comment on lines +351 to +353
if top_k:
model_input_kwargs["num_return_sequences"] = top_k
model_input_kwargs["num_beams"] = top_k
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handling in HFInvocationLayer

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In AnswerGenerator we warn if num_beams < num_return_sequences. If the user wants to control them separately, she can do so via model_kwargs, but I can't even imagine why one would do so.

@tstadel tstadel marked this pull request as ready for review February 15, 2023 09:22
@tstadel tstadel requested a review from a team as a code owner February 15, 2023 09:22
@tstadel tstadel requested review from silvanocerza and removed request for a team February 15, 2023 09:22
@vblagoje
Copy link
Member

Yeah this approach is correct @tstadel Handle top_k on a conceptual level, but each implementation layer handles the specifics of the top_k concept appropriately.

@tstadel
Copy link
Member Author

tstadel commented Feb 15, 2023

#4151 will merge first to master. I will resolve conflicts if they arise.

@vblagoje
Copy link
Member

Repeating comment here for visibility @tstadel @sjrl - could we call this parameter n? top_k is more related to search/retrieval, no?

@tstadel
Copy link
Member Author

tstadel commented Feb 16, 2023

Repeating comment here for visibility @tstadel @sjrl - could we call this parameter n? top_k is more related to search/retrieval, no?

I would go with top_k as it's also being used in generators and readers, too. I think for consistency in Haystack it should be top_k.

@vblagoje
Copy link
Member

Ok, if @sjrl doesn't object let's go with top_k

@sjrl
Copy link
Contributor

sjrl commented Feb 16, 2023

Sounds good to me!

@tstadel tstadel requested a review from sjrl February 20, 2023 09:44

pipe = Pipeline()
pipe.add_node(component=node, name="prompt_node", inputs=["Query"])
result = pipe.run(query="not relevant", documents=[Document("Berlin is the capital of Germany")])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the test! I wanted to ask if it would make more sense to be setting query=None instead of "not relevant" for general use case when using a node that ignores the query? Doesn't matter for tests though, just wanted to ask.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sjrl I think there is still something that required the query to be set. Ideally, we shouldn't have to add query add all here. As long query is required, it doesn't make a difference, I'd say.

Copy link
Contributor

@sjrl sjrl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Thanks for adding this feature!

@sjrl
Copy link
Contributor

sjrl commented Feb 20, 2023

@tstadel Please update your branch with main and rerun CI, then it should be good to go!

@tstadel tstadel merged commit 14578aa into main Feb 20, 2023
@tstadel tstadel deleted the feat/promptnode_topk branch February 20, 2023 13:51
@tstadel tstadel added this to the 1.14.0 milestone Feb 21, 2023
vblagoje pushed a commit that referenced this pull request Feb 22, 2023
* add top_k to PromptNode

* fix OpenAI

* fix openai test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add top_k to PromptNode
3 participants