Skip to content

Commit

Permalink
Add RCIReader for TableQA (deepset-ai#1909)
Browse files Browse the repository at this point in the history
* Add RCIReader

* Add latest docstring and tutorial changes

* Add Doc Strings

* Add latest docstring and tutorial changes

* Add Tests

* Add Doc Strings

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
bogdankostic and github-actions[bot] committed Jan 3, 2022
1 parent 6e8e3c6 commit 45df18c
Show file tree
Hide file tree
Showing 6 changed files with 302 additions and 33 deletions.
72 changes: 72 additions & 0 deletions docs/_src/api/api/reader.md
Original file line number Diff line number Diff line change
Expand Up @@ -609,3 +609,75 @@ WARNING: The answer scores are not reliable, as they are always extremely high,

Dict containing query and answers

<a name="table.RCIReader"></a>
## RCIReader

```python
class RCIReader(BaseReader)
```

Table Reader model based on Glass et al. (2021)'s Row-Column-Intersection model.
See the original paper for more details:
Glass, Michael, et al. (2021): "Capturing Row and Column Semantics in Transformer Based Question Answering over Tables"
(https://aclanthology.org/2021.naacl-main.96/)

Each row and each column is given a score with regard to the query by two separate models. The score of each cell
is then calculated as the sum of the corresponding row score and column score. Accordingly, the predicted answer is
the cell with the highest score.

<a name="table.RCIReader.__init__"></a>
#### \_\_init\_\_

```python
| __init__(row_model_name_or_path: str = "michaelrglass/albert-base-rci-wikisql-row", column_model_name_or_path: str = "michaelrglass/albert-base-rci-wikisql-col", row_model_version: Optional[str] = None, column_model_version: Optional[str] = None, row_tokenizer: Optional[str] = None, column_tokenizer: Optional[str] = None, use_gpu: bool = True, top_k: int = 10, max_seq_len: int = 256)
```

Load an RCI model from Transformers.
Available models include:

- ``'michaelrglass/albert-base-rci-wikisql-row'`` + ``'michaelrglass/albert-base-rci-wikisql-col'``
- ``'michaelrglass/albert-base-rci-wtq-row'`` + ``'michaelrglass/albert-base-rci-wtq-col'``



**Arguments**:

- `row_model_name_or_path`: Directory of a saved row scoring model or the name of a public model
- `column_model_name_or_path`: Directory of a saved column scoring model or the name of a public model
- `row_model_version`: The version of row model to use from the HuggingFace model hub.
Can be tag name, branch name, or commit hash.
- `column_model_version`: The version of column model to use from the HuggingFace model hub.
Can be tag name, branch name, or commit hash.
- `row_tokenizer`: Name of the tokenizer for the row model (usually the same as model)
- `column_tokenizer`: Name of the tokenizer for the column model (usually the same as model)
- `use_gpu`: Whether to use GPU or CPU. Falls back on CPU if no GPU is available.
- `top_k`: The maximum number of answers to return
- `max_seq_len`: Max sequence length of one input table for the model. If the number of tokens of
query + table exceed max_seq_len, the table will be truncated by removing rows until the
input size fits the model.

<a name="table.RCIReader.predict"></a>
#### predict

```python
| predict(query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict
```

Use loaded RCI models to find answers for a query in the supplied list of Documents
of content_type ``'table'``.

Returns dictionary containing query and list of Answer objects sorted by (desc.) score.
The existing RCI models on the HF model hub don"t allow aggregation, therefore, the answer will always be
composed of a single cell.

**Arguments**:

- `query`: Query string
- `documents`: List of Document in which to search for the answer. Documents should be
of content_type ``'table'``.
- `top_k`: The maximum number of answers to return

**Returns**:

Dict containing query and answers

2 changes: 1 addition & 1 deletion haystack/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from haystack.nodes.query_classifier import SklearnQueryClassifier, TransformersQueryClassifier
from haystack.nodes.question_generator import QuestionGenerator
from haystack.nodes.ranker import BaseRanker, SentenceTransformersRanker
from haystack.nodes.reader import BaseReader, FARMReader, TransformersReader, TableReader
from haystack.nodes.reader import BaseReader, FARMReader, TransformersReader, TableReader, RCIReader
from haystack.nodes.retriever import (
BaseRetriever,
DensePassageRetriever,
Expand Down
2 changes: 1 addition & 1 deletion haystack/nodes/reader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from haystack.nodes.reader.base import BaseReader
from haystack.nodes.reader.farm import FARMReader
from haystack.nodes.reader.transformers import TransformersReader
from haystack.nodes.reader.table import TableReader
from haystack.nodes.reader.table import TableReader, RCIReader
209 changes: 208 additions & 1 deletion haystack/nodes/reader/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import numpy as np
import pandas as pd
from quantulum3 import parser
from transformers import TapasTokenizer, TapasForQuestionAnswering, BatchEncoding
from transformers import TapasTokenizer, TapasForQuestionAnswering, AutoTokenizer, AutoModelForSequenceClassification, \
BatchEncoding, AutoConfig

from haystack.schema import Document, Answer, Span
from haystack.nodes.reader.base import BaseReader
Expand Down Expand Up @@ -252,3 +253,209 @@ def _calculate_answer_offsets(answer_coordinates: List[Tuple[int, int]], table:

def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
raise NotImplementedError("Batch prediction not yet available in TableReader.")


class RCIReader(BaseReader):
"""
Table Reader model based on Glass et al. (2021)'s Row-Column-Intersection model.
See the original paper for more details:
Glass, Michael, et al. (2021): "Capturing Row and Column Semantics in Transformer Based Question Answering over Tables"
(https://aclanthology.org/2021.naacl-main.96/)
Each row and each column is given a score with regard to the query by two separate models. The score of each cell
is then calculated as the sum of the corresponding row score and column score. Accordingly, the predicted answer is
the cell with the highest score.
Pros and Cons of RCIReader compared to TableReader:
+ Provides meaningful confidence scores
+ Allows larger tables as input
- Does not support aggregation over table cells
- Slower
"""

def __init__(self,
row_model_name_or_path: str = "michaelrglass/albert-base-rci-wikisql-row",
column_model_name_or_path: str = "michaelrglass/albert-base-rci-wikisql-col",
row_model_version: Optional[str] = None,
column_model_version: Optional[str] = None,
row_tokenizer: Optional[str] = None,
column_tokenizer: Optional[str] = None,
use_gpu: bool = True,
top_k: int = 10,
max_seq_len: int = 256,
):
"""
Load an RCI model from Transformers.
Available models include:
- ``'michaelrglass/albert-base-rci-wikisql-row'`` + ``'michaelrglass/albert-base-rci-wikisql-col'``
- ``'michaelrglass/albert-base-rci-wtq-row'`` + ``'michaelrglass/albert-base-rci-wtq-col'``
:param row_model_name_or_path: Directory of a saved row scoring model or the name of a public model
:param column_model_name_or_path: Directory of a saved column scoring model or the name of a public model
:param row_model_version: The version of row model to use from the HuggingFace model hub.
Can be tag name, branch name, or commit hash.
:param column_model_version: The version of column model to use from the HuggingFace model hub.
Can be tag name, branch name, or commit hash.
:param row_tokenizer: Name of the tokenizer for the row model (usually the same as model)
:param column_tokenizer: Name of the tokenizer for the column model (usually the same as model)
:param use_gpu: Whether to use GPU or CPU. Falls back on CPU if no GPU is available.
:param top_k: The maximum number of answers to return
:param max_seq_len: Max sequence length of one input table for the model. If the number of tokens of
query + table exceed max_seq_len, the table will be truncated by removing rows until the
input size fits the model.
"""
# Save init parameters to enable export of component config as YAML
self.set_config(row_model_name_or_path=row_model_name_or_path,
column_model_name_or_path=column_model_name_or_path, row_model_version=row_model_version,
column_model_version=column_model_version, row_tokenizer=row_tokenizer,
column_tokenizer=column_tokenizer, use_gpu=use_gpu, top_k=top_k, max_seq_len=max_seq_len)

self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)
self.row_model = AutoModelForSequenceClassification.from_pretrained(row_model_name_or_path,
revision=row_model_version)
self.column_model = AutoModelForSequenceClassification.from_pretrained(row_model_name_or_path,
revision=column_model_version)
self.row_model.to(str(self.devices[0]))
self.column_model.to(str(self.devices[0]))

if row_tokenizer is None:
try:
self.row_tokenizer = AutoTokenizer.from_pretrained(row_model_name_or_path)
# The existing RCI models on the model hub don't come with tokenizer vocab files.
except TypeError:
self.row_tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
else:
self.row_tokenizer = AutoTokenizer.from_pretrained(row_tokenizer)

if column_tokenizer is None:
try:
self.column_tokenizer = AutoTokenizer.from_pretrained(column_model_name_or_path)
# The existing RCI models on the model hub don't come with tokenizer vocab files.
except TypeError:
self.column_tokenizer = AutoTokenizer.from_pretrained("albert-base-v2")
else:
self.column_tokenizer = AutoTokenizer.from_pretrained(column_tokenizer)

self.top_k = top_k
self.max_seq_len = max_seq_len
self.return_no_answers = False

def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict:
"""
Use loaded RCI models to find answers for a query in the supplied list of Documents
of content_type ``'table'``.
Returns dictionary containing query and list of Answer objects sorted by (desc.) score.
The existing RCI models on the HF model hub don"t allow aggregation, therefore, the answer will always be
composed of a single cell.
:param query: Query string
:param documents: List of Document in which to search for the answer. Documents should be
of content_type ``'table'``.
:param top_k: The maximum number of answers to return
:return: Dict containing query and answers
"""
if top_k is None:
top_k = self.top_k

answers = []
for document in documents:
if document.content_type != "table":
logger.warning(f"Skipping document with id {document.id} in RCIReader, as it is not of type table.")
continue

table: pd.DataFrame = document.content
table = table.astype(str)
# Create row and column representations
row_reps, column_reps = self._create_row_column_representations(table)

# Get row logits
row_inputs = self.row_tokenizer.batch_encode_plus(
batch_text_or_text_pairs=[(query, row_rep) for row_rep in row_reps],
max_length=self.max_seq_len,
return_tensors="pt",
add_special_tokens=True,
truncation=True,
padding=True
)
row_inputs.to(self.devices[0])
row_logits = self.row_model(**row_inputs)[0].detach().cpu().numpy()[:, 1]

# Get column logits
column_inputs = self.column_tokenizer.batch_encode_plus(
batch_text_or_text_pairs=[(query, column_rep) for column_rep in column_reps],
max_length=self.max_seq_len,
return_tensors="pt",
add_special_tokens=True,
truncation=True,
padding=True
)
column_inputs.to(self.devices[0])
column_logits = self.column_model(**column_inputs)[0].detach().cpu().numpy()[:, 1]

# Calculate cell scores
current_answers: List[Answer] = []
cell_scores_table: List[List[float]] = []
for row_idx, row_score in enumerate(row_logits):
cell_scores_table.append([])
for col_idx, col_score in enumerate(column_logits):
current_cell_score = float(row_score + col_score)
cell_scores_table[-1].append(current_cell_score)

answer_str = table.iloc[row_idx, col_idx]
answer_offsets = self._calculate_answer_offsets(row_idx, col_idx, table)
current_answers.append(
Answer(
answer=answer_str,
type="extractive",
score=current_cell_score,
context=table,
offsets_in_document=[answer_offsets],
offsets_in_context=[answer_offsets],
document_id=document.id,
)
)

# Add cell scores to Answers' meta to be able to use as heatmap
for answer in current_answers:
answer.meta = {"table_scores": cell_scores_table}
answers.extend(current_answers)

# Sort answers by score and select top-k answers
answers = sorted(answers, reverse=True)
answers = answers[:top_k]

results = {"query": query,
"answers": answers}

return results

@staticmethod
def _create_row_column_representations(table: pd.DataFrame) -> Tuple[List[str], List[str]]:
row_reps = []
column_reps = []
columns = table.columns

for idx, row in table.iterrows():
current_row_rep = " * ".join([header + " : " + cell for header, cell in zip(columns, row)])
row_reps.append(current_row_rep)

for col_name in columns:
current_column_rep = f"{col_name} * "
current_column_rep += " * ".join(table[col_name])
column_reps.append(current_column_rep)

return row_reps, column_reps

@staticmethod
def _calculate_answer_offsets(row_idx, column_index, table) -> Span:
n_rows, n_columns = table.shape
answer_cell_offset = (row_idx * n_columns) + column_index

return Span(start=answer_cell_offset, end=answer_cell_offset + 1)

def predict_batch(self, query_doc_list: List[dict], top_k: Optional[int] = None, batch_size: Optional[int] = None):
raise NotImplementedError("Batch prediction not yet available in RCIReader.")
12 changes: 8 additions & 4 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from haystack.document_stores.sql import SQLDocumentStore
from haystack.nodes.reader.farm import FARMReader
from haystack.nodes.reader.transformers import TransformersReader
from haystack.nodes.reader.table import TableReader
from haystack.nodes.reader.table import TableReader, RCIReader
from haystack.nodes.summarizer.transformers import TransformersSummarizer
from haystack.nodes.translator import TransformersTranslator
from haystack.nodes.question_generator import QuestionGenerator
Expand Down Expand Up @@ -338,9 +338,13 @@ def reader(request):
)


@pytest.fixture(scope="function")
def table_reader():
return TableReader(model_name_or_path="google/tapas-base-finetuned-wtq")
@pytest.fixture(params=["tapas", "rci"], scope="function")
def table_reader(request):
if request.param == "tapas":
return TableReader(model_name_or_path="google/tapas-base-finetuned-wtq")
elif request.param == "rci":
return RCIReader(row_model_name_or_path="michaelrglass/albert-base-rci-wikisql-row",
column_model_name_or_path="michaelrglass/albert-base-rci-wikisql-col")


@pytest.fixture(scope="function")
Expand Down
Loading

0 comments on commit 45df18c

Please sign in to comment.