Skip to content

Commit

Permalink
Colbert local mode support both as retriever and reranker. (#797)
Browse files Browse the repository at this point in the history
* return metadata changes

* add metadata changes

* add support for returning metadata and reranking

* colbert integration

* colbert local modifications

* kwargs filtered ids

* colbert return

* colbert retriever and reranker

* colbert retriever error fixes

* colbert config changes in __init__

* colbert notebook

* import errors for colbert

* improt dspy fixes and linting fixes

* PR fixes for colbert

* making the linting gods happy

* remove unnecessary outputs

* colbertv2 docs

* Colbert PR fixes

* linting fixes

* more linting fixes

* fixing previous cache breaks with separate funcs

---------

Co-authored-by: arnavsinghvi11 <[email protected]>
  • Loading branch information
Athe-kunal and arnavsinghvi11 committed Jun 15, 2024
1 parent 5e26fad commit 37b3759
Show file tree
Hide file tree
Showing 8 changed files with 803 additions and 11 deletions.
78 changes: 78 additions & 0 deletions docs/api/retrieval_model_clients/ColBERTv2.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,81 @@ retrieval_response = colbertv2_wiki17_abstracts('When was the first FIFA World C
for result in retrieval_response:
print("Text:", result['text'], "\n")
```

# dspy.ColBERTv2RetrieverLocal

This is taken from the official documentation of [Colbertv2](https://github.com/stanford-futuredata/ColBERT/tree/main) following the [paper](https://arxiv.org/abs/2112.01488).

You can install Colbertv2 by the following instructions from [here](https://github.com/stanford-futuredata/ColBERT?tab=readme-ov-file#installation)

### Constructor
The constructor initializes the ColBERTv2 as a local retriever object. You can initialize a server instance from your ColBERTv2 local instance using the code snippet from [here](https://github.com/stanford-futuredata/ColBERT/blob/main/server.py)

```python
class ColBERTv2RetrieverLocal:
def __init__(
self,
passages:List[str],
colbert_config=None,
load_only:bool=False):
```

**Parameters**
- `passages` (_List[str]_): List of passages to be indexed
- `colbert_config` (_ColBERTConfig_, _Optional_): colbert config for building and searching. Defaults to None.
- `load_only` (_Boolean_): whether to load the index or build and then load. Defaults to False.

The `colbert_config` object is required for ColBERTv2, and it can be imported from `from colbert.infra.config import ColBERTConfig`. You can find the descriptions of config attributes from [here](https://github.com/stanford-futuredata/ColBERT/blob/main/colbert/infra/config/settings.py)

### Methods

#### `forward(self, query:str, k:int, **kwargs) -> Union[list[str], list[dotdict]]`

It retrieves relevant passages from the index based on the query. If you already have a local index, then you can pass the `load_only` flag as `True` and change the `index` attribute of ColBERTConfig to the local path. Also, make sure to change the `checkpoint` attribute of ColBERTConfig to the embedding model that you used to build the index.

**Parameters:**
- `query` (_str_): Query string used for retrieval.
- `k` (_int_, _optional_): Number of passages to retrieve. Defaults to 7

It returns a `Prediction` object for each query

```python
Prediction(
pid=[33, 6, 47, 74, 48],
passages=['No pain, no gain.', 'The best things in life are free.', 'Out of sight, out of mind.', 'To be or not to be, that is the question.', 'Patience is a virtue.']
)
```
# dspy.ColBERTv2RerankerLocal

You can also use ColBERTv2 as a reranker in DSPy.

### Constructor

```python
class ColBERTv2RerankerLocal:

def __init__(
self,
colbert_config=None,
checkpoint:str='bert-base-uncased'):
```

**Parameters**
- `colbert_config` (_ColBERTConfig_, _Optional_): colbert config for building and searching. Defaults to None.
- `checkpoint` (_str_): Embedding model for embeddings the documents and query

### Methods
#### `forward(self,query:str,passages:List[str])`

Based on a query and list of passages, it reranks the passages and returns the scores along with the passages ordered in descending order based on the similarity scores.

**Parameters:**
- `query` (_str_): Query string used for reranking.
- `passages` (_List[str]_): List of passages to be reranked

It returns the similarity scores array and you can link it to the passages by

```python
for idx in np.argsort(scores_arr)[::-1]:
print(f"Passage = {passages[idx]} --> Score = {scores_arr[idx]}")
```
2 changes: 1 addition & 1 deletion dsp/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .clarifai import *
from .cloudflare import *
from .cohere import *
from .colbertv2 import ColBERTv2
from .colbertv2 import ColBERTv2, ColBERTv2RerankerLocal, ColBERTv2RetrieverLocal
from .databricks import *
from .dummy_lm import *
from .google import *
Expand Down
119 changes: 118 additions & 1 deletion dsp/modules/colbertv2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import Any, Optional, Union
from typing import Any, List, Optional, Union

import requests

Expand Down Expand Up @@ -74,3 +74,120 @@ def colbertv2_post_request_v2_wrapped(*args, **kwargs):


colbertv2_post_request = colbertv2_post_request_v2_wrapped

class ColBERTv2RetrieverLocal:
def __init__(self,passages:List[str],colbert_config=None,load_only:bool=False):
"""Colbertv2 retriever module
Args:
passages (List[str]): list of passages
colbert_config (ColBERTConfig, optional): colbert config for building and searching. Defaults to None.
load_only (bool, optional): whether to load the index or build and then load. Defaults to False.
"""
assert colbert_config is not None, "Please pass a valid colbert_config, which you can import from colbert.infra.config import ColBERTConfig and modify it"
self.colbert_config = colbert_config

assert self.colbert_config.checkpoint is not None, "Please pass a valid checkpoint like colbert-ir/colbertv2.0, which you can modify in the ColBERTConfig with attribute name checkpoint"
self.passages = passages

assert self.colbert_config.index_name is not None, "Please pass a valid index_name, which you can modify in the ColBERTConfig with attribute name index_name"
self.passages = passages

if not load_only:
print(f"Building the index for experiment {self.colbert_config.experiment} with index name {self.colbert_config.index_name}")
self.build_index()

print(f"Loading the index for experiment {self.colbert_config.experiment} with index name {self.colbert_config.index_name}")
self.searcher = self.get_index()

def build_index(self):

try:
import colbert
except ImportError:
print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].")

from colbert import Indexer
from colbert.infra import Run, RunConfig
with Run().context(RunConfig(nranks=self.colbert_config.nranks, experiment=self.colbert_config.experiment)):
indexer = Indexer(checkpoint=self.colbert_config.checkpoint, config=self.colbert_config)
indexer.index(name=self.colbert_config.index_name, collection=self.passages, overwrite=True)

def get_index(self):
try:
import colbert
except ImportError:
print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].")

from colbert import Searcher
from colbert.infra import Run, RunConfig

with Run().context(RunConfig(experiment=self.colbert_config.experiment)):
searcher = Searcher(index=self.colbert_config.index_name, collection=self.passages)
return searcher

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.forward(*args, **kwargs)

def forward(self,query:str,k:int=7,**kwargs):
import torch

if kwargs.get("filtered_pids"):
filtered_pids = kwargs.get("filtered_pids")
assert type(filtered_pids) == List[int], "The filtered pids should be a list of integers"
device = "cuda" if torch.cuda.is_available() else "cpu"
results = self.searcher.search(
query,
#Number of passages to receive
k=k,
#Passing the filter function of relevant
filter_fn=lambda pids: torch.tensor(
[pid for pid in pids if pid in filtered_pids],dtype=torch.int32).to(device))
else:
searcher_results = self.searcher.search(query, k=k)
results = []
for pid,rank,score in zip(*searcher_results):
results.append(dotdict({'long_text':self.searcher.collection[pid],'score':score,'pid':pid}))
return results

class ColBERTv2RerankerLocal:

def __init__(self,colbert_config=None,checkpoint:str='bert-base-uncased'):
try:
import colbert
except ImportError:
print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].")
"""_summary_
Args:
colbert_config (ColBERTConfig, optional): Colbert config. Defaults to None.
checkpoint_name (str, optional): checkpoint for embeddings. Defaults to 'bert-base-uncased'.
"""
self.colbert_config = colbert_config
self.checkpoint = checkpoint
self.colbert_config.checkpoint = checkpoint

def __call__(self, *args: Any, **kwargs: Any) -> Any:
return self.forward(*args, **kwargs)

def forward(self,query:str,passages:List[str]=[]):
assert len(passages) > 0, "Passages should not be empty"

import numpy as np
from colbert.modeling.colbert import ColBERT
from colbert.modeling.tokenization.doc_tokenization import DocTokenizer
from colbert.modeling.tokenization.query_tokenization import QueryTokenizer

self.colbert_config.nway = len(passages)
query_tokenizer = QueryTokenizer(self.colbert_config,verbose=1)
doc_tokenizer = DocTokenizer(self.colbert_config)
query_ids,query_masks = query_tokenizer.tensorize([query])
doc_ids,doc_masks = doc_tokenizer.tensorize(passages)

col = ColBERT(self.checkpoint,self.colbert_config)
Q = col.query(query_ids,query_masks)
DOC_IDS,DOC_MASKS = col.doc(doc_ids,doc_masks,keep_dims='return_mask')
Q_duplicated = Q.repeat_interleave(len(passages), dim=0).contiguous()
tensor_scores = col.score(Q_duplicated,DOC_IDS,DOC_MASKS)
passage_score_arr = np.array([score.cpu().detach().numpy().tolist() for score in tensor_scores])
return passage_score_arr
79 changes: 77 additions & 2 deletions dsp/primitives/search.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
from collections.abc import Iterable

import numpy as np

import dsp

logger = logging.getLogger(__name__)

def retrieve(query: str, k: int, **kwargs) -> list[str]:
"""Retrieves passages from the RM for the query and returns the top k passages."""
Expand All @@ -15,12 +17,25 @@ def retrieve(query: str, k: int, **kwargs) -> list[str]:
# TODO: we should unify the type signatures of dspy.Retriever
passages = [passages]
passages = [psg.long_text for psg in passages]

if dsp.settings.reranker:
passages_cs_scores = dsp.settings.reranker(query, passages)
passages_cs_scores_sorted = np.argsort(passages_cs_scores)[::-1]
passages = [passages[idx] for idx in passages_cs_scores_sorted]


return passages
def retrievewithMetadata(query: str, k: int, **kwargs) -> list[str]:
"""Retrieves passages from the RM for the query and returns the top k passages."""

if not dsp.settings.rm:
raise AssertionError("No RM is loaded.")
passages = dsp.settings.rm(query, k=k, **kwargs)
if not isinstance(passages, Iterable):
# it's not an iterable yet; make it one.
# TODO: we should unify the type signatures of dspy.Retriever
passages = [passages]

return passages


Expand All @@ -38,9 +53,31 @@ def retrieveRerankEnsemble(queries: list[str], k: int,**kwargs) -> list[str]:
passages_cs_scores[idx],
]


passages = [(np.average(score), text) for text, score in passages.items()]
return [text for _, text in sorted(passages, reverse=True)[:k]]

def retrieveRerankEnsemblewithMetadata(queries: list[str], k: int, **kwargs) -> list[str]:
if not (dsp.settings.rm and dsp.settings.reranker):
raise AssertionError("Both RM and Reranker are needed to retrieve & re-rank.")
queries = [q for q in queries if q]
all_queries_passages = []
for query in queries:
passages = []
retrieved_passages = dsp.settings.rm(query, k=k * 3, **kwargs)
passages_cs_scores = dsp.settings.reranker(
query, passages=[psg["long_text"] for psg in retrieved_passages],
)
for idx in np.argsort(passages_cs_scores)[::-1][:k]:
curr_passage = retrieved_passages[idx]
curr_passage["rerank_score"] = passages_cs_scores[idx]
passages.append(curr_passage)
all_queries_passages.append(passages)
if len(queries) == 1:
return all_queries_passages[0]
else:
return all_queries_passages


def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs) -> list[str]:
"""Retrieves passages from the RM for each query in queries and returns the top k passages
Expand All @@ -50,7 +87,6 @@ def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs)
raise AssertionError("No RM is loaded.")
if dsp.settings.reranker:
return retrieveRerankEnsemble(queries, k, **kwargs)

queries = [q for q in queries if q]

if len(queries) == 1:
Expand All @@ -68,4 +104,43 @@ def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs)
passages = sorted(passages, reverse=True)[:k]
passages = [text for _, text in passages]


return passages

def retrieveEnsemblewithMetadata(
queries: list[str], k: int, by_prob: bool = True, **kwargs,
) -> list[str]:
"""Retrieves passages from the RM for each query in queries and returns the top k passages
based on the probability or score.
"""

if not dsp.settings.rm:
raise AssertionError("No RM is loaded.")
if not dsp.settings.reranker:
return retrieveRerankEnsemblewithMetadata(queries=queries,k=k)

queries = [q for q in queries if q]

if len(queries) == 1:
return retrieve(queries[0], k)
all_queries_passages = []
for q in queries:
passages = {}
retrieved_passages = dsp.settings.rm(q, k=k * 3, **kwargs)
for idx, psg in enumerate(retrieved_passages):
if by_prob:
passages[(idx, psg.long_text)] = (
passages.get(psg.long_text, 0.0) + psg.prob
)
else:
passages[(idx, psg.long_text)] = (
passages.get(psg.long_text, 0.0) + psg.score
)
retrieved_passages[idx]["tracking_idx"] = idx
passages = sorted(passages.items(), key=lambda item: item[1])[:k]
req_indices = [psg[0][0] for psg in passages]
passages = [
rp for rp in retrieved_passages if rp.get("tracking_idx") in req_indices
]
all_queries_passages.append(passages)
return all_queries_passages
2 changes: 2 additions & 0 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
Databricks = dsp.Databricks
Cohere = dsp.Cohere
ColBERTv2 = dsp.ColBERTv2
ColBERTv2RerankerLocal = dsp.ColBERTv2RerankerLocal
ColBERTv2RetrieverLocal = dsp.ColBERTv2RetrieverLocal
Pyserini = dsp.PyseriniRetriever
Clarifai = dsp.ClarifaiLLM
CloudflareAI = dsp.CloudflareAI
Expand Down
2 changes: 1 addition & 1 deletion dspy/retrieve/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .retrieve import Retrieve
from .retrieve import Retrieve, RetrieveThenRerank
Loading

0 comments on commit 37b3759

Please sign in to comment.