Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load LLMs in FP16 for Faster Inference #10

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
104 changes: 53 additions & 51 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def load_model(split_name):
isDeduped = split_name.startswith("deduped")
model = split_name.split("duped.")[-1]
corresponding_model = f"EleutherAI/pythia-{model}{'-deduped' if isDeduped else ''}"
return GPTNeoXForCausalLM.from_pretrained(corresponding_model, device_map="auto")
return GPTNeoXForCausalLM.from_pretrained(corresponding_model, device_map="auto", torch_dtype=torch.float16)


def calculate_perplexity(logits: torch.Tensor, labels: torch.Tensor) -> torch.float64:
Expand All @@ -69,46 +69,49 @@ def calculate_perplexity(logits: torch.Tensor, labels: torch.Tensor) -> torch.fl
"""
# Store the probabilities for each token. These will be summed later, but having the
# individual probabilities is helpful for debugging.
token_probs = []

# Don't include the final token logits. There are no labels for
# these since the sequence has ended.
num_special_tokens = len(labels[labels == 0])
num_normal_tokens = len(labels) - num_special_tokens

for token_index in range(num_normal_tokens - 1):
# Map the logits to probabilities.
predicted_probs = torch.softmax(logits[token_index], dim=0, dtype=torch.float16)
# Get the probability of the correct label.
label_prob = predicted_probs[labels[token_index + 1]]

# Check if the label probability is 0. This is likely due a rounding error. Recalculate
# the probability using double precision.
if label_prob == 0:
predicted_probs = torch.softmax(logits[token_index], dim=0, dtype=torch.float64)
try:
token_probs = []

# Don't include the final token logits. There are no labels for
# these since the sequence has ended.
num_special_tokens = len(labels[labels == 0])
num_normal_tokens = len(labels) - num_special_tokens

for token_index in range(num_normal_tokens - 1):
# Map the logits to probabilities.
predicted_probs = torch.softmax(logits[token_index], dim=0, dtype=torch.float16)
# Get the probability of the correct label.
label_prob = predicted_probs[labels[token_index + 1]]

# Store the probability for this token.
token_probs.append(label_prob.detach())
# Check if the label probability is 0. This is likely due a rounding error. Recalculate
# the probability using double precision.
if label_prob == 0:
predicted_probs = torch.softmax(logits[token_index], dim=0, dtype=torch.float64)
label_prob = predicted_probs[labels[token_index + 1]]

mid_index = len(token_probs) // 2
prompt_ppl = None
log_likelihood = torch.log(torch.stack(token_probs[:mid_index])).sum()
cross_entropy = -log_likelihood / len(token_probs)
prompt_ppl = torch.exp(cross_entropy).item()
# Store the probability for this token.
token_probs.append(label_prob.detach())

generation_ppl = None
log_likelihood = torch.log(torch.stack(token_probs[mid_index:])).sum()
cross_entropy = -log_likelihood / len(token_probs)
generation_ppl = torch.exp(cross_entropy).item()
mid_index = len(token_probs) // 2
prompt_ppl = None
log_likelihood = torch.log(torch.stack(token_probs[:mid_index])).sum()
cross_entropy = -log_likelihood / len(token_probs)
prompt_ppl = torch.exp(cross_entropy).item()

sequence_ppl = None
log_likelihood = torch.log(torch.stack(token_probs)).sum()
cross_entropy = -log_likelihood / len(token_probs)
sequence_ppl = torch.exp(cross_entropy).item()
generation_ppl = None
log_likelihood = torch.log(torch.stack(token_probs[mid_index:])).sum()
cross_entropy = -log_likelihood / len(token_probs)
generation_ppl = torch.exp(cross_entropy).item()

# assert perplexity != float("inf"), "Perplexity is infinite. This is probably due to a token that has a probability of 0."
return prompt_ppl, generation_ppl, sequence_ppl
sequence_ppl = None
log_likelihood = torch.log(torch.stack(token_probs)).sum()
cross_entropy = -log_likelihood / len(token_probs)
sequence_ppl = torch.exp(cross_entropy).item()

return prompt_ppl, generation_ppl, sequence_ppl
except Exception as e:
print(f"Failed to calulcate perplexity: {e}")
return -1, -1, -1


def get_batch_size(model_name: str) -> int:
Expand Down Expand Up @@ -178,7 +181,7 @@ def run_model_inferences(split_name: str, run_id: str, dataset: str, features: l
pile_dataset = PileDataset(pile_sequences, tokenizer)
batch_size = get_batch_size(split_name)
data_loader = DataLoader(pile_dataset, batch_size=batch_size)

with torch.no_grad():
desc = f"Collecting {dataset} inference responses for {split_name}"
for batch in tqdm(data_loader, desc=desc):
Expand Down Expand Up @@ -208,11 +211,11 @@ def gini(array):
# from: http:https://www.statsdirect.com/help/default.htm#nonparametric_methods/gini.htm
array = array.flatten()
if np.amin(array) < 0:
array -= np.amin(array)
array = np.sort(array)
index = np.arange(1,array.shape[0]+1)
n = array.shape[0]
return ((np.sum((2 * index - n - 1) * array)) / (n * np.sum(array)))
array -= np.amin(array)
array = np.sort(array)
index = np.arange(1,array.shape[0]+1)
n = array.shape[0]
return ((np.sum((2 * index - n - 1) * array)) / (n * np.sum(array)))


def accumilate_inference_log(
Expand All @@ -231,7 +234,7 @@ def accumilate_inference_log(
perplexities = [calculate_perplexity(logits[i], labels[i]) for i in range(len(logits))] if "ppl" in features else None
inference_logs = []
e=1e-8

for index, id_tensor in enumerate(batch_sequence_ids):
total_entropy = []
total_gini = []
Expand All @@ -243,21 +246,21 @@ def accumilate_inference_log(
inference_log["generation_perplexity"] = perplexities[index][1]
inference_log["sequence_perplexity"] = perplexities[index][2]
if "attn" in features:
for layer_index, attention_layer in enumerate(outputs.attentions):
sequence_attention = attention_layer[index].detach()
for layer_index, attention_layer in enumerate(outputs.attentions):
sequence_attention = attention_layer[index].detach()
head_e = []
gini_head = []

for head_index, head in enumerate(sequence_attention):
attention_head = head.detach().cpu().numpy()
for head_index, head in enumerate(sequence_attention):
attention_head = head.detach().cpu().numpy()
attention_head += e #adding 'e' to attention weights that are 0 to avoid log zero error while calculating entropy. Entropy = - ∑(w * log(w))
gini_coefficient = gini(attention_head)
gini_head.append(gini_coefficient)
head_entropy = -np.sum(attention_head * np.log(attention_head))
head_entropy = -np.sum(attention_head * np.log(attention_head))
head_e.append(head_entropy)
inference_log[f"gini_head{head_index+1}_layer{layer_index+1}"] = gini_coefficient
inference_log[f"entropy_head{head_index+1}_layer{layer_index+1}"] = head_entropy

avg_head = np.mean(head_e)
avg_head_gini = np.mean(gini_head)
total_entropy.append(avg_head)
Expand All @@ -266,8 +269,8 @@ def accumilate_inference_log(
average_entropy = np.mean(total_entropy)
average_gini = np.mean(total_gini)
inference_log[f"avg entropy"] = average_entropy
inference_log[f"avg gini"] = average_gini
inference_log[f"avg gini"] = average_gini

inference_logs.append(inference_log)

return inference_logs
Expand All @@ -293,7 +296,6 @@ def parse_cli_args():
"--models",
type=str,
help=models_arg_help,
choices=models_args_default,
default=models_args_default,
)

Expand Down