Skip to content

Commit

Permalink
Update Readme on memorization replication
Browse files Browse the repository at this point in the history
  • Loading branch information
uSaiPrashanth committed Nov 2, 2023
1 parent 34e5c6b commit ec436e5
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 21 deletions.
8 changes: 6 additions & 2 deletions predictable-memorization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
This folder documents our work using Pythia to study memorization of particular sequences in the training dataset, and includes instructions to reproduce our analyses where possible.

## Reproducing Memorization Results
The memorization evaluation script `memorization/eval_memorization.py` assumes that you are running the script in a distributed process, ideally in slurm. If you want to reproduce the evaluation, consider the following steps.
The memorization evaluation script `memorization/eval_memorization.py` assumes that you are running the script in a distributed process, ideally in slurm. It also assumes that you are using s3 to load and save pile's preshuffled datasets (refer [here](https://github.com/EleutherAI/pythia/blob/main/README.md#dataset-viewer) for more details on how to download them).

1. Change `prefix` and `idx_path` local variables of `generate_function()` to point to the right document and index path.
If you want to reproduce the evaluation, consider the following steps.

1. Change `prefix` local variable of `generate_function()` to point to the right document path.

2. If you are not using [Slurm](https://slurm.schedmd.com/documentation.html), You need to change global variables inside the script, like `RANK` and `NUM_PROCS` (world size) to point to the right environment variables.

Expand All @@ -23,9 +25,11 @@ The memorization evaluation script `memorization/eval_memorization.py` assumes t

## Reproducing Figures

Refer to `memorization/eda.ipynb` for details on replication

## Reproducing Scaling Laws Plots

Refer to `memorization/eda.ipynb` for details on replication

## Citation Details

Expand Down
54 changes: 35 additions & 19 deletions predictable-memorization/eval_memorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "utils", "gpt-neox"))

sys.path.append(os.path.join(os.path.dirname(__file__), "..", "utils"))
from mmap_dataset import MMapIndexedDataset
import logging
import time
import datetime
Expand All @@ -19,7 +20,10 @@
import time
from tqdm import trange

def generate_dataset(batch_size, start_seq_idx, end_seq_idx, mp_queue, prefetch_max = 128):
def generate_dataset(batch_size, start_seq_idx, end_seq_idx, mp_queue,
using_s3 = False,
prefetch_max = 128
):
"""Wrapper function to prefetch pile sequences
Intended to run in a saperate `multiprocessing.Process`, this function will continuously prefetch
Expand All @@ -30,6 +34,7 @@ def generate_dataset(batch_size, start_seq_idx, end_seq_idx, mp_queue, prefetch_
start_seq_idx (int): Sequence index of first sequence to be evaluated by current rank
end_seq_idx (int): Sequence index of last sequence to be evalauted by current rank
mp_queue (multiprocessing.Queue): Instance of multiprocessing Queue, to add sequences into
using_s3 (bool): If your datasets are located in s3, set this to true
prefetch_max (int): Maximum number of sequences that can be pre-fetched into the queue
Env Vars:
Expand All @@ -38,27 +43,32 @@ def generate_dataset(batch_size, start_seq_idx, end_seq_idx, mp_queue, prefetch_
"""

# Load Pile dataset
prefix = 'orz/pile/standard/document.bin'
prefix = '/scratch/pile/standard/document.bin'
if "deduped" in os.environ['MODEL']:
prefix = 'orz/pile/deduped/document.bin'
s3 = boto3.client('s3')
buff_size = 2049*1024*2
buff_size = 2049*batch_size*2
if using_s3 == False:
mmap_ds = MMapIndexedDataset(prefix, skip_warmup=True)

# Iterate over pile and add sequences to mp_queue
context_tokens = []
true_continuation = []
i = 0
for i in range(start_seq_idx, end_seq_idx + 1, buff_size // (2049*2)):
dataset = s3.get_object(
Bucket = 's-eai-neox-west',
Key = prefix,
Range = f'bytes={i*2049*2}-{i*2049*2 + buff_size}'
)
data = dataset['Body'].read(buff_size)
data = np.frombuffer(data, dtype = np.uint16).reshape(-1, 2049)
for i in range(start_seq_idx, end_seq_idx + 1, batch_size):
if using_s3:
dataset = s3.get_object(
Bucket = os.environ['BUCKET'],
Key = prefix,
Range = f'bytes={i*2049*2}-{i*2049*2 + buff_size}'
)
data = dataset['Body'].read(buff_size)
data = np.frombuffer(data, dtype = np.uint16).reshape(-1, 2049)
else:
data = mmap_ds[i:i+batch_size]
context_tokens.extend(data[:, :32].tolist())
true_continuation.extend(data[:,32:64].tolist())
i += buff_size // (2049*2)
i += len(context_tokens)

if len(context_tokens) == batch_size:
# (start index of batch, context tokens, true continuation)
Expand Down Expand Up @@ -95,8 +105,8 @@ def score(model, context_tokens, true_continuation):
accuracies (torch.Tensor): Accuracies of shape (batch_size,)
"""
with torch.no_grad():
context_tokens = torch.tensor(context_tokens)
true_continuation = torch.tensor(true_continuation)
context_tokens = torch.tensor(context_tokens).to('cuda')
true_continuation = torch.tensor(true_continuation).to('cuda')

generations = model.generate(context_tokens, temperature = 0.0, top_k = 0, top_p = 0, max_length = 64, min_length = 64)

Expand All @@ -114,23 +124,29 @@ def main():
LOCAL_RANK = int(os.environ['SLURM_LOCALID'])
NUM_PROCS = int(os.environ['SLURM_NPROCS'])

RANK = int(os.environ['RANK'])
LOCAL_RANK = RANK
NUM_PROCS = int(os.environ['WORLD_SIZE'])

# Eval configuration variables
MODEL = os.environ['MODEL']
CHECKPOINT = int(os.environ['CHECKPOINT'])

# Distributed initializations
os.environ['MASTER_ADDR'] = os.environ['SLURM_LAUNCH_NODE_IPADDR']
os.environ['MASTER_PORT'] = '12128'
# os.environ['MASTER_ADDR'] = os.environ['SLURM_LAUNCH_NODE_IPADDR']
# os.environ['MASTER_PORT'] = '12128'
logging.basicConfig(format = f'rank-{RANK}:' + '%(levelname)s:%(message)s', level = logging.INFO)
logging.info(f"Initializing torch distributed with gpus {torch.cuda.device_count()}")

# Initialize torch distributed
torch.cuda.set_device(RANK)
dist.init_process_group(
"nccl",
world_size = NUM_PROCS,
rank = RANK
)
store = dist.TCPStore(os.environ['MASTER_ADDR'], 12125, world_size = NUM_PROCS, is_master = RANK == 0, timeout = datetime.timedelta(hours=3))
store = dist.TCPStore(os.environ['MASTER_ADDR'], port = 12125,
world_size = NUM_PROCS, is_master = RANK == 0, timeout = datetime.timedelta(hours=3))

dist.barrier()

Expand Down Expand Up @@ -192,7 +208,7 @@ def main():
s3 = boto3.client('s3')
s3.put_object(
Body = '\n'.join(memorization_evals).encode(),
Bucket = 's-eai-neox-west',
Bucket = os.environ['Bucket'],
Key = f'memorization-evals/evals-running/memorization_{MODEL}_{CHECKPOINT}/rank-{RANK}.csv'
)
dist.barrier()
Expand Down
3 changes: 3 additions & 0 deletions utils/mmap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ def __init__(self, path, skip_warmup=False):
self._index = None
self._bin_buffer = None

if path.endswith(".bin") or path.endswith(".idx"):
path = path[:-4]

self._do_init(path, skip_warmup)

def __getstate__(self):
Expand Down

0 comments on commit ec436e5

Please sign in to comment.