Skip to content

Commit

Permalink
Updated files to handle new schema
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyle1668 committed Sep 14, 2023
1 parent 96053f7 commit a22fa4c
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def __init__(self, memories, tokenizer):
self.memories = memories

def __getitem__(self, index):
tokens = self.memories.iloc[index]["tokens"][:64]
tokens = self.memories.iloc[index]["Tokens"][:64]
decoded_text = self.tokenizer.decode(tokens)
return self.memories.iloc[index]["index"], decoded_text
return self.memories.iloc[index]["Index"], decoded_text

def __len__(self):
return len(self.memories["index"])
return len(self.memories["Index"])


def load_tokenizer(split_name: str) -> AutoTokenizer:
Expand Down Expand Up @@ -157,7 +157,7 @@ def get_dataset(dataset_name: str, split_name: str, sample: int = None) -> pd.Da
if dataset_name.split("-")[0] == "pile":
scheme = split_name.split(".")[0]
pile_path = f"EleutherAI/pile-{scheme}-pythia-random-sampled"
dataset = load_dataset(pile_path, split="train").to_pandas()[["index", "tokens"]]
dataset = load_dataset(pile_path, split="train").to_pandas()[["Index", "Tokens"]]
else:
dataset = load_dataset("EleutherAI/pythia-memorized-evals")[split_name].to_pandas()

Expand All @@ -174,7 +174,6 @@ def run_model_inferences(split_name: str, run_id: str, dataset: str, features: l
dataset (str): The dataset to run inference on
sample_size (int, optional): The maximum number of random samples run inference on. Defaults to None.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = load_tokenizer(split_name)
pythia_model = load_model(split_name)
pile_sequences = get_dataset(dataset, split_name, sample=sample_size)
Expand All @@ -192,7 +191,7 @@ def run_model_inferences(split_name: str, run_id: str, dataset: str, features: l
truncation=True,
padding=True,
)
tokenized_batch.to(device)
tokenized_batch.to(pythia_model.device)
labels = tokenized_batch["input_ids"]

outputs = pythia_model(
Expand Down

0 comments on commit a22fa4c

Please sign in to comment.