Skip to content

Commit

Permalink
Reranker implementation using Cross Encoders (HuggingFace / SentenceT… (
Browse files Browse the repository at this point in the history
  • Loading branch information
bsbodden authored May 15, 2024
1 parent c7e90ea commit 5e845f2
Show file tree
Hide file tree
Showing 5 changed files with 353 additions and 43 deletions.
144 changes: 116 additions & 28 deletions docs/user_guide/rerankers_06.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
"\n",
"In this notebook, we will show how to use RedisVL to rerank search results\n",
"(documents or chunks or records) based on the input query. Today RedisVL\n",
"supports reranking through the [Cohere /rerank API](https://docs.cohere.com/docs/rerank-2).\n",
"supports reranking through: \n",
"\n",
"- A re-ranker that uses pre-trained [Cross-Encoders](https://sbert.net/examples/applications/cross-encoder/README.html) which can use models from [Hugging Face cross encoder models](https://huggingface.co/cross-encoder) or Hugging Face models that implement a cross encoder function ([example: BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)).\n",
"- The [Cohere /rerank API](https://docs.cohere.com/docs/rerank-2).\n",
"\n",
"Before running this notebook, be sure to:\n",
"1. Have installed ``redisvl`` and have that environment active for this notebook.\n",
Expand All @@ -26,8 +29,10 @@
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"execution_count": 27,
"metadata": {
"metadata": {}
},
"outputs": [],
"source": [
"# import necessary modules\n",
Expand All @@ -48,8 +53,10 @@
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"execution_count": 28,
"metadata": {
"metadata": {}
},
"outputs": [],
"source": [
"query = \"What is the capital of the United States?\"\n",
Expand All @@ -75,24 +82,93 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Init the Reranker\n",
"### Using the Cross-Encoder Reranker\n",
"\n",
"Initialize the reranker. Install the cohere library and provide the right Cohere API Key."
"To use the cross-encoder reranker we initialize an instance of `HFCrossEncoderReranker` passing a suitable model (if no model is provided, the `cross-encoder/ms-marco-MiniLM-L-6-v2` model is used): "
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 29,
"metadata": {
"metadata": {}
},
"outputs": [],
"source": [
"from redisvl.utils.rerank import HFCrossEncoderReranker\n",
"\n",
"cross_encoder_reranker = HFCrossEncoderReranker(\"BAAI/bge-reranker-base\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Rerank documents with HFCrossEncoderReranker\n",
"\n",
"With the obtained reranker instance we can rerank and truncate the list of\n",
"documents based on relevance to the initial query."
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"metadata": {}
},
"outputs": [],
"source": [
"#!pip install cohere"
"results, scores = cross_encoder_reranker.rank(query=query, docs=docs)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 31,
"metadata": {
"metadata": {}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.07461125403642654 -- {'content': 'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. The President of the USA and many major national government offices are in the territory. This makes it the political center of the United States of America.'}\n",
"0.05220315232872963 -- {'content': 'Charlotte Amalie is the capital and largest city of the United States Virgin Islands. It has about 20,000 people. The city is on the island of Saint Thomas.'}\n",
"0.3802368640899658 -- {'content': 'Carson City is the capital city of the American state of Nevada. At the 2010 United States Census, Carson City had a population of 55,274.'}\n"
]
}
],
"source": [
"for result, score in zip(results, scores):\n",
" print(score, \" -- \", result)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Using the Cohere Reranker\n",
"\n",
"To initialize the Cohere reranker you'll need to install the cohere library and provide the right Cohere API Key."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"metadata": {}
},
"outputs": [],
"source": [
"#!pip install cohere"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"metadata": {}
},
"outputs": [],
"source": [
"import getpass\n",
Expand All @@ -103,38 +179,44 @@
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"execution_count": 34,
"metadata": {
"metadata": {}
},
"outputs": [],
"source": [
"from redisvl.utils.rerank import CohereReranker\n",
"\n",
"reranker = CohereReranker(limit=3, api_config={\"api_key\": api_key})"
"cohere_reranker = CohereReranker(limit=3, api_config={\"api_key\": api_key})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Rerank documents\n",
"### Rerank documents with CohereReranker\n",
"\n",
"Below we will use the `CohereReranker` to rerank and also truncate the list of\n",
"Below we will use the `CohereReranker` to rerank and truncate the list of\n",
"documents above based on relevance to the initial query."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"execution_count": 35,
"metadata": {
"metadata": {}
},
"outputs": [],
"source": [
"results, scores = reranker.rank(query=query, docs=docs)"
"results, scores = cohere_reranker.rank(query=query, docs=docs)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"execution_count": 36,
"metadata": {
"metadata": {}
},
"outputs": [
{
"name": "stdout",
Expand Down Expand Up @@ -162,8 +244,10 @@
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"execution_count": 37,
"metadata": {
"metadata": {}
},
"outputs": [],
"source": [
"docs = [\n",
Expand Down Expand Up @@ -192,17 +276,21 @@
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"execution_count": 38,
"metadata": {
"metadata": {}
},
"outputs": [],
"source": [
"results, scores = reranker.rank(query=query, docs=docs, rank_by=[\"passage\", \"source\"])"
"results, scores = cohere_reranker.rank(query=query, docs=docs, rank_by=[\"passage\", \"source\"])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"execution_count": 39,
"metadata": {
"metadata": {}
},
"outputs": [
{
"name": "stdout",
Expand Down Expand Up @@ -236,7 +324,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.11.9"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
6 changes: 2 additions & 4 deletions redisvl/utils/rerank/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from redisvl.utils.rerank.base import BaseReranker
from redisvl.utils.rerank.cohere import CohereReranker
from redisvl.utils.rerank.hf_cross_encoder import HFCrossEncoderReranker

__all__ = [
"BaseReranker",
"CohereReranker",
]
__all__ = ["BaseReranker", "CohereReranker", "HFCrossEncoderReranker"]
129 changes: 129 additions & 0 deletions redisvl/utils/rerank/hf_cross_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from typing import Any, Dict, List, Optional, Tuple, Union

from sentence_transformers import CrossEncoder

from redisvl.utils.rerank.base import BaseReranker


class HFCrossEncoderReranker(BaseReranker):
"""
The HFCrossEncoderReranker class uses a cross-encoder models from Hugging Face
to rerank documents based on an input query.
This reranker loads a cross-encoder model using the `CrossEncoder` class
from the `sentence_transformers` library. It requires the
`sentence_transformers` library to be installed.
.. code-block:: python
from redisvl.utils.rerank import HFCrossEncoderReranker
# set up the HFCrossEncoderReranker with a specific model
reranker = HFCrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2", limit=3)
# rerank raw search results based on user input/query
results = reranker.rank(
query="your input query text here",
docs=[
{"content": "document 1"},
{"content": "document 2"},
{"content": "document 3"}
]
)
"""

def __init__(
self,
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
limit: int = 3,
return_score: bool = True,
) -> None:
"""
Initialize the HFCrossEncoderReranker with a specified model and ranking criteria.
Parameters:
model_name (str): The name or path of the cross-encoder model to use for reranking.
Defaults to 'cross-encoder/ms-marco-MiniLM-L-6-v2'.
limit (int): The maximum number of results to return after reranking. Must be a positive integer.
return_score (bool): Whether to return scores alongside the reranked results.
"""
super().__init__(
model=model_name, rank_by=None, limit=limit, return_score=return_score
)
self.model: CrossEncoder = CrossEncoder(model_name)

def rank(
self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs
) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]:
"""
Rerank documents based on the provided query using the loaded cross-encoder model.
This method processes the user's query and the provided documents to rerank them
in a manner that is potentially more relevant to the query's context.
Parameters:
query (str): The user's search query.
docs (Union[List[Dict[str, Any]], List[str]]): The list of documents to be ranked,
either as dictionaries or strings.
Returns:
Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]:
The reranked list of documents and optionally associated scores.
"""
limit = kwargs.get("limit", self.limit)
return_score = kwargs.get("return_score", self.return_score)

if not query:
raise ValueError("query cannot be empty")

if not isinstance(query, str):
raise TypeError("query must be a string")

if not isinstance(docs, list):
raise TypeError("docs must be a list")

if not docs:
return [] if not return_score else ([], [])

if all(isinstance(doc, dict) for doc in docs):
texts = [
str(doc["content"])
for doc in docs
if isinstance(doc, dict) and "content" in doc
]
doc_subset = [
doc for doc in docs if isinstance(doc, dict) and "content" in doc
]
else:
texts = [str(doc) for doc in docs]
doc_subset = [{"content": doc} for doc in docs]

scores = self.model.predict([(query, text) for text in texts])
scores = [float(score) for score in scores]
docs_with_scores = list(zip(doc_subset, scores))
docs_with_scores.sort(key=lambda x: x[1], reverse=True)
reranked_docs = [doc for doc, _ in docs_with_scores[:limit]]
scores = scores[:limit]

if return_score:
return reranked_docs, scores
return reranked_docs

async def arank(
self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs
) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]:
"""
Asynchronously rerank documents based on the provided query using the loaded cross-encoder model.
This method processes the user's query and the provided documents to rerank them
in a manner that is potentially more relevant to the query's context.
Parameters:
query (str): The user's search query.
docs (Union[List[Dict[str, Any]], List[str]]): The list of documents to be ranked,
either as dictionaries or strings.
Returns:
Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]:
The reranked list of documents and optionally associated scores.
"""
return self.rank(query, docs, **kwargs)
Loading

0 comments on commit 5e845f2

Please sign in to comment.