Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load LLMs in FP16 for Faster Inference #10

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


!datasets/eval/Pythia_70m_Deduped_Low_Perplexity_Labeling_Formatted.csv

*.pt
*.zip
.vscode
### Data ###
Expand Down
265 changes: 147 additions & 118 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from transformers import AutoTokenizer, GPTNeoXForCausalLM
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset, ReadInstruction
from multiprocessing import Process
from argparse import ArgumentParser
from tqdm import tqdm
from datetime import datetime
import pandas as pd
import numpy as np
import multiprocessing
import torch
import os

Expand All @@ -27,15 +29,15 @@ class PileDataset(Dataset):

def __init__(self, memories, tokenizer):
self.tokenizer = tokenizer
self.memories = memories
self.memories = memories.rename(columns={"index": "Index", "tokens": "Tokens"}) if "index" in memories.columns else memories

def __getitem__(self, index):
tokens = self.memories.iloc[index]["tokens"][:64]
tokens = self.memories.iloc[index]["Tokens"][:64]
decoded_text = self.tokenizer.decode(tokens)
return self.memories.iloc[index]["index"], decoded_text
return self.memories.iloc[index]["Index"], decoded_text

def __len__(self):
return len(self.memories["index"])
return len(self.memories["Index"])


def load_tokenizer(split_name: str) -> AutoTokenizer:
Expand Down Expand Up @@ -69,46 +71,50 @@ def calculate_perplexity(logits: torch.Tensor, labels: torch.Tensor) -> torch.fl
"""
# Store the probabilities for each token. These will be summed later, but having the
# individual probabilities is helpful for debugging.
token_probs = []

# Don't include the final token logits. There are no labels for
# these since the sequence has ended.
num_special_tokens = len(labels[labels == 0])
num_normal_tokens = len(labels) - num_special_tokens

for token_index in range(num_normal_tokens - 1):
# Map the logits to probabilities.
predicted_probs = torch.softmax(logits[token_index], dim=0, dtype=torch.float16)
# Get the probability of the correct label.
label_prob = predicted_probs[labels[token_index + 1]]

# Check if the label probability is 0. This is likely due a rounding error. Recalculate
# the probability using double precision.
if label_prob == 0:
try:
token_probs = []

# Don't include the final token logits. There are no labels for
# these since the sequence has ended.
num_special_tokens = len(labels[labels == 0])
num_normal_tokens = len(labels) - num_special_tokens

for token_index in range(num_normal_tokens - 1):
# Map the logits to probabilities.
# predicted_probs = torch.softmax(logits[token_index].view(torch.float64), dim=0, dtype=torch.float16)
predicted_probs = torch.softmax(logits[token_index], dim=0, dtype=torch.float64)
# Get the probability of the correct label.
label_prob = predicted_probs[labels[token_index + 1]]

# Store the probability for this token.
token_probs.append(label_prob.detach())
# Check if the label probability is 0. This is likely due a rounding error. Recalculate
# the probability using double precision.
# if label_prob == 0:
# predicted_probs = torch.softmax(logits[token_index], dim=0, dtype=torch.float64)
# label_prob = predicted_probs[labels[token_index + 1]]

# Store the probability for this token.
token_probs.append(label_prob.detach())

mid_index = len(token_probs) // 2
prompt_ppl = None
log_likelihood = torch.log(torch.stack(token_probs[:mid_index])).sum()
cross_entropy = -log_likelihood / len(token_probs)
prompt_ppl = torch.exp(cross_entropy).item()
mid_index = len(token_probs) // 2
prompt_ppl = None
log_likelihood = torch.log(torch.stack(token_probs[:mid_index])).sum()
cross_entropy = -log_likelihood / len(token_probs)
prompt_ppl = torch.exp(cross_entropy).item()

generation_ppl = None
log_likelihood = torch.log(torch.stack(token_probs[mid_index:])).sum()
cross_entropy = -log_likelihood / len(token_probs)
generation_ppl = torch.exp(cross_entropy).item()
generation_ppl = None
log_likelihood = torch.log(torch.stack(token_probs[mid_index:])).sum()
cross_entropy = -log_likelihood / len(token_probs)
generation_ppl = torch.exp(cross_entropy).item()

sequence_ppl = None
log_likelihood = torch.log(torch.stack(token_probs)).sum()
cross_entropy = -log_likelihood / len(token_probs)
sequence_ppl = torch.exp(cross_entropy).item()
sequence_ppl = None
log_likelihood = torch.log(torch.stack(token_probs)).sum()
cross_entropy = -log_likelihood / len(token_probs)
sequence_ppl = torch.exp(cross_entropy).item()

# assert perplexity != float("inf"), "Perplexity is infinite. This is probably due to a token that has a probability of 0."
return prompt_ppl, generation_ppl, sequence_ppl
return prompt_ppl, generation_ppl, sequence_ppl
except Exception as e:
print(f"Failed to calulcate perplexity: {e}")
return -1, -1, -1


def get_batch_size(model_name: str) -> int:
Expand All @@ -125,16 +131,15 @@ def get_batch_size(model_name: str) -> int:
"""
size_batch_map = {
"70m": 512,
"160m": 512,
"410m": 512,
"1b": 256,
"1.4b": 256,
"2.8b": 128,
"160m": 256,
"410m": 256,
"1b": 128,
"1.4b": 128,
"2.8b": 64,
"6.9b": 64,
"12b": 64,
"12b": 16,
}
model_size = ".".join(model_name.split(".")[1:])
return size_batch_map[model_size]
return size_batch_map[model_name]


