Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Colbert local mode support both as retriever and reranker. #797

Merged
merged 32 commits into from
Jun 15, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
9632e5e
return metadata changes
Athe-kunal Apr 4, 2024
e415f39
Merge branch 'main' of https://github.com/Athe-kunal/dspy
Athe-kunal Apr 4, 2024
a4b3844
add metadata changes
Athe-kunal Apr 4, 2024
321a768
Merge branch 'stanfordnlp:main' into main
Athe-kunal Apr 5, 2024
6cd1d56
add support for returning metadata and reranking
Athe-kunal Apr 6, 2024
eeafacb
colbert integration
Athe-kunal Apr 8, 2024
1639bd2
colbert local modifications
Athe-kunal Apr 8, 2024
ec062b6
kwargs filtered ids
Athe-kunal Apr 8, 2024
987d923
colbert return
Athe-kunal Apr 8, 2024
9ff5b28
colbert retriever and reranker
Athe-kunal Apr 9, 2024
825a272
colbert retriever error fixes
Athe-kunal Apr 9, 2024
c25e9c4
colbert config changes in __init__
Athe-kunal Apr 10, 2024
ab5b12e
colbert notebook
Athe-kunal Apr 10, 2024
63dd534
Merge branch 'stanfordnlp:main' into main
Athe-kunal Apr 10, 2024
f6a9293
import errors for colbert
Athe-kunal Apr 10, 2024
197a2c2
improt dspy fixes and linting fixes
Athe-kunal Apr 10, 2024
4698b00
Merge branch 'stanfordnlp:main' into main
Athe-kunal Apr 13, 2024
81d142f
PR fixes for colbert
Athe-kunal Apr 13, 2024
b73753c
making the linting gods happy
Athe-kunal Apr 13, 2024
0ec1ded
remove unnecessary outputs
Athe-kunal Apr 14, 2024
567d5c4
Merge branch 'stanfordnlp:main' into main
Athe-kunal Apr 17, 2024
685df2a
colbertv2 docs
Athe-kunal Apr 17, 2024
fa2bc20
Merge branch 'stanfordnlp:main' into main
Athe-kunal Apr 19, 2024
509b36c
Merge branch 'stanfordnlp:main' into main
Athe-kunal Apr 20, 2024
34328fd
Merge branch 'stanfordnlp:main' into main
Athe-kunal Apr 22, 2024
146ec7b
Merge branch 'stanfordnlp:main' into main
Athe-kunal Apr 26, 2024
f0437e3
Merge branch 'stanfordnlp:main' into main
Athe-kunal Apr 29, 2024
9cb522b
Colbert PR fixes
Athe-kunal Apr 29, 2024
ec4b9b3
linting fixes
Athe-kunal Apr 29, 2024
326ce01
more linting fixes
Athe-kunal Apr 29, 2024
b5913fc
fixing previous cache breaks with separate funcs
Athe-kunal Jun 8, 2024
c60fadc
Merge branch 'main' into main
arnavsinghvi11 Jun 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dsp/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .cache_utils import *
from .clarifai import *
from .cohere import *
from .colbertv2 import ColBERTv2
from .colbertv2 import ColBERTv2, ColBERTv2RerankerLocal, ColBERTv2RetrieverLocal
from .databricks import *
from .google import *
from .googlevertexai import *
Expand Down
118 changes: 117 additions & 1 deletion dsp/modules/colbertv2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
from typing import Any, Optional, Union
import os
from typing import Any, List, Optional, Union

import requests

Expand Down Expand Up @@ -74,3 +75,118 @@ def colbertv2_post_request_v2_wrapped(*args, **kwargs):


colbertv2_post_request = colbertv2_post_request_v2_wrapped
os.environ['COLBERT_LOAD_TORCH_EXTENSION_VERBOSE'] = "True"
Athe-kunal marked this conversation as resolved.
Show resolved Hide resolved

