forked from EleutherAI/gpt-neox
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
50 changed files
with
1,287 additions
and
2,182 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ This repository records [EleutherAI](www.eleuther.ai)'s work-in-progress for tra | |
|
||
We aim to make this repo a centralized and accessible place to gather techniques for training large scale autoregressive language models, and accelerate research into large scale training. Additionally, we hope to train and open source a 175B parameter GPT3 replication along the way. | ||
|
||
For more info on our progress, please [join our discord](https://discord.com/invite/vtRgjbM) and head to the `#gpt-neo` channel. We're working with cloud compute provider [Coreweave](https://www.coreweave.com/) for training, and hope to release the weights of smaller models as we progress up to 175B parameters. | ||
For more info on our progress, please [join our discord](https://discord.gg/zBGx3azzUn) and head to the `#gpt-neo` channel. We're working with cloud compute provider [Coreweave](https://www.coreweave.com/) for training, and hope to release the weights of smaller models as we progress up to 175B parameters. | ||
|
||
If you're looking for our TPU codebase, see [GPT-Neo](https://github.com/EleutherAI/gpt-neo). | ||
|
||
|
@@ -17,13 +17,13 @@ GPT-NeoX is under active development. | |
|
||
### 3D Parallelism | ||
|
||
- GPTNeoX offers full 3D parallelism (data, model and pipeline parallel) using deepspeed, allowing you to scale model training to hundreds of billions of parameters across multiple GPUs. | ||
- GPTNeoX offers full 3D parallelism (data, model and pipeline parallel) using DeepSpeed, allowing you to scale model training to hundreds of billions of parameters across multiple GPUs. | ||
|
||
### Model Structure | ||
|
||
- **Positional Encodings:** | ||
|
||
- Choose between T5 RPE style positional encodings, a learned encoding added to the input (GPT2-style), Sinusoidal positional encoding, and no positional encodings at all (which [recent](https://arxiv.org/abs/1905.04226) [research](https://arxiv.org/abs/2102.11174) has found to even outperform other positional encodings in autoregressive models). | ||
- Choose between T5 RPE style positional encodings, a learned encoding added to the input (GPT2-style), Sinusoidal positional encoding, [rotary positional encodings](https://arxiv.org/abs/2104.09864), and no positional encodings at all (which [recent](https://arxiv.org/abs/1905.04226) [research](https://arxiv.org/abs/2102.11174) has found to even outperform other positional encodings in autoregressive models). | ||
|
||
- **Sparsity:** | ||
|
||
|
@@ -45,15 +45,14 @@ We offer a choice of layernorm, scalenorm and RMSNorm easily configured by chang | |
|
||
### Straightforward configuration | ||
|
||
- Other libraries such as Megatron-LM require you configure them using command line arguments, which can often be difficult to work with and iterate upon. We offer straightforward configuration using .yaml files, which enables you to launch training runs across 100s of GPUs with a single line bash script. | ||
- Other libraries such as Megatron-LM require you configure them using command line arguments and global variables, which can often be difficult to work with and iterate upon. We offer straightforward configuration using .yaml files, which enables you to launch training runs across 100s of GPUs with a single line bash script. | ||
- Additionally, we hope to make data preparation easier on the user by providing scripts to automatically download and pretokenize a number of large-scale datasets. | ||
|
||
## Getting Started | ||
|
||
Our codebase relies on [DeeperSpeed](https://github.com/EleutherAI/DeeperSpeed), our fork of the [DeepSpeed](https://github.com/microsoft/DeepSpeed) library with some added changes. | ||
We strongly recommend using Anaconda, a virtual machine, or some other form of environment isolation before installing from `requirements.txt`. Failure to do so may cause other repositories that rely on DeepSpeed to break. Python 3.8 or later is required. | ||
Our codebase relies on [DeeperSpeed](https://github.com/EleutherAI/DeeperSpeed), our fork of the [DeepSpeed](https://github.com/microsoft/DeepSpeed) library with some added changes. We strongly recommend using Anaconda, a virtual machine, or some other form of environment isolation before installing from `requirements/requirements.txt`. Failure to do so may cause other repositories that rely on DeepSpeed to break. Python 3.8 or later is required. | ||
|
||
First make sure you are in an environment with `torch>=1.7.1` installed. Then run `pip install -r requirements.txt`. | ||
First make sure you are in an environment with `torch>=1.8` installed. Then run `pip install -r requirements/requirements.txt`. | ||
You may need to change the version of `cupy-cudaxxx` to match your machine's cuda version. | ||
|
||
Finally, certain features rely on apex, which you can install with the command below: | ||
|
@@ -171,17 +170,41 @@ This will deploy the `pretrain_gpt2.py` script on all nodes with one process per | |
|
||
EleutherAI is currently using [Weights & Biases to record experiments](https://wandb.ai/eleutherai/neox). If you are logged into Weights & Biases on your machine - you can do this by executing `wandb login` - your runs will automatically be recorded. Additionally, set the config parameter `wandb_team` if you would like the run to be added to an organisation/team account. | ||
|
||
We also support using Tensorboard via the `tensorboard-dir` argument. To use tensorboard, install the optional packages found at `requirements/requirements-tensorboard.txt` | ||
|
||
## Inference | ||
|
||
[WIP] | ||
|
||
## Eleuther Cluster | ||
## Evaluation | ||
|
||
[WIP] | ||
|
||
## Distilling | ||
|
||
[WIP] | ||
|
||
## Citing GPT-NeoX | ||
|
||
|
||
We run our experiments on a Kubernetes cluster generously provided by [CoreWeave](https://coreweave.com/). The `/kubernetes/` directory contains code designed to facilitate work on our server. If you are an EleutherAI member, see the [corresponding read-me](kubernetes) for information about how to use our cluster. | ||
### Citing | ||
|
||
If you have found GPT-Neo helpful in your work, you can cite this repository as | ||
|
||
``` | ||
@software{gpt-neo, | ||
author = {Andonian, Alex and Biderman, Stella and Black, Sid and Gali, Preetham and Gao, Leo and Hallahan, Eric and Levy-Kramer, Josh and Leahy, Connor and Nestler, Lucas and Parker, Kip and Pieler, Michael and Purohit, Shivanshu and Songz, Tri and Wang, Phil and Weinbach, Samuel}, | ||
title = {{GPT-NeoX}: Large Scale Autoregressive Language Modeling in PyTorch}, | ||
url = {http:https://github.com/eleutherai/gpt-neox}, | ||
year = {2021} | ||
} | ||
``` | ||
|
||
In the above bibtex entry, names are in alphabetical order, and the year corresponds to the project's open-source release. | ||
|
||
## Licensing | ||
|
||
This repository hosts code that is part of EleutherAI's GPT-NeoX project. Copyright (c) 2021, EleutherAI contributors (in alphabetical order): Stella Biderman, Sid Black, Eric Hallahan, Josh Levy-Kramer, Michael Pieler, Shivanshu Purohit. Licensed under the Apache License: | ||
This repository hosts code that is part of EleutherAI's GPT-NeoX project. Copyright (c) 2021, EleutherAI contributors (in alphabetical order): Alex Andonian, Stella Biderman, Sid Black, Preetham Gali, Leo Gao, Eric Hallahan, Josh Levy-Kramer, Connor Leahy, Lucas Nestler, Kip Parker, Michael Pieler, Shivanshu Purohit, Tri Songz, Phil Wang, Samuel Weinbach. Licensed under the Apache License: | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
|
@@ -198,3 +221,7 @@ This repository hosts code that is part of EleutherAI's GPT-NeoX project. Copyri | |
This repository is based off code written by NVIDIA that is licensed under the Apache License, Version 2.0. In accordance with the Apache License, all files that are modifications of code originally written by NVIDIA maintain a NVIDIA copyright header. All files that do not contain such a header are original to EleutherAI contributors. When the NVIDIA code has been modified from its original version, that fact is noted in the copyright header. All derivative works of this repository must preserve these headers under the terms of the Apache License. | ||
|
||
For full terms, see the `LICENSE` file. If you have any questions, comments, or concerns about licensing please email us at [email protected]. | ||
|
||
## Acknowledgements | ||
|
||
We run our experiments on a Kubernetes cluster generously provided by [CoreWeave](https://coreweave.com/). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
from megatron.utils import is_local_main | ||
import best_download | ||
|
||
# patch best_download (eval harness downloader) to only happen on the first rank | ||
fn = best_download.download_file | ||
|
||
def _download_file(*args, **kwargs): | ||
if is_local_main(): | ||
fn(*args, **kwargs) | ||
|
||
best_download.download_file = _download_file | ||
|
||
import os | ||
import sys | ||
from functools import partial | ||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), | ||
os.path.pardir))) | ||
from megatron.utils import print_rank_0 | ||
from tqdm import tqdm | ||
import torch | ||
from lm_eval.base import CacheHook | ||
from lm_eval.models.gpt2 import GPT2LM | ||
from lm_eval import tasks, evaluator, utils | ||
import torch.nn.functional as F | ||
from megatron.text_generation_utils import generate_samples_from_prompt | ||
import inspect | ||
from lm_eval import tasks | ||
from lm_eval.utils import chunks | ||
|
||
# TODO: add data parallel | ||
|
||
class EvalHarnessAdaptor(GPT2LM): | ||
|
||
def __init__(self, model, forward_step_fn, neox_args, batch_size=None): | ||
self.device = torch.device(f'cuda:{neox_args.local_rank}') | ||
self.VOCAB_SIZE = neox_args.padded_vocab_size | ||
self.tokenizer = neox_args.tokenizer | ||
self.EOT_TOKEN_ID = neox_args.tokenizer.eod_id | ||
self.model = model | ||
self._forward_step_fn = partial(forward_step_fn, neox_args=neox_args, timers=None, return_logits=True) | ||
self.max_length = neox_args.max_position_embeddings // 2 | ||
self.max_gen_toks = 128 | ||
self.tokenizer.encode = self.tokenizer.tokenize # patch tokenizer encode + decode methods | ||
self.tokenizer.decode = self.tokenizer.detokenize | ||
self.batch_size = batch_size or neox_args.batch_size | ||
self.neox_args = neox_args | ||
self.cache_hook = CacheHook(None) | ||
self.is_main = neox_args.rank == 0 | ||
self.is_local_main = neox_args.local_rank == 0 | ||
self.is_pipe_parallel = self.model.is_pipe_parallel | ||
self.is_data_parallel = self.model.is_data_parallel | ||
self.is_last_stage = True if not self.is_pipe_parallel else model.is_last_stage() # only the last stage of the pipeline model will receive the logits | ||
self.generate = partial(generate_samples_from_prompt, neox_args=neox_args, model=model, maximum_tokens=self.max_gen_toks) | ||
|
||
|
||
def greedy_until(self, requests, batch=False): | ||
self.model.module.inference_mode() | ||
res = [] | ||
|
||
def _collate(x): | ||
toks = self.tokenizer.encode(x[0]) | ||
return (len(toks), x[0]) | ||
|
||
reord = utils.Reorderer(requests, _collate) | ||
for context, until in tqdm(reord.get_reordered()): | ||
if isinstance(until, str): | ||
until = [until] | ||
stop_tokens = [self.tokenizer.encode(i) for i in until] | ||
cont = self.generate(text=context, | ||
stop_tokens=stop_tokens, | ||
recompute = self.neox_args.recompute) | ||
|
||
s = cont[0]['text'] or '' | ||
|
||
for term in until: | ||
s = s.split(term)[0] | ||
|
||
# partial caching | ||
self.cache_hook.add_partial("greedy_until", (context, until), s) | ||
|
||
res.append(s) | ||
|
||
self.model.module.train_mode() | ||
return reord.get_original(res) | ||
|
||
def _loglikelihood_tokens(self, requests, disable_tqdm=False): | ||
disable_tqdm = disable_tqdm if self.is_main else True | ||
res = [] | ||
res_len = 0 # storing the result length for later | ||
with torch.no_grad(): | ||
|
||
def _collate(x): | ||
toks = x[1] + x[2] | ||
return (-len(toks), tuple(toks)) | ||
|
||
reord = utils.Reorderer(requests, _collate) | ||
for chunk in utils.chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), self.batch_size): | ||
inps, contlens, inplens, padding_length = [], [], [], None | ||
for _, context_enc, continuation_enc in chunk: | ||
|
||
# when too long to fit in context, truncate from the left | ||
inp = torch.tensor( | ||
(context_enc + continuation_enc)[-(self.max_length + 1):][:-1] | ||
, dtype=torch.long).to(self.device) | ||
inplen, = inp.shape | ||
|
||
cont = continuation_enc | ||
|
||
# since in _collate we make sure length is descending, the longest is always the first one. | ||
padding_length = padding_length if padding_length is not None else inplen | ||
|
||
# pad to length | ||
inp = torch.cat([ | ||
inp, # [seq] | ||
torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq] | ||
], dim=0) | ||
|
||
inps.append(inp.unsqueeze(0)) | ||
contlens.append(cont) | ||
inplens.append(inplen) | ||
|
||
logits = self._model_call(torch.cat(inps, dim=0)) | ||
res_len += len(chunk) | ||
|
||
if logits is not None: | ||
multi_logits = F.log_softmax(logits, dim=-1) # [batch, seq, vocab] | ||
for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(chunk, multi_logits, inps, inplens, | ||
contlens): | ||
contlen = len(cont_toks) | ||
logits = logits[inplen - contlen:inplen].unsqueeze(0) # [1, seq, vocab] | ||
greedy_tokens = logits.argmax(dim=-1) | ||
# cont_toks :: [1, seq] | ||
cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0).to(multi_logits.device) | ||
max_equal = (greedy_tokens == cont_toks).all() | ||
logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq] | ||
answer = (float(logits.sum()), bool(max_equal)) | ||
|
||
# partial caching | ||
if cache_key is not None: | ||
self.cache_hook.add_partial("loglikelihood", cache_key, answer) | ||
|
||
res.append(answer) | ||
|
||
# broadcast results to all ranks | ||
if self.is_pipe_parallel: | ||
src_rank = self.model.grid.stage_to_global(self.model.num_stages - 1) | ||
if res: | ||
logits_sums, max_equals = list(zip(*res)) | ||
logits_sums = torch.FloatTensor(logits_sums).cuda() | ||
max_equals = torch.LongTensor(max_equals).cuda() | ||
else: | ||
logits_sums = torch.zeros(res_len, dtype=torch.float32).cuda() | ||
max_equals = torch.zeros(res_len, dtype=torch.int64).cuda() | ||
torch.distributed.broadcast(tensor=logits_sums, src=src_rank) | ||
torch.distributed.broadcast(tensor=max_equals, src=src_rank) | ||
max_equals = [bool(i) for i in max_equals.tolist()] | ||
logits_sums = logits_sums.tolist() | ||
res = list(zip(logits_sums, max_equals)) | ||
|
||
return reord.get_original(res) | ||
|
||
def _model_call(self, inps): | ||
data_wrapped = iter([{'text': F.pad(inps, pad=(0, 1))}]) | ||
if self.neox_args.is_pipe_parallel: | ||
# need these flags to stop deepspeed from hanging | ||
self.model.first_output_send = True | ||
self.model.pipe_recv_buf = None | ||
_, logits = self._forward_step_fn(model=self.model, data_iterator=data_wrapped) | ||
return logits | ||
|
||
def run_eval(self, eval_tasks=None): | ||
was_training = self.model.training | ||
self.model.eval() | ||
in_micro_batches = self.model.micro_batches # store input microbatches - we need to set to 1 during eval | ||
self.model.micro_batches = 1 | ||
if eval_tasks is None: | ||
eval_tasks = ["lambada", "piqa", "hellaswag", "winogrande", "mathqa", "pubmedqa"] | ||
results = evaluator.evaluate(lm=self, | ||
task_dict=tasks.get_task_dict(eval_tasks), | ||
provide_description=False, | ||
num_fewshot=0, | ||
limit=None, | ||
bootstrap_iters=2).get('results') | ||
if was_training: | ||
self.model.train() | ||
self.model.micro_batches = in_micro_batches | ||
return results | ||
|
||
|
||
def run_eval_harness(model, forward_step_fn, neox_args, batch_size=None, eval_tasks=None): | ||
print_rank_0('Running evaluation harness...') | ||
adaptor = EvalHarnessAdaptor(model, forward_step_fn, neox_args, batch_size) | ||
return adaptor.run_eval(eval_tasks=eval_tasks) |
Oops, something went wrong.