Skip to content

Commit

Permalink
Have multiple options
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyle1668 committed Sep 15, 2023
1 parent feda73f commit f6da96b
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


!datasets/eval/Pythia_70m_Deduped_Low_Perplexity_Labeling_Formatted.csv

*.pt
*.zip
.vscode
### Data ###
Expand Down
12 changes: 10 additions & 2 deletions inference_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,16 @@ def run_model_inferences(split_name: str, run_id: str, dataset: str, batch_size:



results = p.map(parse_attn, [t.detach().cpu() for t in outputs.attentions])
print(results)
# results = p.map(parse_attn, [t.detach().cpu() for t in outputs.attentions])
# print(results)

# attentions_table = {}
for i in tqdm(range(len(batch[0]))):
current_example_id = batch[0][i]
current_example_attentions = torch.stack(outputs.attentions)[:, i, :]
# attentions_table[current_example_id] = current_example_attentions
torch.save(current_example_attentions, f"datasets/{run_id}/{dataset}_attentions_{current_example_id}.pt")
# print(current_example_attentions.shape)

# inference_logs = pd.DataFrame({
# "Loss": outputs.loss.detach().cpu().tolist(),
Expand Down
2 changes: 1 addition & 1 deletion inference_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def get_batch_size(model_name: str) -> int:
"2.8b": 128,
# Large
"6.9b": 64,
"12b": 64,
"12b": 32,
}
model_size = ".".join(model_name.split(".")[1:])
return size_batch_map[model_size]
Expand Down

0 comments on commit f6da96b

Please sign in to comment.