class ColBERTv2RetrieverLocal:
def __init__(self,passages:List[str],colbert_config=None,load_only:bool=False,index_name:str="colbert_rm",checkpoint:str='colbert-ir/colbertv2.0'):
"""Colbertv2 retriever module

Args:
passages (List[str]): list of passages
load_only (bool, optional): whether to load the index or . Defaults to False.
Athe-kunal marked this conversation as resolved.
Show resolved Hide resolved
index_name (str, optional): name of the index. Defaults to "colbert_rm".
checkpoint (str, optional): checkpoint for generating embeddings. Defaults to 'colbert-ir/colbertv2.0'.
colbert_config (ColBERTConfig, optional): colbert config for building and searching. Defaults to ColBERTConfig().
Athe-kunal marked this conversation as resolved.
Show resolved Hide resolved
"""
self.checkpoint = checkpoint
self.colbert_config = colbert_config
self.colbert_config.index_name = index_name
self.checkpoint = checkpoint
Athe-kunal marked this conversation as resolved.
Show resolved Hide resolved
self.colbert_config.checkpoint = checkpoint
self.passages = passages

if not load_only:
print(f"Building the index for experiment {self.colbert_config.experiment} with index name {self.colbert_config.index_name}")
self.build_index()

print(f"Loading the index for experiment {self.colbert_config.experiment} with index name {self.colbert_config.index_name}")
self.searcher = self.get_index()

def build_index(self):

try:
import colbert
except ImportError:
print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].")

from colbert import Indexer
from colbert.infra import Run, RunConfig
with Run().context(RunConfig(nranks=self.colbert_config.nranks, experiment=self.colbert_config.experiment)):
indexer = Indexer(checkpoint=self.checkpoint, config=self.colbert_config)
indexer.index(name=self.colbert_config.index_name, collection=self.passages, overwrite=True)

def get_index(self):
try:
import colbert
except ImportError:
print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].")

from colbert import Searcher
from colbert.infra import Run, RunConfig

with Run().context(RunConfig(experiment=self.colbert_config.experiment)):
searcher = Searcher(index=self.colbert_config.index_name, collection=self.passages)
return searcher

def __call__(self,query:str,k:int=7,**kwargs):
import torch

if kwargs.get("filtered_pids"):
filtered_pids = kwargs.get("filtered_pids")
assert type(filtered_pids) == List[int], "The filtered pids should be a list of integers"
device = "cuda" if torch.cuda.is_available() else "cpu"
results = self.searcher.search(
query,
#Number of passages to receive
k=k,
#Passing the filter function of relevant
filter_fn=lambda pids: torch.tensor(
[pid for pid in pids if pid in filtered_pids],dtype=torch.int32).to(device))
else:
searcher_results = self.searcher.search(query, k=k)
results = []
for pid,rank,score in zip(*searcher_results):
results.append(dotdict({'long_text':self.searcher.collection[pid],'score':score,'pid':pid}))
return results

class ColBERTv2RerankerLocal:

def __init__(self,colbert_config=None,checkpoint:str='bert-base-uncased'):
try:
import colbert
except ImportError:
print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].")
"""_summary_

Args:
checkpoint_name (str, optional): checkpoint for embeddings. Defaults to 'bert-base-uncased'.
colbert_config (ColBERTConfig, optional): Colbert config. Defaults to ColBERTConfig().
Athe-kunal marked this conversation as resolved.
Show resolved Hide resolved
"""
self.colbert_config = colbert_config
self.checkpoint_name = checkpoint
self.colbert_config.checkpoint = checkpoint

# def __call__(self, *args: Any, **kwargs: Any) -> Any:
# return self.forward(*args, **kwargs)

def __call__(self,query:str,passages:List[str]=[]):
import numpy as np
from colbert.modeling.colbert import ColBERT
Athe-kunal marked this conversation as resolved.
Show resolved Hide resolved
from colbert.modeling.tokenization.doc_tokenization import DocTokenizer
from colbert.modeling.tokenization.query_tokenization import QueryTokenizer
assert len(passages) > 0, "Passages should not be empty"
self.colbert_config.nway = len(passages)
query_tokenizer = QueryTokenizer(self.colbert_config,verbose=1)
doc_tokenizer = DocTokenizer(self.colbert_config)
query_ids,query_masks = query_tokenizer.tensorize([query])
doc_ids,doc_masks = doc_tokenizer.tensorize(passages)

