Skip to content

Commit

Permalink
Introduce parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyle1668 committed Sep 15, 2023
1 parent a22fa4c commit f56cf7a
Showing 1 changed file with 76 additions and 57 deletions.
133 changes: 76 additions & 57 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from datetime import datetime
import pandas as pd
import numpy as np
import multiprocessing
import torch
import os

Expand All @@ -28,7 +29,7 @@ class PileDataset(Dataset):

def __init__(self, memories, tokenizer):
self.tokenizer = tokenizer
self.memories = memories
self.memories = memories.rename(columns={"index": "Index", "tokens": "Tokens"}) if "index" in memories.columns else memories

def __getitem__(self, index):
tokens = self.memories.iloc[index]["Tokens"][:64]
Expand All @@ -54,7 +55,7 @@ def load_model(split_name):
isDeduped = split_name.startswith("deduped")
model = split_name.split("duped.")[-1]
corresponding_model = f"EleutherAI/pythia-{model}{'-deduped' if isDeduped else ''}"
return GPTNeoXForCausalLM.from_pretrained(corresponding_model, device_map="auto", torch_dtype=torch.float16)
return GPTNeoXForCausalLM.from_pretrained(corresponding_model, device_map="auto")


def calculate_perplexity(logits: torch.Tensor, labels: torch.Tensor) -> torch.float64:
Expand All @@ -80,15 +81,16 @@ def calculate_perplexity(logits: torch.Tensor, labels: torch.Tensor) -> torch.fl

for token_index in range(num_normal_tokens - 1):
# Map the logits to probabilities.
predicted_probs = torch.softmax(logits[token_index], dim=0, dtype=torch.float16)
# predicted_probs = torch.softmax(logits[token_index].view(torch.float64), dim=0, dtype=torch.float16)
predicted_probs = torch.softmax(logits[token_index], dim=0, dtype=torch.float64)
# Get the probability of the correct label.
label_prob = predicted_probs[labels[token_index + 1]]

# Check if the label probability is 0. This is likely due a rounding error. Recalculate
# the probability using double precision.
if label_prob == 0:
predicted_probs = torch.softmax(logits[token_index], dim=0, dtype=torch.float64)
label_prob = predicted_probs[labels[token_index + 1]]
# if label_prob == 0:
# predicted_probs = torch.softmax(logits[token_index], dim=0, dtype=torch.float64)
# label_prob = predicted_probs[labels[token_index + 1]]

