Skip to content

Commit

Permalink
Created configureble batch sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
Kyle1668 committed Aug 9, 2023
1 parent f4b02c0 commit 83f2c2a
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 3 deletions.
8 changes: 5 additions & 3 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def get_dataset(dataset_name: str, split_name: str, sample: int = None) -> pd.Da
return dataset if sample is None else dataset.sample(sample).reset_index(drop=True)


def run_model_inferences(split_name: str, run_id: str, dataset: str, features: list, sample_size: int = None):
def run_model_inferences(split_name: str, run_id: str, dataset: str, features: list, batch_size: int, sample_size: int = None):
"""
Run inference for the given model and dataset. Save the results to a CSV file.
Expand All @@ -180,7 +180,6 @@ def run_model_inferences(split_name: str, run_id: str, dataset: str, features: l
pythia_model = load_model(split_name)
pile_sequences = get_dataset(dataset, split_name, sample=sample_size)
pile_dataset = PileDataset(pile_sequences, tokenizer)
batch_size = get_batch_size(split_name)
data_loader = DataLoader(pile_dataset, batch_size=batch_size)

with torch.no_grad():
Expand Down Expand Up @@ -345,6 +344,8 @@ def parse_cli_args():
default=None,
)

parser.add_argument("--batch_size", type=int, default=None, help="Batch size for inference")

return parser.parse_args()


Expand All @@ -369,7 +370,8 @@ def main():
for dataset in args.datasets if isinstance(args.datasets, list) else args.datasets.split(","):
split_name = f"{data_scheme}.{model_size}"
print(f"Collecting inferences for {split_name} on {dataset} dataset")
run_model_inferences(split_name, experiment_timestamp, dataset, args.features, args.sample_size)
batch_size = args.batch_size if args.batch_size is not None else get_batch_size(model_size)
run_model_inferences(split_name, experiment_timestamp, dataset, args.features, batch_size, args.sample_size)


if __name__ == "__main__":
Expand Down
108 changes: 108 additions & 0 deletions working_dirs/kyle/upload-ppl/upload.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pandas as pd\n",
"from tqdm import tqdm\n",
"from huggingface_hub import HfApi\n",
"from datasets import load_dataset, Dataset"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"RepoUrl('https://huggingface.co/datasets/Kyle1668/pythia-semantic-memorization-perplexities', endpoint='https://huggingface.co', repo_type='dataset', repo_id='Kyle1668/pythia-semantic-memorization-perplexities')"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"api = HfApi()\n",
"api.create_repo(\"pythia-semantic-memorization-perplexities\", repo_type=\"dataset\")"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/1 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Uploading pile_duped_12b.csv...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"pile_duped_12b.csv: 100%|██████████| 207M/207M [00:03<00:00, 52.6MB/s]\n",
"100%|██████████| 1/1 [00:04<00:00, 4.88s/it]\n"
]
}
],
"source": [
"deduped_path = \"/home/kyle/repos/semantic-memorization/datasets/2023-08-07_22-32-14\"\n",
"for file_name in tqdm(os.listdir(deduped_path)):\n",
" print(f\"Uploading {file_name}...\")\n",
" \n",
" api.upload_file(\n",
" path_or_fileobj=f\"{deduped_path}/{file_name}\",\n",
" path_in_repo=f\"{file_name}\",\n",
" repo_id=\"Kyle1668/pythia-semantic-memorization-perplexities\",\n",
" repo_type=\"dataset\",\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "memorization",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 83f2c2a

Please sign in to comment.