col = ColBERT(self.checkpoint_name,self.colbert_config)
# col.colbert_config.nway = len(passages)
Athe-kunal marked this conversation as resolved.
Show resolved Hide resolved
# tensor_scores = col([query_ids,query_masks],[doc_ids,doc_masks])
Q = col.query(query_ids,query_masks)
DOC_IDS,DOC_MASKS = col.doc(doc_ids,doc_masks,keep_dims='return_mask')
Q_duplicated = Q.repeat_interleave(len(passages), dim=0).contiguous()
tensor_scores = col.score(Q_duplicated,DOC_IDS,DOC_MASKS)
passage_score_arr = np.array([score.cpu().detach().numpy().tolist() for score in tensor_scores])
return passage_score_arr
84 changes: 56 additions & 28 deletions dsp/primitives/search.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from collections.abc import Iterable

import numpy as np
Expand All @@ -9,17 +10,21 @@ def retrieve(query: str, k: int, **kwargs) -> list[str]:
"""Retrieves passages from the RM for the query and returns the top k passages."""
if not dsp.settings.rm:
raise AssertionError("No RM is loaded.")
if not dsp.settings.reranker:
warnings.warn("If you want to use the Reranker, please use dspy.RetrieveThenRerank")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we instead handle this through logging and through a Depecration message?

Copy link
Contributor Author

@Athe-kunal Athe-kunal Apr 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arnavsinghvi11, can you explain what you mean by logging?
Do I have to create a python logging object file and then log these? Sorry if this is a trivial question

I have added the deprecation warning for now

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"DeprecationWarning: 'display' has been deprecated. To see all information for debugging, use 'dspy.set_log_level('debug')'. In the future this will raise an error.",
- feel free to reference this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dspy logger object is not available in the dsp folder, hence I followed logging as done here for anthropic LM. Is there a better way to log this?

passages = dsp.settings.rm(query, k=k, **kwargs)
if not isinstance(passages, Iterable):
# it's not an iterable yet; make it one.
# TODO: we should unify the type signatures of dspy.Retriever
passages = [passages]
passages = [psg.long_text for psg in passages]
# passages = [psg.long_text for psg in passages]

if dsp.settings.reranker:
passages_cs_scores = dsp.settings.reranker(query, passages)
passages_cs_scores_sorted = np.argsort(passages_cs_scores)[::-1]
passages = [passages[idx] for idx in passages_cs_scores_sorted]
# if dsp.settings.reranker:
Athe-kunal marked this conversation as resolved.
Show resolved Hide resolved
# passages_tracking_idx = {str(idx):psg for idx, psg in enumerate(passages)}
# passages_long_text = [psg.long_text for psg in passages]
# passages_cs_scores = dsp.settings.reranker(query, passages_long_text)
# passages_cs_scores_sorted = np.argsort(passages_cs_scores)[::-1]
# passages = [passages_long_text[idx] for idx in passages_cs_scores_sorted]

return passages

Expand All @@ -28,44 +33,67 @@ def retrieveRerankEnsemble(queries: list[str], k: int,**kwargs) -> list[str]:
if not (dsp.settings.rm and dsp.settings.reranker):
raise AssertionError("Both RM and Reranker are needed to retrieve & re-rank.")
queries = [q for q in queries if q]
passages = {}
all_queries_passages = []
for query in queries:
passages = []
retrieved_passages = dsp.settings.rm(query, k=k*3,**kwargs)
passages_cs_scores = dsp.settings.reranker(query, [psg.long_text for psg in retrieved_passages])
for idx in np.argsort(passages_cs_scores)[::-1]:
psg = retrieved_passages[idx]
passages[psg.long_text] = passages.get(psg.long_text, []) + [
passages_cs_scores[idx],
]
passages_cs_scores = dsp.settings.reranker(query,passages=[psg["long_text"] for psg in retrieved_passages])
for idx in np.argsort(passages_cs_scores)[::-1][:k]:
curr_passage = retrieved_passages[idx]
curr_passage['rerank_score'] = passages_cs_scores[idx]
passages.append(curr_passage)
all_queries_passages.append(passages)
if len(queries) == 1:
return all_queries_passages[0]
else:
return all_queries_passages

