Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Got an empty gpt2-tokenizer while pretraining with THE-PILE dataset #876

Closed
LostSpirit1307 opened this issue Apr 8, 2023 · 4 comments
Closed
Labels
bug Something isn't working

Comments

@LostSpirit1307
Copy link

Environment:

  • Ubuntu 20.04
  • CUDA Version: 11.6
  • conda version: 4.9.2
  • python version: 3.8.5
  • 4x 3090's

Description:

I have tried to use some of THE-PILE dataset to pretrain a toy model followed with README.md.

The problem occurs when I was running the train.py

The whole process were as follows:

  1. I run the prepare.py:
python tools/preprocess_data.py \
--input ./nlpdata/00.jsonl.zst \
--output-prefix ./nlpdata_processed \
--vocab ./20B_checkpoints/20B_tokenizer.json  \
 --merge-file gpt2-merges.txt \
--dataset-impl mmap \
--tokenizer-type HFTokenizer \
--append-eod

The data was 00.jsonl.zst from PILE.

Then it generated 4 files:

nlpdata_processed_text_document.bin
nlpdata_processed_text_document.idx`
gpt2-merges.txt (empty)  
gpt2-vocab.json (empty)

However, the empty files lead to the following error:

Outputs

building GPT2BPETokenizer tokenizer ... [245/1815]

Traceback (most recent call last):

File "train.py", line 25, in <module>

  neox_args.build_tokenizer()  # tokenizer needs to be build in training in order to set the padding vocab

File "/opt/data/private/tmp/gpt-neox/megatron/neox_arguments/arguments.py", line 138, in build_tokenizer

  self.tokenizer = build_tokenizer(self)

File "/opt/data/private/tmp/gpt-neox/megatron/tokenizer/tokenizer.py", line 41, in build_tokenizer

  tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)

File "/opt/data/private/tmp/gpt-neox/megatron/tokenizer/tokenizer.py", line 158, in __init__

  self.tokenizer = GPT2Tokenizer(

File "/opt/data/private/tmp/gpt-neox/megatron/tokenizer/gpt2_tokenization.py", line 188, in __init__

  self.encoder = json.load(open(vocab_file))

File "/opt/conda/lib/python3.8/json/__init__.py", line 293, in load

  return loads(fp.read(),

File "/opt/conda/lib/python3.8/json/__init__.py", line 357, in loads

  return _default_decoder.decode(s)

File "/opt/conda/lib/python3.8/json/decoder.py", line 337, in decode

  obj, end = self.raw_decode(s, idx=_w(s, 0).end())

File "/opt/conda/lib/python3.8/json/decoder.py", line 355, in raw_decode

  raise JSONDecodeError("Expecting value", s, err.value) from None

json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)

Traceback (most recent call last):

File "train.py", line 25, in <module>

  neox_args.build_tokenizer()  # tokenizer needs to be build in training in order to set the padding vocab

File "/opt/data/private/tmp/gpt-neox/megatron/neox_arguments/arguments.py", line 138, in build_tokenizer

  self.tokenizer = build_tokenizer(self)

File "/opt/data/private/tmp/gpt-neox/megatron/tokenizer/tokenizer.py", line 41, in build_tokenizer

  tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)

File "/opt/data/private/tmp/gpt-neox/megatron/tokenizer/tokenizer.py", line 158, in __init__

  self.tokenizer = GPT2Tokenizer(

File "/opt/data/private/tmp/gpt-neox/megatron/tokenizer/gpt2_tokenization.py", line 188, in __init__

  self.encoder = json.load(open(vocab_file))

File "/opt/conda/lib/python3.8/json/__init__.py", line 293, in load

  return loads(fp.read(),

File "/opt/conda/lib/python3.8/json/__init__.py", line 357, in loads

  return _default_decoder.decode(s)

File "/opt/conda/lib/python3.8/json/decoder.py", line 337, in decode

  obj, end = self.raw_decode(s, idx=_w(s, 0).end())

File "/opt/conda/lib/python3.8/json/decoder.py", line 355, in raw_decode

  raise JSONDecodeError("Expecting value", s, err.value) from None

json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)

Traceback (most recent call last):

File "train.py", line 25, in <module>

  neox_args.build_tokenizer()  # tokenizer needs to be build in training in order to set the padding vocab

File "/opt/data/private/tmp/gpt-neox/megatron/neox_arguments/arguments.py", line 138, in build_tokenizer

  self.tokenizer = build_tokenizer(self)

File "/opt/data/private/tmp/gpt-neox/megatron/tokenizer/tokenizer.py", line 41, in build_tokenizer [199/1815]

  tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)

File "/opt/data/private/tmp/gpt-neox/megatron/tokenizer/tokenizer.py", line 158, in __init__

  self.tokenizer = GPT2Tokenizer(

File "/opt/data/private/tmp/gpt-neox/megatron/tokenizer/gpt2_tokenization.py", line 188, in __init__

  self.encoder = json.load(open(vocab_file))

File "/opt/conda/lib/python3.8/json/__init__.py", line 293, in load

  return loads(fp.read(),

File "/opt/conda/lib/python3.8/json/__init__.py", line 357, in loads

  return _default_decoder.decode(s)

File "/opt/conda/lib/python3.8/json/decoder.py", line 337, in decode

  obj, end = self.raw_decode(s, idx=_w(s, 0).end())

File "/opt/conda/lib/python3.8/json/decoder.py", line 355, in raw_decode

  raise JSONDecodeError("Expecting value", s, err.value) from None

json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)

Traceback (most recent call last):

File "train.py", line 25, in <module>

  neox_args.build_tokenizer()  # tokenizer needs to be build in training in order to set the padding vocab

File "/opt/data/private/tmp/gpt-neox/megatron/neox_arguments/arguments.py", line 138, in build_tokenizer

  self.tokenizer = build_tokenizer(self)

File "/opt/data/private/tmp/gpt-neox/megatron/tokenizer/tokenizer.py", line 41, in build_tokenizer

  tokenizer = _GPT2BPETokenizer(args.vocab_file, args.merge_file)

File "/opt/data/private/tmp/gpt-neox/megatron/tokenizer/tokenizer.py", line 158, in __init__

  self.tokenizer = GPT2Tokenizer(

File "/opt/data/private/tmp/gpt-neox/megatron/tokenizer/gpt2_tokenization.py", line 188, in __init__

  self.encoder = json.load(open(vocab_file))

File "/opt/conda/lib/python3.8/json/__init__.py", line 293, in load

  return loads(fp.read(),

File "/opt/conda/lib/python3.8/json/__init__.py", line 357, in loads

  return _default_decoder.decode(s)

File "/opt/conda/lib/python3.8/json/decoder.py", line 337, in decode

  obj, end = self.raw_decode(s, idx=_w(s, 0).end())

File "/opt/conda/lib/python3.8/json/decoder.py", line 355, in raw_decode

  raise JSONDecodeError("Expecting value", s, err.value) from None

json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)

Killing subprocess 109450

Killing subprocess 109451

Killing subprocess 109452

Killing subprocess 109453

I wonder why the prepare.py gave me 2 empty files that lead to the error.

Could you tell me how to solve the problem?

If you need more information, I would provide them as soon as possible.

@LostSpirit1307 LostSpirit1307 added the bug Something isn't working label Apr 8, 2023
@StellaAthena
Copy link
Member

It seems like you’re using the tokenizer incorrectly. Your specifying a vocab file that corresponds to the GPT-NeoX tokenizer but a merges file that corresponds to the GPT-2 tokenizer. Which one are you trying to use?

@LostSpirit1307
Copy link
Author

I used HFTokenizer to produce the .bin file and .ind file. But I remember that the vocab file and merge file are downloaded from GPT2_VOCAB_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json"
GPT2_MERGE_URL = "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt"
respectively.

I have some new progress on this issue.
First, it is the Internet problem that lead to the empty file. When I downloaded them correctly, the error mentioned above solved. But as you mentioned above, does there exist vocab and merge file for HFTokenizer?
What's more, after I solved the empty problem, I got new errors as follows. It seems that some parameters were not assigned correctly:

time (ms) | model and optimizer: 1164.72 | train/valid/test data iterators: 555.56                                                                                              
training ...
Traceback (most recent call last):
  File "train.py", line 27, in <module>
    pretrain(neox_args=neox_args)
  File "/opt/data/private/tmp/gpt-neox/megatron/training.py", line 116, in pretrain
    iteration = train(
  File "/opt/data/private/tmp/gpt-neox/megatron/training.py", line 580, in train
    loss_dict, skipped_iter = train_step(
  File "/opt/data/private/tmp/gpt-neox/megatron/training.py", line 487, in train_step
    reduced_loss = train_step_pipe(
  File "/opt/data/private/tmp/gpt-neox/megatron/training.py", line 536, in train_step_pipe
    loss = model.train_batch(data_iter=data_iterator)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 305, in train_batch
    self._exec_schedule(sched)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 1308, in _exec_schedule
    self._exec_instr(**cmd.kwargs)
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 787, in _exec_load_micro_batch
    batch = self._next_batch()
  File "/opt/conda/lib/python3.8/site-packages/deepspeed/runtime/pipe/engine.py", line 646, in _next_batch
    batch = next(self.data_iterator)
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 517, in __next__
    data = self._next_data()
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1199, in _next_data
    return self._process_data(data)
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1225, in _process_data
    data.reraise()
  File "/opt/conda/lib/python3.8/site-packages/torch/_utils.py", line 429, in reraise
    raise self.exc_type(msg)
ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 202, in _worker_loop
    data = fetcher.fetch(index)
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/opt/data/private/tmp/gpt-neox/megatron/data/gpt2_dataset.py", line 92, in __getitem__
    self.indexed_dataset.get(self.doc_idx[doc_index_f], offset=offset_f)
  File "/opt/data/private/tmp/gpt-neox/megatron/data/indexed_dataset.py", line 532, in get
    np_array = np.frombuffer(
ValueError: offset must be non-negative and no greater than buffer length (214820)

@LostSpirit1307
Copy link
Author

Here is the content of the configuration file:

# GPT-3 pretraining setup
{
   # parallelism settings
   "pipe-parallel-size": 2,
   "model-parallel-size": 2,

   "data-path": "nlpdata_processed_text_document",
   "vocab-file": "./nlpdata/gpt2-vocab.json",
   "merge-file": "./nlpdata/gpt2-merges.txt",

  # model settings
   "num-layers": 12,
   "hidden-size": 768,
   "num-attention-heads": 12,
   "seq-length": 2048,
   "max-position-embeddings": 2048,
   "norm": "rmsnorm",
   "pos-emb": "none",
   "no-weight-tying": true,
    # this should provide some speedup but takes a while to build, set to true if desired
   "scaled-upper-triang-masked-softmax-fusion": false,
   "train-iters": 320000,

   # optimizer settings
   "optimizer": {
     "type": "Adam",
     "params": {
       "lr": 0.0006,
       #"max_grad_norm": 1.0,
       "betas": [0.9, 0.95]
     }
   },
   # for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
   "zero_optimization": {
    "stage": 0,
    "allgather_partitions": True,
    "allgather_bucket_size": 500000000,
    "overlap_comm": True,
    "reduce_scatter": True,
    "reduce_bucket_size": 500000000,
    "contiguous_gradients": True,
  },

   # batch / data settings
   "train_micro_batch_size_per_gpu": 4,
   "gradient_accumulation_steps": 1,
   "data-impl": "mmap",
   "split": "949,50,1",

   # activation checkpointing
   "checkpoint-activations": true,
   "checkpoint-num-layers": 1,
   "partition-activations": true,
   "synchronize-each-layer": true,

   # regularization
   "gradient_clipping": 1.0,
   "weight-decay": 0,
   "hidden-dropout": 0,
   "attention-dropout": 0,

   # precision settings
   "fp16": {
     "enabled": true,
     "loss_scale": 0,
     "loss_scale_window": 1000,
     "hysteresis": 2,
     "min_loss_scale": 1
   },

   # lr decay settings
   "lr-decay-iters": 320000,
   "lr-decay-style": "cosine",
   "warmup": 0.01,

   # misc. training settings
   "distributed-backend": "nccl",
   "checkpoint-factor": 10000,
   "eval-interval": 1000,
   "eval-iters": 10,

   # logging
   "log-interval": 100,
   "steps_per_print": 10,
   "keep-last-n-checkpoints": 4,
   "wall_clock_breakdown": true
}

@Quentin-Anthony
Copy link
Member

As Stella mentioned, you should choose either HFTokenizer or the gpt2 tokenizer. Your tokenized data and config should all use the same tokenizer/vocab.

If you want to use gpt2, use those vocab/merges files to tokenize with prepare_data.py, then pass them to the neox config.

If you want to use HFTokenizer like we did for neox-20B, use it to tokenize your data with prepare_data.py to produce the .bin and .idx, then grab the vocab from https://the-eye.eu/public/AI/models/GPT-NeoX-20B/slim_weights/20B_tokenizer.json and pass it in your config like we did for Pythia (https://github.com/EleutherAI/pythia/blob/7c6ffc41b3374e4b2b1eadbf2ad62e5b2f297c80/models/12B/pythia-12b.yml#L89-L90)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants