-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added max_seq_length and batch_size params to embeddingretriever #1817
Conversation
@@ -55,7 +55,7 @@ def __init__( | |||
retriever.embedding_model, revision=retriever.model_version, task_type="embeddings", | |||
extraction_strategy=retriever.pooling_strategy, | |||
extraction_layer=retriever.emb_extraction_layer, gpu=retriever.use_gpu, | |||
batch_size=4, max_seq_len=512, num_processes=0,use_auth_token=retriever.use_auth_token | |||
batch_size=retriever.batch_size, max_seq_len=retriever.max_seq_len,, num_processes=0,use_auth_token=retriever.use_auth_token |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a superfluous , here.
Hi @AhmedIdr thanks for raising this pull request. I had a first brief look at your code and started our tests. That way, I identified a syntax error with a superfluous comma. Could you please fix that so that we can run the tests? Thank you! I will review the code more thoroughly once the tests are running. 👍 |
Hey @julian-risch, |
docs/_src/api/api/retriever.md
Outdated
@@ -611,7 +611,7 @@ class EmbeddingRetriever(BaseRetriever) | |||
#### \_\_init\_\_ | |||
|
|||
```python | |||
| __init__(document_store: BaseDocumentStore, embedding_model: str, model_version: Optional[str] = None, use_gpu: bool = True, model_format: str = "farm", pooling_strategy: str = "reduce_mean", emb_extraction_layer: int = -1, top_k: int = 10, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None, use_auth_token: Optional[Union[str,bool]] = None) | |||
| __init__(document_store: BaseDocumentStore, embedding_model: str, model_version: Optional[str] = None, use_gpu: bool = True, batch_size: int = 16, max_seq_len: int = 128, model_format: str = "farm", pooling_strategy: str = "reduce_mean", emb_extraction_layer: int = -1, top_k: int = 10, progress_bar: bool = True, devices: Optional[List[Union[int, str, torch.device]]] = None, use_auth_token: Optional[Union[str,bool]] = None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see that the previous default max_seq_len was 512. Is there any particular reason why you would prefer 128 as the default? If not I'd say we keep 512. (The change in the default value is probably also why some test cases fail right now but I could fix that if we decide to use 128 instead of 512.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No there was no reason, we can set it back to 512 and batch size back to 32.
Changed default batch_size and max_seq_len in EmbeddingRetriever
Hi @AhmedIdr the code looks good to me so far. Thanks for changing the default parameters. Right now, I tried your changes in this colab notebook: https://colab.research.google.com/github/AhmedIdr/haystack/blob/add-params/tutorials/Tutorial6_Better_Retrieval_via_DPR.ipynb but I couldn't see the progress bar when running |
Hey @julian-risch, no, I didn't, I only tried it using a script. But it is basically implemented similar to the other methods and like it was suggested in #1110. I'll try to run it in a notebook and see if I figure out what the problem is. |
Change import tqdm.auto to tqdm
Changing |
My mistake, I didn't change the |
Ahh I see, so using |
Yes, please change it back to |
Changing tqdm back to tqdm.auto
Okay done, hope nothing went wrong when fetching the new commits from the main branch 😅 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me and ready to merge. 👍
Thank you so much for your contribution to haystack! Having your pull request merged is a great achievement.
Thank you and sorry it took a bit of time for everything to work properly 😅 |
Added the option to choose max_seq_length and batch_size when using EmbeddingRetriever.
Related to #1793
For RetriBERT I am not totally sure if I added the params in the right places.
Additionally, added a progress_bar to Faiss writing_documents method, since it is hard to keep track with the progress when working with large datasets.
Related to #1110
I don't usually make any pull requests, so I hope didn't mess anything up.
Status (please check what you already did):