passages = [(np.average(score), text) for text, score in passages.items()]
return [text for _, text in sorted(passages, reverse=True)[:k]]
# def retrieveRerankEnsemble(queries: list[str], k: int,**kwargs) -> list[str]:
Athe-kunal marked this conversation as resolved.
Show resolved Hide resolved
# if not (dsp.settings.rm and dsp.settings.reranker):
# raise AssertionError("Both RM and Reranker are needed to retrieve & re-rank.")
# queries = [q for q in queries if q]
# passages = {}
# for query in queries:
# retrieved_passages = dsp.settings.rm(query, k=k*3,**kwargs)
# passages_cs_scores = dsp.settings.reranker(query, [psg.long_text for psg in retrieved_passages])
# for idx in np.argsort(passages_cs_scores)[::-1]:
# psg = retrieved_passages[idx]
# passages[psg.long_text] = passages.get(psg.long_text, []) + [
# passages_cs_scores[idx],
# ]

# passages = [(np.average(score), text) for text, score in passages.items()]
# return [text for _, text in sorted(passages, reverse=True)[:k]]

def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs) -> list[str]:
"""Retrieves passages from the RM for each query in queries and returns the top k passages
based on the probability or score.
"""
if not dsp.settings.rm:
raise AssertionError("No RM is loaded.")
if dsp.settings.reranker:
return retrieveRerankEnsemble(queries, k)
if not dsp.settings.reranker:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logging here too

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as with above -

"DeprecationWarning: 'display' has been deprecated. To see all information for debugging, use 'dspy.set_log_level('debug')'. In the future this will raise an error.",

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as with above-
The dspy logger object is not available in the dsp folder, hence I followed logging as done here for anthropic LM. Is there a better way to log this?

warnings.warn("If you want to use the Reranker, please use dspy.RetrieveThenRerank. The reranking is ignored here.")

queries = [q for q in queries if q]

if len(queries) == 1:
return retrieve(queries[0], k, **kwargs)

passages = {}
return retrieve(queries[0], k)
all_queries_passages = []
for q in queries:
for psg in dsp.settings.rm(q, k=k * 3,**kwargs):
passages = {}
retrieved_passages = dsp.settings.rm(q, k=k * 3,**kwargs)
# for idx,psg in enumerate(retrieved_passages):
# retrieved_passages[idx]["tracking_idx"] = idx
for idx,psg in enumerate(retrieved_passages):
if by_prob:
passages[psg.long_text] = passages.get(psg.long_text, 0.0) + psg.prob
passages[(idx,psg.long_text)] = passages.get(psg.long_text, 0.0) + psg.prob
else:
passages[psg.long_text] = passages.get(psg.long_text, 0.0) + psg.score

passages = [(score, text) for text, score in passages.items()]
passages = sorted(passages, reverse=True)[:k]
passages = [text for _, text in passages]

return passages
passages[(idx,psg.long_text)] = passages.get(psg.long_text, 0.0) + psg.score
retrieved_passages[idx]["tracking_idx"] = idx
# passages = [(score, text) for text, score in passages.items()]
Athe-kunal marked this conversation as resolved.
Show resolved Hide resolved
passages = sorted(passages.items(), key=lambda item: item[1])[:k]
# passages = sorted(passages, reverse=True)[:k]
req_indices = [psg[0][0] for psg in passages]
passages = [rp for rp in retrieved_passages if rp.get("tracking_idx") in req_indices]
all_queries_passages.append(passages)
return all_queries_passages
2 changes: 2 additions & 0 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
Databricks = dsp.Databricks
Cohere = dsp.Cohere
ColBERTv2 = dsp.ColBERTv2
ColBERTv2RerankerLocal = dsp.ColBERTv2RerankerLocal
ColBERTv2RetrieverLocal = dsp.ColBERTv2RetrieverLocal
Pyserini = dsp.PyseriniRetriever
Clarifai = dsp.ClarifaiLLM
Google = dsp.Google
Expand Down
2 changes: 1 addition & 1 deletion dspy/retrieve/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .retrieve import Retrieve
from .retrieve import Retrieve, RetrieveThenRerank
83 changes: 79 additions & 4 deletions dspy/retrieve/retrieve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import random
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

