diff --git a/docs/_src/api/api/retriever.md b/docs/_src/api/api/retriever.md index dbe5666d54..535057c868 100644 --- a/docs/_src/api/api/retriever.md +++ b/docs/_src/api/api/retriever.md @@ -238,7 +238,7 @@ Karpukhin, Vladimir, et al. (2020): "Dense Passage Retrieval for Open-Domain Que #### \_\_init\_\_ ```python - | __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", progress_bar: bool = True) + | __init__(document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", single_model_path: Optional[Union[Path, str]] = None, model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, top_k: int = 10, use_gpu: bool = True, batch_size: int = 16, embed_title: bool = True, use_fast_tokenizers: bool = True, infer_tokenizer_classes: bool = False, similarity_function: str = "dot_product", progress_bar: bool = True) ``` Init the Retriever incl. the two encoder models from a local or remote model checkpoint. @@ -266,6 +266,9 @@ Currently available remote names: ``"facebook/dpr-question_encoder-single-nq-bas - `passage_embedding_model`: Local path or remote name of passage encoder checkpoint. The format equals the one used by hugging-face transformers' modelhub models Currently available remote names: ``"facebook/dpr-ctx_encoder-single-nq-base"`` +- `single_model_path`: Local path or remote name of a query and passage embedder in one single model. Those +models are typically trained within FARM. +Currently available remote names: TODO add FARM DPR model to HF modelhub - `model_version`: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. - `max_seq_len_query`: Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down." - `max_seq_len_passage`: Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down." diff --git a/haystack/reader/farm.py b/haystack/reader/farm.py index 32389539e1..b685427758 100644 --- a/haystack/reader/farm.py +++ b/haystack/reader/farm.py @@ -559,7 +559,7 @@ def _extract_answers_of_predictions(self, predictions: List[QAPred], top_k: Opti "answer": ans.answer, "score": ans.score, # just a pseudo prob for now - "probability": self._get_pseudo_prob(ans.score), + "probability": ans.confidence, "context": ans.context_window, "offset_start": ans.offset_answer_start - ans.offset_context_window_start, "offset_end": ans.offset_answer_end - ans.offset_context_window_start, diff --git a/haystack/retriever/dense.py b/haystack/retriever/dense.py index 5f86e05b56..101a86ebfe 100644 --- a/haystack/retriever/dense.py +++ b/haystack/retriever/dense.py @@ -37,6 +37,7 @@ def __init__(self, document_store: BaseDocumentStore, query_embedding_model: Union[Path, str] = "facebook/dpr-question_encoder-single-nq-base", passage_embedding_model: Union[Path, str] = "facebook/dpr-ctx_encoder-single-nq-base", + single_model_path: Optional[Union[Path, str]] = None, model_version: Optional[str] = None, max_seq_len_query: int = 64, max_seq_len_passage: int = 256, @@ -73,6 +74,9 @@ def __init__(self, :param passage_embedding_model: Local path or remote name of passage encoder checkpoint. The format equals the one used by hugging-face transformers' modelhub models Currently available remote names: ``"facebook/dpr-ctx_encoder-single-nq-base"`` + :param single_model_path: Local path or remote name of a query and passage embedder in one single model. Those + models are typically trained within FARM. + Currently available remote names: TODO add FARM DPR model to HF modelhub :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. :param max_seq_len_query: Longest length of each query sequence. Maximum number of tokens for the query text. Longer ones will be cut down." :param max_seq_len_passage: Longest length of each passage/context sequence. Maximum number of tokens for the passage text. Longer ones will be cut down." @@ -96,8 +100,6 @@ def __init__(self, self.document_store = document_store self.batch_size = batch_size - self.max_seq_len_passage = max_seq_len_passage - self.max_seq_len_query = max_seq_len_query self.progress_bar = progress_bar self.top_k = top_k @@ -115,7 +117,6 @@ def __init__(self, else: self.device = torch.device("cpu") - self.embed_title = embed_title self.infer_tokenizer_classes = infer_tokenizer_classes tokenizers_default_classes = { "query": "DPRQuestionEncoderTokenizer", @@ -126,43 +127,52 @@ def __init__(self, tokenizers_default_classes["passage"] = None # type: ignore # Init & Load Encoders - self.query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=query_embedding_model, - revision=model_version, - do_lower_case=True, - use_fast=use_fast_tokenizers, - tokenizer_class=tokenizers_default_classes["query"]) - self.query_encoder = LanguageModel.load(pretrained_model_name_or_path=query_embedding_model, - revision=model_version, - language_model_class="DPRQuestionEncoder") - self.passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_embedding_model, - revision=model_version, - do_lower_case=True, - use_fast=use_fast_tokenizers, - tokenizer_class=tokenizers_default_classes["passage"]) - self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path=passage_embedding_model, + if single_model_path is None: + self.query_tokenizer = Tokenizer.load(pretrained_model_name_or_path=query_embedding_model, revision=model_version, - language_model_class="DPRContextEncoder") - - self.processor = TextSimilarityProcessor(tokenizer=self.query_tokenizer, - passage_tokenizer=self.passage_tokenizer, - max_seq_len_passage=self.max_seq_len_passage, - max_seq_len_query=self.max_seq_len_query, - label_list=["hard_negative", "positive"], - metric="text_similarity_metric", - embed_title=self.embed_title, - num_hard_negatives=0, - num_positives=1) - - prediction_head = TextSimilarityHead(similarity_function=similarity_function) - self.model = BiAdaptiveModel( - language_model1=self.query_encoder, - language_model2=self.passage_encoder, - prediction_heads=[prediction_head], - embeds_dropout_prob=0.1, - lm1_output_types=["per_sequence"], - lm2_output_types=["per_sequence"], - device=self.device, - ) + do_lower_case=True, + use_fast=use_fast_tokenizers, + tokenizer_class=tokenizers_default_classes["query"]) + self.query_encoder = LanguageModel.load(pretrained_model_name_or_path=query_embedding_model, + revision=model_version, + language_model_class="DPRQuestionEncoder") + self.passage_tokenizer = Tokenizer.load(pretrained_model_name_or_path=passage_embedding_model, + revision=model_version, + do_lower_case=True, + use_fast=use_fast_tokenizers, + tokenizer_class=tokenizers_default_classes["passage"]) + self.passage_encoder = LanguageModel.load(pretrained_model_name_or_path=passage_embedding_model, + revision=model_version, + language_model_class="DPRContextEncoder") + + self.processor = TextSimilarityProcessor(query_tokenizer=self.query_tokenizer, + passage_tokenizer=self.passage_tokenizer, + max_seq_len_passage=max_seq_len_passage, + max_seq_len_query=max_seq_len_query, + label_list=["hard_negative", "positive"], + metric="text_similarity_metric", + embed_title=embed_title, + num_hard_negatives=0, + num_positives=1) + prediction_head = TextSimilarityHead(similarity_function=similarity_function) + self.model = BiAdaptiveModel( + language_model1=self.query_encoder, + language_model2=self.passage_encoder, + prediction_heads=[prediction_head], + embeds_dropout_prob=0.1, + lm1_output_types=["per_sequence"], + lm2_output_types=["per_sequence"], + device=self.device, + ) + else: + self.processor = TextSimilarityProcessor.load_from_dir(single_model_path) + self.processor.max_seq_len_passage = max_seq_len_passage + self.processor.max_seq_len_query = max_seq_len_query + self.processor.embed_title = embed_title + self.processor.num_hard_negatives = 0 + self.processor.num_positives = 1 # during indexing of documents only one embedding is created + self.model = BiAdaptiveModel.load(single_model_path, device=self.device) + self.model.connect_heads_with_processor(self.processor.tasks, require_labels=False) def retrieve(self, query: str, filters: dict = None, top_k: Optional[int] = None, index: str = None) -> List[Document]: @@ -318,21 +328,14 @@ def train(self, :param passage_encoder_save_dir: directory inside save_dir where passage_encoder model files are saved """ - self.embed_title = embed_title - self.processor = TextSimilarityProcessor(tokenizer=self.query_tokenizer, - passage_tokenizer=self.passage_tokenizer, - max_seq_len_passage=self.max_seq_len_passage, - max_seq_len_query=self.max_seq_len_query, - label_list=["hard_negative", "positive"], - metric="text_similarity_metric", - data_dir=data_dir, - train_filename=train_filename, - dev_filename=dev_filename, - test_filename=test_filename, - dev_split=dev_split, - embed_title=self.embed_title, - num_hard_negatives=num_hard_negatives, - num_positives=num_positives) + self.processor.embed_title = embed_title + self.processor.data_dir = data_dir + self.processor.train_filename = train_filename + self.processor.dev_filename = dev_filename + self.processor.test_filename = test_filename + self.processor.dev_split = dev_split + self.processor.num_hard_negatives = num_hard_negatives + self.processor.num_positives = num_positives self.model.connect_heads_with_processor(self.processor.tasks, require_labels=True) diff --git a/requirements.txt b/requirements.txt index fa84c0e028..8417e48ec4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -farm==0.6.2 +farm==0.7.1 --find-links=https://download.pytorch.org/whl/torch_stable.html fastapi uvicorn diff --git a/test/test_retriever.py b/test/test_retriever.py index 79a8b77667..b6c8499be8 100644 --- a/test/test_retriever.py +++ b/test/test_retriever.py @@ -194,10 +194,10 @@ def sum_params(model): # assert (p1.data.ne(p2.data).sum() == 0) # attributes - assert loaded_retriever.embed_title == True + assert loaded_retriever.processor.embed_title == True assert loaded_retriever.batch_size == 16 - assert loaded_retriever.max_seq_len_passage == 256 - assert loaded_retriever.max_seq_len_query == 64 + assert loaded_retriever.processor.max_seq_len_passage == 256 + assert loaded_retriever.processor.max_seq_len_query == 64 # Tokenizer assert isinstance(loaded_retriever.passage_tokenizer, DPRContextEncoderTokenizerFast)