Skip to content

Commit

Permalink
Added features arg
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyle1668 committed May 24, 2023
1 parent 213d9cc commit 1b30b82
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 34 deletions.
63 changes: 30 additions & 33 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,32 +147,14 @@ def get_dataset(dataset_name: str, split_name: str, sample: int = None) -> pd.Da
if dataset_name.split("-")[0] == "pile":
scheme = split_name.split(".")[0]
pile_path = f"EleutherAI/pile-{scheme}-pythia-random-sampled"
lower_index = 0 if dataset_name == "pile-1" else 50
upper_index = 50 if dataset_name == "pile-1" else 100

print(f"Loading {pile_path} {lower_index}-{upper_index}%")
pile_tokens = load_dataset(
pile_path,
split=ReadInstruction(
"train",
from_=lower_index,
to=upper_index,
unit="%",
rounding="pct1_dropremainder",
),
).to_pandas()[["index", "tokens"]]

if dataset is None:
dataset = pile_tokens
else:
dataset = pd.concat([dataset, pile_tokens])
dataset = load_dataset(pile_path, split="train").to_pandas()[["index", "tokens"]]
else:
dataset = load_dataset("EleutherAI/pythia-memorized-evals")[split_name].to_pandas()

return dataset if sample is None else dataset.sample(sample).reset_index(drop=True)


def run_model_inferences(split_name: str, run_id: str, dataset: str, sample_size: int = None):
def run_model_inferences(split_name: str, run_id: str, dataset: str, features: list, sample_size: int = None):
"""
Run inference for the given model and dataset. Save the results to a CSV file.
Expand Down Expand Up @@ -209,10 +191,12 @@ def run_model_inferences(split_name: str, run_id: str, dataset: str, sample_size
labels=tokenized_batch["input_ids"],
output_attentions=True,
)
save_inference_log(split_name, run_id, dataset, batch, labels, outputs)
save_inference_log(split_name, run_id, dataset, batch, labels, outputs, features)


def save_inference_log(split_name: str, run_id: str, dataset: pd.DataFrame, batch: tuple, labels: torch.Tensor, outputs: CausalLMOutputWithPast):
def save_inference_log(
split_name: str, run_id: str, dataset: pd.DataFrame, batch: tuple, labels: torch.Tensor, outputs: CausalLMOutputWithPast, features: list
):
"""
Extract the desired data from the model response and save it to a CSV file.
Expand All @@ -225,19 +209,21 @@ def save_inference_log(split_name: str, run_id: str, dataset: pd.DataFrame, batc
outputs (CausalLMOutputWithPast): The response from the Pythia model
"""
logits = outputs.logits.detach()
perplexities = [calculate_perplexity(logits[i], labels[i]) for i in range(len(logits))]
perplexities = [calculate_perplexity(logits[i], labels[i]) for i in range(len(logits))] if "ppl" in features else None
inference_logs = []

batch_sequence_ids = batch[0]
for index, id_tensor in enumerate(batch_sequence_ids):
inference_log = {
"index": id_tensor.detach().item(),
"perplexity": perplexities[index],
"mean_loss": outputs.loss.detach().item() / len(labels[index]),
}
for layer_index, attention_layer in enumerate(outputs.attentions):
sequence_attention = attention_layer[index].detach().tolist()
inference_log[f"attn_{layer_index}"] = sequence_attention
inference_log = {"index": id_tensor.detach().item()}
if "loss" in features:
inference_log["loss"] = outputs.loss.detach().item() / len(labels[index])
if "ppl" in features:
inference_log["perplexity"] = perplexities[index]
if "attn" in features:
for layer_index, attention_layer in enumerate(outputs.attentions):
sequence_attention = attention_layer[index].detach().tolist()
inference_log[f"attn_{layer_index}"] = sequence_attention

inference_logs.append(inference_log)

file_name = split_name.replace(".", "_")
Expand Down Expand Up @@ -277,13 +263,24 @@ def parse_cli_args():
default=datasets_args_default,
)

features_arg_help = "The features to extract from the model response. Valid options are: attn, loss, perplexity"
features_arg_default = ["attn", "loss", "ppl"]
parser.add_argument(
"--features",
type=str,
help=features_arg_help,
choices=features_arg_default,
default=features_arg_default,
)

sample_size_arg_help = "The number of samples to take from the dataset. Defaults to None."
parser.add_argument(
"--sample-size",
"--sample_size",
type=int,
help=sample_size_arg_help,
default=None,
)

return parser.parse_args()


Expand All @@ -307,7 +304,7 @@ def main():
for dataset in args.datasets if isinstance(args.datasets, list) else args.datasets.split(","):
split_name = f"{data_scheme}.{model_size}"
print(f"Collecting inferences for {split_name} on {dataset} dataset")
run_model_inferences(split_name, experiment_timestamp, dataset, args.sample_size)
run_model_inferences(split_name, experiment_timestamp, dataset, args.features, args.sample_size)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
line-length = 150

[tool.pylint.format]
max-line-length = 120
max-line-length = 150

0 comments on commit 1b30b82

Please sign in to comment.