Skip to content

Commit

Permalink
add TiktokenTokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
haileyschoelkopf committed Jan 31, 2023
1 parent d36f623 commit b4784bd
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 1 deletion.
3 changes: 2 additions & 1 deletion megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,9 +658,10 @@ class NeoXArgsTokenizer(NeoXArgsTemplate):
"HFGPT2Tokenizer",
"SPMTokenizer",
"CharLevelTokenizer",
"TiktokenTokenizer",
] = "GPT2BPETokenizer"
"""
Type of tokenizer to use - should be one of ["GPT2BPETokenizer", "HFTokenizer", "HFGPT2Tokenizer", "SPMTokenizer", "CharLevelTokenizer"]
Type of tokenizer to use - should be one of ["GPT2BPETokenizer", "HFTokenizer", "HFGPT2Tokenizer", "SPMTokenizer", "CharLevelTokenizer", "TiktokenTokenizer"]
"""

padded_vocab_size: int = None
Expand Down
47 changes: 47 additions & 0 deletions megatron/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from transformers import GPT2Tokenizer, GPT2TokenizerFast
import numpy as np
import sentencepiece as spm
import tiktoken
from typing import List, Union
from .gpt2_tokenization import GPT2Tokenizer

Expand Down Expand Up @@ -52,6 +53,9 @@ def build_tokenizer(args):
tokenizer = HFGPT2Tokenizer(args.vocab_file)
elif args.tokenizer_type.lower() == "CharLevelTokenizer".lower():
tokenizer = CharLevelTokenizer(vocab_size=512)
elif args.tokenizer_type.lower() == "TiktokenTokenizer".lower():
assert args.vocab_file is not None
tokenizer = TiktokenTokenizer(args.vocab_file)
else:
raise NotImplementedError(
"{} tokenizer is not " "implemented.".format(args.tokenizer_type)
Expand Down Expand Up @@ -345,3 +349,46 @@ def detokenize(self, token_ids):
@property
def eod(self):
return self.eod_id


class TiktokenTokenizer(AbstractTokenizer):
"""Tokenizer from OpenAI's tiktoken implementation"""

def __init__(self, vocab_file):
name = "TiktokenTokenizer"
super().__init__(name)

self.tokenizer = tiktoken.get_encoding(vocab_file)
self.eod_id = self.tokenizer.eot_token
self.pad_id = None

@property
def vocab_size(self):
return self.tokenizer.n_vocab

@property
def vocab(self):
raise NotImplementedError

@property
def inv_vocab(self):
raise NotImplementedError

def tokenize(self, text: str):
return self.tokenizer.encode(text) #, allowed_special="all")

def tokenize_batch(self, text_batch: List[str]):
return self.tokenizer.encode_batch(text_batch, allowed_special="all")

def detokenize(self, token_ids):
return self.tokenizer.decode(tokens=token_ids, errors="strict")

@property
def eod(self):
return self.eod_id

@property
def pad(self):
raise NotImplementedError


1 change: 1 addition & 0 deletions prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"HFTokenizer",
"GPT2BPETokenizer",
"CharLevelTokenizer",
"TiktokenTokenizer",
]
DATASET_CHOICES = [i for i in DATA_DOWNLOADERS.keys() if i != "pass"]

Expand Down
1 change: 1 addition & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pybind11==2.6.2
regex
sentencepiece
six
tiktoken==0.1.2
tokenizers==0.12.1
transformers~=4.24.0
wandb==0.10.28
1 change: 1 addition & 0 deletions tools/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def get_args():
"HFTokenizer",
"GPT2BPETokenizer",
"CharLevelTokenizer",
"TiktokenTokenizer",
],
help="What type of tokenizer to use.",
)
Expand Down

0 comments on commit b4784bd

Please sign in to comment.