Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge from magma #20

Merged
merged 10 commits into from
Jun 13, 2023
3 changes: 2 additions & 1 deletion configs/summit-70m-openclipH.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,6 @@
"steps_per_print": 10,
"wall_clock_breakdown": true,

# "tokenizer-type": "HFTokenizer"
"tokenizer-type": "HFTokenizer"

}
6 changes: 2 additions & 4 deletions configs/summit_setup.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
"valid-data-paths": "/gpfs/alpine/csc499/proj-shared/LAION-400m-webdataset/data/{40001..41000}.tar",
"test-data-paths": "/gpfs/alpine/csc499/proj-shared/LAION-400m-webdataset/data/{41000..41455}.tar",

# we use tokenzier from huggingface, don't need vocal or merge file
#"vocab-file": "/home/lfsm/code/gpt-neox/data/gpt2-vocab.json",
#"merge-file": "/home/lfsm/code/gpt-neox/data/gpt2-merges.txt",

"tokenizer_type": "HFGPT2Tokenizer",
"vocab-file": "./data/20B_tokenizer.json",


"save": "/gpfs/alpine/scratch/lfsm/csc499/checkpoints",
"load": "/gpfs/alpine/scratch/lfsm/csc499/checkpoints",
Expand Down
2 changes: 2 additions & 0 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,9 @@ def load_checkpoint(
neox_args.load,
load_optimizer_states=load_optim_and_scheduler,
load_lr_scheduler_states=load_optim_and_scheduler,
load_module_only=not load_optim_and_scheduler,
tag=tag,
load_module_strict=False,
)

if checkpoint_name is None:
Expand Down
7 changes: 3 additions & 4 deletions megatron/data/webdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,17 +330,16 @@ def get_wds_data(args, is_train, epoch=0, floor=False):
preprocess_img = get_clip_transforms(image_size=args.image_size)

assert (
args.tokenizer.name in ['HFGPT2Tokenizer','HFGPT2TokenizerFast']
), f"Webdataset only support HFGPT2Tokenizer or HFGPT2TokenizerFast"
args.tokenizer.name in ['HFGPT2Tokenizer','HFGPT2TokenizerFast','HFTokenizer']
), f"Webdataset only support HFTokenizer, HFGPT2Tokenizer or HFGPT2TokenizerFast"

tokenize = args.tokenizer.tokenize
seq_length = args.seq_length

pipeline.extend([
wds.select(filter_no_caption_or_no_image),
wds.decode("pilrgb", handler=log_and_continue),
wds.rename(image="jpg;png;jpeg;webp", text="txt"),
wds.map_dict(image=preprocess_img, text=lambda text: tokenize(text,seq_length)[0]),
wds.map_dict(image=preprocess_img, text=lambda text: tokenize(text)[0]),
wds.to_tuple("image", "text"),
wds.batched(args.batch_size, collation_fn=image_text_dict_collation_fn, partial=not is_train)
])
Expand Down
42 changes: 32 additions & 10 deletions megatron/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from abc import ABC
from abc import abstractmethod
import torch

from tokenizers import Tokenizer
from transformers import GPT2Tokenizer, GPT2TokenizerFast
Expand All @@ -43,9 +44,9 @@ def build_tokenizer(args):
tokenizer = SentencePieceTokenizer(args.vocab_file)
elif args.tokenizer_type.lower() == "HFTokenizer".lower():
assert args.vocab_file is not None
tokenizer = HFTokenizer(args.vocab_file)
tokenizer = HFTokenizer(args.vocab_file,seq_length=args.seq_length)
elif args.tokenizer_type.lower() == "HFGPT2Tokenizer".lower():
tokenizer = HFGPT2Tokenizer()
tokenizer = HFGPT2Tokenizer(seq_length=args.seq_length)
elif args.tokenizer_type.lower() == "CharLevelTokenizer".lower():
tokenizer = CharLevelTokenizer(vocab_size=512)
elif args.tokenizer_type.lower() == "TiktokenTokenizer".lower():
Expand Down Expand Up @@ -220,13 +221,15 @@ def eod(self):
class HFTokenizer(AbstractTokenizer):
"""Designed to Integrate HF's Tokenizer library."""