def get_dataset(dataset_name: str, split_name: str, sample: int = None) -> pd.DataFrame:
Expand All @@ -154,14 +159,14 @@ def get_dataset(dataset_name: str, split_name: str, sample: int = None) -> pd.Da
if dataset_name.split("-")[0] == "pile":
scheme = split_name.split(".")[0]
pile_path = f"EleutherAI/pile-{scheme}-pythia-random-sampled"
dataset = load_dataset(pile_path, split="train").to_pandas()[["index", "tokens"]]
dataset = load_dataset(pile_path, split="train").to_pandas()[["Index", "Tokens"]]
else:
dataset = load_dataset("EleutherAI/pythia-memorized-evals")[split_name].to_pandas()

return dataset if sample is None else dataset.sample(sample).reset_index(drop=True)


def run_model_inferences(split_name: str, run_id: str, dataset: str, features: list, sample_size: int = None):
def run_model_inferences(split_name: str, run_id: str, dataset: str, features: list, batch_size: int, sample_size: int = None):
"""
Run inference for the given model and dataset. Save the results to a CSV file.

Expand All @@ -171,35 +176,42 @@ def run_model_inferences(split_name: str, run_id: str, dataset: str, features: l
dataset (str): The dataset to run inference on
sample_size (int, optional): The maximum number of random samples run inference on. Defaults to None.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = load_tokenizer(split_name)
pythia_model = load_model(split_name)
pile_sequences = get_dataset(dataset, split_name, sample=sample_size)
pile_dataset = PileDataset(pile_sequences, tokenizer)
batch_size = get_batch_size(split_name)
data_loader = DataLoader(pile_dataset, batch_size=batch_size)

with torch.no_grad():
desc = f"Collecting {dataset} inference responses for {split_name}"
for batch in tqdm(data_loader, desc=desc):
batch_sequences = batch[1]
tokenized_batch = tokenizer(
batch_sequences,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True,
)
tokenized_batch.to(device)
labels = tokenized_batch["input_ids"]

outputs = pythia_model(
**tokenized_batch,
labels=tokenized_batch["input_ids"],
output_attentions=True,
)
inference_logs = accumilate_inference_log(batch[0], labels, outputs, features)
save_inference_log(split_name, run_id, dataset, inference_logs)

num_processes = multiprocessing.cpu_count() // 2
# num_processes = 6
with multiprocessing.Pool(num_processes) as pool:
with torch.no_grad():
desc = f"Collecting {dataset} inference responses for {split_name}"
for batch in tqdm(data_loader, desc=desc):
batch_sequences = batch[1]
tokenized_batch = tokenizer(
batch_sequences,
return_tensors="pt",
max_length=256,
truncation=True,
padding=True,
)
tokenized_batch.to(pythia_model.device)
labels = tokenized_batch["input_ids"]

outputs = pythia_model(
**tokenized_batch,
labels=tokenized_batch["input_ids"],
output_attentions=True,
)
logits = outputs.logits.detach().cpu()
labels = labels.detach().cpu()
loss = outputs.loss.detach().cpu()
attentions = [attn_tensor.detach().cpu() for attn_tensor in outputs.attentions]

inference_logs = accumilate_inference_log(batch[0], labels, logits, loss, attentions, features, pool)
save_inference_log(split_name, run_id, dataset, inference_logs)
torch.cuda.empty_cache()


def gini(array):
Expand All @@ -208,15 +220,15 @@ def gini(array):
# from: http:https://www.statsdirect.com/help/default.htm#nonparametric_methods/gini.htm
array = array.flatten()
if np.amin(array) < 0:
array -= np.amin(array)
array = np.sort(array)
index = np.arange(1,array.shape[0]+1)
n = array.shape[0]
return ((np.sum((2 * index - n - 1) * array)) / (n * np.sum(array)))
array -= np.amin(array)
array = np.sort(array)
index = np.arange(1,array.shape[0]+1)
n = array.shape[0]
return ((np.sum((2 * index - n - 1) * array)) / (n * np.sum(array)))


def accumilate_inference_log(
batch_sequence_ids: list, labels: torch.Tensor, outputs: CausalLMOutputWithPast, features: list
batch_sequence_ids: list, labels: torch.Tensor, logits: torch.Tensor, loss: torch.Tensor, attentions: list[torch.Tensor], features: list, pool: multiprocessing.Pool
):
"""
Extract the desired data from the model response and save it to a CSV file.
Expand All @@ -227,51 +239,65 @@ def accumilate_inference_log(
outputs (CausalLMOutputWithPast): The response from the Pythia model
features (list): The list of features to calculate. A subset of [loss, ppl, attn]
"""
logits = outputs.logits.detach()
perplexities = [calculate_perplexity(logits[i], labels[i]) for i in range(len(logits))] if "ppl" in features else None
inference_logs = []
perplexities = [calculate_perplexity(logits[i], labels[i]) for i in range(len(logits))] if "ppl" in features else None
# perplexities = pool.starmap(calculate_perplexity, zip(logits, labels))
e=1e-8


method_args = []
for index, id_tensor in enumerate(batch_sequence_ids):
total_entropy = []
total_gini = []
inference_log = {"index": id_tensor.detach().item()}
if "loss" in features:
inference_log["loss"] = outputs.loss.detach().item() / len(labels[index])
if "ppl" in features:
inference_log["prompt_perplexity"] = perplexities[index][0]
inference_log["generation_perplexity"] = perplexities[index][1]
inference_log["sequence_perplexity"] = perplexities[index][2]
if "attn" in features:
for layer_index, attention_layer in enumerate(outputs.attentions):
sequence_attention = attention_layer[index].detach()
head_e = []
gini_head = []

for head_index, head in enumerate(sequence_attention):
attention_head = head.detach().cpu().numpy()
attention_head += e #adding 'e' to attention weights that are 0 to avoid log zero error while calculating entropy. Entropy = - ∑(w * log(w))
gini_coefficient = gini(attention_head)
gini_head.append(gini_coefficient)
head_entropy = -np.sum(attention_head * np.log(attention_head))
head_e.append(head_entropy)
inference_log[f"gini_head{head_index+1}_layer{layer_index+1}"] = gini_coefficient
inference_log[f"entropy_head{head_index+1}_layer{layer_index+1}"] = head_entropy

avg_head = np.mean(head_e)
avg_head_gini = np.mean(gini_head)
total_entropy.append(avg_head)
total_gini.append(avg_head_gini)

average_entropy = np.mean(total_entropy)
average_gini = np.mean(total_gini)
inference_log[f"avg entropy"] = average_entropy
inference_log[f"avg gini"] = average_gini

inference_logs.append(inference_log)
method_args.append((labels, loss, attentions, features, perplexities, e, index, id_tensor))
# inference_log = get_inference_log(labels, outputs, features, perplexities, e, index, id_tensor)
# inference_logs.append(inference_log)

inference_logs = pool.starmap(get_inference_log, method_args)
torch.cuda.empty_cache()
del method_args
return inference_logs

def get_inference_log(labels, loss, attentions, features, perplexities, e, index, id_tensor):
total_entropy = []
total_gini = []
inference_log = {"index": id_tensor.detach().item()}
if "loss" in features:
inference_log["loss"] = loss.detach().item() / len(labels[index])
if "attn" in features:
for layer_index, attention_layer in enumerate(attentions):
get_layer_entropy(e, index, total_entropy, total_gini, inference_log, layer_index, attention_layer)

average_entropy = np.mean(total_entropy)
average_gini = np.mean(total_gini)
inference_log[f"avg entropy"] = average_entropy
inference_log[f"avg gini"] = average_gini
if "ppl" in features:
inference_log["prompt_perplexity"] = perplexities[index][0]
inference_log["generation_perplexity"] = perplexities[index][1]
inference_log["sequence_perplexity"] = perplexities[index][2]

return inference_log


def get_layer_entropy(e, index, total_entropy, total_gini, inference_log, layer_index, attention_layer):
sequence_attention = attention_layer[index].detach()
head_e = []
gini_head = []

for head_index, head in enumerate(sequence_attention):
attention_head = head.detach().cpu().numpy()
attention_head += e #adding 'e' to attention weights that are 0 to avoid log zero error while calculating entropy. Entropy = - ∑(w * log(w))
gini_coefficient = gini(attention_head)
gini_head.append(gini_coefficient)
head_entropy = -np.sum(attention_head * np.log(attention_head))
head_e.append(head_entropy)
inference_log[f"gini_head{head_index+1}_layer{layer_index+1}"] = gini_coefficient
inference_log[f"entropy_head{head_index+1}_layer{layer_index+1}"] = head_entropy

avg_head = np.mean(head_e)
avg_head_gini = np.mean(gini_head)
total_entropy.append(avg_head)
total_gini.append(avg_head_gini)


def save_inference_log(split_name: str, run_id: str, dataset: pd.DataFrame, inference_logs: list):
"""Saves the accumilated inference log in a pandas dataframe

Expand All @@ -293,7 +319,6 @@ def parse_cli_args():
"--models",
type=str,
help=models_arg_help,
choices=models_args_default,
default=models_args_default,
)

Expand Down Expand Up @@ -335,6 +360,8 @@ def parse_cli_args():
default=None,
)

parser.add_argument("--batch_size", type=int, default=None, help="Batch size for inference")

return parser.parse_args()


Expand All @@ -359,8 +386,10 @@ def main():
for dataset in args.datasets if isinstance(args.datasets, list) else args.datasets.split(","):
split_name = f"{data_scheme}.{model_size}"
print(f"Collecting inferences for {split_name} on {dataset} dataset")
run_model_inferences(split_name, experiment_timestamp, dataset, args.features, args.sample_size)
batch_size = args.batch_size if args.batch_size is not None else get_batch_size(model_size)
run_model_inferences(split_name, experiment_timestamp, dataset, args.features, batch_size, args.sample_size)


if __name__ == "__main__":
multiprocessing.set_start_method("spawn")
main()