Skip to content

Commit

Permalink
Get perplexity for prompt, generation, and whole sequence
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyle1668 committed May 25, 2023
1 parent 47a4265 commit 4128486
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,24 @@ def calculate_perplexity(logits: torch.Tensor, labels: torch.Tensor) -> torch.fl
# Store the probability for this token.
token_probs.append(label_prob.detach())

# Caluclate the log-likelyhood of the sequence by summing the probabilities
# of each token and then taking the log.
log_likelihood = torch.log(torch.stack(token_probs)).sum()
mid_index = len(token_probs) // 2
prompt_ppl = None
log_likelihood = torch.log(torch.stack(token_probs[:mid_index])).sum()
cross_entropy = -log_likelihood / len(token_probs)
prompt_ppl = torch.exp(cross_entropy).item()

generation_ppl = None
log_likelihood = torch.log(torch.stack(token_probs[mid_index:])).sum()
cross_entropy = -log_likelihood / len(token_probs)
generation_ppl = torch.exp(cross_entropy).item()

# Caluclate the cross entropy by dividing the negative log-likelihood by the number of tokens.
sequence_ppl = None
log_likelihood = torch.log(torch.stack(token_probs)).sum()
cross_entropy = -log_likelihood / len(token_probs)
sequence_ppl = torch.exp(cross_entropy).item()

# Calculate the perplexity by taking the exponential of the cross entropy.
perplexity = torch.exp(cross_entropy).item()
# assert perplexity != float("inf"), "Perplexity is infinite. This is probably due to a token that has a probability of 0."
return perplexity
return prompt_ppl, generation_ppl, sequence_ppl


def get_batch_size(model_name: str) -> int:
Expand Down Expand Up @@ -211,14 +218,16 @@ def save_inference_log(
logits = outputs.logits.detach()
perplexities = [calculate_perplexity(logits[i], labels[i]) for i in range(len(logits))] if "ppl" in features else None
inference_logs = []

batch_sequence_ids = batch[0]

for index, id_tensor in enumerate(batch_sequence_ids):
inference_log = {"index": id_tensor.detach().item()}
if "loss" in features:
inference_log["loss"] = outputs.loss.detach().item() / len(labels[index])
if "ppl" in features:
inference_log["perplexity"] = perplexities[index]
inference_log["prompt_perplexity"] = perplexities[index][0]
inference_log["generation_perplexity"] = perplexities[index][1]
inference_log["sequence_perplexity"] = perplexities[index][2]
if "attn" in features:
for layer_index, attention_layer in enumerate(outputs.attentions):
sequence_attention = attention_layer[index].detach().tolist()
Expand Down Expand Up @@ -295,6 +304,7 @@ def main():
print(f"Models: {args.models}")
print(f"Schemes: {args.schemes}")
print(f"Datasets: {args.datasets}")
print(f"Features: {args.features}")
if args.sample_size is not None:
print(f"Sample size: {args.sample_size}")
print("---------------------------------------------------------------------------")
Expand Down

0 comments on commit 4128486

Please sign in to comment.