def __init__(self, vocab_file):
def __init__(self, vocab_file,seq_length):
name = "HFTokenizer"
super().__init__(name)

self.tokenizer = Tokenizer.from_file(vocab_file)
self.eod_id = self.tokenizer.token_to_id("<|endoftext|>")
self.pad_id = self.tokenizer.token_to_id("<|padding|>")
self.seq_length = seq_length
self.tokenizer.enable_truncation(max_length=seq_length)

@property
def vocab_size(self):
Expand All @@ -240,8 +243,20 @@ def vocab(self):
def inv_vocab(self):
return self.tokenizer.decoder

def tokenize(self, text: str):
return self.tokenizer.encode(text).ids
def tokenize(self, texts: Union[str, List[str]], context_length=2048):
if isinstance(texts, str):
texts = [texts]
texts = [whitespace_clean(basic_clean(text)) for text in texts]
input_ids = [encoding.ids for encoding in self.tokenizer.encode_batch(texts)]
# add eod_id and pad with pad_id
for idx,ids in enumerate(input_ids):
if len(ids) < self.seq_length:
ids = ids+[self.eod_id]+[self.pad_id]*(self.seq_length-len(ids)-1)
else: # truncated
ids = ids[:-1]+[self.eod_id]
input_ids[idx]=ids
input_ids = torch.tensor(input_ids,dtype=torch.int64)
return input_ids

def tokenize_batch(self, text_batch: Union[List[str], str]):
return self.tokenizer.encode_batch(text_batch)
Expand Down Expand Up @@ -271,7 +286,7 @@ def whitespace_clean(text):
class HFGPT2Tokenizer(AbstractTokenizer):
"""Designed to Integrate the pretrained OpenAI GPT2 Tokenizers from HF"""

def __init__(self, vocab_file=None, fast=True):
def __init__(self, seq_length, fast=True):
name = "HFGPT2Tokenizer"
if fast:
name += "Fast"
Expand All @@ -285,6 +300,7 @@ def __init__(self, vocab_file=None, fast=True):
self.tokenizer.add_special_tokens({"pad_token": "<|padding|>"})
self.eod_id = self.tokenizer.eos_token_id
self.pad_id = self.tokenizer.pad_token_id
self.seq_length = seq_length

@property
def vocab_size(self):
Expand All @@ -298,17 +314,23 @@ def vocab(self):
def inv_vocab(self):
return self.tokenizer._tokenizer.decoder

def tokenize(self, texts: Union[str, List[str]], context_length=2048):
def tokenize(self, texts: Union[str, List[str]]):
if isinstance(texts, str):
texts = [texts]
texts = [whitespace_clean(basic_clean(text)) for text in texts]
input_ids = self.tokenizer(
texts,
return_tensors='pt',
max_length=context_length,
padding='max_length',
max_length=self.seq_length,
truncation=True,
).input_ids
# add eod_id and pad with pad_id
for idx,ids in enumerate(input_ids):
if len(ids) < self.seq_length:
ids = ids+[self.eod_id]+[self.pad_id]*(self.seq_length-len(ids)-1)
else: # truncated
ids = ids[:-1]+[self.eod_id]
input_ids[idx]=ids
input_ids = torch.tensor(input_ids,dtype=torch.int64)
return input_ids

# def tokenize_batch(self, text_batch: Union[List[str], str]):
Expand Down
2 changes: 2 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype):
# Get the masks and position ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
data=tokens,
pad_token=neox_args.tokenizer.pad_id,
eod_token=neox_args.tokenizer.eod,
eod_mask_loss=neox_args.eod_mask_loss,
)
Expand Down Expand Up @@ -360,6 +361,7 @@ def get_batch_pipe_image_text(input,neox_args):

attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
data=captions,
pad_token=neox_args.tokenizer.pad_id,
eod_token=neox_args.tokenizer.eod,
eod_mask_loss=neox_args.eod_mask_loss,
)
Expand Down
2 changes: 2 additions & 0 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def get_attn_mask(seq_length, device):

def get_ltor_masks_and_position_ids(
data,
pad_token,
eod_token,
eod_mask_loss=False,
):
Expand All @@ -94,6 +95,7 @@ def get_ltor_masks_and_position_ids(

# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
loss_mask[data == pad_token] = 0.0
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0

Expand Down
Loading