Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor communication between Pipeline Components #1321

Merged
merged 55 commits into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
324e615
Add POC for extractive-qa pipeline
oryx1729 Aug 5, 2021
09eabbc
Remove kwargs from run
oryx1729 Aug 16, 2021
b431e3d
Remove kwargs from reader run
oryx1729 Aug 16, 2021
73739e4
Add handling of debug information from nodes
oryx1729 Aug 16, 2021
97d6ab1
Fix type hints
oryx1729 Aug 16, 2021
4de5f13
Fix standard pipelines
oryx1729 Aug 16, 2021
0b1e8e1
Handle null params
oryx1729 Aug 16, 2021
8d0af9f
Refactor run() for all components
oryx1729 Aug 16, 2021
9e078eb
Fix type hint
oryx1729 Aug 16, 2021
9406cda
Fix EvalDocuments
oryx1729 Aug 16, 2021
525486e
Fix typing
oryx1729 Aug 17, 2021
4466c54
Fix EvalAnswers
oryx1729 Aug 17, 2021
7108c5f
Fix Summarizer test
oryx1729 Aug 17, 2021
79cbb2e
Fix EvalAnswers
oryx1729 Aug 17, 2021
a4ab9b4
Fix test
oryx1729 Aug 17, 2021
de1cd30
Fix Ray test
oryx1729 Aug 17, 2021
50f90be
Fix QueryClassifier
oryx1729 Aug 17, 2021
76d82ed
Fix TransformersQueryClassifier
oryx1729 Aug 17, 2021
bbb7886
Fix SklearnQueryClassifier
oryx1729 Aug 17, 2021
d6ab5c1
Add support for more types as Pipeline inputs
oryx1729 Aug 18, 2021
e2b750a
Fix eval test
oryx1729 Aug 18, 2021
2d7ddbf
Fix RayPipeline
oryx1729 Aug 18, 2021
94cdd59
Fix QuestionGenerator
oryx1729 Aug 18, 2021
af03490
Cast Reader.run_batch() results to dict
oryx1729 Aug 18, 2021
e8a9eb4
Fix Retriever.run()
oryx1729 Aug 18, 2021
2bb3603
Fix translator
oryx1729 Aug 18, 2021
d2a3763
Fix JoinDocuments
oryx1729 Aug 18, 2021
0a9855c
Allows dicts as run params
oryx1729 Aug 18, 2021
6658042
Fix typing
oryx1729 Aug 18, 2021
5014db5
Udpate Pipeline tests
oryx1729 Aug 18, 2021
b71c791
Adjust Pipeline tests
oryx1729 Aug 18, 2021
b0b045d
Refactor REST APIs
oryx1729 Aug 19, 2021
9d05f97
Subclass dict for primititves
oryx1729 Aug 19, 2021
032b977
Revert dict cast for primitives
oryx1729 Aug 19, 2021
28ae659
Update tests for rest_api
oryx1729 Aug 19, 2021
9498fea
Fix pipeline test
oryx1729 Aug 19, 2021
9fa4252
Fix Eval
oryx1729 Aug 19, 2021
a960bf7
Add tests for invalid input to Pipelines
oryx1729 Aug 19, 2021
d673516
Add docstring for _dispatch_run()
oryx1729 Aug 20, 2021
266ba37
Adapt UI query endpoint
oryx1729 Aug 20, 2021
6eeb4ac
Update tutorials
oryx1729 Aug 20, 2021
fb82bb3
Fix filters dict access in query API
oryx1729 Aug 20, 2021
b383000
Update tutorial
oryx1729 Sep 2, 2021
45c1d55
Add type hints for run() in eval.py
oryx1729 Sep 2, 2021
cfbcc05
Fix docstring
oryx1729 Sep 2, 2021
7bd549c
Add explicit args for run() in BaseComponent
oryx1729 Sep 2, 2021
8cf409a
Update docstrings for standard pipelines
oryx1729 Sep 2, 2021
1c28554
Add missing import
oryx1729 Sep 2, 2021
ef46880
Add test for _debug
oryx1729 Sep 2, 2021
f97acb4
Update example in README
oryx1729 Sep 2, 2021
c8ff595
Update Pipelines README
oryx1729 Sep 9, 2021
d2c9755
Update Pipeline Tutorial
oryx1729 Sep 9, 2021
ca1c214
Remove kwargs from crawler run()
oryx1729 Sep 9, 2021
49f5d8a
Remove kwargs from FileTypeClassifier run()
oryx1729 Sep 9, 2021
32b6aae
Fix QueryClassifier in tutorial
oryx1729 Sep 9, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor run() for all components
  • Loading branch information
