Skip to content

Commit

Permalink
Modularize save_inference_log()
Browse files Browse the repository at this point in the history
  • Loading branch information
uSaiPrashanth committed Jul 3, 2023
1 parent 48ce4e1 commit ff263c4
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def run_model_inferences(split_name: str, run_id: str, dataset: str, features: l
labels=tokenized_batch["input_ids"],
output_attentions=True,
)
save_inference_log(split_name, run_id, dataset, batch, labels, outputs, features)
inference_logs = accumilate_inference_log(batch[0], labels, outputs, features)
save_inference_log(split_name, run_id, dataset, inference_logs)


def gini(array):
Expand All @@ -214,24 +215,21 @@ def gini(array):
return ((np.sum((2 * index - n - 1) * array)) / (n * np.sum(array)))


def save_inference_log(
split_name: str, run_id: str, dataset: pd.DataFrame, batch: tuple, labels: torch.Tensor, outputs: CausalLMOutputWithPast, features: list
def accumilate_inference_log(
batch_sequence_ids: list, labels: torch.Tensor, outputs: CausalLMOutputWithPast, features: list
):
"""
Extract the desired data from the model response and save it to a CSV file.
Args:
split_name (str): The model+scheme used to determine the tokenizer and model
run_id (str): The timestamp for this run
dataset (str): The dataset to run inference on
batch (tuple): The input batch containing the sequence ids and sequences
batch_sequence_ids (list): The list containing the sequence ids
labels (torch.Tensor): The labels for the batch. Used to calculate perplexity
outputs (CausalLMOutputWithPast): The response from the Pythia model
features (list): The list of features to calculate. A subset of [loss, ppl, attn]
"""
logits = outputs.logits.detach()
perplexities = [calculate_perplexity(logits[i], labels[i]) for i in range(len(logits))] if "ppl" in features else None
inference_logs = []
batch_sequence_ids = batch[0]
e=1e-8

for index, id_tensor in enumerate(batch_sequence_ids):
Expand Down Expand Up @@ -273,13 +271,21 @@ def save_inference_log(

inference_logs.append(inference_log)

return inference_logs

def save_inference_log(split_name: str, run_id: str, dataset: pd.DataFrame, inference_logs: list):
"""Saves the accumilated inference log in a pandas dataframe
Args:
split_name (str): The model+scheme used to determine the tokenizer and model
run_id (str): The timestamp for this run
dataset (str): The dataset to run inference on
inference_logs (list): Accumilated inference logs
"""
file_name = split_name.replace(".", "_")
inference_logs_df = pd.DataFrame(inference_logs)
inference_logs_df.to_csv(f"datasets/{run_id}/{dataset}_{file_name}.csv", index=False, mode="a")

return inference_logs


def parse_cli_args():
parser = ArgumentParser()
models_arg_help = "The Pythia model to get the perplexities for. Valid options are: 70m, 160m, 410m, 1b, 1.4b, 2.8b, 6.9b, 12b"
Expand Down

0 comments on commit ff263c4

Please sign in to comment.