From 9632e5ef08e53b1b41aa844e36d4a16de1417432 Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Wed, 3 Apr 2024 22:41:32 -0400 Subject: [PATCH 01/21] return metadata changes --- dsp/primitives/search.py | 29 +++++++++++++++++------------ dspy/retrieve/retrieve.py | 25 +++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/dsp/primitives/search.py b/dsp/primitives/search.py index aef9cd8eb..ef0f2b9d2 100644 --- a/dsp/primitives/search.py +++ b/dsp/primitives/search.py @@ -14,7 +14,7 @@ def retrieve(query: str, k: int, **kwargs) -> list[str]: # it's not an iterable yet; make it one. # TODO: we should unify the type signatures of dspy.Retriever passages = [passages] - passages = [psg.long_text for psg in passages] + # passages = [psg.long_text for psg in passages] if dsp.settings.reranker: passages_cs_scores = dsp.settings.reranker(query, passages) @@ -55,17 +55,22 @@ def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs) if len(queries) == 1: return retrieve(queries[0], k) - - passages = {} + all_queries_passages = [] for q in queries: - for psg in dsp.settings.rm(q, k=k * 3,**kwargs): + passages = {} + retrieved_passages = dsp.settings.rm(q, k=k * 3,**kwargs) + # for idx,psg in enumerate(retrieved_passages): + # retrieved_passages[idx]["tracking_idx"] = idx + for idx,psg in enumerate(retrieved_passages): if by_prob: - passages[psg.long_text] = passages.get(psg.long_text, 0.0) + psg.prob + passages[(idx,psg.long_text)] = passages.get(psg.long_text, 0.0) + psg.prob else: - passages[psg.long_text] = passages.get(psg.long_text, 0.0) + psg.score - - passages = [(score, text) for text, score in passages.items()] - passages = sorted(passages, reverse=True)[:k] - passages = [text for _, text in passages] - - return passages + passages[(idx,psg.long_text)] = passages.get(psg.long_text, 0.0) + psg.score + retrieved_passages[idx]["tracking_idx"] = idx + # passages = [(score, text) for text, score in passages.items()] + passages = sorted(passages.items(), key=lambda item: item[1])[:k] + # passages = sorted(passages, reverse=True)[:k] + req_indices = [psg[0][0] for psg in passages] + passages = [rp for rp in retrieved_passages if rp.get("tracking_idx") in req_indices] + all_queries_passages.append(passages) + return all_queries_passages diff --git a/dspy/retrieve/retrieve.py b/dspy/retrieve/retrieve.py index 7f026e2aa..0a078581c 100644 --- a/dspy/retrieve/retrieve.py +++ b/dspy/retrieve/retrieve.py @@ -1,5 +1,5 @@ import random -from typing import List, Optional, Union +from typing import List, Optional, Union, Dict, Any import dsp from dspy.predict.parameter import Parameter @@ -37,6 +37,27 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = No # TODO: Consider removing any quote-like markers that surround the query too. k = k if k is not None else self.k passages = dsp.retrieveEnsemble(queries, k=k,**kwargs) - return Prediction(passages=passages) + if isinstance(passages[0],List): + pred_returns = [] + for query_passages in passages: + passages_dict = {key:[] for key in list(query_passages[0].keys()) if key!="tracking_idx"} + for psg in query_passages: + for key,value in psg.items(): + if key == "tracking_idx": continue + passages_dict[key].append(value) + if "long_text" in passages_dict: + passages_dict["passages"] = passages_dict.pop("long_text") + pred_returns.append(Prediction(**passages_dict)) + return pred_returns + elif isinstance(passages[0], Dict): + #passages dict will contain {"long_text":long_text_list,"metadatas";metadatas_list...} + passages_dict = {key:[] for key in list(passages[0].keys())} + + for psg in passages: + for key,value in psg.items(): + passages_dict[key].append(value) + if "long_text" in passages_dict: + passages_dict["passages"] = passages_dict.pop("long_text") + return Prediction(**passages_dict) # TODO: Consider doing Prediction.from_completions with the individual sets of passages (per query) too. From a4b3844811fa5f57b3dc42547fce4d477cbe8824 Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Thu, 4 Apr 2024 00:59:04 -0400 Subject: [PATCH 02/21] add metadata changes --- dsp/primitives/search.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/dsp/primitives/search.py b/dsp/primitives/search.py index 78a83a579..ef0f2b9d2 100644 --- a/dsp/primitives/search.py +++ b/dsp/primitives/search.py @@ -54,14 +54,8 @@ def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs) queries = [q for q in queries if q] if len(queries) == 1: -<<<<<<< HEAD return retrieve(queries[0], k) all_queries_passages = [] -======= - return retrieve(queries[0], k, **kwargs) - - passages = {} ->>>>>>> fd63306642553ecb7d0916ea4156a374ae53c255 for q in queries: passages = {} retrieved_passages = dsp.settings.rm(q, k=k * 3,**kwargs) From 6cd1d56f70328bb42f14c7fee9ea6ba7f36972d3 Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Sat, 6 Apr 2024 02:35:02 -0400 Subject: [PATCH 03/21] add support for returning metadata and reranking --- dsp/primitives/search.py | 58 ++++++++++++++++++++++++++++--------- dspy/retrieve/__init__.py | 2 +- dspy/retrieve/retrieve.py | 60 +++++++++++++++++++++++++++++++++++++-- rm_test.py | 42 +++++++++++++++++++++++++++ 4 files changed, 145 insertions(+), 17 deletions(-) create mode 100644 rm_test.py diff --git a/dsp/primitives/search.py b/dsp/primitives/search.py index ef0f2b9d2..42338cc04 100644 --- a/dsp/primitives/search.py +++ b/dsp/primitives/search.py @@ -1,5 +1,5 @@ from collections.abc import Iterable - +import warnings import numpy as np import dsp @@ -9,6 +9,8 @@ def retrieve(query: str, k: int, **kwargs) -> list[str]: """Retrieves passages from the RM for the query and returns the top k passages.""" if not dsp.settings.rm: raise AssertionError("No RM is loaded.") + if not dsp.settings.reranker: + warnings.warn("If you want to use the Reranker, please use dspy.RetrieveThenRerank") passages = dsp.settings.rm(query, k=k, **kwargs) if not isinstance(passages, Iterable): # it's not an iterable yet; make it one. @@ -16,10 +18,12 @@ def retrieve(query: str, k: int, **kwargs) -> list[str]: passages = [passages] # passages = [psg.long_text for psg in passages] - if dsp.settings.reranker: - passages_cs_scores = dsp.settings.reranker(query, passages) - passages_cs_scores_sorted = np.argsort(passages_cs_scores)[::-1] - passages = [passages[idx] for idx in passages_cs_scores_sorted] + # if dsp.settings.reranker: + # passages_tracking_idx = {str(idx):psg for idx, psg in enumerate(passages)} + # passages_long_text = [psg.long_text for psg in passages] + # passages_cs_scores = dsp.settings.reranker(query, passages_long_text) + # passages_cs_scores_sorted = np.argsort(passages_cs_scores)[::-1] + # passages = [passages_long_text[idx] for idx in passages_cs_scores_sorted] return passages @@ -28,19 +32,45 @@ def retrieveRerankEnsemble(queries: list[str], k: int,**kwargs) -> list[str]: if not (dsp.settings.rm and dsp.settings.reranker): raise AssertionError("Both RM and Reranker are needed to retrieve & re-rank.") queries = [q for q in queries if q] - passages = {} + all_queries_passages = [] for query in queries: + passages = [] retrieved_passages = dsp.settings.rm(query, k=k*3,**kwargs) - passages_cs_scores = dsp.settings.reranker(query, [psg.long_text for psg in retrieved_passages]) + # passages_cs_scores = dsp.settings.reranker(query,k=k,passages=[psg["long_text"] for psg in retrieved_passages]) + passages_cs_scores = dsp.settings.reranker(query,k=k) + # passages_cs_scores = dsp.settings.reranker(query, retrieved_passages) for idx in np.argsort(passages_cs_scores)[::-1]: psg = retrieved_passages[idx] - passages[psg.long_text] = passages.get(psg.long_text, []) + [ - passages_cs_scores[idx], - ] + # passages[psg.long_text] = passages.get(psg.long_text, []) + [ + # passages_cs_scores[idx], + # ] + # passages.append((psg,passages_cs_scores[idx])) + passages.append([passages_cs_scores[idx],psg]) + # all_queries_passages.append(passages) + + # passages = [(np.average(score), psg) for score,psg in passages] + all_queries_passages.append(passages[:k]) + if len(queries) == 1: + return all_queries_passages[0] + else: + return all_queries_passages - passages = [(np.average(score), text) for text, score in passages.items()] - return [text for _, text in sorted(passages, reverse=True)[:k]] +# def retrieveRerankEnsemble(queries: list[str], k: int,**kwargs) -> list[str]: +# if not (dsp.settings.rm and dsp.settings.reranker): +# raise AssertionError("Both RM and Reranker are needed to retrieve & re-rank.") +# queries = [q for q in queries if q] +# passages = {} +# for query in queries: +# retrieved_passages = dsp.settings.rm(query, k=k*3,**kwargs) +# passages_cs_scores = dsp.settings.reranker(query, [psg.long_text for psg in retrieved_passages]) +# for idx in np.argsort(passages_cs_scores)[::-1]: +# psg = retrieved_passages[idx] +# passages[psg.long_text] = passages.get(psg.long_text, []) + [ +# passages_cs_scores[idx], +# ] +# passages = [(np.average(score), text) for text, score in passages.items()] +# return [text for _, text in sorted(passages, reverse=True)[:k]] def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs) -> list[str]: """Retrieves passages from the RM for each query in queries and returns the top k passages @@ -48,8 +78,8 @@ def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs) """ if not dsp.settings.rm: raise AssertionError("No RM is loaded.") - if dsp.settings.reranker: - return retrieveRerankEnsemble(queries, k) + if not dsp.settings.reranker: + warnings.warn("If you want to use the Reranker, please use dspy.RetrieveThenRerank. The reranking is ignored here.") queries = [q for q in queries if q] diff --git a/dspy/retrieve/__init__.py b/dspy/retrieve/__init__.py index 1d1f9e8b7..2f699c23a 100644 --- a/dspy/retrieve/__init__.py +++ b/dspy/retrieve/__init__.py @@ -1 +1 @@ -from .retrieve import Retrieve \ No newline at end of file +from .retrieve import Retrieve, RetrieveThenRerank \ No newline at end of file diff --git a/dspy/retrieve/retrieve.py b/dspy/retrieve/retrieve.py index 0a078581c..e719a8357 100644 --- a/dspy/retrieve/retrieve.py +++ b/dspy/retrieve/retrieve.py @@ -29,7 +29,7 @@ def load_state(self, state): def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) - def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None,**kwargs) -> Prediction: + def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None,**kwargs) -> Union[Prediction,List[Prediction]]: queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries queries = [query.strip().split('\n')[0].strip() for query in queries] @@ -59,5 +59,61 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = No if "long_text" in passages_dict: passages_dict["passages"] = passages_dict.pop("long_text") return Prediction(**passages_dict) - + # elif isinstance(passages,List): + # return Prediction(passages=passages) # TODO: Consider doing Prediction.from_completions with the individual sets of passages (per query) too. + +class RetrieveThenRerank(Parameter): + name = "Search" + input_variable = "query" + desc = "takes a search query and returns one or more potentially relevant passages followed by reranking from a corpus" + + def __init__(self, k=3): + self.stage = random.randbytes(8).hex() + self.k = k + + def reset(self): + pass + + def dump_state(self): + state_keys = ["k"] + return {k: getattr(self, k) for k in state_keys} + + def load_state(self, state): + for name, value in state.items(): + setattr(self, name, value) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None,**kwargs) -> Union[Prediction,List[Prediction]]: + queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries + queries = [query.strip().split('\n')[0].strip() for query in queries] + + # print(queries) + # TODO: Consider removing any quote-like markers that surround the query too. + k = k if k is not None else self.k + passages = dsp.retrieveRerankEnsemble(queries, k=k,**kwargs) + if isinstance(passages[0],List): + pred_returns = [] + for query_passages in passages: + passages_dict = {key:[] for key in list(query_passages[0].keys()) if key!="tracking_idx"} + for psg in query_passages: + for key,value in psg.items(): + if key == "tracking_idx": continue + passages_dict[key].append(value) + if "long_text" in passages_dict: + passages_dict["passages"] = passages_dict.pop("long_text") + pred_returns.append(Prediction(**passages_dict)) + return pred_returns + elif isinstance(passages[0], Dict): + #passages dict will contain {"long_text":long_text_list,"metadatas";metadatas_list...} + passages_dict = {key:[] for key in list(passages[0].keys())} + + for psg in passages: + for key,value in psg.items(): + passages_dict[key].append(value) + if "long_text" in passages_dict: + passages_dict["passages"] = passages_dict.pop("long_text") + return Prediction(**passages_dict) + \ No newline at end of file diff --git a/rm_test.py b/rm_test.py new file mode 100644 index 000000000..df9dd25db --- /dev/null +++ b/rm_test.py @@ -0,0 +1,42 @@ +# import dspy.retrieve +# from dspy.retrieve.chromadb_rm import ChromadbRM +# from dotenv import load_dotenv +# import chromadb.utils.embedding_functions as embedding_functions +# from chromadb.utils.batch_utils import create_batches +# import os +# import dspy + + +# load_dotenv(override=True) +# emb_fn = embedding_functions.OpenAIEmbeddingFunction( +# api_key=os.environ['OPENAI_API_KEY'], +# model_name="text-embedding-3-small") + +# crm = ChromadbRM( +# collection_name="rows", +# persist_directory="/home/athekunal/DSPy-contributions/Text-to-SQL/India_TABLE", +# embedding_function=emb_fn +# ) +# # reranker = dspy. +# dspy.settings.configure(rm=crm) + +# retriever = dspy.Retrieve(k=2) + +# print(retriever(["Software Internet"],by_prob=False,where={"table_name":"capexIndia"})) +# print("-"*100) +# print(retriever(["Software Internet","Packaging"],by_prob=False,where={"table_name":"capexIndia"})) +import dspy + +colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts') +dspy.settings.configure(rm=colbertv2_wiki17_abstracts,reranker=colbertv2_wiki17_abstracts) + +#Define Retrieve Module +retriever = dspy.RetrieveThenRerank(k=3) + +query='When was the first FIFA World Cup held?' + +# Call the retriever on a particular query. +topK_passages = retriever([query]) + +for idx, passage in enumerate(topK_passages): + print(f'{idx+1}]', passage, '\n') \ No newline at end of file From eeafacb27ecf7d9f052b96bd3105c0fcd42f3292 Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Sun, 7 Apr 2024 21:25:54 -0400 Subject: [PATCH 04/21] colbert integration --- dsp/modules/__init__.py | 2 +- dsp/modules/colbertv2.py | 41 ++++++++++- dspy/__init__.py | 1 + rm_test.py | 143 ++++++++++++++++++++++++++++++++++++--- 4 files changed, 174 insertions(+), 13 deletions(-) diff --git a/dsp/modules/__init__.py b/dsp/modules/__init__.py index b52d663cd..235dbfcc7 100644 --- a/dsp/modules/__init__.py +++ b/dsp/modules/__init__.py @@ -4,7 +4,7 @@ from .cache_utils import * from .clarifai import * from .cohere import * -from .colbertv2 import ColBERTv2 +from .colbertv2 import ColBERTv2, ColBERTv2Local from .databricks import * from .google import * from .gpt3 import * diff --git a/dsp/modules/colbertv2.py b/dsp/modules/colbertv2.py index 8ff3c1622..e024b6c53 100644 --- a/dsp/modules/colbertv2.py +++ b/dsp/modules/colbertv2.py @@ -1,10 +1,14 @@ import functools -from typing import Any, Optional, Union +from typing import Any, Optional, Union, List import requests - +import colbert +from colbert import Indexer, Searcher +from colbert.infra import Run, RunConfig, ColBERTConfig +from colbert.data import Queries, Collection from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory from dsp.utils import dotdict +import os # TODO: Ideally, this takes the name of the index and looks up its port. @@ -74,3 +78,36 @@ def colbertv2_post_request_v2_wrapped(*args, **kwargs): colbertv2_post_request = colbertv2_post_request_v2_wrapped +os.environ['COLBERT_LOAD_TORCH_EXTENSION_VERBOSE'] = "True" + +class ColBERTv2Local: + def __init__(self,checkpoint:str='colbert-ir/colbertv2.0'): + + self.checkpoint = checkpoint + + + def build_index(self,passages:List[str],nranks:int=1,index_name_or_path:str = "Colbert-RM-",nbits:int=2,DOC_MAXLEN:int=300,INDEX_BSIZE:int=256,KMEANS_ITER:int=8,experiment_name:str="Colbert-Experiment"): + + with Run().context(RunConfig(nranks=nranks, experiment=experiment_name)): + config = ColBERTConfig(doc_maxlen=DOC_MAXLEN, nbits=nbits, kmeans_niters=KMEANS_ITER,index_bsize=INDEX_BSIZE) + + + indexer = Indexer(checkpoint=self.checkpoint, config=config) + indexer.index(name=index_name_or_path, collection=passages, overwrite=True) + + def get_index(self,index_name_or_path:str = "Colbert-RM-",experiment_name:str="Colbert-Experiment",passages:List[str] = []): + with Run().context(RunConfig(experiment=experiment_name)): + searcher = Searcher(index=index_name_or_path, collection=passages) + self.searcher = searcher + return searcher + + def get_docs(self,searcher:Searcher,query:str,k:int=7): + + results = searcher.search( + query, + #Number of passages to receive + k=k) + #Passing the filter function of relevant + # filter_fn=lambda pids: torch.tensor( + # [pid for pid in pids if pid in relevant_ids],dtype=torch.int32).to(device)) + return results \ No newline at end of file diff --git a/dspy/__init__.py b/dspy/__init__.py index 75d233281..e63368294 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -18,6 +18,7 @@ Databricks = dsp.Databricks Cohere = dsp.Cohere ColBERTv2 = dsp.ColBERTv2 +ColBERTv2Local = dsp.ColBERTv2Local Pyserini = dsp.PyseriniRetriever Clarifai = dsp.ClarifaiLLM Google = dsp.Google diff --git a/rm_test.py b/rm_test.py index df9dd25db..03dbbd0c7 100644 --- a/rm_test.py +++ b/rm_test.py @@ -25,18 +25,141 @@ # print(retriever(["Software Internet"],by_prob=False,where={"table_name":"capexIndia"})) # print("-"*100) # print(retriever(["Software Internet","Packaging"],by_prob=False,where={"table_name":"capexIndia"})) -import dspy +# import dspy + +# colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts') +# dspy.settings.configure(rm=colbertv2_wiki17_abstracts,reranker=colbertv2_wiki17_abstracts) -colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts') -dspy.settings.configure(rm=colbertv2_wiki17_abstracts,reranker=colbertv2_wiki17_abstracts) +# #Define Retrieve Module +# retriever = dspy.RetrieveThenRerank(k=3) -#Define Retrieve Module -retriever = dspy.RetrieveThenRerank(k=3) +# query='When was the first FIFA World Cup held?' -query='When was the first FIFA World Cup held?' +# # Call the retriever on a particular query. +# topK_passages = retriever([query]) -# Call the retriever on a particular query. -topK_passages = retriever([query]) +# for idx, passage in enumerate(topK_passages): +# print(f'{idx+1}]', passage, '\n') + +import os +import dspy +os.environ['COLBERT_LOAD_TORCH_EXTENSION_VERBOSE'] = "True" +if __name__ == "__main__": + passages = [ + "The quick brown fox jumps over the lazy dog.", + "She sells seashells by the seashore.", + "I am the master of my fate, I am the captain of my soul.", + "To be or not to be, that is the question.", + "All's fair in love and war.", + "A journey of a thousand miles begins with a single step.", + "Two wrongs don't make a right.", + "The pen is mightier than the sword.", + "Actions speak louder than words.", + "Beauty is in the eye of the beholder.", + "Practice makes perfect.", + "Where there's a will, there's a way.", + "When in Rome, do as the Romans do.", + "The early bird catches the worm.", + "You can't judge a book by its cover.", + "A picture is worth a thousand words.", + "Honesty is the best policy.", + "Don't count your chickens before they hatch.", + "Every cloud has a silver lining.", + "If at first you don't succeed, try, try again.", + "Look before you leap.", + "Rome wasn't built in a day.", + "The grass is always greener on the other side.", + "Absence makes the heart grow fonder.", + "Actions speak louder than words.", + "Ask and you shall receive.", + "Better late than never.", + "Don't bite the hand that feeds you.", + "Don't put all your eggs in one basket.", + "Easy come, easy go.", + "Every dog has its day.", + "Good things come to those who wait.", + "It's a piece of cake.", + "It's raining cats and dogs.", + "Kill two birds with one stone.", + "Let sleeping dogs lie.", + "Like father, like son.", + "Make hay while the sun shines.", + "Necessity is the mother of invention.", + "Out of sight, out of mind.", + "Patience is a virtue.", + "Practice what you preach.", + "The best things in life are free.", + "The squeaky wheel gets the grease.", + "There's no place like home.", + "Too many cooks spoil the broth.", + "When the going gets tough, the tough get going.", + "You reap what you sow.", + "A watched pot never boils.", + "Actions speak louder than words.", + "An apple a day keeps the doctor away.", + "Beggars can't be choosers.", + "Curiosity killed the cat.", + "Don't cry over spilled milk.", + "Don't put off until tomorrow what you can do today.", + "Every cloud has a silver lining.", + "Fortune favors the bold.", + "If the shoe fits, wear it.", + "It takes two to tango.", + "Keep your friends close and your enemies closer.", + "Let bygones be bygones.", + "No pain, no gain.", + "Once bitten, twice shy.", + "Practice makes perfect.", + "The apple doesn't fall far from the tree.", + "The early bird catches the worm.", + "The grass is always greener on the other side.", + "The more, the merrier.", + "There's no such thing as a free lunch.", + "To kill two birds with one stone.", + "When in Rome, do as the Romans do.", + "You can't have your cake and eat it too.", + "You can't make an omelet without breaking eggs.", + "A friend in need is a friend indeed.", + "A penny saved is a penny earned.", + "Actions speak louder than words.", + "Beauty is in the eye of the beholder.", + "Better late than never.", + "Don't count your chickens before they hatch.", + "Don't put all your eggs in one basket.", + "Every cloud has a silver lining.", + "If at first you don't succeed, try, try again.", + "If you can't beat them, join them.", + "Necessity is the mother of invention.", + "One man's trash is another man's treasure.", + "Practice makes perfect.", + "The early bird catches the worm.", + "The grass is always greener on the other side.", + "There's no place like home.", + "Too many cooks spoil the broth.", + "When in Rome, do as the Romans do.", + "You can't judge a book by its cover.", + "You reap what you sow.", + "A bird in the hand is worth two in the bush.", + "A penny for your thoughts.", + "Actions speak louder than words.", + "All good things must come to an end.", + "Beauty is only skin deep.", + "Don't bite the hand that feeds you.", + "Don't put off until tomorrow what you can do today.", + "Every dog has its day.", + "Fortune favors the bold.", + "If you want something done right, do it yourself.", + "It's better to be safe than sorry.", + "Make hay while the sun shines.", + "Necessity is the mother of invention.", + "Out of sight, out of mind.", + "Practice what you preach.", + "The best things in life are free.", + "The early bird catches the worm." +] -for idx, passage in enumerate(topK_passages): - print(f'{idx+1}]', passage, '\n') \ No newline at end of file + col = dspy.ColBERTv2Local() + col.build_index(passages=passages) + searcher = col.get_index(passages=passages[:10]) + res = searcher.get_docs(searcher,query="Software",k=5) + print(res) From 1639bd214b234ff15e0b2bc31c6a24108dec04d8 Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Sun, 7 Apr 2024 22:39:36 -0400 Subject: [PATCH 05/21] colbert local modifications --- dsp/modules/colbertv2.py | 77 +++++++++++++++++++++++++++++----------- rm_test.py | 10 +++--- 2 files changed, 61 insertions(+), 26 deletions(-) diff --git a/dsp/modules/colbertv2.py b/dsp/modules/colbertv2.py index e024b6c53..7bb53c874 100644 --- a/dsp/modules/colbertv2.py +++ b/dsp/modules/colbertv2.py @@ -2,10 +2,6 @@ from typing import Any, Optional, Union, List import requests -import colbert -from colbert import Indexer, Searcher -from colbert.infra import Run, RunConfig, ColBERTConfig -from colbert.data import Queries, Collection from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory from dsp.utils import dotdict import os @@ -81,33 +77,72 @@ def colbertv2_post_request_v2_wrapped(*args, **kwargs): os.environ['COLBERT_LOAD_TORCH_EXTENSION_VERBOSE'] = "True" class ColBERTv2Local: - def __init__(self,checkpoint:str='colbert-ir/colbertv2.0'): + def __init__(self,checkpoint:str='colbert-ir/colbertv2.0',passages:List[str]=[],index_name_or_path:str = "Colbert-RM",experiment_name:str="Colbert-Experiment",load_only:bool=False,nranks:int=1,nbits:int=2,DOC_MAXLEN:int=300,INDEX_BSIZE:int=256,KMEANS_ITER:int=8): + self.checkpoint = checkpoint + self.index_name_or_path = index_name_or_path + self.experiment_name = experiment_name + self.nranks = nranks + self.nbits = nbits + self.DOC_MAXLEN = DOC_MAXLEN + self.INDEX_BSIZE = INDEX_BSIZE + self.KMEANS_ITER = KMEANS_ITER + self.passages = passages + + if not load_only: + print(f"Building the index for experiment {self.experiment_name} with index name {self.index_name_or_path}") + self.build_index() + + print(f"Loading the index for experiment {self.experiment_name} with index name {self.index_name_or_path}") + self.searcher = self.get_index() + def build_index(self): - def build_index(self,passages:List[str],nranks:int=1,index_name_or_path:str = "Colbert-RM-",nbits:int=2,DOC_MAXLEN:int=300,INDEX_BSIZE:int=256,KMEANS_ITER:int=8,experiment_name:str="Colbert-Experiment"): + try: + import colbert + except ImportError: + print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].") - with Run().context(RunConfig(nranks=nranks, experiment=experiment_name)): - config = ColBERTConfig(doc_maxlen=DOC_MAXLEN, nbits=nbits, kmeans_niters=KMEANS_ITER,index_bsize=INDEX_BSIZE) + from colbert import Indexer + from colbert.infra import Run, RunConfig, ColBERTConfig + with Run().context(RunConfig(nranks=self.nranks, experiment=self.experiment_name)): + config = ColBERTConfig(doc_maxlen=self.DOC_MAXLEN, nbits=self.nbits, kmeans_niters=self.KMEANS_ITER,index_bsize=self.INDEX_BSIZE) indexer = Indexer(checkpoint=self.checkpoint, config=config) - indexer.index(name=index_name_or_path, collection=passages, overwrite=True) + indexer.index(name=self.index_name_or_path, collection=self.passages, overwrite=True) - def get_index(self,index_name_or_path:str = "Colbert-RM-",experiment_name:str="Colbert-Experiment",passages:List[str] = []): - with Run().context(RunConfig(experiment=experiment_name)): - searcher = Searcher(index=index_name_or_path, collection=passages) - self.searcher = searcher + def get_index(self): + try: + import colbert + except ImportError: + print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].") + + from colbert import Searcher + from colbert.infra import Run, RunConfig + + with Run().context(RunConfig(experiment=self.experiment_name)): + searcher = Searcher(index=self.index_name_or_path, collection=self.passages) return searcher - def get_docs(self,searcher:Searcher,query:str,k:int=7): + def __call__(self,query:str,k:int=7,**kwargs): + try: + import colbert + except ImportError: + print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].") + import torch - results = searcher.search( - query, - #Number of passages to receive - k=k) - #Passing the filter function of relevant - # filter_fn=lambda pids: torch.tensor( - # [pid for pid in pids if pid in relevant_ids],dtype=torch.int32).to(device)) + if kwargs.get("filtered_pids"): + filtered_pids = kwargs.get("filtered_pids") + assert type(filtered_pids) == List[int], "The filtered pids should be a list of integers" + device = "cuda" if torch.cuda.is_available() else "cpu" + results = self.searcher.search( + query, + #Number of passages to receive + k=k, + #Passing the filter function of relevant + filter_fn=lambda pids: torch.tensor( + [pid for pid in pids if pid in filtered_pids],dtype=torch.int32).to(device)) + return results \ No newline at end of file diff --git a/rm_test.py b/rm_test.py index 03dbbd0c7..84343bced 100644 --- a/rm_test.py +++ b/rm_test.py @@ -158,8 +158,8 @@ "The early bird catches the worm." ] - col = dspy.ColBERTv2Local() - col.build_index(passages=passages) - searcher = col.get_index(passages=passages[:10]) - res = searcher.get_docs(searcher,query="Software",k=5) - print(res) + col = dspy.ColBERTv2Local(passages=passages) + + # searcher = col.get_index(passages=passages[:10]) + # res = searcher.get_docs(searcher,query="Software",k=5) + # print(res) From ec062b656ca1114d45c5b45def38a7b049a8c6b2 Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Sun, 7 Apr 2024 23:45:26 -0400 Subject: [PATCH 06/21] kwargs filtered ids --- dsp/modules/colbertv2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dsp/modules/colbertv2.py b/dsp/modules/colbertv2.py index 7bb53c874..f09065b00 100644 --- a/dsp/modules/colbertv2.py +++ b/dsp/modules/colbertv2.py @@ -144,5 +144,6 @@ def __call__(self,query:str,k:int=7,**kwargs): #Passing the filter function of relevant filter_fn=lambda pids: torch.tensor( [pid for pid in pids if pid in filtered_pids],dtype=torch.int32).to(device)) - + else: + results = self.searcher.search(query, k=k) return results \ No newline at end of file From 987d923b5b5f81595e00d0150567396f6e080304 Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Mon, 8 Apr 2024 01:57:42 -0400 Subject: [PATCH 07/21] colbert return --- dsp/modules/colbertv2.py | 23 ++++++++++++----------- dsp/primitives/search.py | 11 ++--------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/dsp/modules/colbertv2.py b/dsp/modules/colbertv2.py index f09065b00..7cb798ff9 100644 --- a/dsp/modules/colbertv2.py +++ b/dsp/modules/colbertv2.py @@ -77,7 +77,7 @@ def colbertv2_post_request_v2_wrapped(*args, **kwargs): os.environ['COLBERT_LOAD_TORCH_EXTENSION_VERBOSE'] = "True" class ColBERTv2Local: - def __init__(self,checkpoint:str='colbert-ir/colbertv2.0',passages:List[str]=[],index_name_or_path:str = "Colbert-RM",experiment_name:str="Colbert-Experiment",load_only:bool=False,nranks:int=1,nbits:int=2,DOC_MAXLEN:int=300,INDEX_BSIZE:int=256,KMEANS_ITER:int=8): + def __init__(self,checkpoint:str='colbert-ir/colbertv2.0',passages:List[str]=[],index_name_or_path:str = "Colbert-RM",experiment_name:str="Colbert-Experiment",load_only:bool=False,nranks:int=1,nbits:int=2,DOC_MAXLEN:int=300,INDEX_BSIZE:int=256,KMEANS_ITER:int=8,**kwargs): self.checkpoint = checkpoint @@ -92,12 +92,12 @@ def __init__(self,checkpoint:str='colbert-ir/colbertv2.0',passages:List[str]=[], if not load_only: print(f"Building the index for experiment {self.experiment_name} with index name {self.index_name_or_path}") - self.build_index() + self.build_index(**kwargs) print(f"Loading the index for experiment {self.experiment_name} with index name {self.index_name_or_path}") self.searcher = self.get_index() - def build_index(self): + def build_index(self,**kwargs): try: import colbert @@ -107,9 +107,7 @@ def build_index(self): from colbert import Indexer from colbert.infra import Run, RunConfig, ColBERTConfig with Run().context(RunConfig(nranks=self.nranks, experiment=self.experiment_name)): - config = ColBERTConfig(doc_maxlen=self.DOC_MAXLEN, nbits=self.nbits, kmeans_niters=self.KMEANS_ITER,index_bsize=self.INDEX_BSIZE) - - + config = ColBERTConfig(doc_maxlen=self.DOC_MAXLEN, nbits=self.nbits, kmeans_niters=self.KMEANS_ITER,index_bsize=self.INDEX_BSIZE,**kwargs) indexer = Indexer(checkpoint=self.checkpoint, config=config) indexer.index(name=self.index_name_or_path, collection=self.passages, overwrite=True) @@ -127,10 +125,6 @@ def get_index(self): return searcher def __call__(self,query:str,k:int=7,**kwargs): - try: - import colbert - except ImportError: - print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].") import torch if kwargs.get("filtered_pids"): @@ -146,4 +140,11 @@ def __call__(self,query:str,k:int=7,**kwargs): [pid for pid in pids if pid in filtered_pids],dtype=torch.int32).to(device)) else: results = self.searcher.search(query, k=k) - return results \ No newline at end of file + passage_ids = [] + passage_score = [] + passages = [] + for pid,_,score in zip(*results): + passage_ids.append(pid) + passage_score.append(score) + passages.append(self.searcher.collection[pid]) + return passage_ids,passage_score,passages \ No newline at end of file diff --git a/dsp/primitives/search.py b/dsp/primitives/search.py index 42338cc04..0cd92338c 100644 --- a/dsp/primitives/search.py +++ b/dsp/primitives/search.py @@ -37,15 +37,8 @@ def retrieveRerankEnsemble(queries: list[str], k: int,**kwargs) -> list[str]: passages = [] retrieved_passages = dsp.settings.rm(query, k=k*3,**kwargs) # passages_cs_scores = dsp.settings.reranker(query,k=k,passages=[psg["long_text"] for psg in retrieved_passages]) - passages_cs_scores = dsp.settings.reranker(query,k=k) - # passages_cs_scores = dsp.settings.reranker(query, retrieved_passages) - for idx in np.argsort(passages_cs_scores)[::-1]: - psg = retrieved_passages[idx] - # passages[psg.long_text] = passages.get(psg.long_text, []) + [ - # passages_cs_scores[idx], - # ] - # passages.append((psg,passages_cs_scores[idx])) - passages.append([passages_cs_scores[idx],psg]) + passage_ids,passage_scores,passages = dsp.settings.reranker(query,k=k) + # all_queries_passages.append(passages) # passages = [(np.average(score), psg) for score,psg in passages] From 9ff5b28356b254ee4989d8acea0817a7dfb37758 Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Tue, 9 Apr 2024 00:16:57 -0400 Subject: [PATCH 08/21] colbert retriever and reranker --- dsp/modules/__init__.py | 2 +- dsp/modules/colbertv2.py | 50 +++++-- dsp/primitives/search.py | 13 +- dspy/__init__.py | 3 +- dspy/retrieve/retrieve.py | 24 ++-- rm_test.py | 275 +++++++++++++++++++------------------- 6 files changed, 202 insertions(+), 165 deletions(-) diff --git a/dsp/modules/__init__.py b/dsp/modules/__init__.py index 235dbfcc7..a9940f824 100644 --- a/dsp/modules/__init__.py +++ b/dsp/modules/__init__.py @@ -4,7 +4,7 @@ from .cache_utils import * from .clarifai import * from .cohere import * -from .colbertv2 import ColBERTv2, ColBERTv2Local +from .colbertv2 import ColBERTv2, ColBERTv2RetrieverLocal,ColBERTv2RerankerLocal from .databricks import * from .google import * from .gpt3 import * diff --git a/dsp/modules/colbertv2.py b/dsp/modules/colbertv2.py index 7cb798ff9..04c7fad3b 100644 --- a/dsp/modules/colbertv2.py +++ b/dsp/modules/colbertv2.py @@ -76,7 +76,7 @@ def colbertv2_post_request_v2_wrapped(*args, **kwargs): colbertv2_post_request = colbertv2_post_request_v2_wrapped os.environ['COLBERT_LOAD_TORCH_EXTENSION_VERBOSE'] = "True" -class ColBERTv2Local: +class ColBERTv2RetrieverLocal: def __init__(self,checkpoint:str='colbert-ir/colbertv2.0',passages:List[str]=[],index_name_or_path:str = "Colbert-RM",experiment_name:str="Colbert-Experiment",load_only:bool=False,nranks:int=1,nbits:int=2,DOC_MAXLEN:int=300,INDEX_BSIZE:int=256,KMEANS_ITER:int=8,**kwargs): @@ -124,6 +124,7 @@ def get_index(self): searcher = Searcher(index=self.index_name_or_path, collection=self.passages) return searcher + @CacheMemory.cache def __call__(self,query:str,k:int=7,**kwargs): import torch @@ -140,11 +141,44 @@ def __call__(self,query:str,k:int=7,**kwargs): [pid for pid in pids if pid in filtered_pids],dtype=torch.int32).to(device)) else: results = self.searcher.search(query, k=k) - passage_ids = [] - passage_score = [] - passages = [] + results = [] for pid,_,score in zip(*results): - passage_ids.append(pid) - passage_score.append(score) - passages.append(self.searcher.collection[pid]) - return passage_ids,passage_score,passages \ No newline at end of file + results.append(dotdict({'long_text':self.searcher.collection[pid],'score':score,'pid':pid})) + return results + +class ColBERTv2RerankerLocal: + try: + import colbert + except ImportError: + print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].") + from colbert.infra.config.config import ColBERTConfig + + def __init__(self,checkpoint_name:str='bert-base-uncased',colbert_config:ColBERTConfig=None): + self.colbert_config = colbert_config + self.checkpoint_name = checkpoint_name + self.colbert_config.checkpoint = checkpoint_name + + # def __call__(self, *args: Any, **kwargs: Any) -> Any: + # return self.forward(*args, **kwargs) + + def __call__(self,query:str,passages:List[str]=[]): + from colbert.modeling.tokenization.doc_tokenization import DocTokenizer + from colbert.modeling.tokenization.query_tokenization import QueryTokenizer + from colbert.modeling.colbert import ColBERT + import numpy as np + assert len(passages) > 0, "Passages should not be empty" + self.colbert_config.nway = len(passages) + query_tokenizer = QueryTokenizer(self.colbert_config,verbose=1) + doc_tokenizer = DocTokenizer(self.colbert_config) + query_ids,query_masks = query_tokenizer.tensorize([query]) + doc_ids,doc_masks = doc_tokenizer.tensorize(passages) + + col = ColBERT(self.checkpoint_name,self.colbert_config) + # col.colbert_config.nway = len(passages) + # tensor_scores = col([query_ids,query_masks],[doc_ids,doc_masks]) + Q = col.query(query_ids,query_masks) + DOC_IDS,DOC_MASKS = col.doc(doc_ids,doc_masks,keep_dims='return_mask') + Q_duplicated = Q.repeat_interleave(len(passages), dim=0).contiguous() + tensor_scores = col.score(Q_duplicated,DOC_IDS,DOC_MASKS) + passage_score_arr = np.array([score.cpu().detach().numpy().tolist() for score in tensor_scores]) + return passage_score_arr \ No newline at end of file diff --git a/dsp/primitives/search.py b/dsp/primitives/search.py index 0cd92338c..12a6dd828 100644 --- a/dsp/primitives/search.py +++ b/dsp/primitives/search.py @@ -36,13 +36,12 @@ def retrieveRerankEnsemble(queries: list[str], k: int,**kwargs) -> list[str]: for query in queries: passages = [] retrieved_passages = dsp.settings.rm(query, k=k*3,**kwargs) - # passages_cs_scores = dsp.settings.reranker(query,k=k,passages=[psg["long_text"] for psg in retrieved_passages]) - passage_ids,passage_scores,passages = dsp.settings.reranker(query,k=k) - - # all_queries_passages.append(passages) - - # passages = [(np.average(score), psg) for score,psg in passages] - all_queries_passages.append(passages[:k]) + passages_cs_scores = dsp.settings.reranker(query,passages=[psg["long_text"] for psg in retrieved_passages]) + for idx in np.argsort(passages_cs_scores)[::-1][:k]: + curr_passage = retrieved_passages[idx] + curr_passage['rerank_score'] = passages_cs_scores[idx] + passages.append(curr_passage) + all_queries_passages.append(passages) if len(queries) == 1: return all_queries_passages[0] else: diff --git a/dspy/__init__.py b/dspy/__init__.py index e63368294..f8d3baf40 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -18,7 +18,8 @@ Databricks = dsp.Databricks Cohere = dsp.Cohere ColBERTv2 = dsp.ColBERTv2 -ColBERTv2Local = dsp.ColBERTv2Local +ColBERTv2RerankerLocal = dsp.ColBERTv2RerankerLocal +ColBERTv2RetrieverLocal = dsp.ColBERTv2RetrieverLocal Pyserini = dsp.PyseriniRetriever Clarifai = dsp.ClarifaiLLM Google = dsp.Google diff --git a/dspy/retrieve/retrieve.py b/dspy/retrieve/retrieve.py index e719a8357..cd13ca699 100644 --- a/dspy/retrieve/retrieve.py +++ b/dspy/retrieve/retrieve.py @@ -83,10 +83,10 @@ def load_state(self, state): for name, value in state.items(): setattr(self, name, value) - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) + # def __call__(self, *args, **kwargs): + # return self.forward(*args, **kwargs) - def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None,**kwargs) -> Union[Prediction,List[Prediction]]: + def __call__(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None,**kwargs) -> Union[Prediction,List[Prediction]]: queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries queries = [query.strip().split('\n')[0].strip() for query in queries] @@ -97,23 +97,21 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = No if isinstance(passages[0],List): pred_returns = [] for query_passages in passages: - passages_dict = {key:[] for key in list(query_passages[0].keys()) if key!="tracking_idx"} - for psg in query_passages: - for key,value in psg.items(): - if key == "tracking_idx": continue + passages_dict = {key:[] for key in list(query_passages[0].keys())} + for docs in query_passages: + for key,value in docs.items(): passages_dict[key].append(value) if "long_text" in passages_dict: - passages_dict["passages"] = passages_dict.pop("long_text") + passages_dict["passages"] = passages_dict.pop("long_text") + pred_returns.append(Prediction(**passages_dict)) return pred_returns elif isinstance(passages[0], Dict): - #passages dict will contain {"long_text":long_text_list,"metadatas";metadatas_list...} passages_dict = {key:[] for key in list(passages[0].keys())} - - for psg in passages: - for key,value in psg.items(): + for docs in passages: + for key,value in docs.items(): passages_dict[key].append(value) if "long_text" in passages_dict: passages_dict["passages"] = passages_dict.pop("long_text") return Prediction(**passages_dict) - \ No newline at end of file + \ No newline at end of file diff --git a/rm_test.py b/rm_test.py index 84343bced..f2364d655 100644 --- a/rm_test.py +++ b/rm_test.py @@ -1,24 +1,29 @@ -# import dspy.retrieve -# from dspy.retrieve.chromadb_rm import ChromadbRM -# from dotenv import load_dotenv -# import chromadb.utils.embedding_functions as embedding_functions -# from chromadb.utils.batch_utils import create_batches -# import os -# import dspy +import dspy.retrieve +from dspy.retrieve.chromadb_rm import ChromadbRM +from dotenv import load_dotenv +import chromadb.utils.embedding_functions as embedding_functions +from chromadb.utils.batch_utils import create_batches +import os +import dspy +from colbert.infra.config.config import ColBERTConfig + +load_dotenv(override=True) +emb_fn = embedding_functions.OpenAIEmbeddingFunction( + api_key=os.environ['OPENAI_API_KEY'], + model_name="text-embedding-3-small") +crm = ChromadbRM( + collection_name="rows", + persist_directory="/home/athekunal/DSPy-contributions/Text-to-SQL/India_TABLE", + embedding_function=emb_fn +) +reranker = dspy.ColBERTv2RerankerLocal(checkpoint_name='colbert-ir/colbertv2.0',colbert_config=ColBERTConfig()) +dspy.settings.configure(rm=crm,reranker=reranker) -# load_dotenv(override=True) -# emb_fn = embedding_functions.OpenAIEmbeddingFunction( -# api_key=os.environ['OPENAI_API_KEY'], -# model_name="text-embedding-3-small") +ret_rerank = dspy.RetrieveThenRerank(k=3) -# crm = ChromadbRM( -# collection_name="rows", -# persist_directory="/home/athekunal/DSPy-contributions/Text-to-SQL/India_TABLE", -# embedding_function=emb_fn -# ) -# # reranker = dspy. -# dspy.settings.configure(rm=crm) +print(ret_rerank(["Software Internet","Packaging and Container"],k=3)) +print(ret_rerank(["Software Internet"],k=3)) # retriever = dspy.Retrieve(k=2) @@ -41,124 +46,124 @@ # for idx, passage in enumerate(topK_passages): # print(f'{idx+1}]', passage, '\n') -import os -import dspy -os.environ['COLBERT_LOAD_TORCH_EXTENSION_VERBOSE'] = "True" -if __name__ == "__main__": - passages = [ - "The quick brown fox jumps over the lazy dog.", - "She sells seashells by the seashore.", - "I am the master of my fate, I am the captain of my soul.", - "To be or not to be, that is the question.", - "All's fair in love and war.", - "A journey of a thousand miles begins with a single step.", - "Two wrongs don't make a right.", - "The pen is mightier than the sword.", - "Actions speak louder than words.", - "Beauty is in the eye of the beholder.", - "Practice makes perfect.", - "Where there's a will, there's a way.", - "When in Rome, do as the Romans do.", - "The early bird catches the worm.", - "You can't judge a book by its cover.", - "A picture is worth a thousand words.", - "Honesty is the best policy.", - "Don't count your chickens before they hatch.", - "Every cloud has a silver lining.", - "If at first you don't succeed, try, try again.", - "Look before you leap.", - "Rome wasn't built in a day.", - "The grass is always greener on the other side.", - "Absence makes the heart grow fonder.", - "Actions speak louder than words.", - "Ask and you shall receive.", - "Better late than never.", - "Don't bite the hand that feeds you.", - "Don't put all your eggs in one basket.", - "Easy come, easy go.", - "Every dog has its day.", - "Good things come to those who wait.", - "It's a piece of cake.", - "It's raining cats and dogs.", - "Kill two birds with one stone.", - "Let sleeping dogs lie.", - "Like father, like son.", - "Make hay while the sun shines.", - "Necessity is the mother of invention.", - "Out of sight, out of mind.", - "Patience is a virtue.", - "Practice what you preach.", - "The best things in life are free.", - "The squeaky wheel gets the grease.", - "There's no place like home.", - "Too many cooks spoil the broth.", - "When the going gets tough, the tough get going.", - "You reap what you sow.", - "A watched pot never boils.", - "Actions speak louder than words.", - "An apple a day keeps the doctor away.", - "Beggars can't be choosers.", - "Curiosity killed the cat.", - "Don't cry over spilled milk.", - "Don't put off until tomorrow what you can do today.", - "Every cloud has a silver lining.", - "Fortune favors the bold.", - "If the shoe fits, wear it.", - "It takes two to tango.", - "Keep your friends close and your enemies closer.", - "Let bygones be bygones.", - "No pain, no gain.", - "Once bitten, twice shy.", - "Practice makes perfect.", - "The apple doesn't fall far from the tree.", - "The early bird catches the worm.", - "The grass is always greener on the other side.", - "The more, the merrier.", - "There's no such thing as a free lunch.", - "To kill two birds with one stone.", - "When in Rome, do as the Romans do.", - "You can't have your cake and eat it too.", - "You can't make an omelet without breaking eggs.", - "A friend in need is a friend indeed.", - "A penny saved is a penny earned.", - "Actions speak louder than words.", - "Beauty is in the eye of the beholder.", - "Better late than never.", - "Don't count your chickens before they hatch.", - "Don't put all your eggs in one basket.", - "Every cloud has a silver lining.", - "If at first you don't succeed, try, try again.", - "If you can't beat them, join them.", - "Necessity is the mother of invention.", - "One man's trash is another man's treasure.", - "Practice makes perfect.", - "The early bird catches the worm.", - "The grass is always greener on the other side.", - "There's no place like home.", - "Too many cooks spoil the broth.", - "When in Rome, do as the Romans do.", - "You can't judge a book by its cover.", - "You reap what you sow.", - "A bird in the hand is worth two in the bush.", - "A penny for your thoughts.", - "Actions speak louder than words.", - "All good things must come to an end.", - "Beauty is only skin deep.", - "Don't bite the hand that feeds you.", - "Don't put off until tomorrow what you can do today.", - "Every dog has its day.", - "Fortune favors the bold.", - "If you want something done right, do it yourself.", - "It's better to be safe than sorry.", - "Make hay while the sun shines.", - "Necessity is the mother of invention.", - "Out of sight, out of mind.", - "Practice what you preach.", - "The best things in life are free.", - "The early bird catches the worm." -] +# import os +# import dspy +# os.environ['COLBERT_LOAD_TORCH_EXTENSION_VERBOSE'] = "True" +# if __name__ == "__main__": +# passages = [ +# "The quick brown fox jumps over the lazy dog.", +# "She sells seashells by the seashore.", +# "I am the master of my fate, I am the captain of my soul.", +# "To be or not to be, that is the question.", +# "All's fair in love and war.", +# "A journey of a thousand miles begins with a single step.", +# "Two wrongs don't make a right.", +# "The pen is mightier than the sword.", +# "Actions speak louder than words.", +# "Beauty is in the eye of the beholder.", +# "Practice makes perfect.", +# "Where there's a will, there's a way.", +# "When in Rome, do as the Romans do.", +# "The early bird catches the worm.", +# "You can't judge a book by its cover.", +# "A picture is worth a thousand words.", +# "Honesty is the best policy.", +# "Don't count your chickens before they hatch.", +# "Every cloud has a silver lining.", +# "If at first you don't succeed, try, try again.", +# "Look before you leap.", +# "Rome wasn't built in a day.", +# "The grass is always greener on the other side.", +# "Absence makes the heart grow fonder.", +# "Actions speak louder than words.", +# "Ask and you shall receive.", +# "Better late than never.", +# "Don't bite the hand that feeds you.", +# "Don't put all your eggs in one basket.", +# "Easy come, easy go.", +# "Every dog has its day.", +# "Good things come to those who wait.", +# "It's a piece of cake.", +# "It's raining cats and dogs.", +# "Kill two birds with one stone.", +# "Let sleeping dogs lie.", +# "Like father, like son.", +# "Make hay while the sun shines.", +# "Necessity is the mother of invention.", +# "Out of sight, out of mind.", +# "Patience is a virtue.", +# "Practice what you preach.", +# "The best things in life are free.", +# "The squeaky wheel gets the grease.", +# "There's no place like home.", +# "Too many cooks spoil the broth.", +# "When the going gets tough, the tough get going.", +# "You reap what you sow.", +# "A watched pot never boils.", +# "Actions speak louder than words.", +# "An apple a day keeps the doctor away.", +# "Beggars can't be choosers.", +# "Curiosity killed the cat.", +# "Don't cry over spilled milk.", +# "Don't put off until tomorrow what you can do today.", +# "Every cloud has a silver lining.", +# "Fortune favors the bold.", +# "If the shoe fits, wear it.", +# "It takes two to tango.", +# "Keep your friends close and your enemies closer.", +# "Let bygones be bygones.", +# "No pain, no gain.", +# "Once bitten, twice shy.", +# "Practice makes perfect.", +# "The apple doesn't fall far from the tree.", +# "The early bird catches the worm.", +# "The grass is always greener on the other side.", +# "The more, the merrier.", +# "There's no such thing as a free lunch.", +# "To kill two birds with one stone.", +# "When in Rome, do as the Romans do.", +# "You can't have your cake and eat it too.", +# "You can't make an omelet without breaking eggs.", +# "A friend in need is a friend indeed.", +# "A penny saved is a penny earned.", +# "Actions speak louder than words.", +# "Beauty is in the eye of the beholder.", +# "Better late than never.", +# "Don't count your chickens before they hatch.", +# "Don't put all your eggs in one basket.", +# "Every cloud has a silver lining.", +# "If at first you don't succeed, try, try again.", +# "If you can't beat them, join them.", +# "Necessity is the mother of invention.", +# "One man's trash is another man's treasure.", +# "Practice makes perfect.", +# "The early bird catches the worm.", +# "The grass is always greener on the other side.", +# "There's no place like home.", +# "Too many cooks spoil the broth.", +# "When in Rome, do as the Romans do.", +# "You can't judge a book by its cover.", +# "You reap what you sow.", +# "A bird in the hand is worth two in the bush.", +# "A penny for your thoughts.", +# "Actions speak louder than words.", +# "All good things must come to an end.", +# "Beauty is only skin deep.", +# "Don't bite the hand that feeds you.", +# "Don't put off until tomorrow what you can do today.", +# "Every dog has its day.", +# "Fortune favors the bold.", +# "If you want something done right, do it yourself.", +# "It's better to be safe than sorry.", +# "Make hay while the sun shines.", +# "Necessity is the mother of invention.", +# "Out of sight, out of mind.", +# "Practice what you preach.", +# "The best things in life are free.", +# "The early bird catches the worm." +# ] - col = dspy.ColBERTv2Local(passages=passages) +# col = dspy.ColBERTv2Local(passages=passages) # searcher = col.get_index(passages=passages[:10]) # res = searcher.get_docs(searcher,query="Software",k=5) From 825a272aeba8fda71056180dc13900388355776e Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Tue, 9 Apr 2024 19:46:40 -0400 Subject: [PATCH 09/21] colbert retriever error fixes --- dsp/modules/colbertv2.py | 5 +- rm_test.py | 170 --------------------------------------- 2 files changed, 2 insertions(+), 173 deletions(-) delete mode 100644 rm_test.py diff --git a/dsp/modules/colbertv2.py b/dsp/modules/colbertv2.py index 04c7fad3b..9700a2d48 100644 --- a/dsp/modules/colbertv2.py +++ b/dsp/modules/colbertv2.py @@ -124,7 +124,6 @@ def get_index(self): searcher = Searcher(index=self.index_name_or_path, collection=self.passages) return searcher - @CacheMemory.cache def __call__(self,query:str,k:int=7,**kwargs): import torch @@ -140,9 +139,9 @@ def __call__(self,query:str,k:int=7,**kwargs): filter_fn=lambda pids: torch.tensor( [pid for pid in pids if pid in filtered_pids],dtype=torch.int32).to(device)) else: - results = self.searcher.search(query, k=k) + searcher_results = self.searcher.search(query, k=k) results = [] - for pid,_,score in zip(*results): + for pid,_,score in zip(*searcher_results): results.append(dotdict({'long_text':self.searcher.collection[pid],'score':score,'pid':pid})) return results diff --git a/rm_test.py b/rm_test.py deleted file mode 100644 index f2364d655..000000000 --- a/rm_test.py +++ /dev/null @@ -1,170 +0,0 @@ -import dspy.retrieve -from dspy.retrieve.chromadb_rm import ChromadbRM -from dotenv import load_dotenv -import chromadb.utils.embedding_functions as embedding_functions -from chromadb.utils.batch_utils import create_batches -import os -import dspy -from colbert.infra.config.config import ColBERTConfig - -load_dotenv(override=True) -emb_fn = embedding_functions.OpenAIEmbeddingFunction( - api_key=os.environ['OPENAI_API_KEY'], - model_name="text-embedding-3-small") - -crm = ChromadbRM( - collection_name="rows", - persist_directory="/home/athekunal/DSPy-contributions/Text-to-SQL/India_TABLE", - embedding_function=emb_fn -) -reranker = dspy.ColBERTv2RerankerLocal(checkpoint_name='colbert-ir/colbertv2.0',colbert_config=ColBERTConfig()) -dspy.settings.configure(rm=crm,reranker=reranker) - -ret_rerank = dspy.RetrieveThenRerank(k=3) - -print(ret_rerank(["Software Internet","Packaging and Container"],k=3)) -print(ret_rerank(["Software Internet"],k=3)) - -# retriever = dspy.Retrieve(k=2) - -# print(retriever(["Software Internet"],by_prob=False,where={"table_name":"capexIndia"})) -# print("-"*100) -# print(retriever(["Software Internet","Packaging"],by_prob=False,where={"table_name":"capexIndia"})) -# import dspy - -# colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts') -# dspy.settings.configure(rm=colbertv2_wiki17_abstracts,reranker=colbertv2_wiki17_abstracts) - -# #Define Retrieve Module -# retriever = dspy.RetrieveThenRerank(k=3) - -# query='When was the first FIFA World Cup held?' - -# # Call the retriever on a particular query. -# topK_passages = retriever([query]) - -# for idx, passage in enumerate(topK_passages): -# print(f'{idx+1}]', passage, '\n') - -# import os -# import dspy -# os.environ['COLBERT_LOAD_TORCH_EXTENSION_VERBOSE'] = "True" -# if __name__ == "__main__": -# passages = [ -# "The quick brown fox jumps over the lazy dog.", -# "She sells seashells by the seashore.", -# "I am the master of my fate, I am the captain of my soul.", -# "To be or not to be, that is the question.", -# "All's fair in love and war.", -# "A journey of a thousand miles begins with a single step.", -# "Two wrongs don't make a right.", -# "The pen is mightier than the sword.", -# "Actions speak louder than words.", -# "Beauty is in the eye of the beholder.", -# "Practice makes perfect.", -# "Where there's a will, there's a way.", -# "When in Rome, do as the Romans do.", -# "The early bird catches the worm.", -# "You can't judge a book by its cover.", -# "A picture is worth a thousand words.", -# "Honesty is the best policy.", -# "Don't count your chickens before they hatch.", -# "Every cloud has a silver lining.", -# "If at first you don't succeed, try, try again.", -# "Look before you leap.", -# "Rome wasn't built in a day.", -# "The grass is always greener on the other side.", -# "Absence makes the heart grow fonder.", -# "Actions speak louder than words.", -# "Ask and you shall receive.", -# "Better late than never.", -# "Don't bite the hand that feeds you.", -# "Don't put all your eggs in one basket.", -# "Easy come, easy go.", -# "Every dog has its day.", -# "Good things come to those who wait.", -# "It's a piece of cake.", -# "It's raining cats and dogs.", -# "Kill two birds with one stone.", -# "Let sleeping dogs lie.", -# "Like father, like son.", -# "Make hay while the sun shines.", -# "Necessity is the mother of invention.", -# "Out of sight, out of mind.", -# "Patience is a virtue.", -# "Practice what you preach.", -# "The best things in life are free.", -# "The squeaky wheel gets the grease.", -# "There's no place like home.", -# "Too many cooks spoil the broth.", -# "When the going gets tough, the tough get going.", -# "You reap what you sow.", -# "A watched pot never boils.", -# "Actions speak louder than words.", -# "An apple a day keeps the doctor away.", -# "Beggars can't be choosers.", -# "Curiosity killed the cat.", -# "Don't cry over spilled milk.", -# "Don't put off until tomorrow what you can do today.", -# "Every cloud has a silver lining.", -# "Fortune favors the bold.", -# "If the shoe fits, wear it.", -# "It takes two to tango.", -# "Keep your friends close and your enemies closer.", -# "Let bygones be bygones.", -# "No pain, no gain.", -# "Once bitten, twice shy.", -# "Practice makes perfect.", -# "The apple doesn't fall far from the tree.", -# "The early bird catches the worm.", -# "The grass is always greener on the other side.", -# "The more, the merrier.", -# "There's no such thing as a free lunch.", -# "To kill two birds with one stone.", -# "When in Rome, do as the Romans do.", -# "You can't have your cake and eat it too.", -# "You can't make an omelet without breaking eggs.", -# "A friend in need is a friend indeed.", -# "A penny saved is a penny earned.", -# "Actions speak louder than words.", -# "Beauty is in the eye of the beholder.", -# "Better late than never.", -# "Don't count your chickens before they hatch.", -# "Don't put all your eggs in one basket.", -# "Every cloud has a silver lining.", -# "If at first you don't succeed, try, try again.", -# "If you can't beat them, join them.", -# "Necessity is the mother of invention.", -# "One man's trash is another man's treasure.", -# "Practice makes perfect.", -# "The early bird catches the worm.", -# "The grass is always greener on the other side.", -# "There's no place like home.", -# "Too many cooks spoil the broth.", -# "When in Rome, do as the Romans do.", -# "You can't judge a book by its cover.", -# "You reap what you sow.", -# "A bird in the hand is worth two in the bush.", -# "A penny for your thoughts.", -# "Actions speak louder than words.", -# "All good things must come to an end.", -# "Beauty is only skin deep.", -# "Don't bite the hand that feeds you.", -# "Don't put off until tomorrow what you can do today.", -# "Every dog has its day.", -# "Fortune favors the bold.", -# "If you want something done right, do it yourself.", -# "It's better to be safe than sorry.", -# "Make hay while the sun shines.", -# "Necessity is the mother of invention.", -# "Out of sight, out of mind.", -# "Practice what you preach.", -# "The best things in life are free.", -# "The early bird catches the worm." -# ] - -# col = dspy.ColBERTv2Local(passages=passages) - - # searcher = col.get_index(passages=passages[:10]) - # res = searcher.get_docs(searcher,query="Software",k=5) - # print(res) From c25e9c44ed3202e6b770bc07225d5c5ed39fa5ef Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Tue, 9 Apr 2024 20:14:38 -0400 Subject: [PATCH 10/21] colbert config changes in __init__ --- dsp/modules/colbertv2.py | 42 +++--- .../integrations/colbert/colbert_local.ipynb | 127 ++++++++++++++++++ 2 files changed, 149 insertions(+), 20 deletions(-) create mode 100644 examples/integrations/colbert/colbert_local.ipynb diff --git a/dsp/modules/colbertv2.py b/dsp/modules/colbertv2.py index 9700a2d48..25289cc43 100644 --- a/dsp/modules/colbertv2.py +++ b/dsp/modules/colbertv2.py @@ -77,27 +77,30 @@ def colbertv2_post_request_v2_wrapped(*args, **kwargs): os.environ['COLBERT_LOAD_TORCH_EXTENSION_VERBOSE'] = "True" class ColBERTv2RetrieverLocal: - def __init__(self,checkpoint:str='colbert-ir/colbertv2.0',passages:List[str]=[],index_name_or_path:str = "Colbert-RM",experiment_name:str="Colbert-Experiment",load_only:bool=False,nranks:int=1,nbits:int=2,DOC_MAXLEN:int=300,INDEX_BSIZE:int=256,KMEANS_ITER:int=8,**kwargs): - - + from colbert.infra import Run, RunConfig, ColBERTConfig + def __init__(self,passages:List[str],load_only:bool=False,checkpoint:str='colbert-ir/colbertv2.0',colbert_config:ColBERTConfig=ColBERTConfig()): + """Colbertv2 retriever module + + Args: + passages (List[str]): list of passages + load_only (bool, optional): whether to load the index or . Defaults to False. + checkpoint (str, optional): checkpoint for generating embeddings. Defaults to 'colbert-ir/colbertv2.0'. + colbert_config (ColBERTConfig, optional): colbert config for building and searching. Defaults to ColBERTConfig(). + """ self.checkpoint = checkpoint - self.index_name_or_path = index_name_or_path - self.experiment_name = experiment_name - self.nranks = nranks - self.nbits = nbits - self.DOC_MAXLEN = DOC_MAXLEN - self.INDEX_BSIZE = INDEX_BSIZE - self.KMEANS_ITER = KMEANS_ITER + self.colbert_config = colbert_config + self.checkpoint = checkpoint + self.colbert_config.checkpoint = checkpoint self.passages = passages if not load_only: - print(f"Building the index for experiment {self.experiment_name} with index name {self.index_name_or_path}") - self.build_index(**kwargs) + print(f"Building the index for experiment {self.colbert_config.experiment} with index name {self.colbert_config.index_name}") + self.build_index() - print(f"Loading the index for experiment {self.experiment_name} with index name {self.index_name_or_path}") + print(f"Loading the index for experiment {self.experiment} with index name {self.index_name}") self.searcher = self.get_index() - def build_index(self,**kwargs): + def build_index(self): try: import colbert @@ -105,11 +108,10 @@ def build_index(self,**kwargs): print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].") from colbert import Indexer - from colbert.infra import Run, RunConfig, ColBERTConfig - with Run().context(RunConfig(nranks=self.nranks, experiment=self.experiment_name)): - config = ColBERTConfig(doc_maxlen=self.DOC_MAXLEN, nbits=self.nbits, kmeans_niters=self.KMEANS_ITER,index_bsize=self.INDEX_BSIZE,**kwargs) - indexer = Indexer(checkpoint=self.checkpoint, config=config) - indexer.index(name=self.index_name_or_path, collection=self.passages, overwrite=True) + from colbert.infra import Run, RunConfig + with Run().context(RunConfig(nranks=self.colbert_config.nranks, experiment=self.colbert_config.experiment)): + indexer = Indexer(checkpoint=self.checkpoint, config=self.colbert_config) + indexer.index(name=self.colbert_config.index_name, collection=self.passages, overwrite=True) def get_index(self): try: @@ -152,7 +154,7 @@ class ColBERTv2RerankerLocal: print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].") from colbert.infra.config.config import ColBERTConfig - def __init__(self,checkpoint_name:str='bert-base-uncased',colbert_config:ColBERTConfig=None): + def __init__(self,checkpoint_name:str='bert-base-uncased',colbert_config:ColBERTConfig=ColBERTConfig()): self.colbert_config = colbert_config self.checkpoint_name = checkpoint_name self.colbert_config.checkpoint = checkpoint_name diff --git a/examples/integrations/colbert/colbert_local.ipynb b/examples/integrations/colbert/colbert_local.ipynb new file mode 100644 index 000000000..060ec7a61 --- /dev/null +++ b/examples/integrations/colbert/colbert_local.ipynb @@ -0,0 +1,127 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ColBERTConfig(query_token_id='[unused0]', doc_token_id='[unused1]', query_token='[Q]', doc_token='[D]', ncells=None, centroid_score_threshold=None, ndocs=None, load_index_with_mmap=False, index_path=None, index_bsize=64, nbits=1, kmeans_niters=4, resume=False, similarity='cosine', bsize=32, accumsteps=1, lr=3e-06, maxsteps=500000, save_every=None, warmup=None, warmup_bert=None, relu=False, nway=2, use_ib_negatives=False, reranker=False, distillation_alpha=1.0, ignore_scores=False, model_name=None, query_maxlen=32, attend_to_mask_tokens=False, interaction='colbert', dim=128, doc_maxlen=220, mask_punctuation=True, checkpoint=None, triples=None, collection=None, queries=None, index_name=None, overwrite=False, root='/home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments', experiment='default', index_root=None, name='2024-04/09/19.53.42', rank=0, nranks=1, amp=True, gpus=1, avoid_fork_if_possible=False)\n" + ] + } + ], + "source": [ + "from colbert.infra.config import ColBERTConfig\n", + "\n", + "print(ColBERTConfig())" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "query_token_id --> [unused0]\n", + "doc_token_id --> [unused1]\n", + "query_token --> [Q]\n", + "doc_token --> [D]\n", + "ncells --> None\n", + "centroid_score_threshold --> None\n", + "ndocs --> None\n", + "load_index_with_mmap --> False\n", + "index_path --> None\n", + "index_bsize --> 64\n", + "nbits --> 1\n", + "kmeans_niters --> 4\n", + "resume --> False\n", + "similarity --> cosine\n", + "bsize --> 32\n", + "accumsteps --> 1\n", + "lr --> 3e-06\n", + "maxsteps --> 500000\n", + "save_every --> None\n", + "warmup --> None\n", + "warmup_bert --> None\n", + "relu --> False\n", + "nway --> 2\n", + "use_ib_negatives --> False\n", + "reranker --> False\n", + "distillation_alpha --> 1.0\n", + "ignore_scores --> False\n", + "model_name --> None\n", + "query_maxlen --> 32\n", + "attend_to_mask_tokens --> False\n", + "interaction --> colbert\n", + "dim --> 128\n", + "doc_maxlen --> 220\n", + "mask_punctuation --> True\n", + "checkpoint --> None\n", + "triples --> None\n", + "collection --> None\n", + "queries --> None\n", + "index_name --> None\n", + "overwrite --> False\n", + "root --> /home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments\n", + "experiment --> default\n", + "index_root --> None\n", + "name --> 2024-04/09/19.53.42\n", + "rank --> 0\n", + "nranks --> 1\n", + "amp --> True\n", + "gpus --> 1\n", + "avoid_fork_if_possible --> False\n", + "assigned --> {}\n" + ] + } + ], + "source": [ + "for k,v in ColBERTConfig().__dict__.items():\n", + " print(f\"{k} --> {v}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "passages = [\"It's a piece of cake.\", \"Don't put off until tomorrow what you can do today.\", 'To kill two birds with one stone.', 'Actions speak louder than words.', 'Honesty is the best policy.', 'If you want something done right, do it yourself.', 'The best things in life are free.', \"Don't count your chickens before they hatch.\", 'She sells seashells by the seashore.', 'Practice makes perfect.', \"Where there's a will, there's a way.\", 'Absence makes the heart grow fonder.', 'When the going gets tough, the tough get going.', 'A journey of a thousand miles begins with a single step.', \"You can't have your cake and eat it too.\", \"If you can't beat them, join them.\", 'Keep your friends close and your enemies closer.', \"Don't put all your eggs in one basket.\", \"All's fair in love and war.\", 'Every dog has its day.', 'All good things must come to an end.', 'Once bitten, twice shy.', \"The apple doesn't fall far from the tree.\", 'A penny saved is a penny earned.', \"Don't bite the hand that feeds you.\", 'You reap what you sow.', 'An apple a day keeps the doctor away.', \"One man's trash is another man's treasure.\", 'The squeaky wheel gets the grease.', 'A picture is worth a thousand words.', 'Fortune favors the bold.', 'Practice what you preach.', 'A watched pot never boils.', 'No pain, no gain.', \"You can't make an omelet without breaking eggs.\", \"There's no place like home.\", 'Ask and you shall receive.', 'Let sleeping dogs lie.', 'If the shoe fits, wear it.', 'Every cloud has a silver lining.', 'Look before you leap.', 'The more, the merrier.', 'The grass is always greener on the other side.', 'Beauty is only skin deep.', \"Two wrongs don't make a right.\", 'Beauty is in the eye of the beholder.', 'Necessity is the mother of invention.', 'Out of sight, out of mind.', 'Patience is a virtue.', 'Curiosity killed the cat.', \"If at first you don't succeed, try, try again.\", \"Beggars can't be choosers.\", 'Too many cooks spoil the broth.', 'Easy come, easy go.', \"Don't cry over spilled milk.\", \"There's no such thing as a free lunch.\", 'A bird in the hand is worth two in the bush.', 'Good things come to those who wait.', 'The quick brown fox jumps over the lazy dog.', 'It takes two to tango.', 'A friend in need is a friend indeed.', 'Like father, like son.', 'Let bygones be bygones.', 'Kill two birds with one stone.', 'A penny for your thoughts.', 'I am the master of my fate, I am the captain of my soul.', 'The pen is mightier than the sword.', 'When in Rome, do as the Romans do.', \"Rome wasn't built in a day.\", \"You can't judge a book by its cover.\", \"It's raining cats and dogs.\", 'Make hay while the sun shines.', \"It's better to be safe than sorry.\", 'The early bird catches the worm.', 'To be or not to be, that is the question.', 'Better late than never.']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From ab5b12ecb8cf1d0e01a0d89371912976ccc50f3b Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Tue, 9 Apr 2024 21:01:16 -0400 Subject: [PATCH 11/21] colbert notebook --- dsp/modules/colbertv2.py | 24 +- .../integrations/colbert/colbert_local.ipynb | 631 +++++++++++++++++- 2 files changed, 643 insertions(+), 12 deletions(-) diff --git a/dsp/modules/colbertv2.py b/dsp/modules/colbertv2.py index 25289cc43..c5b0d7968 100644 --- a/dsp/modules/colbertv2.py +++ b/dsp/modules/colbertv2.py @@ -78,17 +78,19 @@ def colbertv2_post_request_v2_wrapped(*args, **kwargs): class ColBERTv2RetrieverLocal: from colbert.infra import Run, RunConfig, ColBERTConfig - def __init__(self,passages:List[str],load_only:bool=False,checkpoint:str='colbert-ir/colbertv2.0',colbert_config:ColBERTConfig=ColBERTConfig()): + def __init__(self,passages:List[str],load_only:bool=False,index_name:str="colbert_rm",checkpoint:str='colbert-ir/colbertv2.0',colbert_config:ColBERTConfig=ColBERTConfig()): """Colbertv2 retriever module Args: passages (List[str]): list of passages load_only (bool, optional): whether to load the index or . Defaults to False. + index_name (str, optional): name of the index. Defaults to "colbert_rm". checkpoint (str, optional): checkpoint for generating embeddings. Defaults to 'colbert-ir/colbertv2.0'. colbert_config (ColBERTConfig, optional): colbert config for building and searching. Defaults to ColBERTConfig(). """ self.checkpoint = checkpoint self.colbert_config = colbert_config + self.colbert_config.index_name = index_name self.checkpoint = checkpoint self.colbert_config.checkpoint = checkpoint self.passages = passages @@ -97,7 +99,7 @@ def __init__(self,passages:List[str],load_only:bool=False,checkpoint:str='colber print(f"Building the index for experiment {self.colbert_config.experiment} with index name {self.colbert_config.index_name}") self.build_index() - print(f"Loading the index for experiment {self.experiment} with index name {self.index_name}") + print(f"Loading the index for experiment {self.colbert_config.experiment} with index name {self.colbert_config.index_name}") self.searcher = self.get_index() def build_index(self): @@ -122,8 +124,8 @@ def get_index(self): from colbert import Searcher from colbert.infra import Run, RunConfig - with Run().context(RunConfig(experiment=self.experiment_name)): - searcher = Searcher(index=self.index_name_or_path, collection=self.passages) + with Run().context(RunConfig(experiment=self.colbert_config.experiment)): + searcher = Searcher(index=self.colbert_config.index_name, collection=self.passages) return searcher def __call__(self,query:str,k:int=7,**kwargs): @@ -143,7 +145,7 @@ def __call__(self,query:str,k:int=7,**kwargs): else: searcher_results = self.searcher.search(query, k=k) results = [] - for pid,_,score in zip(*searcher_results): + for pid,rank,score in zip(*searcher_results): results.append(dotdict({'long_text':self.searcher.collection[pid],'score':score,'pid':pid})) return results @@ -154,10 +156,16 @@ class ColBERTv2RerankerLocal: print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].") from colbert.infra.config.config import ColBERTConfig - def __init__(self,checkpoint_name:str='bert-base-uncased',colbert_config:ColBERTConfig=ColBERTConfig()): + def __init__(self,checkpoint:str='bert-base-uncased',colbert_config:ColBERTConfig=ColBERTConfig()): + """_summary_ + + Args: + checkpoint_name (str, optional): checkpoint for embeddings. Defaults to 'bert-base-uncased'. + colbert_config (ColBERTConfig, optional): Colbert config. Defaults to ColBERTConfig(). + """ self.colbert_config = colbert_config - self.checkpoint_name = checkpoint_name - self.colbert_config.checkpoint = checkpoint_name + self.checkpoint_name = checkpoint + self.colbert_config.checkpoint = checkpoint # def __call__(self, *args: Any, **kwargs: Any) -> Any: # return self.forward(*args, **kwargs) diff --git a/examples/integrations/colbert/colbert_local.ipynb b/examples/integrations/colbert/colbert_local.ipynb index 060ec7a61..157cd6ca9 100644 --- a/examples/integrations/colbert/colbert_local.ipynb +++ b/examples/integrations/colbert/colbert_local.ipynb @@ -1,15 +1,37 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## IN THIS NOTEBOOK, WE WILL EXPLORE THE COLBERT AS A RERANKER AND RETRIEVER IN LOCAL MODE. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "* If you want to build a server from your colbert local index, please refer [here](https://github.com/stanford-futuredata/ColBERT/blob/main/server.py)" + ] + }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/athekunal/DSPy-contributions/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ - "ColBERTConfig(query_token_id='[unused0]', doc_token_id='[unused1]', query_token='[Q]', doc_token='[D]', ncells=None, centroid_score_threshold=None, ndocs=None, load_index_with_mmap=False, index_path=None, index_bsize=64, nbits=1, kmeans_niters=4, resume=False, similarity='cosine', bsize=32, accumsteps=1, lr=3e-06, maxsteps=500000, save_every=None, warmup=None, warmup_bert=None, relu=False, nway=2, use_ib_negatives=False, reranker=False, distillation_alpha=1.0, ignore_scores=False, model_name=None, query_maxlen=32, attend_to_mask_tokens=False, interaction='colbert', dim=128, doc_maxlen=220, mask_punctuation=True, checkpoint=None, triples=None, collection=None, queries=None, index_name=None, overwrite=False, root='/home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments', experiment='default', index_root=None, name='2024-04/09/19.53.42', rank=0, nranks=1, amp=True, gpus=1, avoid_fork_if_possible=False)\n" + "ColBERTConfig(query_token_id='[unused0]', doc_token_id='[unused1]', query_token='[Q]', doc_token='[D]', ncells=None, centroid_score_threshold=None, ndocs=None, load_index_with_mmap=False, index_path=None, index_bsize=64, nbits=1, kmeans_niters=4, resume=False, similarity='cosine', bsize=32, accumsteps=1, lr=3e-06, maxsteps=500000, save_every=None, warmup=None, warmup_bert=None, relu=False, nway=2, use_ib_negatives=False, reranker=False, distillation_alpha=1.0, ignore_scores=False, model_name=None, query_maxlen=32, attend_to_mask_tokens=False, interaction='colbert', dim=128, doc_maxlen=220, mask_punctuation=True, checkpoint=None, triples=None, collection=None, queries=None, index_name=None, overwrite=False, root='/home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments', experiment='default', index_root=None, name='2024-04/09/20.41.14', rank=0, nranks=1, amp=True, gpus=1, avoid_fork_if_possible=False)\n" ] } ], @@ -19,9 +41,16 @@ "print(ColBERTConfig())" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Let's review the colbert config class" + ] + }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -71,7 +100,7 @@ "root --> /home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments\n", "experiment --> default\n", "index_root --> None\n", - "name --> 2024-04/09/19.53.42\n", + "name --> 2024-04/09/20.41.14\n", "rank --> 0\n", "nranks --> 1\n", "amp --> True\n", @@ -88,13 +117,607 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "passages = [\"It's a piece of cake.\", \"Don't put off until tomorrow what you can do today.\", 'To kill two birds with one stone.', 'Actions speak louder than words.', 'Honesty is the best policy.', 'If you want something done right, do it yourself.', 'The best things in life are free.', \"Don't count your chickens before they hatch.\", 'She sells seashells by the seashore.', 'Practice makes perfect.', \"Where there's a will, there's a way.\", 'Absence makes the heart grow fonder.', 'When the going gets tough, the tough get going.', 'A journey of a thousand miles begins with a single step.', \"You can't have your cake and eat it too.\", \"If you can't beat them, join them.\", 'Keep your friends close and your enemies closer.', \"Don't put all your eggs in one basket.\", \"All's fair in love and war.\", 'Every dog has its day.', 'All good things must come to an end.', 'Once bitten, twice shy.', \"The apple doesn't fall far from the tree.\", 'A penny saved is a penny earned.', \"Don't bite the hand that feeds you.\", 'You reap what you sow.', 'An apple a day keeps the doctor away.', \"One man's trash is another man's treasure.\", 'The squeaky wheel gets the grease.', 'A picture is worth a thousand words.', 'Fortune favors the bold.', 'Practice what you preach.', 'A watched pot never boils.', 'No pain, no gain.', \"You can't make an omelet without breaking eggs.\", \"There's no place like home.\", 'Ask and you shall receive.', 'Let sleeping dogs lie.', 'If the shoe fits, wear it.', 'Every cloud has a silver lining.', 'Look before you leap.', 'The more, the merrier.', 'The grass is always greener on the other side.', 'Beauty is only skin deep.', \"Two wrongs don't make a right.\", 'Beauty is in the eye of the beholder.', 'Necessity is the mother of invention.', 'Out of sight, out of mind.', 'Patience is a virtue.', 'Curiosity killed the cat.', \"If at first you don't succeed, try, try again.\", \"Beggars can't be choosers.\", 'Too many cooks spoil the broth.', 'Easy come, easy go.', \"Don't cry over spilled milk.\", \"There's no such thing as a free lunch.\", 'A bird in the hand is worth two in the bush.', 'Good things come to those who wait.', 'The quick brown fox jumps over the lazy dog.', 'It takes two to tango.', 'A friend in need is a friend indeed.', 'Like father, like son.', 'Let bygones be bygones.', 'Kill two birds with one stone.', 'A penny for your thoughts.', 'I am the master of my fate, I am the captain of my soul.', 'The pen is mightier than the sword.', 'When in Rome, do as the Romans do.', \"Rome wasn't built in a day.\", \"You can't judge a book by its cover.\", \"It's raining cats and dogs.\", 'Make hay while the sun shines.', \"It's better to be safe than sorry.\", 'The early bird catches the worm.', 'To be or not to be, that is the question.', 'Better late than never.']" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## This tutorial is running from the examples/integrations/tutorials folder, hence we need to add the system path for dspy" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append(\"../../..\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## COLBERT AS RETRIEVER" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Building the index for experiment default with index name colbert-ir-index\n", + "\n", + "\n", + "[Apr 09, 20:41:33] #> Creating directory /home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments/default/indexes/colbert-ir-index \n", + "\n", + "\n", + "#> Starting...\n", + "nranks = 1 \t num_gpus = 1 \t device=0\n", + "{\n", + " \"query_token_id\": \"[unused0]\",\n", + " \"doc_token_id\": \"[unused1]\",\n", + " \"query_token\": \"[Q]\",\n", + " \"doc_token\": \"[D]\",\n", + " \"ncells\": null,\n", + " \"centroid_score_threshold\": null,\n", + " \"ndocs\": null,\n", + " \"load_index_with_mmap\": false,\n", + " \"index_path\": null,\n", + " \"index_bsize\": 64,\n", + " \"nbits\": 1,\n", + " \"kmeans_niters\": 20,\n", + " \"resume\": false,\n", + " \"similarity\": \"cosine\",\n", + " \"bsize\": 64,\n", + " \"accumsteps\": 1,\n", + " \"lr\": 1e-5,\n", + " \"maxsteps\": 400000,\n", + " \"save_every\": null,\n", + " \"warmup\": 20000,\n", + " \"warmup_bert\": null,\n", + " \"relu\": false,\n", + " \"nway\": 64,\n", + " \"use_ib_negatives\": true,\n", + " \"reranker\": false,\n", + " \"distillation_alpha\": 1.0,\n", + " \"ignore_scores\": false,\n", + " \"model_name\": null,\n", + " \"query_maxlen\": 32,\n", + " \"attend_to_mask_tokens\": false,\n", + " \"interaction\": \"colbert\",\n", + " \"dim\": 128,\n", + " \"doc_maxlen\": 180,\n", + " \"mask_punctuation\": true,\n", + " \"checkpoint\": \"colbert-ir\\/colbertv2.0\",\n", + " \"triples\": \"\\/future\\/u\\/okhattab\\/root\\/unit\\/experiments\\/2021.10\\/downstream.distillation.round2.2_score\\/round2.nway6.cosine.ib\\/examples.64.json\",\n", + " \"collection\": [\n", + " \"It's a piece of cake.\",\n", + " \"Don't put off until tomorrow what you can do today.\",\n", + " \"To kill two birds with one stone.\",\n", + " \"Actions speak louder than words.\",\n", + " \"Honesty is the best policy.\",\n", + " \"If you want something done right, do it yourself.\",\n", + " \"The best things in life are free.\",\n", + " \"Don't count your chickens before they hatch.\",\n", + " \"She sells seashells by the seashore.\",\n", + " \"Practice makes perfect.\",\n", + " \"Where there's a will, there's a way.\",\n", + " \"Absence makes the heart grow fonder.\",\n", + " \"When the going gets tough, the tough get going.\",\n", + " \"A journey of a thousand miles begins with a single step.\",\n", + " \"You can't have your cake and eat it too.\",\n", + " \"If you can't beat them, join them.\",\n", + " \"Keep your friends close and your enemies closer.\",\n", + " \"Don't put all your eggs in one basket.\",\n", + " \"All's fair in love and war.\",\n", + " \"Every dog has its day.\",\n", + " \"All good things must come to an end.\",\n", + " \"Once bitten, twice shy.\",\n", + " \"The apple doesn't fall far from the tree.\",\n", + " \"A penny saved is a penny earned.\",\n", + " \"Don't bite the hand that feeds you.\",\n", + " \"You reap what you sow.\",\n", + " \"An apple a day keeps the doctor away.\",\n", + " \"One man's trash is another man's treasure.\",\n", + " \"The squeaky wheel gets the grease.\",\n", + " \"A picture is worth a thousand words.\",\n", + " \"Fortune favors the bold.\",\n", + " \"Practice what you preach.\",\n", + " \"A watched pot never boils.\",\n", + " \"No pain, no gain.\",\n", + " \"You can't make an omelet without breaking eggs.\",\n", + " \"There's no place like home.\",\n", + " \"Ask and you shall receive.\",\n", + " \"Let sleeping dogs lie.\",\n", + " \"If the shoe fits, wear it.\",\n", + " \"Every cloud has a silver lining.\",\n", + " \"Look before you leap.\",\n", + " \"The more, the merrier.\",\n", + " \"The grass is always greener on the other side.\",\n", + " \"Beauty is only skin deep.\",\n", + " \"Two wrongs don't make a right.\",\n", + " \"Beauty is in the eye of the beholder.\",\n", + " \"Necessity is the mother of invention.\",\n", + " \"Out of sight, out of mind.\",\n", + " \"Patience is a virtue.\",\n", + " \"Curiosity killed the cat.\",\n", + " \"If at first you don't succeed, try, try again.\",\n", + " \"Beggars can't be choosers.\",\n", + " \"Too many cooks spoil the broth.\",\n", + " \"Easy come, easy go.\",\n", + " \"Don't cry over spilled milk.\",\n", + " \"There's no such thing as a free lunch.\",\n", + " \"A bird in the hand is worth two in the bush.\",\n", + " \"Good things come to those who wait.\",\n", + " \"The quick brown fox jumps over the lazy dog.\",\n", + " \"It takes two to tango.\",\n", + " \"A friend in need is a friend indeed.\",\n", + " \"Like father, like son.\",\n", + " \"Let bygones be bygones.\",\n", + " \"Kill two birds with one stone.\",\n", + " \"A penny for your thoughts.\",\n", + " \"I am the master of my fate, I am the captain of my soul.\",\n", + " \"The pen is mightier than the sword.\",\n", + " \"When in Rome, do as the Romans do.\",\n", + " \"Rome wasn't built in a day.\",\n", + " \"You can't judge a book by its cover.\",\n", + " \"It's raining cats and dogs.\",\n", + " \"Make hay while the sun shines.\",\n", + " \"It's better to be safe than sorry.\",\n", + " \"The early bird catches the worm.\",\n", + " \"To be or not to be, that is the question.\",\n", + " \"Better late than never.\"\n", + " ],\n", + " \"queries\": \"\\/future\\/u\\/okhattab\\/data\\/MSMARCO\\/queries.train.tsv\",\n", + " \"index_name\": \"colbert-ir-index\",\n", + " \"overwrite\": false,\n", + " \"root\": \"\\/home\\/athekunal\\/DSPy-contributions\\/dspy\\/examples\\/integrations\\/colbert\\/experiments\",\n", + " \"experiment\": \"default\",\n", + " \"index_root\": null,\n", + " \"name\": \"2024-04\\/09\\/20.41.14\",\n", + " \"rank\": 0,\n", + " \"nranks\": 1,\n", + " \"amp\": true,\n", + " \"gpus\": 1,\n", + " \"avoid_fork_if_possible\": false\n", + "}\n", + "[Apr 09, 20:41:37] [0] \t\t # of sampled PIDs = 76 \t sampled_pids[:3] = [53, 1, 38]\n", + "[Apr 09, 20:41:37] [0] \t\t #> Encoding 76 passages..\n", + "[Apr 09, 20:41:38] [0] \t\t avg_doclen_est = 10.078947067260742 \t len(local_sample) = 76\n", + "[Apr 09, 20:41:38] [0] \t\t Creating 256 partitions.\n", + "[Apr 09, 20:41:38] [0] \t\t *Estimated* 765 embeddings.\n", + "[Apr 09, 20:41:38] [0] \t\t #> Saving the indexing plan to /home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments/default/indexes/colbert-ir-index/plan.json ..\n", + "Clustering 728 points in 128D to 256 clusters, redo 1 times, 20 iterations\n", + " Preprocessing in 0.00 s\n", + " Iteration 19 (0.05 s, search 0.05 s): objective=155.376 imbalance=1.400 nsplit=0 \n", + "[Apr 09, 20:41:39] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...\n", + "ninja: no work to do.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING clustering 728 points to 256 centroids: please provide at least 9984 training points\n", + "Using /home/athekunal/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...\n", + "Detected CUDA files, patching ldflags\n", + "Emitting ninja build file /home/athekunal/.cache/torch_extensions/py310_cu117/decompress_residuals_cpp/build.ninja...\n", + "Building extension module decompress_residuals_cpp...\n", + "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", + "Loading extension module decompress_residuals_cpp...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Apr 09, 20:41:39] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...\n", + "ninja: no work to do.\n", + "[0.053, 0.05, 0.061, 0.053, 0.044, 0.06, 0.047, 0.049, 0.064, 0.05, 0.051, 0.057, 0.052, 0.042, 0.061, 0.069, 0.052, 0.051, 0.063, 0.043, 0.04, 0.052, 0.043, 0.042, 0.035, 0.055, 0.07, 0.048, 0.051, 0.04, 0.057, 0.045, 0.054, 0.052, 0.042, 0.051, 0.048, 0.047, 0.056, 0.059, 0.05, 0.063, 0.054, 0.06, 0.046, 0.051, 0.04, 0.071, 0.04, 0.049, 0.056, 0.043, 0.048, 0.051, 0.045, 0.052, 0.041, 0.073, 0.039, 0.045, 0.052, 0.056, 0.053, 0.06, 0.041, 0.053, 0.054, 0.052, 0.051, 0.05, 0.061, 0.053, 0.035, 0.05, 0.049, 0.057, 0.045, 0.044, 0.05, 0.05, 0.041, 0.048, 0.043, 0.049, 0.05, 0.039, 0.056, 0.055, 0.048, 0.045, 0.044, 0.041, 0.046, 0.044, 0.046, 0.064, 0.056, 0.054, 0.058, 0.04, 0.043, 0.045, 0.051, 0.058, 0.06, 0.043, 0.057, 0.043, 0.053, 0.056, 0.047, 0.039, 0.057, 0.044, 0.055, 0.063, 0.041, 0.047, 0.049, 0.051, 0.046, 0.042, 0.053, 0.045, 0.044, 0.053, 0.053, 0.046]\n", + "[Apr 09, 20:41:39] #> Got bucket_cutoffs_quantiles = tensor([0.5000], device='cuda:0') and bucket_weights_quantiles = tensor([0.2500, 0.7500], device='cuda:0')\n", + "[Apr 09, 20:41:39] #> Got bucket_cutoffs = tensor([0.0007], device='cuda:0') and bucket_weights = tensor([-0.0378, 0.0386], device='cuda:0')\n", + "[Apr 09, 20:41:39] avg_residual = 0.050323486328125\n", + "[Apr 09, 20:41:40] [0] \t\t #> Encoding 76 passages..\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using /home/athekunal/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...\n", + "Detected CUDA files, patching ldflags\n", + "Emitting ninja build file /home/athekunal/.cache/torch_extensions/py310_cu117/packbits_cpp/build.ninja...\n", + "Building extension module packbits_cpp...\n", + "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", + "Loading extension module packbits_cpp...\n", + "0it [00:00, ?it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Apr 09, 20:41:40] [0] \t\t #> Saving chunk 0: \t 76 passages and 766 embeddings. From #0 onward.\n", + "[Apr 09, 20:41:40] [0] \t\t #> Checking all files were saved...\n", + "[Apr 09, 20:41:40] [0] \t\t Found all files!\n", + "[Apr 09, 20:41:40] [0] \t\t #> Building IVF...\n", + "[Apr 09, 20:41:40] [0] \t\t #> Loading codes...\n", + "[Apr 09, 20:41:40] [0] \t\t Sorting codes...\n", + "[Apr 09, 20:41:40] [0] \t\t Getting unique codes...\n", + "[Apr 09, 20:41:40] #> Optimizing IVF to store map from centroids to list of pids..\n", + "[Apr 09, 20:41:40] #> Building the emb2pid mapping..\n", + "[Apr 09, 20:41:40] len(emb2pid) = 766\n", + "[Apr 09, 20:41:40] #> Saved optimized IVF to /home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments/default/indexes/colbert-ir-index/ivf.pid.pt\n", + "[Apr 09, 20:41:40] [0] \t\t #> Saving the indexing metadata to /home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments/default/indexes/colbert-ir-index/metadata.json ..\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "1it [00:00, 13.04it/s]\n", + "100%|██████████| 1/1 [00:00<00:00, 3336.76it/s]\n", + "100%|██████████| 256/256 [00:00<00:00, 267565.87it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "#> Joined...\n", + "Loading the index for experiment default with index name colbert-ir-index\n", + "[Apr 09, 20:41:43] #> Loading codec...\n", + "[Apr 09, 20:41:43] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using /home/athekunal/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...\n", + "Detected CUDA files, patching ldflags\n", + "Emitting ninja build file /home/athekunal/.cache/torch_extensions/py310_cu117/decompress_residuals_cpp/build.ninja...\n", + "Building extension module decompress_residuals_cpp...\n", + "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ninja: no work to do.\n", + "[Apr 09, 20:41:43] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading extension module decompress_residuals_cpp...\n", + "Using /home/athekunal/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...\n", + "Detected CUDA files, patching ldflags\n", + "Emitting ninja build file /home/athekunal/.cache/torch_extensions/py310_cu117/packbits_cpp/build.ninja...\n", + "Building extension module packbits_cpp...\n", + "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ninja: no work to do.\n", + "[Apr 09, 20:41:43] #> Loading IVF...\n", + "[Apr 09, 20:41:43] #> Loading doclens...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading extension module packbits_cpp...\n", + "100%|██████████| 1/1 [00:00<00:00, 2012.62it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Apr 09, 20:41:43] #> Loading codes and residuals...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "100%|██████████| 1/1 [00:00<00:00, 722.16it/s]\n" + ] + } + ], + "source": [ + "import dspy\n", + "colbert_retriever = dspy.ColBERTv2RetrieverLocal(\n", + " checkpoint='colbert-ir/colbertv2.0',passages = passages,load_only=False,\n", + " index_name='colbert-ir-index'\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "#CONFIGURE COLBERT IN DSPY\n", + "dspy.settings.configure(rm=colbert_retriever)\n", + "\n", + "retrieved_docs = dspy.Retrieve(k=5)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/../../../dsp/primitives/search.py:74: UserWarning: If you want to use the Reranker, please use dspy.RetrieveThenRerank. The reranking is ignored here.\n", + " warnings.warn(\"If you want to use the Reranker, please use dspy.RetrieveThenRerank. The reranking is ignored here.\")\n", + "/home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/../../../dsp/primitives/search.py:13: UserWarning: If you want to use the Reranker, please use dspy.RetrieveThenRerank\n", + " warnings.warn(\"If you want to use the Reranker, please use dspy.RetrieveThenRerank\")\n" + ] + } + ], + "source": [ + "pred = retrieved_docs(\n", + " \"What is the meaning of life?\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Prediction(\n", + " score=[nan, nan, nan, nan, nan],\n", + " pid=[33, 6, 47, 74, 48],\n", + " passages=['No pain, no gain.', 'The best things in life are free.', 'Out of sight, out of mind.', 'To be or not to be, that is the question.', 'Patience is a virtue.']\n", + ")" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pred" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## COLBERT AS RERANKER" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "colbert_config = ColBERTConfig()\n", + "colbert_config.index_name = 'colbert-ir-index'\n", + "colbert_reranker = dspy.ColBERTv2RerankerLocal(\n", + " checkpoint='colbert-ir/colbertv2.0',colbert_config=colbert_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "dspy.settings.configure(rm=colbert_retriever,reranker=colbert_reranker)\n", + "\n", + "retrieve_rerank = dspy.RetrieveThenRerank(k=5)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "pred = retrieve_rerank(\n", + " [\"What is the meaning of life?\",\"Meaning of pain?\"]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Prediction(\n", + " score=[nan, nan, nan, nan, nan],\n", + " pid=[6, 48, 74, 47, 33],\n", + " rerank_score=[15.8359375, 14.2109375, 12.5703125, 11.7890625, 9.1796875],\n", + " passages=['The best things in life are free.', 'Patience is a virtue.', 'To be or not to be, that is the question.', 'Out of sight, out of mind.', 'No pain, no gain.']\n", + " ),\n", + " Prediction(\n", + " score=[nan, nan, nan, nan, nan],\n", + " pid=[33, 0, 47, 74, 16],\n", + " rerank_score=[19.828125, 12.2890625, 11.171875, 9.09375, 6.8984375],\n", + " passages=['No pain, no gain.', \"It's a piece of cake.\", 'Out of sight, out of mind.', 'To be or not to be, that is the question.', 'Keep your friends close and your enemies closer.']\n", + " )]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pred" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## YOU CAN ALSO COLBERT RERANKER AS STANDALONE MODEL" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import tabulate\n", + "\n", + "scores_arr = colbert_reranker(\n", + " \"What is the meaning of life and pain?\",\n", + " passages\n", + ")\n", + "\n", + "tabulate_data = []\n", + "for idx in np.argsort(scores_arr)[::-1]:\n", + " # print(f\"Passage = {passages[idx]} --> Score = {scores_arr[idx]}\")\n", + " tabulate_data.append([passages[idx],scores_arr[idx]])\n", + "\n", + "table = tabulate.tabulate(tabulate_data,tablefmt=\"html\",headers={'sentence','score'})" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
score sentence
No pain, no gain. 16.4844
The best things in life are free. 12.5156
Patience is a virtue. 11.7578
It's better to be safe than sorry. 11.2109
A friend in need is a friend indeed. 10.4609
Out of sight, out of mind. 10.3516
Once bitten, twice shy. 10
It's a piece of cake. 10
All's fair in love and war. 9.92188
All good things must come to an end. 9.92188
A penny for your thoughts. 9.72656
When the going gets tough, the tough get going. 9.60938
Beauty is in the eye of the beholder. 9.49219
Necessity is the mother of invention. 9.45312
Beauty is only skin deep. 9.44531
The more, the merrier. 9.21094
A picture is worth a thousand words. 8.90625
Ask and you shall receive. 8.67969
Easy come, easy go. 8.64062
Better late than never. 8.61719
You reap what you sow. 8.57812
An apple a day keeps the doctor away. 8.54688
Good things come to those who wait. 8.39844
To be or not to be, that is the question. 8.34375
Practice makes perfect. 8.27344
Where there's a will, there's a way. 8.21094
Keep your friends close and your enemies closer. 8.19531
Like father, like son. 7.76172
Honesty is the best policy. 7.57422
To kill two birds with one stone. 7.51953
A penny saved is a penny earned. 7.44922
The grass is always greener on the other side. 7.38281
If at first you don't succeed, try, try again. 7.32812
A journey of a thousand miles begins with a single step. 7.26953
Actions speak louder than words. 7.05469
There's no place like home. 7.01562
Practice what you preach. 7.00781
One man's trash is another man's treasure. 6.89062
Every dog has its day. 6.81641
There's no such thing as a free lunch. 6.78125
Absence makes the heart grow fonder. 6.72656
When in Rome, do as the Romans do. 6.60156
I am the master of my fate, I am the captain of my soul. 6.54688
If you want something done right, do it yourself. 6.52344
Look before you leap. 6.39844
You can't judge a book by its cover. 6.24219
The pen is mightier than the sword. 6.07422
Let bygones be bygones. 6.00781
Two wrongs don't make a right. 5.96094
Rome wasn't built in a day. 5.64453
It's raining cats and dogs. 5.5
Let sleeping dogs lie. 5.28125
The early bird catches the worm. 5.27344
Make hay while the sun shines. 5.05469
Don't bite the hand that feeds you. 5.05078
Fortune favors the bold. 5.01953
If you can't beat them, join them. 4.97656
Every cloud has a silver lining. 4.73438
The apple doesn't fall far from the tree. 4.72266
You can't have your cake and eat it too. 4.63672
A bird in the hand is worth two in the bush. 4.60156
Kill two birds with one stone. 4.14062
It takes two to tango. 3.99219
Don't put off until tomorrow what you can do today. 3.78711
Curiosity killed the cat. 3.50781
A watched pot never boils. 3.31641
Too many cooks spoil the broth. 3.14258
If the shoe fits, wear it. 3.10547
Don't cry over spilled milk. 3.10547
The quick brown fox jumps over the lazy dog. 3.08984
Beggars can't be choosers. 2.90625
The squeaky wheel gets the grease. 2.86328
She sells seashells by the seashore. 2.77148
Don't put all your eggs in one basket. 2.37109
Don't count your chickens before they hatch. 1.82227
You can't make an omelet without breaking eggs. 1.43945
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import HTML, display\n", + "display(HTML(table))" + ] + }, { "cell_type": "code", "execution_count": null, From f6a9293bdf1162c6df03066da7b937080a31d2db Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Wed, 10 Apr 2024 00:26:11 -0400 Subject: [PATCH 12/21] import errors for colbert --- dsp/modules/colbertv2.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/dsp/modules/colbertv2.py b/dsp/modules/colbertv2.py index c5b0d7968..7bc558040 100644 --- a/dsp/modules/colbertv2.py +++ b/dsp/modules/colbertv2.py @@ -77,8 +77,8 @@ def colbertv2_post_request_v2_wrapped(*args, **kwargs): os.environ['COLBERT_LOAD_TORCH_EXTENSION_VERBOSE'] = "True" class ColBERTv2RetrieverLocal: - from colbert.infra import Run, RunConfig, ColBERTConfig def __init__(self,passages:List[str],load_only:bool=False,index_name:str="colbert_rm",checkpoint:str='colbert-ir/colbertv2.0',colbert_config:ColBERTConfig=ColBERTConfig()): + from colbert.infra import Run, RunConfig, ColBERTConfig """Colbertv2 retriever module Args: @@ -150,13 +150,13 @@ def __call__(self,query:str,k:int=7,**kwargs): return results class ColBERTv2RerankerLocal: - try: - import colbert - except ImportError: - print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].") - from colbert.infra.config.config import ColBERTConfig def __init__(self,checkpoint:str='bert-base-uncased',colbert_config:ColBERTConfig=ColBERTConfig()): + try: + import colbert + except ImportError: + print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].") + from colbert.infra.config.config import ColBERTConfig """_summary_ Args: From 197a2c2b57f04f61e1ab03b0d1c871cea68a6b4d Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Wed, 10 Apr 2024 00:32:09 -0400 Subject: [PATCH 13/21] improt dspy fixes and linting fixes --- dsp/modules/__init__.py | 2 +- dsp/modules/colbertv2.py | 15 +++++++-------- dsp/primitives/search.py | 3 ++- dspy/retrieve/retrieve.py | 2 +- examples/integrations/colbert/colbert_local.ipynb | 3 ++- 5 files changed, 13 insertions(+), 12 deletions(-) diff --git a/dsp/modules/__init__.py b/dsp/modules/__init__.py index a6c01e956..99f22cf40 100644 --- a/dsp/modules/__init__.py +++ b/dsp/modules/__init__.py @@ -4,7 +4,7 @@ from .cache_utils import * from .clarifai import * from .cohere import * -from .colbertv2 import ColBERTv2, ColBERTv2RetrieverLocal,ColBERTv2RerankerLocal +from .colbertv2 import ColBERTv2, ColBERTv2RerankerLocal, ColBERTv2RetrieverLocal from .databricks import * from .google import * from .googlevertexai import * diff --git a/dsp/modules/colbertv2.py b/dsp/modules/colbertv2.py index 7bc558040..69c42d32c 100644 --- a/dsp/modules/colbertv2.py +++ b/dsp/modules/colbertv2.py @@ -1,10 +1,11 @@ import functools -from typing import Any, Optional, Union, List +import os +from typing import Any, List, Optional, Union import requests + from dsp.modules.cache_utils import CacheMemory, NotebookCacheMemory from dsp.utils import dotdict -import os # TODO: Ideally, this takes the name of the index and looks up its port. @@ -77,8 +78,7 @@ def colbertv2_post_request_v2_wrapped(*args, **kwargs): os.environ['COLBERT_LOAD_TORCH_EXTENSION_VERBOSE'] = "True" class ColBERTv2RetrieverLocal: - def __init__(self,passages:List[str],load_only:bool=False,index_name:str="colbert_rm",checkpoint:str='colbert-ir/colbertv2.0',colbert_config:ColBERTConfig=ColBERTConfig()): - from colbert.infra import Run, RunConfig, ColBERTConfig + def __init__(self,passages:List[str],colbert_config=None,load_only:bool=False,index_name:str="colbert_rm",checkpoint:str='colbert-ir/colbertv2.0'): """Colbertv2 retriever module Args: @@ -151,12 +151,11 @@ def __call__(self,query:str,k:int=7,**kwargs): class ColBERTv2RerankerLocal: - def __init__(self,checkpoint:str='bert-base-uncased',colbert_config:ColBERTConfig=ColBERTConfig()): + def __init__(self,colbert_config=None,checkpoint:str='bert-base-uncased'): try: import colbert except ImportError: print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].") - from colbert.infra.config.config import ColBERTConfig """_summary_ Args: @@ -171,10 +170,10 @@ def __init__(self,checkpoint:str='bert-base-uncased',colbert_config:ColBERTConfi # return self.forward(*args, **kwargs) def __call__(self,query:str,passages:List[str]=[]): + import numpy as np + from colbert.modeling.colbert import ColBERT from colbert.modeling.tokenization.doc_tokenization import DocTokenizer from colbert.modeling.tokenization.query_tokenization import QueryTokenizer - from colbert.modeling.colbert import ColBERT - import numpy as np assert len(passages) > 0, "Passages should not be empty" self.colbert_config.nway = len(passages) query_tokenizer = QueryTokenizer(self.colbert_config,verbose=1) diff --git a/dsp/primitives/search.py b/dsp/primitives/search.py index 12a6dd828..09622f1a9 100644 --- a/dsp/primitives/search.py +++ b/dsp/primitives/search.py @@ -1,5 +1,6 @@ -from collections.abc import Iterable import warnings +from collections.abc import Iterable + import numpy as np import dsp diff --git a/dspy/retrieve/retrieve.py b/dspy/retrieve/retrieve.py index cd13ca699..881c02e06 100644 --- a/dspy/retrieve/retrieve.py +++ b/dspy/retrieve/retrieve.py @@ -1,5 +1,5 @@ import random -from typing import List, Optional, Union, Dict, Any +from typing import Dict, List, Optional, Union import dsp from dspy.predict.parameter import Parameter diff --git a/examples/integrations/colbert/colbert_local.ipynb b/examples/integrations/colbert/colbert_local.ipynb index 157cd6ca9..591887f9b 100644 --- a/examples/integrations/colbert/colbert_local.ipynb +++ b/examples/integrations/colbert/colbert_local.ipynb @@ -448,9 +448,10 @@ ], "source": [ "import dspy\n", + "colbert_config = ColBERTConfig()\n", "colbert_retriever = dspy.ColBERTv2RetrieverLocal(\n", " checkpoint='colbert-ir/colbertv2.0',passages = passages,load_only=False,\n", - " index_name='colbert-ir-index'\n", + " index_name='colbert-ir-index',colbert_config=colbert_config\n", ")" ] }, From 81d142f2aea3d162a00051c4efabed3eabf96740 Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Sat, 13 Apr 2024 16:48:10 -0400 Subject: [PATCH 14/21] PR fixes for colbert --- dsp/modules/colbertv2.py | 42 +-- dsp/primitives/search.py | 33 +-- dspy/retrieve/retrieve.py | 9 +- .../integrations/colbert/colbert_local.ipynb | 279 +++++++----------- 4 files changed, 133 insertions(+), 230 deletions(-) diff --git a/dsp/modules/colbertv2.py b/dsp/modules/colbertv2.py index 69c42d32c..0cd353a73 100644 --- a/dsp/modules/colbertv2.py +++ b/dsp/modules/colbertv2.py @@ -75,24 +75,23 @@ def colbertv2_post_request_v2_wrapped(*args, **kwargs): colbertv2_post_request = colbertv2_post_request_v2_wrapped -os.environ['COLBERT_LOAD_TORCH_EXTENSION_VERBOSE'] = "True" class ColBERTv2RetrieverLocal: - def __init__(self,passages:List[str],colbert_config=None,load_only:bool=False,index_name:str="colbert_rm",checkpoint:str='colbert-ir/colbertv2.0'): + def __init__(self,passages:List[str],colbert_config=None,load_only:bool=False): """Colbertv2 retriever module Args: passages (List[str]): list of passages - load_only (bool, optional): whether to load the index or . Defaults to False. - index_name (str, optional): name of the index. Defaults to "colbert_rm". - checkpoint (str, optional): checkpoint for generating embeddings. Defaults to 'colbert-ir/colbertv2.0'. - colbert_config (ColBERTConfig, optional): colbert config for building and searching. Defaults to ColBERTConfig(). + colbert_config (ColBERTConfig, optional): colbert config for building and searching. Defaults to None. + load_only (bool, optional): whether to load the index or build and then load. Defaults to False. """ - self.checkpoint = checkpoint + assert colbert_config is not None, "Please pass a valid colbert_config, which you can import from colbert.infra.config import ColBERTConfig and modify it" self.colbert_config = colbert_config - self.colbert_config.index_name = index_name - self.checkpoint = checkpoint - self.colbert_config.checkpoint = checkpoint + + assert self.colbert_config.checkpoint is not None, "Please pass a valid checkpoint like colbert-ir/colbertv2.0, which you can modify in the ColBERTConfig with attribute name checkpoint" + self.passages = passages + + assert self.colbert_config.index_name is not None, "Please pass a valid index_name, which you can modify in the ColBERTConfig with attribute name index_name" self.passages = passages if not load_only: @@ -112,7 +111,7 @@ def build_index(self): from colbert import Indexer from colbert.infra import Run, RunConfig with Run().context(RunConfig(nranks=self.colbert_config.nranks, experiment=self.colbert_config.experiment)): - indexer = Indexer(checkpoint=self.checkpoint, config=self.colbert_config) + indexer = Indexer(checkpoint=self.colbert_config.checkpoint, config=self.colbert_config) indexer.index(name=self.colbert_config.index_name, collection=self.passages, overwrite=True) def get_index(self): @@ -128,7 +127,10 @@ def get_index(self): searcher = Searcher(index=self.colbert_config.index_name, collection=self.passages) return searcher - def __call__(self,query:str,k:int=7,**kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self.forward(*args, **kwargs) + + def forward(self,query:str,k:int=7,**kwargs): import torch if kwargs.get("filtered_pids"): @@ -146,7 +148,7 @@ def __call__(self,query:str,k:int=7,**kwargs): searcher_results = self.searcher.search(query, k=k) results = [] for pid,rank,score in zip(*searcher_results): - results.append(dotdict({'long_text':self.searcher.collection[pid],'score':score,'pid':pid})) + results.append(dotdict({'long_text':self.searcher.collection[pid],'pid':pid})) return results class ColBERTv2RerankerLocal: @@ -159,22 +161,24 @@ def __init__(self,colbert_config=None,checkpoint:str='bert-base-uncased'): """_summary_ Args: + colbert_config (ColBERTConfig, optional): Colbert config. Defaults to None. checkpoint_name (str, optional): checkpoint for embeddings. Defaults to 'bert-base-uncased'. - colbert_config (ColBERTConfig, optional): Colbert config. Defaults to ColBERTConfig(). """ self.colbert_config = colbert_config self.checkpoint_name = checkpoint self.colbert_config.checkpoint = checkpoint - # def __call__(self, *args: Any, **kwargs: Any) -> Any: - # return self.forward(*args, **kwargs) + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self.forward(*args, **kwargs) + + def forward(self,query:str,passages:List[str]=[]): + assert len(passages) > 0, "Passages should not be empty" - def __call__(self,query:str,passages:List[str]=[]): import numpy as np from colbert.modeling.colbert import ColBERT from colbert.modeling.tokenization.doc_tokenization import DocTokenizer from colbert.modeling.tokenization.query_tokenization import QueryTokenizer - assert len(passages) > 0, "Passages should not be empty" + self.colbert_config.nway = len(passages) query_tokenizer = QueryTokenizer(self.colbert_config,verbose=1) doc_tokenizer = DocTokenizer(self.colbert_config) @@ -182,8 +186,6 @@ def __call__(self,query:str,passages:List[str]=[]): doc_ids,doc_masks = doc_tokenizer.tensorize(passages) col = ColBERT(self.checkpoint_name,self.colbert_config) - # col.colbert_config.nway = len(passages) - # tensor_scores = col([query_ids,query_masks],[doc_ids,doc_masks]) Q = col.query(query_ids,query_masks) DOC_IDS,DOC_MASKS = col.doc(doc_ids,doc_masks,keep_dims='return_mask') Q_duplicated = Q.repeat_interleave(len(passages), dim=0).contiguous() diff --git a/dsp/primitives/search.py b/dsp/primitives/search.py index 09622f1a9..81acef82d 100644 --- a/dsp/primitives/search.py +++ b/dsp/primitives/search.py @@ -11,20 +11,12 @@ def retrieve(query: str, k: int, **kwargs) -> list[str]: if not dsp.settings.rm: raise AssertionError("No RM is loaded.") if not dsp.settings.reranker: - warnings.warn("If you want to use the Reranker, please use dspy.RetrieveThenRerank") + warnings.warn("If you want to use the Reranker, please use dspy.RetrieveThenRerank",DeprecationWarning) passages = dsp.settings.rm(query, k=k, **kwargs) if not isinstance(passages, Iterable): # it's not an iterable yet; make it one. # TODO: we should unify the type signatures of dspy.Retriever passages = [passages] - # passages = [psg.long_text for psg in passages] - - # if dsp.settings.reranker: - # passages_tracking_idx = {str(idx):psg for idx, psg in enumerate(passages)} - # passages_long_text = [psg.long_text for psg in passages] - # passages_cs_scores = dsp.settings.reranker(query, passages_long_text) - # passages_cs_scores_sorted = np.argsort(passages_cs_scores)[::-1] - # passages = [passages_long_text[idx] for idx in passages_cs_scores_sorted] return passages @@ -48,23 +40,6 @@ def retrieveRerankEnsemble(queries: list[str], k: int,**kwargs) -> list[str]: else: return all_queries_passages -# def retrieveRerankEnsemble(queries: list[str], k: int,**kwargs) -> list[str]: -# if not (dsp.settings.rm and dsp.settings.reranker): -# raise AssertionError("Both RM and Reranker are needed to retrieve & re-rank.") -# queries = [q for q in queries if q] -# passages = {} -# for query in queries: -# retrieved_passages = dsp.settings.rm(query, k=k*3,**kwargs) -# passages_cs_scores = dsp.settings.reranker(query, [psg.long_text for psg in retrieved_passages]) -# for idx in np.argsort(passages_cs_scores)[::-1]: -# psg = retrieved_passages[idx] -# passages[psg.long_text] = passages.get(psg.long_text, []) + [ -# passages_cs_scores[idx], -# ] - -# passages = [(np.average(score), text) for text, score in passages.items()] -# return [text for _, text in sorted(passages, reverse=True)[:k]] - def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs) -> list[str]: """Retrieves passages from the RM for each query in queries and returns the top k passages based on the probability or score. @@ -72,7 +47,7 @@ def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs) if not dsp.settings.rm: raise AssertionError("No RM is loaded.") if not dsp.settings.reranker: - warnings.warn("If you want to use the Reranker, please use dspy.RetrieveThenRerank. The reranking is ignored here.") + warnings.warn("If you want to use the Reranker, please use dspy.RetrieveThenRerank. The reranking is ignored here.",DeprecationWarning) queries = [q for q in queries if q] @@ -82,17 +57,13 @@ def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs) for q in queries: passages = {} retrieved_passages = dsp.settings.rm(q, k=k * 3,**kwargs) - # for idx,psg in enumerate(retrieved_passages): - # retrieved_passages[idx]["tracking_idx"] = idx for idx,psg in enumerate(retrieved_passages): if by_prob: passages[(idx,psg.long_text)] = passages.get(psg.long_text, 0.0) + psg.prob else: passages[(idx,psg.long_text)] = passages.get(psg.long_text, 0.0) + psg.score retrieved_passages[idx]["tracking_idx"] = idx - # passages = [(score, text) for text, score in passages.items()] passages = sorted(passages.items(), key=lambda item: item[1])[:k] - # passages = sorted(passages, reverse=True)[:k] req_indices = [psg[0][0] for psg in passages] passages = [rp for rp in retrieved_passages if rp.get("tracking_idx") in req_indices] all_queries_passages.append(passages) diff --git a/dspy/retrieve/retrieve.py b/dspy/retrieve/retrieve.py index 881c02e06..b12283443 100644 --- a/dspy/retrieve/retrieve.py +++ b/dspy/retrieve/retrieve.py @@ -59,8 +59,7 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = No if "long_text" in passages_dict: passages_dict["passages"] = passages_dict.pop("long_text") return Prediction(**passages_dict) - # elif isinstance(passages,List): - # return Prediction(passages=passages) + # TODO: Consider doing Prediction.from_completions with the individual sets of passages (per query) too. class RetrieveThenRerank(Parameter): @@ -83,10 +82,10 @@ def load_state(self, state): for name, value in state.items(): setattr(self, name, value) - # def __call__(self, *args, **kwargs): - # return self.forward(*args, **kwargs) + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) - def __call__(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None,**kwargs) -> Union[Prediction,List[Prediction]]: + def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None,**kwargs) -> Union[Prediction,List[Prediction]]: queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries queries = [query.strip().split('\n')[0].strip() for query in queries] diff --git a/examples/integrations/colbert/colbert_local.ipynb b/examples/integrations/colbert/colbert_local.ipynb index 591887f9b..2835497cc 100644 --- a/examples/integrations/colbert/colbert_local.ipynb +++ b/examples/integrations/colbert/colbert_local.ipynb @@ -26,19 +26,21 @@ "/home/athekunal/DSPy-contributions/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ColBERTConfig(query_token_id='[unused0]', doc_token_id='[unused1]', query_token='[Q]', doc_token='[D]', ncells=None, centroid_score_threshold=None, ndocs=None, load_index_with_mmap=False, index_path=None, index_bsize=64, nbits=1, kmeans_niters=4, resume=False, similarity='cosine', bsize=32, accumsteps=1, lr=3e-06, maxsteps=500000, save_every=None, warmup=None, warmup_bert=None, relu=False, nway=2, use_ib_negatives=False, reranker=False, distillation_alpha=1.0, ignore_scores=False, model_name=None, query_maxlen=32, attend_to_mask_tokens=False, interaction='colbert', dim=128, doc_maxlen=220, mask_punctuation=True, checkpoint=None, triples=None, collection=None, queries=None, index_name=None, overwrite=False, root='/home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments', experiment='default', index_root=None, name='2024-04/09/20.41.14', rank=0, nranks=1, amp=True, gpus=1, avoid_fork_if_possible=False)\n" - ] } ], "source": [ - "from colbert.infra.config import ColBERTConfig\n", - "\n", - "print(ColBERTConfig())" + "from colbert.infra.config import ColBERTConfig" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "# You can set this environment variable for debugging purposes\n", + "os.environ['COLBERT_LOAD_TORCH_EXTENSION_VERBOSE'] = \"True\"" ] }, { @@ -50,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -100,7 +102,7 @@ "root --> /home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments\n", "experiment --> default\n", "index_root --> None\n", - "name --> 2024-04/09/20.41.14\n", + "name --> 2024-04/13/16.46.49\n", "rank --> 0\n", "nranks --> 1\n", "amp --> True\n", @@ -111,13 +113,14 @@ } ], "source": [ + "# You can view the different attributes of the colbert config by uncommenting cell below\n", "for k,v in ColBERTConfig().__dict__.items():\n", " print(f\"{k} --> {v}\")" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -133,7 +136,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -150,17 +153,17 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Building the index for experiment default with index name colbert-ir-index\n", + "Building the index for experiment Colbert-Experiment with index name Colbert-RM\n", "\n", "\n", - "[Apr 09, 20:41:33] #> Creating directory /home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments/default/indexes/colbert-ir-index \n", + "[Apr 13, 16:46:52] #> Creating directory /home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments/Colbert-Experiment/indexes/Colbert-RM \n", "\n", "\n", "#> Starting...\n", @@ -281,29 +284,24 @@ " \"Better late than never.\"\n", " ],\n", " \"queries\": \"\\/future\\/u\\/okhattab\\/data\\/MSMARCO\\/queries.train.tsv\",\n", - " \"index_name\": \"colbert-ir-index\",\n", + " \"index_name\": \"Colbert-RM\",\n", " \"overwrite\": false,\n", " \"root\": \"\\/home\\/athekunal\\/DSPy-contributions\\/dspy\\/examples\\/integrations\\/colbert\\/experiments\",\n", - " \"experiment\": \"default\",\n", + " \"experiment\": \"Colbert-Experiment\",\n", " \"index_root\": null,\n", - " \"name\": \"2024-04\\/09\\/20.41.14\",\n", + " \"name\": \"2024-04\\/13\\/16.46.49\",\n", " \"rank\": 0,\n", " \"nranks\": 1,\n", " \"amp\": true,\n", " \"gpus\": 1,\n", " \"avoid_fork_if_possible\": false\n", "}\n", - "[Apr 09, 20:41:37] [0] \t\t # of sampled PIDs = 76 \t sampled_pids[:3] = [53, 1, 38]\n", - "[Apr 09, 20:41:37] [0] \t\t #> Encoding 76 passages..\n", - "[Apr 09, 20:41:38] [0] \t\t avg_doclen_est = 10.078947067260742 \t len(local_sample) = 76\n", - "[Apr 09, 20:41:38] [0] \t\t Creating 256 partitions.\n", - "[Apr 09, 20:41:38] [0] \t\t *Estimated* 765 embeddings.\n", - "[Apr 09, 20:41:38] [0] \t\t #> Saving the indexing plan to /home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments/default/indexes/colbert-ir-index/plan.json ..\n", - "Clustering 728 points in 128D to 256 clusters, redo 1 times, 20 iterations\n", - " Preprocessing in 0.00 s\n", - " Iteration 19 (0.05 s, search 0.05 s): objective=155.376 imbalance=1.400 nsplit=0 \n", - "[Apr 09, 20:41:39] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...\n", - "ninja: no work to do.\n" + "[Apr 13, 16:46:56] [0] \t\t # of sampled PIDs = 76 \t sampled_pids[:3] = [53, 1, 38]\n", + "[Apr 13, 16:46:56] [0] \t\t #> Encoding 76 passages..\n", + "[Apr 13, 16:46:57] [0] \t\t avg_doclen_est = 10.078947067260742 \t len(local_sample) = 76\n", + "[Apr 13, 16:46:57] [0] \t\t Creating 256 partitions.\n", + "[Apr 13, 16:46:57] [0] \t\t *Estimated* 765 embeddings.\n", + "[Apr 13, 16:46:57] [0] \t\t #> Saving the indexing plan to /home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments/Colbert-Experiment/indexes/Colbert-RM/plan.json ..\n" ] }, { @@ -316,26 +314,7 @@ "Emitting ninja build file /home/athekunal/.cache/torch_extensions/py310_cu117/decompress_residuals_cpp/build.ninja...\n", "Building extension module decompress_residuals_cpp...\n", "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", - "Loading extension module decompress_residuals_cpp...\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[Apr 09, 20:41:39] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...\n", - "ninja: no work to do.\n", - "[0.053, 0.05, 0.061, 0.053, 0.044, 0.06, 0.047, 0.049, 0.064, 0.05, 0.051, 0.057, 0.052, 0.042, 0.061, 0.069, 0.052, 0.051, 0.063, 0.043, 0.04, 0.052, 0.043, 0.042, 0.035, 0.055, 0.07, 0.048, 0.051, 0.04, 0.057, 0.045, 0.054, 0.052, 0.042, 0.051, 0.048, 0.047, 0.056, 0.059, 0.05, 0.063, 0.054, 0.06, 0.046, 0.051, 0.04, 0.071, 0.04, 0.049, 0.056, 0.043, 0.048, 0.051, 0.045, 0.052, 0.041, 0.073, 0.039, 0.045, 0.052, 0.056, 0.053, 0.06, 0.041, 0.053, 0.054, 0.052, 0.051, 0.05, 0.061, 0.053, 0.035, 0.05, 0.049, 0.057, 0.045, 0.044, 0.05, 0.05, 0.041, 0.048, 0.043, 0.049, 0.05, 0.039, 0.056, 0.055, 0.048, 0.045, 0.044, 0.041, 0.046, 0.044, 0.046, 0.064, 0.056, 0.054, 0.058, 0.04, 0.043, 0.045, 0.051, 0.058, 0.06, 0.043, 0.057, 0.043, 0.053, 0.056, 0.047, 0.039, 0.057, 0.044, 0.055, 0.063, 0.041, 0.047, 0.049, 0.051, 0.046, 0.042, 0.053, 0.045, 0.044, 0.053, 0.053, 0.046]\n", - "[Apr 09, 20:41:39] #> Got bucket_cutoffs_quantiles = tensor([0.5000], device='cuda:0') and bucket_weights_quantiles = tensor([0.2500, 0.7500], device='cuda:0')\n", - "[Apr 09, 20:41:39] #> Got bucket_cutoffs = tensor([0.0007], device='cuda:0') and bucket_weights = tensor([-0.0378, 0.0386], device='cuda:0')\n", - "[Apr 09, 20:41:39] avg_residual = 0.050323486328125\n", - "[Apr 09, 20:41:40] [0] \t\t #> Encoding 76 passages..\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ + "Loading extension module decompress_residuals_cpp...\n", "Using /home/athekunal/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...\n", "Detected CUDA files, patching ldflags\n", "Emitting ninja build file /home/athekunal/.cache/torch_extensions/py310_cu117/packbits_cpp/build.ninja...\n", @@ -349,27 +328,39 @@ "name": "stdout", "output_type": "stream", "text": [ - "[Apr 09, 20:41:40] [0] \t\t #> Saving chunk 0: \t 76 passages and 766 embeddings. From #0 onward.\n", - "[Apr 09, 20:41:40] [0] \t\t #> Checking all files were saved...\n", - "[Apr 09, 20:41:40] [0] \t\t Found all files!\n", - "[Apr 09, 20:41:40] [0] \t\t #> Building IVF...\n", - "[Apr 09, 20:41:40] [0] \t\t #> Loading codes...\n", - "[Apr 09, 20:41:40] [0] \t\t Sorting codes...\n", - "[Apr 09, 20:41:40] [0] \t\t Getting unique codes...\n", - "[Apr 09, 20:41:40] #> Optimizing IVF to store map from centroids to list of pids..\n", - "[Apr 09, 20:41:40] #> Building the emb2pid mapping..\n", - "[Apr 09, 20:41:40] len(emb2pid) = 766\n", - "[Apr 09, 20:41:40] #> Saved optimized IVF to /home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments/default/indexes/colbert-ir-index/ivf.pid.pt\n", - "[Apr 09, 20:41:40] [0] \t\t #> Saving the indexing metadata to /home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments/default/indexes/colbert-ir-index/metadata.json ..\n" + "Clustering 728 points in 128D to 256 clusters, redo 1 times, 20 iterations\n", + " Preprocessing in 0.00 s\n", + " Iteration 19 (0.05 s, search 0.05 s): objective=155.376 imbalance=1.400 nsplit=0 \n", + "[Apr 13, 16:46:58] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...\n", + "ninja: no work to do.\n", + "[Apr 13, 16:46:58] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...\n", + "ninja: no work to do.\n", + "[0.053, 0.05, 0.061, 0.053, 0.044, 0.06, 0.047, 0.049, 0.064, 0.05, 0.051, 0.057, 0.052, 0.042, 0.061, 0.069, 0.052, 0.051, 0.063, 0.043, 0.04, 0.052, 0.043, 0.042, 0.035, 0.055, 0.07, 0.048, 0.051, 0.04, 0.057, 0.045, 0.054, 0.052, 0.042, 0.051, 0.048, 0.047, 0.056, 0.059, 0.05, 0.063, 0.054, 0.06, 0.046, 0.051, 0.04, 0.071, 0.04, 0.049, 0.056, 0.043, 0.048, 0.051, 0.045, 0.052, 0.041, 0.073, 0.039, 0.045, 0.052, 0.056, 0.053, 0.06, 0.041, 0.053, 0.054, 0.052, 0.051, 0.05, 0.061, 0.053, 0.035, 0.05, 0.049, 0.057, 0.045, 0.044, 0.05, 0.05, 0.041, 0.048, 0.043, 0.049, 0.05, 0.039, 0.056, 0.055, 0.048, 0.045, 0.044, 0.041, 0.046, 0.044, 0.046, 0.064, 0.056, 0.054, 0.058, 0.04, 0.043, 0.045, 0.051, 0.058, 0.06, 0.043, 0.057, 0.043, 0.053, 0.056, 0.047, 0.039, 0.057, 0.044, 0.055, 0.063, 0.041, 0.047, 0.049, 0.051, 0.046, 0.042, 0.053, 0.045, 0.044, 0.053, 0.053, 0.046]\n", + "[Apr 13, 16:46:58] #> Got bucket_cutoffs_quantiles = tensor([0.5000], device='cuda:0') and bucket_weights_quantiles = tensor([0.2500, 0.7500], device='cuda:0')\n", + "[Apr 13, 16:46:58] #> Got bucket_cutoffs = tensor([0.0007], device='cuda:0') and bucket_weights = tensor([-0.0378, 0.0386], device='cuda:0')\n", + "[Apr 13, 16:46:58] avg_residual = 0.050323486328125\n", + "[Apr 13, 16:46:58] [0] \t\t #> Encoding 76 passages..\n", + "[Apr 13, 16:46:58] [0] \t\t #> Saving chunk 0: \t 76 passages and 766 embeddings. From #0 onward.\n", + "[Apr 13, 16:46:58] [0] \t\t #> Checking all files were saved...\n", + "[Apr 13, 16:46:58] [0] \t\t Found all files!\n", + "[Apr 13, 16:46:58] [0] \t\t #> Building IVF...\n", + "[Apr 13, 16:46:58] [0] \t\t #> Loading codes...\n", + "[Apr 13, 16:46:58] [0] \t\t Sorting codes...\n", + "[Apr 13, 16:46:58] [0] \t\t Getting unique codes...\n", + "[Apr 13, 16:46:58] #> Optimizing IVF to store map from centroids to list of pids..\n", + "[Apr 13, 16:46:58] #> Building the emb2pid mapping..\n", + "[Apr 13, 16:46:58] len(emb2pid) = 766\n", + "[Apr 13, 16:46:58] #> Saved optimized IVF to /home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments/Colbert-Experiment/indexes/Colbert-RM/ivf.pid.pt\n", + "[Apr 13, 16:46:58] [0] \t\t #> Saving the indexing metadata to /home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments/Colbert-Experiment/indexes/Colbert-RM/metadata.json ..\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "1it [00:00, 13.04it/s]\n", - "100%|██████████| 1/1 [00:00<00:00, 3336.76it/s]\n", - "100%|██████████| 256/256 [00:00<00:00, 267565.87it/s]\n" + "1it [00:00, 15.27it/s]\n", + "100%|██████████| 1/1 [00:00<00:00, 3551.49it/s]\n", + "100%|██████████| 256/256 [00:00<00:00, 281526.44it/s]\n" ] }, { @@ -377,9 +368,9 @@ "output_type": "stream", "text": [ "#> Joined...\n", - "Loading the index for experiment default with index name colbert-ir-index\n", - "[Apr 09, 20:41:43] #> Loading codec...\n", - "[Apr 09, 20:41:43] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...\n" + "Loading the index for experiment Colbert-Experiment with index name Colbert-RM\n", + "[Apr 13, 16:47:02] #> Loading codec...\n", + "[Apr 13, 16:47:02] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...\n" ] }, { @@ -398,7 +389,7 @@ "output_type": "stream", "text": [ "ninja: no work to do.\n", - "[Apr 09, 20:41:43] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...\n" + "[Apr 13, 16:47:02] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...\n" ] }, { @@ -418,8 +409,8 @@ "output_type": "stream", "text": [ "ninja: no work to do.\n", - "[Apr 09, 20:41:43] #> Loading IVF...\n", - "[Apr 09, 20:41:43] #> Loading doclens...\n" + "[Apr 13, 16:47:02] #> Loading IVF...\n", + "[Apr 13, 16:47:02] #> Loading doclens...\n" ] }, { @@ -427,14 +418,14 @@ "output_type": "stream", "text": [ "Loading extension module packbits_cpp...\n", - "100%|██████████| 1/1 [00:00<00:00, 2012.62it/s]" + "100%|██████████| 1/1 [00:00<00:00, 4969.55it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "[Apr 09, 20:41:43] #> Loading codes and residuals...\n" + "[Apr 13, 16:47:02] #> Loading codes and residuals...\n" ] }, { @@ -442,22 +433,25 @@ "output_type": "stream", "text": [ "\n", - "100%|██████████| 1/1 [00:00<00:00, 722.16it/s]\n" + "100%|██████████| 1/1 [00:00<00:00, 805.05it/s]\n" ] } ], "source": [ "import dspy\n", "colbert_config = ColBERTConfig()\n", + "colbert_config.index_name = \"Colbert-RM\"\n", + "colbert_config.experiment = \"Colbert-Experiment\"\n", + "colbert_config.checkpoint = \"colbert-ir/colbertv2.0\"\n", "colbert_retriever = dspy.ColBERTv2RetrieverLocal(\n", - " checkpoint='colbert-ir/colbertv2.0',passages = passages,load_only=False,\n", - " index_name='colbert-ir-index',colbert_config=colbert_config\n", + " passages = passages,load_only=False,\n", + " colbert_config=colbert_config\n", ")" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ @@ -469,17 +463,22 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "metadata": {}, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "/home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/../../../dsp/primitives/search.py:74: UserWarning: If you want to use the Reranker, please use dspy.RetrieveThenRerank. The reranking is ignored here.\n", - " warnings.warn(\"If you want to use the Reranker, please use dspy.RetrieveThenRerank. The reranking is ignored here.\")\n", - "/home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/../../../dsp/primitives/search.py:13: UserWarning: If you want to use the Reranker, please use dspy.RetrieveThenRerank\n", - " warnings.warn(\"If you want to use the Reranker, please use dspy.RetrieveThenRerank\")\n" + "\n", + "#> QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==\n", + "#> Input: . What is the meaning of life?, \t\t True, \t\t None\n", + "#> Output IDs: torch.Size([32]), tensor([ 101, 1, 2054, 2003, 1996, 3574, 1997, 2166, 1029, 102, 103, 103,\n", + " 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103,\n", + " 103, 103, 103, 103, 103, 103, 103, 103], device='cuda:0')\n", + "#> Output Mask: torch.Size([32]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')\n", + "\n" ] } ], @@ -491,20 +490,19 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Prediction(\n", - " score=[nan, nan, nan, nan, nan],\n", " pid=[33, 6, 47, 74, 48],\n", " passages=['No pain, no gain.', 'The best things in life are free.', 'Out of sight, out of mind.', 'To be or not to be, that is the question.', 'Patience is a virtue.']\n", ")" ] }, - "execution_count": 11, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -522,7 +520,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -534,7 +532,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -545,7 +543,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -556,27 +554,25 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[Prediction(\n", - " score=[nan, nan, nan, nan, nan],\n", " pid=[6, 48, 74, 47, 33],\n", " rerank_score=[15.8359375, 14.2109375, 12.5703125, 11.7890625, 9.1796875],\n", " passages=['The best things in life are free.', 'Patience is a virtue.', 'To be or not to be, that is the question.', 'Out of sight, out of mind.', 'No pain, no gain.']\n", " ),\n", " Prediction(\n", - " score=[nan, nan, nan, nan, nan],\n", " pid=[33, 0, 47, 74, 16],\n", " rerank_score=[19.828125, 12.2890625, 11.171875, 9.09375, 6.8984375],\n", " passages=['No pain, no gain.', \"It's a piece of cake.\", 'Out of sight, out of mind.', 'To be or not to be, that is the question.', 'Keep your friends close and your enemies closer.']\n", " )]" ] }, - "execution_count": 18, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -594,7 +590,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -603,7 +599,8 @@ "\n", "scores_arr = colbert_reranker(\n", " \"What is the meaning of life and pain?\",\n", - " passages\n", + " # Pass a subset of passages\n", + " passages[:10]\n", ")\n", "\n", "tabulate_data = []\n", @@ -616,7 +613,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -624,85 +621,19 @@ "text/html": [ "\n", "\n", - "\n", + "\n", "\n", "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", "\n", "
score sentence
score sentence
No pain, no gain. 16.4844
The best things in life are free. 12.5156
Patience is a virtue. 11.7578
It's better to be safe than sorry. 11.2109
A friend in need is a friend indeed. 10.4609
Out of sight, out of mind. 10.3516
Once bitten, twice shy. 10
It's a piece of cake. 10
All's fair in love and war. 9.92188
All good things must come to an end. 9.92188
A penny for your thoughts. 9.72656
When the going gets tough, the tough get going. 9.60938
Beauty is in the eye of the beholder. 9.49219
Necessity is the mother of invention. 9.45312
Beauty is only skin deep. 9.44531
The more, the merrier. 9.21094
A picture is worth a thousand words. 8.90625
Ask and you shall receive. 8.67969
Easy come, easy go. 8.64062
Better late than never. 8.61719
You reap what you sow. 8.57812
An apple a day keeps the doctor away. 8.54688
Good things come to those who wait. 8.39844
To be or not to be, that is the question. 8.34375
Practice makes perfect. 8.27344
Where there's a will, there's a way. 8.21094
Keep your friends close and your enemies closer. 8.19531
Like father, like son. 7.76172
Honesty is the best policy. 7.57422
To kill two birds with one stone. 7.51953
A penny saved is a penny earned. 7.44922
The grass is always greener on the other side. 7.38281
If at first you don't succeed, try, try again. 7.32812
A journey of a thousand miles begins with a single step. 7.26953
Actions speak louder than words. 7.05469
There's no place like home. 7.01562
Practice what you preach. 7.00781
One man's trash is another man's treasure. 6.89062
Every dog has its day. 6.81641
There's no such thing as a free lunch. 6.78125
Absence makes the heart grow fonder. 6.72656
When in Rome, do as the Romans do. 6.60156
I am the master of my fate, I am the captain of my soul. 6.54688
If you want something done right, do it yourself. 6.52344
Look before you leap. 6.39844
You can't judge a book by its cover. 6.24219
The pen is mightier than the sword. 6.07422
Let bygones be bygones. 6.00781
Two wrongs don't make a right. 5.96094
Rome wasn't built in a day. 5.64453
It's raining cats and dogs. 5.5
Let sleeping dogs lie. 5.28125
The early bird catches the worm. 5.27344
Make hay while the sun shines. 5.05469
Don't bite the hand that feeds you. 5.05078
Fortune favors the bold. 5.01953
If you can't beat them, join them. 4.97656
Every cloud has a silver lining. 4.73438
The apple doesn't fall far from the tree. 4.72266
You can't have your cake and eat it too. 4.63672
A bird in the hand is worth two in the bush. 4.60156
Kill two birds with one stone. 4.14062
It takes two to tango. 3.99219
Don't put off until tomorrow what you can do today. 3.78711
Curiosity killed the cat. 3.50781
A watched pot never boils. 3.31641
Too many cooks spoil the broth. 3.14258
If the shoe fits, wear it. 3.10547
Don't cry over spilled milk. 3.10547
The quick brown fox jumps over the lazy dog. 3.08984
Beggars can't be choosers. 2.90625
The squeaky wheel gets the grease. 2.86328
She sells seashells by the seashore. 2.77148
Don't put all your eggs in one basket. 2.37109
Don't count your chickens before they hatch. 1.82227
You can't make an omelet without breaking eggs. 1.43945
The best things in life are free. 12.5156
It's a piece of cake. 10
Practice makes perfect. 8.27344
Honesty is the best policy. 7.57422
To kill two birds with one stone. 7.51953
Actions speak louder than words. 7.05469
If you want something done right, do it yourself. 6.52344
Don't put off until tomorrow what you can do today. 3.78711
She sells seashells by the seashore. 2.77148
Don't count your chickens before they hatch. 1.82227
" ], From b73753c055cde6c7e8b8c87f6497297a3d902672 Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Sat, 13 Apr 2024 16:50:54 -0400 Subject: [PATCH 15/21] making the linting gods happy --- dsp/modules/colbertv2.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dsp/modules/colbertv2.py b/dsp/modules/colbertv2.py index 0cd353a73..58739688d 100644 --- a/dsp/modules/colbertv2.py +++ b/dsp/modules/colbertv2.py @@ -1,5 +1,4 @@ import functools -import os from typing import Any, List, Optional, Union import requests From 0ec1ded54d224c5215cf7f6c8da70c81e2a248aa Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Sun, 14 Apr 2024 03:02:42 -0400 Subject: [PATCH 16/21] remove unnecessary outputs --- .../integrations/colbert/colbert_local.ipynb | 65 ++----------------- 1 file changed, 4 insertions(+), 61 deletions(-) diff --git a/examples/integrations/colbert/colbert_local.ipynb b/examples/integrations/colbert/colbert_local.ipynb index 2835497cc..84555da70 100644 --- a/examples/integrations/colbert/colbert_local.ipynb +++ b/examples/integrations/colbert/colbert_local.ipynb @@ -52,70 +52,13 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "query_token_id --> [unused0]\n", - "doc_token_id --> [unused1]\n", - "query_token --> [Q]\n", - "doc_token --> [D]\n", - "ncells --> None\n", - "centroid_score_threshold --> None\n", - "ndocs --> None\n", - "load_index_with_mmap --> False\n", - "index_path --> None\n", - "index_bsize --> 64\n", - "nbits --> 1\n", - "kmeans_niters --> 4\n", - "resume --> False\n", - "similarity --> cosine\n", - "bsize --> 32\n", - "accumsteps --> 1\n", - "lr --> 3e-06\n", - "maxsteps --> 500000\n", - "save_every --> None\n", - "warmup --> None\n", - "warmup_bert --> None\n", - "relu --> False\n", - "nway --> 2\n", - "use_ib_negatives --> False\n", - "reranker --> False\n", - "distillation_alpha --> 1.0\n", - "ignore_scores --> False\n", - "model_name --> None\n", - "query_maxlen --> 32\n", - "attend_to_mask_tokens --> False\n", - "interaction --> colbert\n", - "dim --> 128\n", - "doc_maxlen --> 220\n", - "mask_punctuation --> True\n", - "checkpoint --> None\n", - "triples --> None\n", - "collection --> None\n", - "queries --> None\n", - "index_name --> None\n", - "overwrite --> False\n", - "root --> /home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments\n", - "experiment --> default\n", - "index_root --> None\n", - "name --> 2024-04/13/16.46.49\n", - "rank --> 0\n", - "nranks --> 1\n", - "amp --> True\n", - "gpus --> 1\n", - "avoid_fork_if_possible --> False\n", - "assigned --> {}\n" - ] - } - ], + "outputs": [], "source": [ "# You can view the different attributes of the colbert config by uncommenting cell below\n", - "for k,v in ColBERTConfig().__dict__.items():\n", - " print(f\"{k} --> {v}\")" + "# for k,v in ColBERTConfig().__dict__.items():\n", + "# print(f\"{k} --> {v}\")" ] }, { From 685df2a24e1633992505c2e534630f7207931df6 Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Tue, 16 Apr 2024 20:55:24 -0400 Subject: [PATCH 17/21] colbertv2 docs --- docs/api/retrieval_model_clients/ColBERTv2.md | 78 +++++++++++++++++++ dsp/modules/colbertv2.py | 4 +- 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/docs/api/retrieval_model_clients/ColBERTv2.md b/docs/api/retrieval_model_clients/ColBERTv2.md index 2dd31bef8..a8fea9492 100644 --- a/docs/api/retrieval_model_clients/ColBERTv2.md +++ b/docs/api/retrieval_model_clients/ColBERTv2.md @@ -49,3 +49,81 @@ retrieval_response = colbertv2_wiki17_abstracts('When was the first FIFA World C for result in retrieval_response: print("Text:", result['text'], "\n") ``` + +# dspy.ColBERTv2RetrieverLocal + +This is taken from the official documentation of [Colbertv2](https://github.com/stanford-futuredata/ColBERT/tree/main) following the [paper](https://arxiv.org/abs/2112.01488). + +You can install Colbertv2 by the following instructions from [here](https://github.com/stanford-futuredata/ColBERT?tab=readme-ov-file#installation) + +### Constructor +The constructor initializes the ColBERTv2 as a local retriever object. You can initialize a server instance from your ColBERTv2 local instance using the code snippet from [here](https://github.com/stanford-futuredata/ColBERT/blob/main/server.py) + +```python +class ColBERTv2RetrieverLocal: + def __init__( + self, + passages:List[str], + colbert_config=None, + load_only:bool=False): +``` + +**Parameters** +- `passages` (_List[str]_): List of passages to be indexed +- `colbert_config` (_ColBERTConfig_, _Optional_): colbert config for building and searching. Defaults to None. +- `load_only` (_Boolean_): whether to load the index or build and then load. Defaults to False. + +The `colbert_config` object is required for ColBERTv2, and it can be imported from `from colbert.infra.config import ColBERTConfig`. You can find the descriptions of config attributes from [here](https://github.com/stanford-futuredata/ColBERT/blob/main/colbert/infra/config/settings.py) + +### Methods + +#### `forward(self, query:str, k:int, **kwargs) -> Union[list[str], list[dotdict]]` + +It retrieves relevant passages from the index based on the query. If you already have a local index, then you can pass the `load_only` flag as `True` and change the `index` attribute of ColBERTConfig to the local path. Also, make sure to change the `checkpoint` attribute of ColBERTConfig to the embedding model that you used to build the index. + +**Parameters:** +- `query` (_str_): Query string used for retrieval. +- `k` (_int_, _optional_): Number of passages to retrieve. Defaults to 7 + +It returns a `Prediction` object for each query + +```python +Prediction( + pid=[33, 6, 47, 74, 48], + passages=['No pain, no gain.', 'The best things in life are free.', 'Out of sight, out of mind.', 'To be or not to be, that is the question.', 'Patience is a virtue.'] +) +``` +# dspy.ColBERTv2RerankerLocal + +You can also use ColBERTv2 as a reranker in DSPy. + +### Constructor + +```python +class ColBERTv2RerankerLocal: + + def __init__( + self, + colbert_config=None, + checkpoint:str='bert-base-uncased'): +``` + +**Parameters** +- `colbert_config` (_ColBERTConfig_, _Optional_): colbert config for building and searching. Defaults to None. +- `checkpoint` (_str_): Embedding model for embeddings the documents and query + +### Methods +#### `forward(self,query:str,passages:List[str])` + +Based on a query and list of passages, it reranks the passages and returns the scores along with the passages ordered in descending order based on the similarity scores. + +**Parameters:** +- `query` (_str_): Query string used for reranking. +- `passages` (_List[str]_): List of passages to be reranked + +It returns the similarity scores array and you can link it to the passages by + +```python +for idx in np.argsort(scores_arr)[::-1]: + print(f"Passage = {passages[idx]} --> Score = {scores_arr[idx]}") +``` \ No newline at end of file diff --git a/dsp/modules/colbertv2.py b/dsp/modules/colbertv2.py index 58739688d..04926967e 100644 --- a/dsp/modules/colbertv2.py +++ b/dsp/modules/colbertv2.py @@ -164,7 +164,7 @@ def __init__(self,colbert_config=None,checkpoint:str='bert-base-uncased'): checkpoint_name (str, optional): checkpoint for embeddings. Defaults to 'bert-base-uncased'. """ self.colbert_config = colbert_config - self.checkpoint_name = checkpoint + self.checkpoint = checkpoint self.colbert_config.checkpoint = checkpoint def __call__(self, *args: Any, **kwargs: Any) -> Any: @@ -184,7 +184,7 @@ def forward(self,query:str,passages:List[str]=[]): query_ids,query_masks = query_tokenizer.tensorize([query]) doc_ids,doc_masks = doc_tokenizer.tensorize(passages) - col = ColBERT(self.checkpoint_name,self.colbert_config) + col = ColBERT(self.checkpoint,self.colbert_config) Q = col.query(query_ids,query_masks) DOC_IDS,DOC_MASKS = col.doc(doc_ids,doc_masks,keep_dims='return_mask') Q_duplicated = Q.repeat_interleave(len(passages), dim=0).contiguous() From 9cb522bb6bd05249fa27c7ff258df6bdba5cfc32 Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Mon, 29 Apr 2024 02:22:05 -0400 Subject: [PATCH 18/21] Colbert PR fixes --- dsp/modules/colbertv2.py | 2 +- dsp/primitives/search.py | 107 ++++- dspy/retrieve/retrieve.py | 33 +- .../integrations/colbert/colbert_local.ipynb | 397 ++++-------------- 4 files changed, 199 insertions(+), 340 deletions(-) diff --git a/dsp/modules/colbertv2.py b/dsp/modules/colbertv2.py index 04926967e..67b246c5e 100644 --- a/dsp/modules/colbertv2.py +++ b/dsp/modules/colbertv2.py @@ -147,7 +147,7 @@ def forward(self,query:str,k:int=7,**kwargs): searcher_results = self.searcher.search(query, k=k) results = [] for pid,rank,score in zip(*searcher_results): - results.append(dotdict({'long_text':self.searcher.collection[pid],'pid':pid})) + results.append(dotdict({'long_text':self.searcher.collection[pid],'score':score,'pid':pid})) return results class ColBERTv2RerankerLocal: diff --git a/dsp/primitives/search.py b/dsp/primitives/search.py index 81acef82d..6d9b79e51 100644 --- a/dsp/primitives/search.py +++ b/dsp/primitives/search.py @@ -1,17 +1,35 @@ -import warnings +import logging from collections.abc import Iterable import numpy as np import dsp +logger = logging.getLogger(__name__) +# def retrieve(query: str, k: int, **kwargs) -> list[str]: +# """Retrieves passages from the RM for the query and returns the top k passages.""" +# if not dsp.settings.rm: +# raise AssertionError("No RM is loaded.") +# passages = dsp.settings.rm(query, k=k, **kwargs) +# if not isinstance(passages, Iterable): +# # it's not an iterable yet; make it one. +# # TODO: we should unify the type signatures of dspy.Retriever +# passages = [passages] +# passages = [psg.long_text for psg in passages] + +# if dsp.settings.reranker: +# passages_cs_scores = dsp.settings.reranker(query, passages) +# passages_cs_scores_sorted = np.argsort(passages_cs_scores)[::-1] +# passages = [passages[idx] for idx in passages_cs_scores_sorted] + + +# return passages def retrieve(query: str, k: int, **kwargs) -> list[str]: """Retrieves passages from the RM for the query and returns the top k passages.""" + if not dsp.settings.rm: raise AssertionError("No RM is loaded.") - if not dsp.settings.reranker: - warnings.warn("If you want to use the Reranker, please use dspy.RetrieveThenRerank",DeprecationWarning) passages = dsp.settings.rm(query, k=k, **kwargs) if not isinstance(passages, Iterable): # it's not an iterable yet; make it one. @@ -21,18 +39,37 @@ def retrieve(query: str, k: int, **kwargs) -> list[str]: return passages -def retrieveRerankEnsemble(queries: list[str], k: int,**kwargs) -> list[str]: +# def retrieveRerankEnsemble(queries: list[str], k: int,**kwargs) -> list[str]: +# if not (dsp.settings.rm and dsp.settings.reranker): +# raise AssertionError("Both RM and Reranker are needed to retrieve & re-rank.") +# queries = [q for q in queries if q] +# passages = {} +# for query in queries: +# retrieved_passages = dsp.settings.rm(query, k=k*3,**kwargs) +# passages_cs_scores = dsp.settings.reranker(query, [psg.long_text for psg in retrieved_passages]) +# for idx in np.argsort(passages_cs_scores)[::-1]: +# psg = retrieved_passages[idx] +# passages[psg.long_text] = passages.get(psg.long_text, []) + [ +# passages_cs_scores[idx], +# ] + + +# passages = [(np.average(score), text) for text, score in passages.items()] +# return [text for _, text in sorted(passages, reverse=True)[:k]] +def retrieveRerankEnsemble(queries: list[str], k: int, **kwargs) -> list[str]: if not (dsp.settings.rm and dsp.settings.reranker): raise AssertionError("Both RM and Reranker are needed to retrieve & re-rank.") queries = [q for q in queries if q] all_queries_passages = [] for query in queries: passages = [] - retrieved_passages = dsp.settings.rm(query, k=k*3,**kwargs) - passages_cs_scores = dsp.settings.reranker(query,passages=[psg["long_text"] for psg in retrieved_passages]) + retrieved_passages = dsp.settings.rm(query, k=k * 3, **kwargs) + passages_cs_scores = dsp.settings.reranker( + query, passages=[psg["long_text"] for psg in retrieved_passages] + ) for idx in np.argsort(passages_cs_scores)[::-1][:k]: curr_passage = retrieved_passages[idx] - curr_passage['rerank_score'] = passages_cs_scores[idx] + curr_passage["rerank_score"] = passages_cs_scores[idx] passages.append(curr_passage) all_queries_passages.append(passages) if len(queries) == 1: @@ -40,15 +77,49 @@ def retrieveRerankEnsemble(queries: list[str], k: int,**kwargs) -> list[str]: else: return all_queries_passages -def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs) -> list[str]: + +# def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs) -> list[str]: +# """Retrieves passages from the RM for each query in queries and returns the top k passages +# based on the probability or score. +# """ +# if not dsp.settings.rm: +# raise AssertionError("No RM is loaded.") +# if dsp.settings.reranker: +# return retrieveRerankEnsemble(queries, k) + +# queries = [q for q in queries if q] + +# if len(queries) == 1: +# return retrieve(queries[0], k, **kwargs) + +# passages = {} +# for q in queries: +# for psg in dsp.settings.rm(q, k=k * 3,**kwargs): +# if by_prob: +# passages[psg.long_text] = passages.get(psg.long_text, 0.0) + psg.prob +# else: +# passages[psg.long_text] = passages.get(psg.long_text, 0.0) + psg.score + +# passages = [(score, text) for text, score in passages.items()] +# passages = sorted(passages, reverse=True)[:k] +# passages = [text for _, text in passages] + + +# return passages +def retrieveEnsemble( + queries: list[str], k: int, by_prob: bool = True, **kwargs +) -> list[str]: """Retrieves passages from the RM for each query in queries and returns the top k passages based on the probability or score. """ + if not dsp.settings.rm: raise AssertionError("No RM is loaded.") if not dsp.settings.reranker: - warnings.warn("If you want to use the Reranker, please use dspy.RetrieveThenRerank. The reranking is ignored here.",DeprecationWarning) - + logger.warn( + "DeprecationWarning: 'dspy.Retrieve' for reranking has been deprecated, please use dspy.RetrieveThenRerank. The reranking is ignored here. In the future this will raise an error." + ) + queries = [q for q in queries if q] if len(queries) == 1: @@ -56,15 +127,21 @@ def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs) all_queries_passages = [] for q in queries: passages = {} - retrieved_passages = dsp.settings.rm(q, k=k * 3,**kwargs) - for idx,psg in enumerate(retrieved_passages): + retrieved_passages = dsp.settings.rm(q, k=k * 3, **kwargs) + for idx, psg in enumerate(retrieved_passages): if by_prob: - passages[(idx,psg.long_text)] = passages.get(psg.long_text, 0.0) + psg.prob + passages[(idx, psg.long_text)] = ( + passages.get(psg.long_text, 0.0) + psg.prob + ) else: - passages[(idx,psg.long_text)] = passages.get(psg.long_text, 0.0) + psg.score + passages[(idx, psg.long_text)] = ( + passages.get(psg.long_text, 0.0) + psg.score + ) retrieved_passages[idx]["tracking_idx"] = idx passages = sorted(passages.items(), key=lambda item: item[1])[:k] req_indices = [psg[0][0] for psg in passages] - passages = [rp for rp in retrieved_passages if rp.get("tracking_idx") in req_indices] + passages = [ + rp for rp in retrieved_passages if rp.get("tracking_idx") in req_indices + ] all_queries_passages.append(passages) return all_queries_passages diff --git a/dspy/retrieve/retrieve.py b/dspy/retrieve/retrieve.py index b12283443..0dceff8de 100644 --- a/dspy/retrieve/retrieve.py +++ b/dspy/retrieve/retrieve.py @@ -5,6 +5,14 @@ from dspy.predict.parameter import Parameter from dspy.primitives.prediction import Prediction +def single_query_passage(passages): + passages_dict = {key:[] for key in list(passages[0].keys())} + for docs in passages: + for key,value in docs.items(): + passages_dict[key].append(value) + if "long_text" in passages_dict: + passages_dict["passages"] = passages_dict.pop("long_text") + return Prediction(**passages_dict) class Retrieve(Parameter): name = "Search" @@ -30,6 +38,14 @@ def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None,**kwargs) -> Union[Prediction,List[Prediction]]: + # queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries + # queries = [query.strip().split('\n')[0].strip() for query in queries] + + # # print(queries) + # # TODO: Consider removing any quote-like markers that surround the query too. + # k = k if k is not None else self.k + # passages = dsp.retrieveEnsemble(queries, k=k,**kwargs) + # return Prediction(passages=passages) queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries queries = [query.strip().split('\n')[0].strip() for query in queries] @@ -51,14 +67,7 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = No return pred_returns elif isinstance(passages[0], Dict): #passages dict will contain {"long_text":long_text_list,"metadatas";metadatas_list...} - passages_dict = {key:[] for key in list(passages[0].keys())} - - for psg in passages: - for key,value in psg.items(): - passages_dict[key].append(value) - if "long_text" in passages_dict: - passages_dict["passages"] = passages_dict.pop("long_text") - return Prediction(**passages_dict) + return single_query_passage(passages=passages) # TODO: Consider doing Prediction.from_completions with the individual sets of passages (per query) too. @@ -106,11 +115,5 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = No pred_returns.append(Prediction(**passages_dict)) return pred_returns elif isinstance(passages[0], Dict): - passages_dict = {key:[] for key in list(passages[0].keys())} - for docs in passages: - for key,value in docs.items(): - passages_dict[key].append(value) - if "long_text" in passages_dict: - passages_dict["passages"] = passages_dict.pop("long_text") - return Prediction(**passages_dict) + return single_query_passage(passages=passages) \ No newline at end of file diff --git a/examples/integrations/colbert/colbert_local.ipynb b/examples/integrations/colbert/colbert_local.ipynb index 84555da70..f5eb881a2 100644 --- a/examples/integrations/colbert/colbert_local.ipynb +++ b/examples/integrations/colbert/colbert_local.ipynb @@ -16,25 +16,16 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/athekunal/DSPy-contributions/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], + "outputs": [], "source": [ "from colbert.infra.config import ColBERTConfig" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -63,7 +54,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -74,12 +65,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## This tutorial is running from the examples/integrations/tutorials folder, hence we need to add the system path for dspy" + "## This tutorial is running from the `examples/integrations/tutorials folder`, hence we need to add the system path for dspy\n", + "\n", + "* If you have installed the dspy package, then you don't need to run the below cell" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -96,290 +89,9 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Building the index for experiment Colbert-Experiment with index name Colbert-RM\n", - "\n", - "\n", - "[Apr 13, 16:46:52] #> Creating directory /home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments/Colbert-Experiment/indexes/Colbert-RM \n", - "\n", - "\n", - "#> Starting...\n", - "nranks = 1 \t num_gpus = 1 \t device=0\n", - "{\n", - " \"query_token_id\": \"[unused0]\",\n", - " \"doc_token_id\": \"[unused1]\",\n", - " \"query_token\": \"[Q]\",\n", - " \"doc_token\": \"[D]\",\n", - " \"ncells\": null,\n", - " \"centroid_score_threshold\": null,\n", - " \"ndocs\": null,\n", - " \"load_index_with_mmap\": false,\n", - " \"index_path\": null,\n", - " \"index_bsize\": 64,\n", - " \"nbits\": 1,\n", - " \"kmeans_niters\": 20,\n", - " \"resume\": false,\n", - " \"similarity\": \"cosine\",\n", - " \"bsize\": 64,\n", - " \"accumsteps\": 1,\n", - " \"lr\": 1e-5,\n", - " \"maxsteps\": 400000,\n", - " \"save_every\": null,\n", - " \"warmup\": 20000,\n", - " \"warmup_bert\": null,\n", - " \"relu\": false,\n", - " \"nway\": 64,\n", - " \"use_ib_negatives\": true,\n", - " \"reranker\": false,\n", - " \"distillation_alpha\": 1.0,\n", - " \"ignore_scores\": false,\n", - " \"model_name\": null,\n", - " \"query_maxlen\": 32,\n", - " \"attend_to_mask_tokens\": false,\n", - " \"interaction\": \"colbert\",\n", - " \"dim\": 128,\n", - " \"doc_maxlen\": 180,\n", - " \"mask_punctuation\": true,\n", - " \"checkpoint\": \"colbert-ir\\/colbertv2.0\",\n", - " \"triples\": \"\\/future\\/u\\/okhattab\\/root\\/unit\\/experiments\\/2021.10\\/downstream.distillation.round2.2_score\\/round2.nway6.cosine.ib\\/examples.64.json\",\n", - " \"collection\": [\n", - " \"It's a piece of cake.\",\n", - " \"Don't put off until tomorrow what you can do today.\",\n", - " \"To kill two birds with one stone.\",\n", - " \"Actions speak louder than words.\",\n", - " \"Honesty is the best policy.\",\n", - " \"If you want something done right, do it yourself.\",\n", - " \"The best things in life are free.\",\n", - " \"Don't count your chickens before they hatch.\",\n", - " \"She sells seashells by the seashore.\",\n", - " \"Practice makes perfect.\",\n", - " \"Where there's a will, there's a way.\",\n", - " \"Absence makes the heart grow fonder.\",\n", - " \"When the going gets tough, the tough get going.\",\n", - " \"A journey of a thousand miles begins with a single step.\",\n", - " \"You can't have your cake and eat it too.\",\n", - " \"If you can't beat them, join them.\",\n", - " \"Keep your friends close and your enemies closer.\",\n", - " \"Don't put all your eggs in one basket.\",\n", - " \"All's fair in love and war.\",\n", - " \"Every dog has its day.\",\n", - " \"All good things must come to an end.\",\n", - " \"Once bitten, twice shy.\",\n", - " \"The apple doesn't fall far from the tree.\",\n", - " \"A penny saved is a penny earned.\",\n", - " \"Don't bite the hand that feeds you.\",\n", - " \"You reap what you sow.\",\n", - " \"An apple a day keeps the doctor away.\",\n", - " \"One man's trash is another man's treasure.\",\n", - " \"The squeaky wheel gets the grease.\",\n", - " \"A picture is worth a thousand words.\",\n", - " \"Fortune favors the bold.\",\n", - " \"Practice what you preach.\",\n", - " \"A watched pot never boils.\",\n", - " \"No pain, no gain.\",\n", - " \"You can't make an omelet without breaking eggs.\",\n", - " \"There's no place like home.\",\n", - " \"Ask and you shall receive.\",\n", - " \"Let sleeping dogs lie.\",\n", - " \"If the shoe fits, wear it.\",\n", - " \"Every cloud has a silver lining.\",\n", - " \"Look before you leap.\",\n", - " \"The more, the merrier.\",\n", - " \"The grass is always greener on the other side.\",\n", - " \"Beauty is only skin deep.\",\n", - " \"Two wrongs don't make a right.\",\n", - " \"Beauty is in the eye of the beholder.\",\n", - " \"Necessity is the mother of invention.\",\n", - " \"Out of sight, out of mind.\",\n", - " \"Patience is a virtue.\",\n", - " \"Curiosity killed the cat.\",\n", - " \"If at first you don't succeed, try, try again.\",\n", - " \"Beggars can't be choosers.\",\n", - " \"Too many cooks spoil the broth.\",\n", - " \"Easy come, easy go.\",\n", - " \"Don't cry over spilled milk.\",\n", - " \"There's no such thing as a free lunch.\",\n", - " \"A bird in the hand is worth two in the bush.\",\n", - " \"Good things come to those who wait.\",\n", - " \"The quick brown fox jumps over the lazy dog.\",\n", - " \"It takes two to tango.\",\n", - " \"A friend in need is a friend indeed.\",\n", - " \"Like father, like son.\",\n", - " \"Let bygones be bygones.\",\n", - " \"Kill two birds with one stone.\",\n", - " \"A penny for your thoughts.\",\n", - " \"I am the master of my fate, I am the captain of my soul.\",\n", - " \"The pen is mightier than the sword.\",\n", - " \"When in Rome, do as the Romans do.\",\n", - " \"Rome wasn't built in a day.\",\n", - " \"You can't judge a book by its cover.\",\n", - " \"It's raining cats and dogs.\",\n", - " \"Make hay while the sun shines.\",\n", - " \"It's better to be safe than sorry.\",\n", - " \"The early bird catches the worm.\",\n", - " \"To be or not to be, that is the question.\",\n", - " \"Better late than never.\"\n", - " ],\n", - " \"queries\": \"\\/future\\/u\\/okhattab\\/data\\/MSMARCO\\/queries.train.tsv\",\n", - " \"index_name\": \"Colbert-RM\",\n", - " \"overwrite\": false,\n", - " \"root\": \"\\/home\\/athekunal\\/DSPy-contributions\\/dspy\\/examples\\/integrations\\/colbert\\/experiments\",\n", - " \"experiment\": \"Colbert-Experiment\",\n", - " \"index_root\": null,\n", - " \"name\": \"2024-04\\/13\\/16.46.49\",\n", - " \"rank\": 0,\n", - " \"nranks\": 1,\n", - " \"amp\": true,\n", - " \"gpus\": 1,\n", - " \"avoid_fork_if_possible\": false\n", - "}\n", - "[Apr 13, 16:46:56] [0] \t\t # of sampled PIDs = 76 \t sampled_pids[:3] = [53, 1, 38]\n", - "[Apr 13, 16:46:56] [0] \t\t #> Encoding 76 passages..\n", - "[Apr 13, 16:46:57] [0] \t\t avg_doclen_est = 10.078947067260742 \t len(local_sample) = 76\n", - "[Apr 13, 16:46:57] [0] \t\t Creating 256 partitions.\n", - "[Apr 13, 16:46:57] [0] \t\t *Estimated* 765 embeddings.\n", - "[Apr 13, 16:46:57] [0] \t\t #> Saving the indexing plan to /home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments/Colbert-Experiment/indexes/Colbert-RM/plan.json ..\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING clustering 728 points to 256 centroids: please provide at least 9984 training points\n", - "Using /home/athekunal/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...\n", - "Detected CUDA files, patching ldflags\n", - "Emitting ninja build file /home/athekunal/.cache/torch_extensions/py310_cu117/decompress_residuals_cpp/build.ninja...\n", - "Building extension module decompress_residuals_cpp...\n", - "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", - "Loading extension module decompress_residuals_cpp...\n", - "Using /home/athekunal/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...\n", - "Detected CUDA files, patching ldflags\n", - "Emitting ninja build file /home/athekunal/.cache/torch_extensions/py310_cu117/packbits_cpp/build.ninja...\n", - "Building extension module packbits_cpp...\n", - "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n", - "Loading extension module packbits_cpp...\n", - "0it [00:00, ?it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Clustering 728 points in 128D to 256 clusters, redo 1 times, 20 iterations\n", - " Preprocessing in 0.00 s\n", - " Iteration 19 (0.05 s, search 0.05 s): objective=155.376 imbalance=1.400 nsplit=0 \n", - "[Apr 13, 16:46:58] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...\n", - "ninja: no work to do.\n", - "[Apr 13, 16:46:58] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...\n", - "ninja: no work to do.\n", - "[0.053, 0.05, 0.061, 0.053, 0.044, 0.06, 0.047, 0.049, 0.064, 0.05, 0.051, 0.057, 0.052, 0.042, 0.061, 0.069, 0.052, 0.051, 0.063, 0.043, 0.04, 0.052, 0.043, 0.042, 0.035, 0.055, 0.07, 0.048, 0.051, 0.04, 0.057, 0.045, 0.054, 0.052, 0.042, 0.051, 0.048, 0.047, 0.056, 0.059, 0.05, 0.063, 0.054, 0.06, 0.046, 0.051, 0.04, 0.071, 0.04, 0.049, 0.056, 0.043, 0.048, 0.051, 0.045, 0.052, 0.041, 0.073, 0.039, 0.045, 0.052, 0.056, 0.053, 0.06, 0.041, 0.053, 0.054, 0.052, 0.051, 0.05, 0.061, 0.053, 0.035, 0.05, 0.049, 0.057, 0.045, 0.044, 0.05, 0.05, 0.041, 0.048, 0.043, 0.049, 0.05, 0.039, 0.056, 0.055, 0.048, 0.045, 0.044, 0.041, 0.046, 0.044, 0.046, 0.064, 0.056, 0.054, 0.058, 0.04, 0.043, 0.045, 0.051, 0.058, 0.06, 0.043, 0.057, 0.043, 0.053, 0.056, 0.047, 0.039, 0.057, 0.044, 0.055, 0.063, 0.041, 0.047, 0.049, 0.051, 0.046, 0.042, 0.053, 0.045, 0.044, 0.053, 0.053, 0.046]\n", - "[Apr 13, 16:46:58] #> Got bucket_cutoffs_quantiles = tensor([0.5000], device='cuda:0') and bucket_weights_quantiles = tensor([0.2500, 0.7500], device='cuda:0')\n", - "[Apr 13, 16:46:58] #> Got bucket_cutoffs = tensor([0.0007], device='cuda:0') and bucket_weights = tensor([-0.0378, 0.0386], device='cuda:0')\n", - "[Apr 13, 16:46:58] avg_residual = 0.050323486328125\n", - "[Apr 13, 16:46:58] [0] \t\t #> Encoding 76 passages..\n", - "[Apr 13, 16:46:58] [0] \t\t #> Saving chunk 0: \t 76 passages and 766 embeddings. From #0 onward.\n", - "[Apr 13, 16:46:58] [0] \t\t #> Checking all files were saved...\n", - "[Apr 13, 16:46:58] [0] \t\t Found all files!\n", - "[Apr 13, 16:46:58] [0] \t\t #> Building IVF...\n", - "[Apr 13, 16:46:58] [0] \t\t #> Loading codes...\n", - "[Apr 13, 16:46:58] [0] \t\t Sorting codes...\n", - "[Apr 13, 16:46:58] [0] \t\t Getting unique codes...\n", - "[Apr 13, 16:46:58] #> Optimizing IVF to store map from centroids to list of pids..\n", - "[Apr 13, 16:46:58] #> Building the emb2pid mapping..\n", - "[Apr 13, 16:46:58] len(emb2pid) = 766\n", - "[Apr 13, 16:46:58] #> Saved optimized IVF to /home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments/Colbert-Experiment/indexes/Colbert-RM/ivf.pid.pt\n", - "[Apr 13, 16:46:58] [0] \t\t #> Saving the indexing metadata to /home/athekunal/DSPy-contributions/dspy/examples/integrations/colbert/experiments/Colbert-Experiment/indexes/Colbert-RM/metadata.json ..\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "1it [00:00, 15.27it/s]\n", - "100%|██████████| 1/1 [00:00<00:00, 3551.49it/s]\n", - "100%|██████████| 256/256 [00:00<00:00, 281526.44it/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "#> Joined...\n", - "Loading the index for experiment Colbert-Experiment with index name Colbert-RM\n", - "[Apr 13, 16:47:02] #> Loading codec...\n", - "[Apr 13, 16:47:02] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Using /home/athekunal/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...\n", - "Detected CUDA files, patching ldflags\n", - "Emitting ninja build file /home/athekunal/.cache/torch_extensions/py310_cu117/decompress_residuals_cpp/build.ninja...\n", - "Building extension module decompress_residuals_cpp...\n", - "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ninja: no work to do.\n", - "[Apr 13, 16:47:02] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Loading extension module decompress_residuals_cpp...\n", - "Using /home/athekunal/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...\n", - "Detected CUDA files, patching ldflags\n", - "Emitting ninja build file /home/athekunal/.cache/torch_extensions/py310_cu117/packbits_cpp/build.ninja...\n", - "Building extension module packbits_cpp...\n", - "Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "ninja: no work to do.\n", - "[Apr 13, 16:47:02] #> Loading IVF...\n", - "[Apr 13, 16:47:02] #> Loading doclens...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Loading extension module packbits_cpp...\n", - "100%|██████████| 1/1 [00:00<00:00, 4969.55it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[Apr 13, 16:47:02] #> Loading codes and residuals...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n", - "100%|██████████| 1/1 [00:00<00:00, 805.05it/s]\n" - ] - } - ], + "outputs": [], "source": [ "import dspy\n", "colbert_config = ColBERTConfig()\n", @@ -394,7 +106,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -406,9 +118,16 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "DeprecationWarning: 'dspy.Retrieve' for reranking has been deprecated, please use dspy.RetrieveThenRerank. The reranking is ignored here. In the future this will raise an error.\n" + ] + }, { "name": "stdout", "output_type": "stream", @@ -433,19 +152,20 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Prediction(\n", + " score=[nan, nan, nan, nan, nan],\n", " pid=[33, 6, 47, 74, 48],\n", " passages=['No pain, no gain.', 'The best things in life are free.', 'Out of sight, out of mind.', 'To be or not to be, that is the question.', 'Patience is a virtue.']\n", ")" ] }, - "execution_count": 9, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -454,6 +174,54 @@ "pred" ] }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "DeprecationWarning: 'dspy.Retrieve' for reranking has been deprecated, please use dspy.RetrieveThenRerank. The reranking is ignored here. In the future this will raise an error.\n" + ] + } + ], + "source": [ + "multiple_pred = retrieved_docs(\n", + " [\"What is the meaning of life?\",\"Meaning of pain?\"],by_prob=False\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Prediction(\n", + " score=[nan, nan, nan, nan, nan],\n", + " pid=[33, 6, 47, 74, 48],\n", + " passages=['No pain, no gain.', 'The best things in life are free.', 'Out of sight, out of mind.', 'To be or not to be, that is the question.', 'Patience is a virtue.']\n", + " ),\n", + " Prediction(\n", + " score=[nan, nan, nan, nan, nan],\n", + " pid=[16, 0, 47, 74, 26],\n", + " passages=['Keep your friends close and your enemies closer.', \"It's a piece of cake.\", 'Out of sight, out of mind.', 'To be or not to be, that is the question.', 'An apple a day keeps the doctor away.']\n", + " )]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "multiple_pred" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -463,7 +231,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -475,7 +243,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -486,7 +254,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -497,25 +265,27 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[Prediction(\n", + " score=[nan, nan, nan, nan, nan],\n", " pid=[6, 48, 74, 47, 33],\n", " rerank_score=[15.8359375, 14.2109375, 12.5703125, 11.7890625, 9.1796875],\n", " passages=['The best things in life are free.', 'Patience is a virtue.', 'To be or not to be, that is the question.', 'Out of sight, out of mind.', 'No pain, no gain.']\n", " ),\n", " Prediction(\n", + " score=[nan, nan, nan, nan, nan],\n", " pid=[33, 0, 47, 74, 16],\n", " rerank_score=[19.828125, 12.2890625, 11.171875, 9.09375, 6.8984375],\n", " passages=['No pain, no gain.', \"It's a piece of cake.\", 'Out of sight, out of mind.', 'To be or not to be, that is the question.', 'Keep your friends close and your enemies closer.']\n", " )]" ] }, - "execution_count": 13, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -533,7 +303,16 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install tabulate" + ] + }, + { + "cell_type": "code", + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -556,7 +335,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 19, "metadata": {}, "outputs": [ { From ec4b9b3036e79dc5946f62239f109f1cacec1cbd Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Mon, 29 Apr 2024 02:23:20 -0400 Subject: [PATCH 19/21] linting fixes --- dsp/primitives/search.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dsp/primitives/search.py b/dsp/primitives/search.py index 6d9b79e51..922d142aa 100644 --- a/dsp/primitives/search.py +++ b/dsp/primitives/search.py @@ -65,7 +65,7 @@ def retrieveRerankEnsemble(queries: list[str], k: int, **kwargs) -> list[str]: passages = [] retrieved_passages = dsp.settings.rm(query, k=k * 3, **kwargs) passages_cs_scores = dsp.settings.reranker( - query, passages=[psg["long_text"] for psg in retrieved_passages] + query, passages=[psg["long_text"] for psg in retrieved_passages], ) for idx in np.argsort(passages_cs_scores)[::-1][:k]: curr_passage = retrieved_passages[idx] @@ -107,7 +107,7 @@ def retrieveRerankEnsemble(queries: list[str], k: int, **kwargs) -> list[str]: # return passages def retrieveEnsemble( - queries: list[str], k: int, by_prob: bool = True, **kwargs + queries: list[str], k: int, by_prob: bool = True, **kwargs, ) -> list[str]: """Retrieves passages from the RM for each query in queries and returns the top k passages based on the probability or score. @@ -117,7 +117,7 @@ def retrieveEnsemble( raise AssertionError("No RM is loaded.") if not dsp.settings.reranker: logger.warn( - "DeprecationWarning: 'dspy.Retrieve' for reranking has been deprecated, please use dspy.RetrieveThenRerank. The reranking is ignored here. In the future this will raise an error." + "DeprecationWarning: 'dspy.Retrieve' for reranking has been deprecated, please use dspy.RetrieveThenRerank. The reranking is ignored here. In the future this will raise an error.", ) queries = [q for q in queries if q] From 326ce0172dcfe219acf6127b085232cde43349ad Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Mon, 29 Apr 2024 02:26:35 -0400 Subject: [PATCH 20/21] more linting fixes --- dspy/retrieve/retrieve.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dspy/retrieve/retrieve.py b/dspy/retrieve/retrieve.py index 0dceff8de..c1395a9c1 100644 --- a/dspy/retrieve/retrieve.py +++ b/dspy/retrieve/retrieve.py @@ -5,6 +5,7 @@ from dspy.predict.parameter import Parameter from dspy.primitives.prediction import Prediction + def single_query_passage(passages): passages_dict = {key:[] for key in list(passages[0].keys())} for docs in passages: From b5913fc6677e8c213172b704ef94b6383ff87e6a Mon Sep 17 00:00:00 2001 From: Athe-kunal Date: Fri, 7 Jun 2024 23:31:37 -0400 Subject: [PATCH 21/21] fixing previous cache breaks with separate funcs --- dsp/primitives/search.py | 126 +++++++++++++++++++------------------- dspy/retrieve/retrieve.py | 120 +++++++++++++++++++++++------------- 2 files changed, 142 insertions(+), 104 deletions(-) diff --git a/dsp/primitives/search.py b/dsp/primitives/search.py index 922d142aa..8f2b215b9 100644 --- a/dsp/primitives/search.py +++ b/dsp/primitives/search.py @@ -7,26 +7,26 @@ logger = logging.getLogger(__name__) -# def retrieve(query: str, k: int, **kwargs) -> list[str]: -# """Retrieves passages from the RM for the query and returns the top k passages.""" -# if not dsp.settings.rm: -# raise AssertionError("No RM is loaded.") -# passages = dsp.settings.rm(query, k=k, **kwargs) -# if not isinstance(passages, Iterable): -# # it's not an iterable yet; make it one. -# # TODO: we should unify the type signatures of dspy.Retriever -# passages = [passages] -# passages = [psg.long_text for psg in passages] - -# if dsp.settings.reranker: -# passages_cs_scores = dsp.settings.reranker(query, passages) -# passages_cs_scores_sorted = np.argsort(passages_cs_scores)[::-1] -# passages = [passages[idx] for idx in passages_cs_scores_sorted] - - -# return passages def retrieve(query: str, k: int, **kwargs) -> list[str]: """Retrieves passages from the RM for the query and returns the top k passages.""" + if not dsp.settings.rm: + raise AssertionError("No RM is loaded.") + passages = dsp.settings.rm(query, k=k, **kwargs) + if not isinstance(passages, Iterable): + # it's not an iterable yet; make it one. + # TODO: we should unify the type signatures of dspy.Retriever + passages = [passages] + passages = [psg.long_text for psg in passages] + + if dsp.settings.reranker: + passages_cs_scores = dsp.settings.reranker(query, passages) + passages_cs_scores_sorted = np.argsort(passages_cs_scores)[::-1] + passages = [passages[idx] for idx in passages_cs_scores_sorted] + + + return passages +def retrievewithMetadata(query: str, k: int, **kwargs) -> list[str]: + """Retrieves passages from the RM for the query and returns the top k passages.""" if not dsp.settings.rm: raise AssertionError("No RM is loaded.") @@ -39,24 +39,25 @@ def retrieve(query: str, k: int, **kwargs) -> list[str]: return passages -# def retrieveRerankEnsemble(queries: list[str], k: int,**kwargs) -> list[str]: -# if not (dsp.settings.rm and dsp.settings.reranker): -# raise AssertionError("Both RM and Reranker are needed to retrieve & re-rank.") -# queries = [q for q in queries if q] -# passages = {} -# for query in queries: -# retrieved_passages = dsp.settings.rm(query, k=k*3,**kwargs) -# passages_cs_scores = dsp.settings.reranker(query, [psg.long_text for psg in retrieved_passages]) -# for idx in np.argsort(passages_cs_scores)[::-1]: -# psg = retrieved_passages[idx] -# passages[psg.long_text] = passages.get(psg.long_text, []) + [ -# passages_cs_scores[idx], -# ] - - -# passages = [(np.average(score), text) for text, score in passages.items()] -# return [text for _, text in sorted(passages, reverse=True)[:k]] -def retrieveRerankEnsemble(queries: list[str], k: int, **kwargs) -> list[str]: +def retrieveRerankEnsemble(queries: list[str], k: int,**kwargs) -> list[str]: + if not (dsp.settings.rm and dsp.settings.reranker): + raise AssertionError("Both RM and Reranker are needed to retrieve & re-rank.") + queries = [q for q in queries if q] + passages = {} + for query in queries: + retrieved_passages = dsp.settings.rm(query, k=k*3,**kwargs) + passages_cs_scores = dsp.settings.reranker(query, [psg.long_text for psg in retrieved_passages]) + for idx in np.argsort(passages_cs_scores)[::-1]: + psg = retrieved_passages[idx] + passages[psg.long_text] = passages.get(psg.long_text, []) + [ + passages_cs_scores[idx], + ] + + + passages = [(np.average(score), text) for text, score in passages.items()] + return [text for _, text in sorted(passages, reverse=True)[:k]] + +def retrieveRerankEnsemblewithMetadata(queries: list[str], k: int, **kwargs) -> list[str]: if not (dsp.settings.rm and dsp.settings.reranker): raise AssertionError("Both RM and Reranker are needed to retrieve & re-rank.") queries = [q for q in queries if q] @@ -78,35 +79,36 @@ def retrieveRerankEnsemble(queries: list[str], k: int, **kwargs) -> list[str]: return all_queries_passages -# def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs) -> list[str]: -# """Retrieves passages from the RM for each query in queries and returns the top k passages -# based on the probability or score. -# """ -# if not dsp.settings.rm: -# raise AssertionError("No RM is loaded.") -# if dsp.settings.reranker: -# return retrieveRerankEnsemble(queries, k) +def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs) -> list[str]: + """Retrieves passages from the RM for each query in queries and returns the top k passages + based on the probability or score. + """ + if not dsp.settings.rm: + raise AssertionError("No RM is loaded.") + if dsp.settings.reranker: + return retrieveRerankEnsemble(queries, k) + + queries = [q for q in queries if q] -# queries = [q for q in queries if q] + if len(queries) == 1: + return retrieve(queries[0], k, **kwargs) -# if len(queries) == 1: -# return retrieve(queries[0], k, **kwargs) + passages = {} + for q in queries: + for psg in dsp.settings.rm(q, k=k * 3,**kwargs): + if by_prob: + passages[psg.long_text] = passages.get(psg.long_text, 0.0) + psg.prob + else: + passages[psg.long_text] = passages.get(psg.long_text, 0.0) + psg.score -# passages = {} -# for q in queries: -# for psg in dsp.settings.rm(q, k=k * 3,**kwargs): -# if by_prob: -# passages[psg.long_text] = passages.get(psg.long_text, 0.0) + psg.prob -# else: -# passages[psg.long_text] = passages.get(psg.long_text, 0.0) + psg.score + passages = [(score, text) for text, score in passages.items()] + passages = sorted(passages, reverse=True)[:k] + passages = [text for _, text in passages] -# passages = [(score, text) for text, score in passages.items()] -# passages = sorted(passages, reverse=True)[:k] -# passages = [text for _, text in passages] + return passages -# return passages -def retrieveEnsemble( +def retrieveEnsemblewithMetadata( queries: list[str], k: int, by_prob: bool = True, **kwargs, ) -> list[str]: """Retrieves passages from the RM for each query in queries and returns the top k passages @@ -116,9 +118,7 @@ def retrieveEnsemble( if not dsp.settings.rm: raise AssertionError("No RM is loaded.") if not dsp.settings.reranker: - logger.warn( - "DeprecationWarning: 'dspy.Retrieve' for reranking has been deprecated, please use dspy.RetrieveThenRerank. The reranking is ignored here. In the future this will raise an error.", - ) + return retrieveRerankEnsemblewithMetadata(queries=queries,k=k) queries = [q for q in queries if q] @@ -144,4 +144,4 @@ def retrieveEnsemble( rp for rp in retrieved_passages if rp.get("tracking_idx") in req_indices ] all_queries_passages.append(passages) - return all_queries_passages + return all_queries_passages \ No newline at end of file diff --git a/dspy/retrieve/retrieve.py b/dspy/retrieve/retrieve.py index c1395a9c1..6c50e2bbf 100644 --- a/dspy/retrieve/retrieve.py +++ b/dspy/retrieve/retrieve.py @@ -7,14 +7,15 @@ def single_query_passage(passages): - passages_dict = {key:[] for key in list(passages[0].keys())} + passages_dict = {key: [] for key in list(passages[0].keys())} for docs in passages: - for key,value in docs.items(): + for key, value in docs.items(): passages_dict[key].append(value) if "long_text" in passages_dict: passages_dict["passages"] = passages_dict.pop("long_text") return Prediction(**passages_dict) + class Retrieve(Parameter): name = "Search" input_variable = "query" @@ -38,7 +39,14 @@ def load_state(self, state): def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) - def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None,**kwargs) -> Union[Prediction,List[Prediction]]: + def forward( + self, + query_or_queries: Union[str, List[str]], + k: Optional[int] = None, + by_prob: bool = True, + with_metadata: bool = False, + **kwargs, + ) -> Union[List[str], Prediction, List[Prediction]]: # queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries # queries = [query.strip().split('\n')[0].strip() for query in queries] @@ -47,31 +55,48 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = No # k = k if k is not None else self.k # passages = dsp.retrieveEnsemble(queries, k=k,**kwargs) # return Prediction(passages=passages) - queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries - queries = [query.strip().split('\n')[0].strip() for query in queries] + queries = ( + [query_or_queries] + if isinstance(query_or_queries, str) + else query_or_queries + ) + queries = [query.strip().split("\n")[0].strip() for query in queries] # print(queries) # TODO: Consider removing any quote-like markers that surround the query too. k = k if k is not None else self.k - passages = dsp.retrieveEnsemble(queries, k=k,**kwargs) - if isinstance(passages[0],List): - pred_returns = [] - for query_passages in passages: - passages_dict = {key:[] for key in list(query_passages[0].keys()) if key!="tracking_idx"} - for psg in query_passages: - for key,value in psg.items(): - if key == "tracking_idx": continue - passages_dict[key].append(value) - if "long_text" in passages_dict: - passages_dict["passages"] = passages_dict.pop("long_text") - pred_returns.append(Prediction(**passages_dict)) - return pred_returns - elif isinstance(passages[0], Dict): - #passages dict will contain {"long_text":long_text_list,"metadatas";metadatas_list...} - return single_query_passage(passages=passages) - + if not with_metadata: + passages = dsp.retrieveEnsemble(queries, k=k, by_prob=by_prob, **kwargs) + return Prediction(passages=passages) + else: + passages = dsp.retrieveEnsemblewithMetadata( + queries, k=k, by_prob=by_prob, **kwargs + ) + if isinstance(passages[0], List): + pred_returns = [] + for query_passages in passages: + passages_dict = { + key: [] + for key in list(query_passages[0].keys()) + if key != "tracking_idx" + } + for psg in query_passages: + for key, value in psg.items(): + if key == "tracking_idx": + continue + passages_dict[key].append(value) + if "long_text" in passages_dict: + passages_dict["passages"] = passages_dict.pop("long_text") + pred_returns.append(Prediction(**passages_dict)) + return pred_returns + elif isinstance(passages[0], Dict): + # passages dict will contain {"long_text":long_text_list,"metadatas";metadatas_list...} + return single_query_passage(passages=passages) + + # TODO: Consider doing Prediction.from_completions with the individual sets of passages (per query) too. + class RetrieveThenRerank(Parameter): name = "Search" input_variable = "query" @@ -95,26 +120,39 @@ def load_state(self, state): def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) - def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None,**kwargs) -> Union[Prediction,List[Prediction]]: - queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries - queries = [query.strip().split('\n')[0].strip() for query in queries] + def forward( + self, + query_or_queries: Union[str, List[str]], + k: Optional[int] = None, + with_metadata: bool = False, + **kwargs, + ) -> Union[List[str], Prediction, List[Prediction]]: + queries = ( + [query_or_queries] + if isinstance(query_or_queries, str) + else query_or_queries + ) + queries = [query.strip().split("\n")[0].strip() for query in queries] # print(queries) # TODO: Consider removing any quote-like markers that surround the query too. k = k if k is not None else self.k - passages = dsp.retrieveRerankEnsemble(queries, k=k,**kwargs) - if isinstance(passages[0],List): - pred_returns = [] - for query_passages in passages: - passages_dict = {key:[] for key in list(query_passages[0].keys())} - for docs in query_passages: - for key,value in docs.items(): - passages_dict[key].append(value) - if "long_text" in passages_dict: - passages_dict["passages"] = passages_dict.pop("long_text") - - pred_returns.append(Prediction(**passages_dict)) - return pred_returns - elif isinstance(passages[0], Dict): - return single_query_passage(passages=passages) - \ No newline at end of file + if not with_metadata: + passages = dsp.retrieveRerankEnsemble(queries, k=k, **kwargs) + return passages + else: + passages = dsp.retrieveRerankEnsemblewithMetadata(queries, k=k, **kwargs) + if isinstance(passages[0], List): + pred_returns = [] + for query_passages in passages: + passages_dict = {key: [] for key in list(query_passages[0].keys())} + for docs in query_passages: + for key, value in docs.items(): + passages_dict[key].append(value) + if "long_text" in passages_dict: + passages_dict["passages"] = passages_dict.pop("long_text") + + pred_returns.append(Prediction(**passages_dict)) + return pred_returns + elif isinstance(passages[0], Dict): + return single_query_passage(passages=passages)