oryx1729 committed Sep 2, 2021
commit 8d0af9f203b364bac76312a9bb48fd3bff7d5db5
8 changes: 2 additions & 6 deletions haystack/classifier/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] =
def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
pass

def run(self, query: str, documents: List[Document], top_k: Optional[int] = None, **kwargs): # type: ignore
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): # type: ignore
self.query_count += 1
if documents:
predict = self.timing(self.predict, "query_time")
Expand All @@ -36,11 +36,7 @@ def run(self, query: str, documents: List[Document], top_k: Optional[int] = None

document_ids = [doc.id for doc in results]
logger.debug(f"Retrieved documents with IDs: {document_ids}")
output = {
"query": query,
"documents": results,
**kwargs
}
output = {"documents": results}

return output, "output_1"

Expand Down
7 changes: 3 additions & 4 deletions haystack/classifier/farm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,12 @@ class FARMClassifier(BaseClassifier):
retriever = ElasticsearchRetriever(document_store=document_store)
classifier = FARMClassifier(model_name_or_path="deepset/bert-base-german-cased-sentiment-Germeval17")
p = Pipeline()
p.add_node(component=retriever, name="ESRetriever", inputs=["Query"])
p.add_node(component=classifier, name="Classifier", inputs=["ESRetriever"])
p.add_node(component=retriever, name="Retriever", inputs=["Query"])
p.add_node(component=classifier, name="Classifier", inputs=["Retriever"])

res = p_extractive.run(
query="Who is the father of Arya Stark?",
oryx1729 marked this conversation as resolved.
Show resolved Hide resolved
top_k_retriever=10,
top_k_reader=5
params={"Retriever": {"top_k": 10}, "Classifier": {"top_k": 5}}
)

print(res["documents"][0].to_dict()["meta"]["classification"]["label"])
Expand Down
4 changes: 2 additions & 2 deletions haystack/document_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,9 @@ def delete_all_documents(self, index: Optional[str] = None, filters: Optional[Di
def delete_documents(self, index: Optional[str] = None, filters: Optional[Dict[str, List[str]]] = None):
pass

def run(self, documents: List[dict], index: Optional[str] = None, **kwargs): # type: ignore
def run(self, documents: List[dict], index: Optional[str] = None): # type: ignore
self.write_documents(documents=documents, index=index)
return kwargs, "output_1"
return {}, "output_1"

@abstractmethod
def get_documents_by_id(self, ids: List[str], index: Optional[str] = None,
Expand Down
26 changes: 13 additions & 13 deletions haystack/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class EvalDocuments:
a look at our evaluation tutorial for more info about open vs closed domain eval (
https://haystack.deepset.ai/tutorials/evaluation).
"""
def __init__(self, debug: bool=False, open_domain: bool=True, top_k_eval_documents: int=10, name="EvalDocuments"):
def __init__(self, debug: bool=False, open_domain: bool=True, top_k: int=10, name="EvalDocuments"):
"""
:param open_domain: When True, a document is considered correctly retrieved so long as the answer string can be found within it.
When False, correct retrieval is evaluated based on document_id.
Expand All @@ -35,7 +35,7 @@ def __init__(self, debug: bool=False, open_domain: bool=True, top_k_eval_documen
self.debug = debug
self.log: List = []
self.open_domain = open_domain
self.top_k_eval_documents = top_k_eval_documents
self.top_k = top_k
self.name = name
self.too_few_docs_warning = False
self.top_k_used = 0
Expand All @@ -53,25 +53,25 @@ def init_counts(self):
self.reciprocal_rank_sum = 0.0
self.has_answer_reciprocal_rank_sum = 0.0

def run(self, documents, labels: dict, top_k_eval_documents: Optional[int]=None, **kwargs):
def run(self, documents, labels: dict, top_k: Optional[int]=None, **kwargs):
"""Run this node on one sample and its labels"""
oryx1729 marked this conversation as resolved.
Show resolved Hide resolved
self.query_count += 1
retriever_labels = get_label(labels, kwargs["node_id"])
if not top_k_eval_documents:
top_k_eval_documents = self.top_k_eval_documents
if not top_k:
top_k = self.top_k

if not self.top_k_used:
self.top_k_used = top_k_eval_documents
elif self.top_k_used != top_k_eval_documents:
self.top_k_used = top_k
elif self.top_k_used != top_k:
logger.warning(f"EvalDocuments was last run with top_k_eval_documents={self.top_k_used} but is "
f"being run again with top_k_eval_documents={self.top_k_eval_documents}. "
f"being run again with top_k={self.top_k}. "
f"The evaluation counter is being reset from this point so that the evaluation "
f"metrics are interpretable.")
self.init_counts()

if len(documents) < top_k_eval_documents and not self.too_few_docs_warning:
logger.warning(f"EvalDocuments is being provided less candidate documents than top_k_eval_documents "
f"(currently set to {top_k_eval_documents}).")
if len(documents) < top_k and not self.too_few_docs_warning:
logger.warning(f"EvalDocuments is being provided less candidate documents than top_k "
f"(currently set to {top_k}).")
self.too_few_docs_warning = True

# TODO retriever_labels is currently a Multilabel object but should eventually be a RetrieverLabel object
Expand All @@ -89,7 +89,7 @@ def run(self, documents, labels: dict, top_k_eval_documents: Optional[int]=None,
# If there are answer span annotations in the labels
else:
self.has_answer_count += 1
retrieved_reciprocal_rank = self.reciprocal_rank_retrieved(retriever_labels, documents, top_k_eval_documents)
retrieved_reciprocal_rank = self.reciprocal_rank_retrieved(retriever_labels, documents, top_k)
self.reciprocal_rank_sum += retrieved_reciprocal_rank
correct_retrieval = True if retrieved_reciprocal_rank > 0 else False
self.has_answer_correct += int(correct_retrieval)
Expand All @@ -101,7 +101,7 @@ def run(self, documents, labels: dict, top_k_eval_documents: Optional[int]=None,
self.recall = self.correct_retrieval_count / self.query_count
self.mean_reciprocal_rank = self.reciprocal_rank_sum / self.query_count

self.top_k_used = top_k_eval_documents
self.top_k_used = top_k

if self.debug:
self.log.append({"documents": documents, "labels": labels, "correct_retrieval": correct_retrieval, "retrieved_reciprocal_rank": retrieved_reciprocal_rank, **kwargs})
Expand Down
4 changes: 2 additions & 2 deletions haystack/file_converter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def validate_language(self, text: str) -> bool:
def run(self, file_paths: Union[Path, List[Path]], # type: ignore
meta: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, # type: ignore
remove_numeric_tables: Optional[bool] = None, # type: ignore
valid_languages: Optional[List[str]] = None, **kwargs): # type: ignore
valid_languages: Optional[List[str]] = None): # type: ignore

if isinstance(file_paths, Path):
file_paths = [file_paths]
Expand All @@ -110,7 +110,7 @@ def run(self, file_paths: Union[Path, List[Path]], # type: ignore
)
)

result = {"documents": documents, **kwargs}
result = {"documents": documents}
return result, "output_1"


Expand Down
5 changes: 2 additions & 3 deletions haystack/generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,11 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int]) -
"""
pass

def run(self, query: str, documents: List[Document], top_k_generator: Optional[int] = None, **kwargs): # type: ignore
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): # type: ignore

if documents:
results = self.predict(query=query, documents=documents, top_k=top_k_generator)
results = self.predict(query=query, documents=documents, top_k=top_k)
else:
results = {"answers": []}

results.update(**kwargs)
return results, "output_1"
6 changes: 2 additions & 4 deletions haystack/graph_retriever/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ def retrieve(self, query: str, top_k: int):
def eval(self):
raise NotImplementedError

def run(self, query: str, top_k: int, **kwargs): # type: ignore
def run(self, query: str, top_k: int): # type: ignore
answers = self.retrieve(query=query, top_k=top_k)
results = {"query": query,
"answers": answers,
**kwargs}
results = {"answers": answers}
return results, "output_1"
2 changes: 1 addition & 1 deletion haystack/knowledge_graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class BaseKnowledgeGraph(BaseComponent):

def run(self, sparql_query: str, index: Optional[str] = None, **kwargs): # type: ignore
result = self.query(sparql_query=sparql_query, index=index)
output = {"sparql_result": result, **kwargs}
output = {"sparql_result": result}
return output, "output_1"

def query(self, sparql_query: str, index: Optional[str] = None):
Expand Down
71 changes: 16 additions & 55 deletions haystack/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,15 +254,15 @@ def set_node(self, name: str, component):
"""
self.graph.nodes[name]["component"] = component

def run(self, query: Optional[str] = None, file: Optional[str] = None, params: Optional[dict] = None): # type: ignore
def run(self, query: Optional[str] = None, file_paths: Optional[List[str]] = None, params: Optional[dict] = None): # type: ignore
node_output = None
queue = {
self.root_node: {"root_node": self.root_node, "params": params}
} # ordered dict with "node_id" -> "input" mapping that acts as a FIFO queue
if query:
queue[self.root_node]["query"] = query
if file:
queue[self.root_node]["file"] = file
if file_paths:
queue[self.root_node]["file_paths"] = file_paths
i = 0 # the first item is popped off the queue unless it is a "join" node with unprocessed predecessors
while queue:
node_id = list(queue.keys())[i]
Expand Down Expand Up @@ -505,6 +505,10 @@ def draw(self, path: Path = Path("pipeline.png")):
"""
self.pipeline.draw(path)

def run(self, query: str, params: Optional[dict] = None):
output = self.pipeline.run(query=query, params=params)
return output


class ExtractiveQAPipeline(BaseStandardPipeline):
def __init__(self, reader: BaseReader, retriever: BaseRetriever):
Expand All @@ -518,15 +522,6 @@ def __init__(self, reader: BaseReader, retriever: BaseRetriever):
self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])
self.pipeline.add_node(component=reader, name="Reader", inputs=["Retriever"])

def run(self, query: str, filters: Optional[Dict] = None, top_k_retriever: int = 10, top_k_reader: int = 10):
params = {
"filters": filters,
"Retriever": {"top_k": top_k_retriever},
"Reader": {"top_k": top_k_reader},
}
output = self.pipeline.run(query=query, params=params)
return output


class DocumentSearchPipeline(BaseStandardPipeline):
def __init__(self, retriever: BaseRetriever):
Expand All @@ -538,8 +533,7 @@ def __init__(self, retriever: BaseRetriever):
self.pipeline = Pipeline()
self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])

def run(self, query: str, filters: Optional[Dict] = None, top_k_retriever: Optional[int] = None):
params = {"filters": filters, "Retriever": {"top_k": top_k_retriever}}
def run(self, query: str, params: Optional[dict] = None):
output = self.pipeline.run(query=query, params=params)
document_dicts = [doc.to_dict() for doc in output["documents"]]
output["documents"] = document_dicts
Expand All @@ -558,60 +552,28 @@ def __init__(self, generator: BaseGenerator, retriever: BaseRetriever):
self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])
self.pipeline.add_node(component=generator, name="Generator", inputs=["Retriever"])

def run(
self,
query: str,
filters: Optional[Dict] = None,
top_k_retriever: Optional[int] = None,
top_k_generator: Optional[int] = None
):
params = {
"filters": filters,
"Retriever": {"top_k": top_k_retriever},
"Generator": {"top_k": top_k_generator},
}
output = self.pipeline.run(query=query, params=params)
return output


class SearchSummarizationPipeline(BaseStandardPipeline):
def __init__(self, summarizer: BaseSummarizer, retriever: BaseRetriever):
def __init__(self, summarizer: BaseSummarizer, retriever: BaseRetriever, return_in_answer_format: bool = False):
"""
Initialize a Pipeline that retrieves documents for a query and then summarizes those documents.

:param summarizer: Summarizer instance
:param retriever: Retriever instance
:param return_in_answer_format: Whether the results should be returned as documents (False) or in the answer
format used in other QA pipelines (True). With the latter, you can use this
pipeline as a "drop-in replacement" for other QA pipelines.
"""
self.pipeline = Pipeline()
self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])
self.pipeline.add_node(component=summarizer, name="Summarizer", inputs=["Retriever"])
self.return_in_answer_format = return_in_answer_format

def run(
self,
query: str,
filters: Optional[Dict] = None,
top_k_retriever: Optional[int] = None,
generate_single_summary: Optional[bool] = None,
return_in_answer_format: bool = False,
):
"""
:param query: Your search query
:param filters:
:param top_k_retriever: Number of top docs the retriever should pass to the summarizer.
The higher this value, the slower your pipeline.
:param generate_single_summary: Whether to generate single summary from all retrieved docs (True) or one per doc (False).
:param return_in_answer_format: Whether the results should be returned as documents (False) or in the answer format used in other QA pipelines (True).
With the latter, you can use this pipeline as a "drop-in replacement" for other QA pipelines.
"""
params = {
"filters": filters,
"Retriever": {"top_k": top_k_retriever},
"Summarizer": {"generate_single_summary": generate_single_summary},
}
def run(self, query: str, params: Optional[dict] = None):
oryx1729 marked this conversation as resolved.
Show resolved Hide resolved
output = self.pipeline.run(query=query, params=params)

# Convert to answer format to allow "drop-in replacement" for other QA pipelines
if return_in_answer_format:
if self.return_in_answer_format:
results: Dict = {"query": query, "answers": []}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably also use the new Doc2Answer node instead here. But I think it's not really the scope of this PR

docs = deepcopy(output["documents"])
for doc in docs:
Expand Down Expand Up @@ -642,8 +604,7 @@ def __init__(self, retriever: BaseRetriever):
self.pipeline = Pipeline()
self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])

def run(self, query: str, filters: Optional[Dict] = None, top_k_retriever: Optional[int] = None):
params = {"filters": filters, "Retriever": {"top_k": top_k_retriever}}
def run(self, query: str, params: Optional[dict] = None):
output = self.pipeline.run(query=query, params=params)
documents = output["documents"]

Expand Down
3 changes: 1 addition & 2 deletions haystack/preprocessor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ def run( # type: ignore
split_length: Optional[int] = None,
split_overlap: Optional[int] = None,
split_respect_sentence_boundary: Optional[bool] = None,
**kwargs,
):
documents = self.process(
documents=documents,
Expand All @@ -59,5 +58,5 @@ def run( # type: ignore
split_overlap=split_overlap,
split_respect_sentence_boundary=split_respect_sentence_boundary,
)
result = {"documents": documents, **kwargs}
result = {"documents": documents}
return result, "output_1"
8 changes: 4 additions & 4 deletions haystack/question_generator/question_generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from transformers import AutoModelForSeq2SeqLM
from transformers import AutoTokenizer
from haystack import BaseComponent
from haystack import BaseComponent, Document
from haystack.preprocessor import PreProcessor
from typing import List


class QuestionGenerator(BaseComponent):
Expand Down Expand Up @@ -50,16 +51,15 @@ def __init__(self,
self.preprocessor = PreProcessor()
self.prompt = prompt

def run(self, **kwargs):
documents = kwargs["documents"]
def run(self, documents: List[Document]):
generated_questions = []
for d in documents:
questions = self.generate(d.text)
curr_dict = {"document_id": d.id,
"document_sample": d.text[:200],
"questions": questions}
generated_questions.append(curr_dict)
output = {"generated_questions": generated_questions, **kwargs}
output = {"generated_questions": generated_questions}
return output, "output_1"

def generate(self, text):
Expand Down
10 changes: 3 additions & 7 deletions haystack/ranker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,17 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] =
def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
pass

def run(self, query: str, documents: List[Document], top_k_ranker: Optional[int] = None, **kwargs): # type: ignore
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): # type: ignore
self.query_count += 1
if documents:
predict = self.timing(self.predict, "query_time")
results = predict(query=query, documents=documents, top_k=top_k_ranker)
results = predict(query=query, documents=documents, top_k=top_k)
else:
results = []

document_ids = [doc.id for doc in results]
logger.debug(f"Retrieved documents with IDs: {document_ids}")
output = {
"query": query,
"documents": results,
**kwargs
}
output = {"documents": results}

return output, "output_1"

Expand Down
4 changes: 2 additions & 2 deletions haystack/reader/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def run(self, query: str, documents: List[Document], top_k: Optional[int] = None

return results, "output_1"
tholor marked this conversation as resolved.
Show resolved Hide resolved

def run_batch(self, query_doc_list: List[Dict], top_k_reader: Optional[int] = None):
def run_batch(self, query_doc_list: List[Dict], top_k: Optional[int] = None):
""" A unoptimized implementation of running Reader queries in batch """
self.query_count += len(query_doc_list)
results = []
Expand All @@ -76,7 +76,7 @@ def run_batch(self, query_doc_list: List[Dict], top_k_reader: Optional[int] = No
q = qd["queries"]
docs = qd["docs"]
predict = self.timing(self.predict, "query_time")
result = predict(query=q, documents=docs, top_k=top_k_reader)
result = predict(query=q, documents=docs, top_k=top_k)
results.append(result)
else:
results = [{"answers": [], "query": ""}]
Expand Down
Loading