Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reranker implementation using Cross Encoders (HuggingFace / SentenceT… #150

Merged
merged 1 commit into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 116 additions & 28 deletions docs/user_guide/rerankers_06.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
"\n",
"In this notebook, we will show how to use RedisVL to rerank search results\n",
"(documents or chunks or records) based on the input query. Today RedisVL\n",
"supports reranking through the [Cohere /rerank API](https://docs.cohere.com/docs/rerank-2).\n",
"supports reranking through: \n",
"\n",
"- A re-ranker that uses pre-trained [Cross-Encoders](https://sbert.net/examples/applications/cross-encoder/README.html) which can use models from [Hugging Face cross encoder models](https://huggingface.co/cross-encoder) or Hugging Face models that implement a cross encoder function ([example: BAAI/bge-reranker-base](https://huggingface.co/BAAI/bge-reranker-base)).\n",
"- The [Cohere /rerank API](https://docs.cohere.com/docs/rerank-2).\n",
"\n",
"Before running this notebook, be sure to:\n",
"1. Have installed ``redisvl`` and have that environment active for this notebook.\n",
Expand All @@ -26,8 +29,10 @@
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"execution_count": 27,
"metadata": {
"metadata": {}
},
"outputs": [],
"source": [
"# import necessary modules\n",
Expand All @@ -48,8 +53,10 @@
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"execution_count": 28,
"metadata": {
"metadata": {}
},
"outputs": [],
"source": [
"query = \"What is the capital of the United States?\"\n",
Expand All @@ -75,24 +82,93 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Init the Reranker\n",
"### Using the Cross-Encoder Reranker\n",
"\n",
"Initialize the reranker. Install the cohere library and provide the right Cohere API Key."
"To use the cross-encoder reranker we initialize an instance of `HFCrossEncoderReranker` passing a suitable model (if no model is provided, the `cross-encoder/ms-marco-MiniLM-L-6-v2` model is used): "
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 29,
"metadata": {
"metadata": {}
},
"outputs": [],
"source": [
"from redisvl.utils.rerank import HFCrossEncoderReranker\n",
"\n",
"cross_encoder_reranker = HFCrossEncoderReranker(\"BAAI/bge-reranker-base\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Rerank documents with HFCrossEncoderReranker\n",
"\n",
"With the obtained reranker instance we can rerank and truncate the list of\n",
"documents based on relevance to the initial query."
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"metadata": {}
},
"outputs": [],
"source": [
"#!pip install cohere"
"results, scores = cross_encoder_reranker.rank(query=query, docs=docs)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 31,
"metadata": {
"metadata": {}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.07461125403642654 -- {'content': 'Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. The President of the USA and many major national government offices are in the territory. This makes it the political center of the United States of America.'}\n",
"0.05220315232872963 -- {'content': 'Charlotte Amalie is the capital and largest city of the United States Virgin Islands. It has about 20,000 people. The city is on the island of Saint Thomas.'}\n",
"0.3802368640899658 -- {'content': 'Carson City is the capital city of the American state of Nevada. At the 2010 United States Census, Carson City had a population of 55,274.'}\n"
]
}
],
"source": [
"for result, score in zip(results, scores):\n",
" print(score, \" -- \", result)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Using the Cohere Reranker\n",
"\n",
"To initialize the Cohere reranker you'll need to install the cohere library and provide the right Cohere API Key."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"metadata": {}
},
"outputs": [],
"source": [
"#!pip install cohere"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"metadata": {}
},
"outputs": [],
"source": [
"import getpass\n",
Expand All @@ -103,38 +179,44 @@
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"execution_count": 34,
"metadata": {
"metadata": {}
},
"outputs": [],
"source": [
"from redisvl.utils.rerank import CohereReranker\n",
"\n",
"reranker = CohereReranker(limit=3, api_config={\"api_key\": api_key})"
"cohere_reranker = CohereReranker(limit=3, api_config={\"api_key\": api_key})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Rerank documents\n",
"### Rerank documents with CohereReranker\n",
"\n",
"Below we will use the `CohereReranker` to rerank and also truncate the list of\n",
"Below we will use the `CohereReranker` to rerank and truncate the list of\n",
"documents above based on relevance to the initial query."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"execution_count": 35,
"metadata": {
"metadata": {}
},
"outputs": [],
"source": [
"results, scores = reranker.rank(query=query, docs=docs)"
"results, scores = cohere_reranker.rank(query=query, docs=docs)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"execution_count": 36,
"metadata": {
"metadata": {}
},
"outputs": [
{
"name": "stdout",
Expand Down Expand Up @@ -162,8 +244,10 @@
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"execution_count": 37,
"metadata": {
"metadata": {}
},
"outputs": [],
"source": [
"docs = [\n",
Expand Down Expand Up @@ -192,17 +276,21 @@
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"execution_count": 38,
"metadata": {
"metadata": {}
},
"outputs": [],
"source": [
"results, scores = reranker.rank(query=query, docs=docs, rank_by=[\"passage\", \"source\"])"
"results, scores = cohere_reranker.rank(query=query, docs=docs, rank_by=[\"passage\", \"source\"])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"execution_count": 39,
"metadata": {
"metadata": {}
},
"outputs": [
{
"name": "stdout",
Expand Down Expand Up @@ -236,7 +324,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.11.9"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
6 changes: 2 additions & 4 deletions redisvl/utils/rerank/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from redisvl.utils.rerank.base import BaseReranker
from redisvl.utils.rerank.cohere import CohereReranker
from redisvl.utils.rerank.hf_cross_encoder import HFCrossEncoderReranker

__all__ = [
"BaseReranker",
"CohereReranker",
]
__all__ = ["BaseReranker", "CohereReranker", "HFCrossEncoderReranker"]
129 changes: 129 additions & 0 deletions redisvl/utils/rerank/hf_cross_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from typing import Any, Dict, List, Optional, Tuple, Union

from sentence_transformers import CrossEncoder

from redisvl.utils.rerank.base import BaseReranker


class HFCrossEncoderReranker(BaseReranker):
"""
The HFCrossEncoderReranker class uses a cross-encoder models from Hugging Face
to rerank documents based on an input query.

This reranker loads a cross-encoder model using the `CrossEncoder` class
from the `sentence_transformers` library. It requires the
`sentence_transformers` library to be installed.

.. code-block:: python

from redisvl.utils.rerank import HFCrossEncoderReranker

# set up the HFCrossEncoderReranker with a specific model
reranker = HFCrossEncoderReranker(model_name="cross-encoder/ms-marco-MiniLM-L-6-v2", limit=3)
# rerank raw search results based on user input/query
results = reranker.rank(
query="your input query text here",
docs=[
{"content": "document 1"},
{"content": "document 2"},
{"content": "document 3"}
]
)
"""

def __init__(
self,
model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
limit: int = 3,
return_score: bool = True,
) -> None:
"""
Initialize the HFCrossEncoderReranker with a specified model and ranking criteria.

Parameters:
model_name (str): The name or path of the cross-encoder model to use for reranking.
Defaults to 'cross-encoder/ms-marco-MiniLM-L-6-v2'.
limit (int): The maximum number of results to return after reranking. Must be a positive integer.
return_score (bool): Whether to return scores alongside the reranked results.
"""
super().__init__(
model=model_name, rank_by=None, limit=limit, return_score=return_score
)
self.model: CrossEncoder = CrossEncoder(model_name)

def rank(
self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs
) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]:
"""
Rerank documents based on the provided query using the loaded cross-encoder model.

This method processes the user's query and the provided documents to rerank them
in a manner that is potentially more relevant to the query's context.

Parameters:
query (str): The user's search query.
docs (Union[List[Dict[str, Any]], List[str]]): The list of documents to be ranked,
either as dictionaries or strings.

Returns:
Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]:
The reranked list of documents and optionally associated scores.
"""
limit = kwargs.get("limit", self.limit)
return_score = kwargs.get("return_score", self.return_score)

if not query:
raise ValueError("query cannot be empty")

if not isinstance(query, str):
raise TypeError("query must be a string")

if not isinstance(docs, list):
raise TypeError("docs must be a list")

if not docs:
return [] if not return_score else ([], [])

if all(isinstance(doc, dict) for doc in docs):
texts = [
str(doc["content"])
for doc in docs
if isinstance(doc, dict) and "content" in doc
]
doc_subset = [
doc for doc in docs if isinstance(doc, dict) and "content" in doc
]
else:
texts = [str(doc) for doc in docs]
doc_subset = [{"content": doc} for doc in docs]

scores = self.model.predict([(query, text) for text in texts])
scores = [float(score) for score in scores]
docs_with_scores = list(zip(doc_subset, scores))
docs_with_scores.sort(key=lambda x: x[1], reverse=True)
reranked_docs = [doc for doc, _ in docs_with_scores[:limit]]
scores = scores[:limit]

if return_score:
return reranked_docs, scores
return reranked_docs

async def arank(
self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs
) -> Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]:
"""
Asynchronously rerank documents based on the provided query using the loaded cross-encoder model.

This method processes the user's query and the provided documents to rerank them
in a manner that is potentially more relevant to the query's context.

Parameters:
query (str): The user's search query.
docs (Union[List[Dict[str, Any]], List[str]]): The list of documents to be ranked,
either as dictionaries or strings.

Returns:
Union[Tuple[List[Dict[str, Any]], List[float]], List[Dict[str, Any]]]:
The reranked list of documents and optionally associated scores.
"""
return self.rank(query, docs, **kwargs)
Loading
Loading