# Store the probability for this token.
token_probs.append(label_prob.detach())
Expand Down Expand Up @@ -129,13 +131,13 @@ def get_batch_size(model_name: str) -> int:
"""
size_batch_map = {
"70m": 512,
"160m": 512,
"410m": 512,
"1b": 256,
"1.4b": 256,
"2.8b": 128,
"160m": 256,
"410m": 256,
"1b": 128,
"1.4b": 128,
"2.8b": 64,
"6.9b": 64,
"12b": 64,
"12b": 16,
}
return size_batch_map[model_name]

Expand Down Expand Up @@ -180,27 +182,36 @@ def run_model_inferences(split_name: str, run_id: str, dataset: str, features: l
pile_dataset = PileDataset(pile_sequences, tokenizer)
data_loader = DataLoader(pile_dataset, batch_size=batch_size)

with torch.no_grad():
desc = f"Collecting {dataset} inference responses for {split_name}"
for batch in tqdm(data_loader, desc=desc):
batch_sequences = batch[1]
tokenized_batch = tokenizer(
batch_sequences,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True,
)
tokenized_batch.to(pythia_model.device)
labels = tokenized_batch["input_ids"]

outputs = pythia_model(
**tokenized_batch,
labels=tokenized_batch["input_ids"],
output_attentions=True,
)
inference_logs = accumilate_inference_log(batch[0], labels, outputs, features)
save_inference_log(split_name, run_id, dataset, inference_logs)
num_processes = multiprocessing.cpu_count()
# num_processes = 1
with multiprocessing.Pool(num_processes) as pool:
with torch.no_grad():
desc = f"Collecting {dataset} inference responses for {split_name}"
for batch in tqdm(data_loader, desc=desc):
batch_sequences = batch[1]
tokenized_batch = tokenizer(
batch_sequences,
return_tensors="pt",
max_length=256,
truncation=True,
padding=True,
)
tokenized_batch.to(pythia_model.device)
labels = tokenized_batch["input_ids"]

outputs = pythia_model(
**tokenized_batch,
labels=tokenized_batch["input_ids"],
output_attentions=True,
)
logits = outputs.logits.detach().cpu()
labels = labels.detach().cpu()
loss = outputs.loss.detach().cpu()
attentions = [attn_tensor.detach().cpu() for attn_tensor in outputs.attentions]

inference_logs = accumilate_inference_log(batch[0], labels, logits, loss, attentions, features, pool)
save_inference_log(split_name, run_id, dataset, inference_logs)
torch.cuda.empty_cache()


def gini(array):
Expand All @@ -217,7 +228,7 @@ def gini(array):


def accumilate_inference_log(
batch_sequence_ids: list, labels: torch.Tensor, outputs: CausalLMOutputWithPast, features: list
batch_sequence_ids: list, labels: torch.Tensor, logits: torch.Tensor, loss: torch.Tensor, attentions: list[torch.Tensor], features: list, pool: multiprocessing.Pool
):
"""
Extract the desired data from the model response and save it to a CSV file.
Expand All @@ -228,36 +239,43 @@ def accumilate_inference_log(
outputs (CausalLMOutputWithPast): The response from the Pythia model
features (list): The list of features to calculate. A subset of [loss, ppl, attn]
"""
logits = outputs.logits.detach()
perplexities = [calculate_perplexity(logits[i], labels[i]) for i in range(len(logits))] if "ppl" in features else None
inference_logs = []
perplexities = [calculate_perplexity(logits[i], labels[i]) for i in range(len(logits))] if "ppl" in features else None
# perplexities = pool.starmap(calculate_perplexity, zip(logits, labels))
e=1e-8

method_args = []
for index, id_tensor in enumerate(batch_sequence_ids):
total_entropy = []
total_gini = []
inference_log = {"index": id_tensor.detach().item()}
if "loss" in features:
inference_log["loss"] = outputs.loss.detach().item() / len(labels[index])
if "ppl" in features:
inference_log["prompt_perplexity"] = perplexities[index][0]
inference_log["generation_perplexity"] = perplexities[index][1]
inference_log["sequence_perplexity"] = perplexities[index][2]
if "attn" in features:
# process_args = [layer_index, attention_layer for layer_index, attention_layer in enumerate(outputs.attentions)]
# p = Process(target=get_layer_entropy, args=(e, index, total_entropy, total_gini, inference_log, layer_index, attention_layer))
for layer_index, attention_layer in enumerate(outputs.attentions):
get_layer_entropy(e, index, total_entropy, total_gini, inference_log, layer_index, attention_layer)

average_entropy = np.mean(total_entropy)
average_gini = np.mean(total_gini)
inference_log[f"avg entropy"] = average_entropy
inference_log[f"avg gini"] = average_gini

inference_logs.append(inference_log)
method_args.append((labels, loss, attentions, features, perplexities, e, index, id_tensor))
# inference_log = get_inference_log(labels, outputs, features, perplexities, e, index, id_tensor)
# inference_logs.append(inference_log)

inference_logs = pool.starmap(get_inference_log, method_args)
torch.cuda.empty_cache()
del method_args
return inference_logs

def get_inference_log(labels, loss, attentions, features, perplexities, e, index, id_tensor):
total_entropy = []
total_gini = []
inference_log = {"index": id_tensor.detach().item()}
if "loss" in features:
inference_log["loss"] = loss.detach().item() / len(labels[index])
if "attn" in features:
for layer_index, attention_layer in enumerate(attentions):
get_layer_entropy(e, index, total_entropy, total_gini, inference_log, layer_index, attention_layer)

average_entropy = np.mean(total_entropy)
average_gini = np.mean(total_gini)
inference_log[f"avg entropy"] = average_entropy
inference_log[f"avg gini"] = average_gini
if "ppl" in features:
inference_log["prompt_perplexity"] = perplexities[index][0]
inference_log["generation_perplexity"] = perplexities[index][1]
inference_log["sequence_perplexity"] = perplexities[index][2]

return inference_log


def get_layer_entropy(e, index, total_entropy, total_gini, inference_log, layer_index, attention_layer):
sequence_attention = attention_layer[index].detach()
Expand Down Expand Up @@ -373,4 +391,5 @@ def main():


if __name__ == "__main__":
multiprocessing.set_start_method("spawn")
main()

0 comments on commit f56cf7a

Please sign in to comment.