Skip to content

Commit

Permalink
Add error handling for when there is an exception in the pple code
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyle1668 committed Jul 15, 2023
1 parent acbb67f commit 85b45e8
Showing 1 changed file with 52 additions and 49 deletions.
101 changes: 52 additions & 49 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,46 +69,49 @@ def calculate_perplexity(logits: torch.Tensor, labels: torch.Tensor) -> torch.fl
"""
# Store the probabilities for each token. These will be summed later, but having the
# individual probabilities is helpful for debugging.
token_probs = []

# Don't include the final token logits. There are no labels for
# these since the sequence has ended.
num_special_tokens = len(labels[labels == 0])
num_normal_tokens = len(labels) - num_special_tokens

for token_index in range(num_normal_tokens - 1):
# Map the logits to probabilities.
predicted_probs = torch.softmax(logits[token_index], dim=0, dtype=torch.float16)
# Get the probability of the correct label.
label_prob = predicted_probs[labels[token_index + 1]]

# Check if the label probability is 0. This is likely due a rounding error. Recalculate
# the probability using double precision.
if label_prob == 0:
predicted_probs = torch.softmax(logits[token_index], dim=0, dtype=torch.float64)
try:
token_probs = []

# Don't include the final token logits. There are no labels for
# these since the sequence has ended.
num_special_tokens = len(labels[labels == 0])
num_normal_tokens = len(labels) - num_special_tokens

for token_index in range(num_normal_tokens - 1):
# Map the logits to probabilities.
predicted_probs = torch.softmax(logits[token_index], dim=0, dtype=torch.float16)
# Get the probability of the correct label.
label_prob = predicted_probs[labels[token_index + 1]]

# Store the probability for this token.
token_probs.append(label_prob.detach())
# Check if the label probability is 0. This is likely due a rounding error. Recalculate
# the probability using double precision.
if label_prob == 0:
predicted_probs = torch.softmax(logits[token_index], dim=0, dtype=torch.float64)
label_prob = predicted_probs[labels[token_index + 1]]

mid_index = len(token_probs) // 2
prompt_ppl = None
log_likelihood = torch.log(torch.stack(token_probs[:mid_index])).sum()
cross_entropy = -log_likelihood / len(token_probs)
prompt_ppl = torch.exp(cross_entropy).item()
# Store the probability for this token.
token_probs.append(label_prob.detach())

generation_ppl = None
log_likelihood = torch.log(torch.stack(token_probs[mid_index:])).sum()
cross_entropy = -log_likelihood / len(token_probs)
generation_ppl = torch.exp(cross_entropy).item()
mid_index = len(token_probs) // 2
prompt_ppl = None
log_likelihood = torch.log(torch.stack(token_probs[:mid_index])).sum()
cross_entropy = -log_likelihood / len(token_probs)
prompt_ppl = torch.exp(cross_entropy).item()

sequence_ppl = None
log_likelihood = torch.log(torch.stack(token_probs)).sum()
cross_entropy = -log_likelihood / len(token_probs)
sequence_ppl = torch.exp(cross_entropy).item()
generation_ppl = None
log_likelihood = torch.log(torch.stack(token_probs[mid_index:])).sum()
cross_entropy = -log_likelihood / len(token_probs)
generation_ppl = torch.exp(cross_entropy).item()

# assert perplexity != float("inf"), "Perplexity is infinite. This is probably due to a token that has a probability of 0."
return prompt_ppl, generation_ppl, sequence_ppl
sequence_ppl = None
log_likelihood = torch.log(torch.stack(token_probs)).sum()
cross_entropy = -log_likelihood / len(token_probs)
sequence_ppl = torch.exp(cross_entropy).item()

return prompt_ppl, generation_ppl, sequence_ppl
except Exception as e:
print(f"Failed to calulcate perplexity: {e}")
return -1, -1, -1


def get_batch_size(model_name: str) -> int:
Expand Down Expand Up @@ -178,7 +181,7 @@ def run_model_inferences(split_name: str, run_id: str, dataset: str, features: l
pile_dataset = PileDataset(pile_sequences, tokenizer)
batch_size = get_batch_size(split_name)
data_loader = DataLoader(pile_dataset, batch_size=batch_size)

with torch.no_grad():
desc = f"Collecting {dataset} inference responses for {split_name}"
for batch in tqdm(data_loader, desc=desc):
Expand Down Expand Up @@ -208,11 +211,11 @@ def gini(array):
# from: http:https://www.statsdirect.com/help/default.htm#nonparametric_methods/gini.htm
array = array.flatten()
if np.amin(array) < 0:
array -= np.amin(array)
array = np.sort(array)
index = np.arange(1,array.shape[0]+1)
n = array.shape[0]
return ((np.sum((2 * index - n - 1) * array)) / (n * np.sum(array)))
array -= np.amin(array)
array = np.sort(array)
index = np.arange(1,array.shape[0]+1)
n = array.shape[0]
return ((np.sum((2 * index - n - 1) * array)) / (n * np.sum(array)))


def accumilate_inference_log(
Expand All @@ -231,7 +234,7 @@ def accumilate_inference_log(
perplexities = [calculate_perplexity(logits[i], labels[i]) for i in range(len(logits))] if "ppl" in features else None
inference_logs = []
e=1e-8

for index, id_tensor in enumerate(batch_sequence_ids):
total_entropy = []
total_gini = []
Expand All @@ -243,21 +246,21 @@ def accumilate_inference_log(
inference_log["generation_perplexity"] = perplexities[index][1]
inference_log["sequence_perplexity"] = perplexities[index][2]
if "attn" in features:
for layer_index, attention_layer in enumerate(outputs.attentions):
sequence_attention = attention_layer[index].detach()
for layer_index, attention_layer in enumerate(outputs.attentions):
sequence_attention = attention_layer[index].detach()
head_e = []
gini_head = []

for head_index, head in enumerate(sequence_attention):
attention_head = head.detach().cpu().numpy()
for head_index, head in enumerate(sequence_attention):
attention_head = head.detach().cpu().numpy()
attention_head += e #adding 'e' to attention weights that are 0 to avoid log zero error while calculating entropy. Entropy = - ∑(w * log(w))
gini_coefficient = gini(attention_head)
gini_head.append(gini_coefficient)
head_entropy = -np.sum(attention_head * np.log(attention_head))
head_entropy = -np.sum(attention_head * np.log(attention_head))
head_e.append(head_entropy)
inference_log[f"gini_head{head_index+1}_layer{layer_index+1}"] = gini_coefficient
inference_log[f"entropy_head{head_index+1}_layer{layer_index+1}"] = head_entropy

avg_head = np.mean(head_e)
avg_head_gini = np.mean(gini_head)
total_entropy.append(avg_head)
Expand All @@ -266,8 +269,8 @@ def accumilate_inference_log(
average_entropy = np.mean(total_entropy)
average_gini = np.mean(total_gini)
inference_log[f"avg entropy"] = average_entropy
inference_log[f"avg gini"] = average_gini
inference_log[f"avg gini"] = average_gini

inference_logs.append(inference_log)

return inference_logs
Expand Down

0 comments on commit 85b45e8

Please sign in to comment.