Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scalable Qa aggregation #268

Merged
merged 6 commits into from
Feb 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions farm/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,24 +310,25 @@ def _get_predictions(self, dataset, tensor_names, baskets, rest_api_schema=False
data_loader = NamedDataLoader(
dataset=dataset, sampler=SequentialSampler(dataset), batch_size=self.batch_size, tensor_names=tensor_names
)
logits_all = []
unaggregated_preds_all = []
preds_all = []
aggregate_preds = hasattr(self.model.prediction_heads[0], "aggregate_preds")
for i, batch in enumerate(tqdm(data_loader, desc=f"Inferencing")):
batch = {key: batch[key].to(self.device) for key in batch}

if not aggregate_preds:
batch_samples = samples[i * self.batch_size : (i + 1) * self.batch_size]

# get logits
with torch.no_grad():
logits = self.model.forward(**batch)[0]

# either just stack the logits (and convert later to readable predictions)
if aggregate_preds:
logits_all += [l for l in logits]

# or convert directly
# Aggregation works on preds, not logits. We want as much processing happening in one batch + on GPU
# So we transform logits to preds here as well
logits = self.model.forward(**batch)
preds = self.model.logits_to_preds(logits, **batch)[0]
unaggregated_preds_all += preds
else:
logits = self.model.forward(**batch)[0]
preds = self.model.formatted_preds(
logits=[logits],
samples=batch_samples,
Expand All @@ -343,9 +344,8 @@ def _get_predictions(self, dataset, tensor_names, baskets, rest_api_schema=False
# and then aggregating them here.
if aggregate_preds:
# can assume that we have only complete docs i.e. all the samples of one doc are in the current chunk
# TODO is there a better way than having to wrap logits all in list?
# TODO can QA formatted preds deal with samples?
preds_all = self.model.formatted_preds(logits=[logits_all],
preds_all = self.model.formatted_preds(logits=[None], # For QA we collected preds per batch and do not want to pass logits
preds_p=unaggregated_preds_all,
baskets=baskets,
rest_api_schema=rest_api_schema)[0]
return preds_all
Expand Down
3 changes: 1 addition & 2 deletions farm/modeling/adaptive_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,13 @@ def formatted_preds(self, logits, **kwargs):

:param logits: model logits
:type logits: torch.tensor
:param label_maps: dictionary for mapping ids to label strings
:type label_maps: dict[int:str]
:param kwargs: placeholder for passing generic parameters
:type kwargs: object
:return: predictions in the right format
"""
all_preds = []
# collect preds from all heads
# TODO add switch between single vs multiple prediction heads
for head, logits_for_head in zip(
self.prediction_heads, logits
):
Expand Down
71 changes: 45 additions & 26 deletions farm/modeling/prediction_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,25 @@ def logits_to_preds(self, logits, padding_mask, start_of_word, seq_2_start_t, ma
end_matrix = end_logits.unsqueeze(1).expand(-1, max_seq_len, -1)
start_end_matrix = start_matrix + end_matrix

# disqualify answers where start > end
# (set the lower triangular matrix to low value, incl diagonal, excl item 0,0)
indices = torch.tril_indices(max_seq_len, max_seq_len)
start_end_matrix[:, indices[0][1:], indices[1][1:]] = -999

# disqualify answers where start=0, but end != 0
start_end_matrix[:, 0, 1:] = -999


# TODO continue vectorization of valid_answer_idxs
# # disqualify where answers < seq_2_start_t and idx != 0
# # disqualify where answer falls into padding
# # seq_2_start_t can be different when 2 different questions are handled within one batch
# # n_non_padding can be different on sample level, too
# for i in range(batch_size):
# start_end_matrix[i, 1:seq_2_start_t[i], 1:seq_2_start_t[i]] = -888
# start_end_matrix[i, n_non_padding[i]-1:, n_non_padding[i]-1:] = -777


# Sort the candidate answers by their score. Sorting happens on the flattened matrix.
# flat_sorted_indices.shape: (batch_size, max_seq_len^2, 1)
flat_scores = start_end_matrix.view(batch_size, -1)
Expand All @@ -1042,9 +1061,11 @@ def logits_to_preds(self, logits, padding_mask, start_of_word, seq_2_start_t, ma

# Get the n_best candidate answers for each sample that are valid (via some heuristic checks)
for sample_idx in range(batch_size):
sample_top_n = self.get_top_candidates(sorted_candidates[sample_idx], start_end_matrix[sample_idx],
n_non_padding[sample_idx], max_answer_length,
seq_2_start_t[sample_idx])
sample_top_n = self.get_top_candidates(sorted_candidates[sample_idx],
start_end_matrix[sample_idx],
n_non_padding[sample_idx].item(),
max_answer_length,
seq_2_start_t[sample_idx].item())
all_top_n.append(sample_top_n)

return all_top_n
Expand All @@ -1071,9 +1092,8 @@ def get_top_candidates(self, sorted_candidates, start_end_matrix,
if start_idx == 0 and end_idx == 0:
continue
# Check that the candidate's indices are valid and save them if they are
score = start_end_matrix[start_idx, end_idx].item()
if self.valid_answer_idxs(start_idx, end_idx, n_non_padding, max_answer_length, seq_2_start_t):
# score = start_end_matrix[start_idx, end_idx].item()
score = start_end_matrix[start_idx, end_idx].item()
top_candidates.append([start_idx, end_idx, score])

no_answer_score = start_end_matrix[0, 0].item()
Expand All @@ -1087,7 +1107,7 @@ def valid_answer_idxs(start_idx, end_idx, n_non_padding, max_answer_length, seq_
should be on sample/passage level (special tokens + question_tokens + passag_tokens)
and not document level"""

# This function can seriously slow down inferencing and eval
# This function can seriously slow down inferencing and eval. In the future this function will be completely vectorized
# Continue if start or end label points to a padding token
if start_idx < seq_2_start_t and start_idx != 0:
return False
Expand All @@ -1099,42 +1119,41 @@ def valid_answer_idxs(start_idx, end_idx, n_non_padding, max_answer_length, seq_
return False
if end_idx >= n_non_padding - 1:
return False
# Check if start comes after end
if end_idx < start_idx:
return False
# If one of the two indices is 0, the other must also be 0
if start_idx == 0 and end_idx != 0:
return False
if start_idx != 0 and end_idx == 0:
return False

# # Check if start comes after end
# # Handled on matrix level by: start_end_matrix[:, indices[0][1:], indices[1][1:]] = -999
# if end_idx < start_idx:
# return False

# # If one of the two indices is 0, the other must also be 0
# # Handled on matrix level by setting: start_end_matrix[:, 0, 1:] = -999
# if start_idx == 0 and end_idx != 0:
# return False
# if start_idx != 0 and end_idx == 0:
# return False

length = end_idx - start_idx + 1
if length > max_answer_length:
return False
return True

def formatted_preds(self, logits, baskets, rest_api_schema=False):
""" Takes a list of logits, each corresponding to one sample, and converts them into document level predictions.
Leverages information in the SampleBaskets. Assumes that we are being passed logits from ALL samples in the one
SampleBasket i.e. all passages of a document. """
def formatted_preds(self, logits, preds_p, baskets, rest_api_schema=False):
""" Takes a list of predictions, each corresponding to one sample, and converts them into document level predictions.
Leverages information in the SampleBaskets. Assumes that we are being passed predictions from ALL samples
in the one SampleBasket i.e. all passages of a document.
Logits should be None, because we have already converted the logits to predictions before calling formatted_preds
"""

# Unpack some useful variables
# passage_start_t is the token index of the passage relative to the document (usually a multiple of doc_stride)
# seq_2_start_t is the token index of the first token in passage relative to the input sequence (i.e. number of
# special tokens and question tokens that come before the passage tokens)
assert logits is None, "Logits are not None, something is passed wrongly into formatted_preds() in infer.py"
samples = [s for b in baskets for s in b.samples]
ids = [s.id.split("-") for s in samples]
passage_start_t = [s.features[0]["passage_start_t"] for s in samples]
seq_2_start_t = [s.features[0]["seq_2_start_t"] for s in samples]

# Prepare tensors
logits = torch.stack(logits)
padding_mask = torch.tensor([s.features[0]["padding_mask"] for s in samples], dtype=torch.long)
start_of_word = torch.tensor([s.features[0]["start_of_word"] for s in samples], dtype=torch.long)

# Return n + 1 predictions per passage / sample
preds_p = self.logits_to_preds(logits, padding_mask, start_of_word, seq_2_start_t)

# Aggregate passage level predictions to create document level predictions.
# This method assumes that all passages of each document are contained in preds_p
# i.e. that there are no incomplete documents. The output of this step
Expand Down