import dsp
from dspy.predict.parameter import Parameter
Expand Down Expand Up @@ -29,14 +29,89 @@ def load_state(self, state):
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)

def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None,**kwargs) -> Prediction:
def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None,**kwargs) -> Union[Prediction,List[Prediction]]:
queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries
queries = [query.strip().split('\n')[0].strip() for query in queries]

# print(queries)
# TODO: Consider removing any quote-like markers that surround the query too.
k = k if k is not None else self.k
passages = dsp.retrieveEnsemble(queries, k=k,**kwargs)
return Prediction(passages=passages)

if isinstance(passages[0],List):
pred_returns = []
for query_passages in passages:
passages_dict = {key:[] for key in list(query_passages[0].keys()) if key!="tracking_idx"}
for psg in query_passages:
for key,value in psg.items():
if key == "tracking_idx": continue
passages_dict[key].append(value)
if "long_text" in passages_dict:
passages_dict["passages"] = passages_dict.pop("long_text")
pred_returns.append(Prediction(**passages_dict))
return pred_returns
elif isinstance(passages[0], Dict):
#passages dict will contain {"long_text":long_text_list,"metadatas";metadatas_list...}
passages_dict = {key:[] for key in list(passages[0].keys())}

for psg in passages:
for key,value in psg.items():
passages_dict[key].append(value)
if "long_text" in passages_dict:
passages_dict["passages"] = passages_dict.pop("long_text")
return Prediction(**passages_dict)
# elif isinstance(passages,List):
# return Prediction(passages=passages)
# TODO: Consider doing Prediction.from_completions with the individual sets of passages (per query) too.

class RetrieveThenRerank(Parameter):
name = "Search"
input_variable = "query"
desc = "takes a search query and returns one or more potentially relevant passages followed by reranking from a corpus"

def __init__(self, k=3):
self.stage = random.randbytes(8).hex()
self.k = k

def reset(self):
pass

def dump_state(self):
state_keys = ["k"]
return {k: getattr(self, k) for k in state_keys}

def load_state(self, state):
for name, value in state.items():
setattr(self, name, value)

# def __call__(self, *args, **kwargs):
Athe-kunal marked this conversation as resolved.
Show resolved Hide resolved
# return self.forward(*args, **kwargs)

def __call__(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None,**kwargs) -> Union[Prediction,List[Prediction]]:
queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries
queries = [query.strip().split('\n')[0].strip() for query in queries]

# print(queries)
# TODO: Consider removing any quote-like markers that surround the query too.
k = k if k is not None else self.k
passages = dsp.retrieveRerankEnsemble(queries, k=k,**kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we have maintain the forward pass call from before and abstract the repetitive code from below within the forward pass?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I was not able to understand this. @arnavsinghvi11
Do you want to have a common utility function for both Retrieve and RetrieveTheRerank to process the returned documents?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Athe-kunal , yeah it seemed like there is some repetitive code in both forward passes that can be abstracted out for the different retriever types. let me know if this change makes sense

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @arnavsinghvi11
I have abstracted the repetitive code part. However there are some nuances in the multi-query retriever, hence I didn't make a helper function for it. But for a single query, I have added a helper function single_query_passage. Please let me know if I need to make some other changes.

if isinstance(passages[0],List):
pred_returns = []
for query_passages in passages:
passages_dict = {key:[] for key in list(query_passages[0].keys())}
for docs in query_passages:
for key,value in docs.items():
passages_dict[key].append(value)
if "long_text" in passages_dict:
passages_dict["passages"] = passages_dict.pop("long_text")

pred_returns.append(Prediction(**passages_dict))
return pred_returns
elif isinstance(passages[0], Dict):
passages_dict = {key:[] for key in list(passages[0].keys())}
for docs in passages:
for key,value in docs.items():
passages_dict[key].append(value)
if "long_text" in passages_dict:
passages_dict["passages"] = passages_dict.pop("long_text")
return Prediction(**passages_dict)

Loading
Loading