diff --git a/configs/1-3B.yml b/configs/1-3B.yml index 3e80ae7fc..f5523c6ba 100644 --- a/configs/1-3B.yml +++ b/configs/1-3B.yml @@ -88,4 +88,7 @@ "steps_per_print": 10, "keep_last_n_checkpoints": 4, "wall_clock_breakdown": true, + + ## tokenizer type + "tokenizer_type": "SPMTokenizer", } diff --git a/configs/125M.yml b/configs/125M.yml index 15a4b3b01..504879123 100644 --- a/configs/125M.yml +++ b/configs/125M.yml @@ -90,5 +90,8 @@ "wall_clock_breakdown": true, # networking - "hostfile": "/mock_path" + "hostfile": "/mock_path", + + ## tokenizer type + "tokenizer_type": "SPMTokenizer" } diff --git a/configs/19M.yml b/configs/19M.yml index 83e5c594a..94648b3a8 100644 --- a/configs/19M.yml +++ b/configs/19M.yml @@ -76,20 +76,24 @@ "warmup": 0.01, "checkpoint_factor": 1000, "eval_interval": 100000, - "eval_iters": 10, + "eval_iters": 1000, + "keep_last_n_checkpoints": 4, + "save_iters": 1000, - "log_interval": 10, - "steps_per_print": 10, + "log_interval": 1000, + "steps_per_print": 1000, "wall_clock_breakdown": true, # additional deepspeed args not specified above "deepspeed_extra_args": { "comms_logger": { - "enabled": true, - "verbose": true, - "prof_all": true, + "enabled": false, + "verbose": false, + "prof_all": false, "debug": false }, - } - + }, + + ## tokenizer type + "tokenizer_type": "SPMTokenizer" } diff --git a/configs/20B.yml b/configs/20B.yml index 243f794d0..46b44c04b 100644 --- a/configs/20B.yml +++ b/configs/20B.yml @@ -104,7 +104,8 @@ "wall_clock_breakdown": false, ### NEW DATA: #### - "tokenizer_type": "HFTokenizer", + # "tokenizer_type": "HFTokenizer", + "tokenizer_type": "SPMTokenizer" "tensorboard-dir": "./tensorboard", "log_dir": "./logs", diff --git a/configs/49M.yml b/configs/49M.yml index 9852320b0..c688a54a3 100644 --- a/configs/49M.yml +++ b/configs/49M.yml @@ -9,15 +9,42 @@ "num_attention_heads": 10, "seq_length": 2048, "max_position_embeddings": 2048, - "pos_emb": "rotary", + # "pos_emb": "rotary", + "pos_emb": "xpos", "rotary_pct": 0.25, "no_weight_tying": true, "gpt_j_residual": true, "output_layer_parallelism": "column", + # "activation": "gelu", + "activation": "swiglu", + "norm": "rmsnorm", + # "use_bnb_optimizer": true, # these should provide some speedup but takes a while to build, set to true if desired - "scaled_upper_triang_masked_softmax_fusion": false, - "bias_gelu_fusion": false, + #"scaled_upper_triang_masked_softmax_fusion": false, + #"bias_gelu_fusion": false, + "scaled_upper_triang_masked_softmax_fusion": true, + "bias-gelu-fusion": true, + # "attention-config": [ + # [ + # [ + # "flash" + # ], + # 10 + # ] + # ], + "curriculum_learning": { + "enabled": true, + "curriculum_type": "seqlen", + "min_difficulty": 64, + "max_difficulty": 2048, + "schedule_type": "fixed_linear", + "schedule_config": { + "total_curriculum_step": 20000, + "difficulty_step": 8 + } + }, + # init methods "init_method": "small_init", @@ -42,11 +69,13 @@ "overlap_comm": True, "reduce_scatter": True, "reduce_bucket_size": 500000000, - "contiguous_gradients": True, + "contiguous_gradients": True, }, + # "zero_allow_untested_optimizer": true, # batch / data settings - "train_micro_batch_size_per_gpu": 32, + "train_micro_batch_size_per_gpu": 8, + # "train_micro_batch_size_per_gpu": 32, "gas": 1, "data_impl": "mmap", "num_workers": 1, @@ -80,12 +109,27 @@ "distributed_backend": "nccl", "lr_decay_style": "cosine", "warmup": 0.01, - "checkpoint_factor": 1000, + "checkpoint_factor": 5000, "eval_interval": 100000, - "eval_iters": 10, + "eval_iters": 1000, + "keep_last_n_checkpoints": 4, + "save_iters": 5000, # logging - "log_interval": 10, - "steps_per_print": 10, + "log_interval": 1000, + "steps_per_print": 1000, + "keep_last_n_checkpoints": 4, "wall_clock_breakdown": true, + + ## tokenizer type + "tokenizer_type": "SPMTokenizer", + + "deepspeed_extra_args": { + "comms_logger": { + "enabled": false, + "verbose": false, + "prof_all": false, + "debug": false + } + } } diff --git a/configs/convert_19M_settings.yml b/configs/convert_19M_settings.yml new file mode 100644 index 000000000..baf797385 --- /dev/null +++ b/configs/convert_19M_settings.yml @@ -0,0 +1,31 @@ +{ + "tokenizer_type": "SPMTokenizer", + "vocab-file": "./novelAI/tokenizer.model", + + "pipe_parallel_size": 1, + "model_parallel_size": 1, + + # model settings + "num_layers": 6, + "hidden_size": 512, + "num_attention_heads": 8, + "seq_length": 2048, + "max_position_embeddings": 2048, + "pos_emb": "rotary", + "no_weight_tying": true, + "gpt_j_residual": false, + "output_layer_parallelism": "column", + + "scaled_upper_triang_masked_softmax_fusion": false, + "bias_gelu_fusion": false, + + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.0001 +} diff --git a/configs/convert_49M_settings.yml b/configs/convert_49M_settings.yml new file mode 100644 index 000000000..9dd0301bd --- /dev/null +++ b/configs/convert_49M_settings.yml @@ -0,0 +1,37 @@ +{ + "tokenizer_type": "SPMTokenizer", + "vocab-file": "./novelAI/tokenizer.model", + + "pipe_parallel_size": 1, + "model_parallel_size": 1, + + # model settings + "num_layers": 10, + "hidden_size": 640, + "num_attention_heads": 10, + "seq_length": 2048, + "max_position_embeddings": 2048, + + "activation": "swiglu", + "norm": "rmsnorm", + "pos_emb": "xpos", + + ## ------------------- + # "pos_emb": "rotary", + "no_weight_tying": true, + "gpt_j_residual": false, + "output_layer_parallelism": "column", + + "scaled_upper_triang_masked_softmax_fusion": false, + "bias_gelu_fusion": false, + + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.0001 +} diff --git a/configs/convert_settings.yml b/configs/convert_settings.yml new file mode 100644 index 000000000..baf797385 --- /dev/null +++ b/configs/convert_settings.yml @@ -0,0 +1,31 @@ +{ + "tokenizer_type": "SPMTokenizer", + "vocab-file": "./novelAI/tokenizer.model", + + "pipe_parallel_size": 1, + "model_parallel_size": 1, + + # model settings + "num_layers": 6, + "hidden_size": 512, + "num_attention_heads": 8, + "seq_length": 2048, + "max_position_embeddings": 2048, + "pos_emb": "rotary", + "no_weight_tying": true, + "gpt_j_residual": false, + "output_layer_parallelism": "column", + + "scaled_upper_triang_masked_softmax_fusion": false, + "bias_gelu_fusion": false, + + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.001, + "betas": [0.9, 0.95], + "eps": 1.0e-8, + } + }, + "min_lr": 0.0001 +} diff --git a/configs/local_setup_ja.yml b/configs/local_setup_ja.yml new file mode 100644 index 000000000..69ab5effb --- /dev/null +++ b/configs/local_setup_ja.yml @@ -0,0 +1,36 @@ +# Suggested data paths when using GPT-NeoX locally +{ + # "data_path": "data/wiki_ja", + # "data_path": "data/wiki_ja/wiki_ja_text_document", + + # or for weighted datasets: + "train-data-paths": ["data/wiki_ja_novelAI_bin/wiki_ja_text_document", "data/oscar_ja_novelAI_bin/oscar_ja_text_document","data/wiki_en_novelAI_bin/wiki_en_text_document" ,"data/aozora_ja_novelAI_bin/aozora_ja_text_document"], + "test-data-paths": ["data/wiki_ja_novelAI_bin/wiki_ja_text_document", "data/oscar_ja_novelAI_bin/oscar_ja_text_document","data/wiki_en_novelAI_bin/wiki_en_text_document" ,"data/aozora_ja_novelAI_bin/aozora_ja_text_document"], + "valid-data-paths": ["data/wiki_ja_novelAI_bin/wiki_ja_text_document", "data/oscar_ja_novelAI_bin/oscar_ja_text_document","data/wiki_en_novelAI_bin/wiki_en_text_document" ,"data/aozora_ja_novelAI_bin/aozora_ja_text_document"], + "train-data-weights": [0.9, 0.9, 0.9, 0.9], + "test-data-weights": [0.1, 0.1, 0.1, 0.1], + "valid-data-weights": [0.1, 0.1, 0.1, 0.1], + + # If weight_by_num_documents is True, Builds dataset weights from a multinomial distribution over groups of data according to the number of documents in each group. + # WARNING: setting this to True will override any user provided weights + # "weight_by_num_documents": false, + # "weighted_sampler_alpha": 0.3, + + "vocab_file": "./novelAI/tokenizer.model", + + #"save": "checkpoints", + # "load": "checkpoints", + "save": "/content/drive/MyDrive/pre_trained/49M/checkpoints", + "load": "/content/drive/MyDrive/pre_trained/49M/checkpoints", + + "checkpoint_validation_with_forward_pass": False, + + ## logging + # "log_dir": "logs", + "log_dir": "/content/drive/MyDrive/pre_trained/49M/logs", + + # "tensorboard_dir": "tensorboard", + "tensorboard_dir": "/content/drive/MyDrive/pre_trained/49M/tensorboard", + # "log_dir": "logs", + "use_wandb": False +} diff --git a/configs/text_generation.yml b/configs/text_generation.yml index 5a49d61e5..637105563 100644 --- a/configs/text_generation.yml +++ b/configs/text_generation.yml @@ -2,7 +2,7 @@ # Make sure `load` is specified somewhere else { # Text gen type: `input-file`, `unconditional` or `interactive` - "text_gen_type": "unconditional", + "text_gen_type": "interactive", # Params for all "maximum_tokens": 102, @@ -13,9 +13,9 @@ "recompute": false, # `unconditional`: samples - "num_samples": 10, + # "num_samples": 10, # input/output file - "sample_input_file": "sample_input.txt", - "sample_output_file": "sample_output.txt", + #"sample_input_file": "sample_input.txt", + #"sample_output_file": "sample_output.txt", } diff --git a/eval_tasks/eval_adapter.py b/eval_tasks/eval_adapter.py index e0a32797d..cede4f93b 100644 --- a/eval_tasks/eval_adapter.py +++ b/eval_tasks/eval_adapter.py @@ -13,18 +13,18 @@ # limitations under the License. from megatron.utils import is_local_main, print_rank_0 -import best_download +# import best_download -# patch best_download (eval harness downloader) to only happen on the first local rank -fn = best_download.download_file +# # patch best_download (eval harness downloader) to only happen on the first local rank +# fn = best_download.download_file -def _download_file(*args, **kwargs): - if is_local_main(): - fn(*args, **kwargs) +# def _download_file(*args, **kwargs): +# if is_local_main(): +# fn(*args, **kwargs) -best_download.download_file = _download_file +# best_download.download_file = _download_file import os import sys diff --git a/megatron/__init__.py b/megatron/__init__.py index 4a9f98a31..a2cf4df99 100644 --- a/megatron/__init__.py +++ b/megatron/__init__.py @@ -15,13 +15,12 @@ def print_rank_0(*message): - """If distributed is initialized print only on rank 0.""" + """If distributed is initialized print only on rank 0.""" if torch.distributed.is_initialized(): if torch.distributed.get_rank() == 0: print(*message, flush=True) else: - print(*message, flush=True) - + print(*message, flush=True) from .initialize import initialize_megatron from .neox_arguments import NeoXArgs diff --git a/megatron/model/activations.py b/megatron/model/activations.py index 5c4ba1d5a..ea7935d8a 100644 --- a/megatron/model/activations.py +++ b/megatron/model/activations.py @@ -46,7 +46,9 @@ def get_activation(neox_args): elif neox_args.activation == "mish": activation_func = mish elif neox_args.activation == "silu": - activation_func = F.silu + activation_func = F.silu + elif neox_args.activation == "swiglu": + activation_func = swiglu else: raise ValueError(f"Activation function {neox_args.activation} not recognized") return activation_func @@ -120,6 +122,11 @@ def swish(x, beta: float = 1.0): def mish(x): return x * torch.tanh(F.softplus(x)) +@torch.jit.script +def swiglu(x): + return F.silu(x) * x + # x = torch.chunk(x, 2, dim=-1) + # return F.silu(x[0]) * x[1] class GEGLU(torch.nn.Module): def __init__(self, neox_args): diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py index 68815075a..59f1bb02e 100644 --- a/megatron/model/positional_embeddings.py +++ b/megatron/model/positional_embeddings.py @@ -221,3 +221,110 @@ def forward(self, x): ) # seq_len_k - 1 points to the last token index in the current inference batch. return x + a + + +# Original implementation adjusted from https://github.com/sunyt32/torchscale + +def fixed_pos_embedding(x, base): + seq_len, dim = x.shape + inv_freq = 1.0 / (base ** (torch.arange(0, dim) / dim)) + sinusoid_inp = ( + torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x) + ) + return torch.cos(sinusoid_inp), torch.sin(sinusoid_inp) + + +class XPosEmbedding(torch.nn.Module): + """ + xPos positional embeddings from https://arxiv.org/abs/2212.10554. + """ + + def __init__(self, head_dim, freq_base=10000, scale_base=512, gamma=0.4, precision=torch.half): + super().__init__() + self.scale_base = scale_base + self.register_buffer( + "scale", + ( + (torch.arange(0, head_dim, 2) + gamma * head_dim) + / ((1.0 + gamma) * head_dim) + ), + ) + self.max_seq_len_cached = None + self.precision = precision + self.freq_base = freq_base + + def forward(self, x, seq_dim=1, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + scale = ( + self.scale + ** ( + torch.arange(0, seq_len, 1) - seq_len // 2 + ).to(self.scale).div(self.scale_base)[:, None] + ) + + if ( + self.max_seq_len_cached is None + or (seq_len > self.max_seq_len_cached) + ): + self.max_seq_len_cached = seq_len + cos, sin = fixed_pos_embedding(scale, self.freq_base) + self.cos_cached = cos + self.sin_cached = sin + if self.precision == torch.bfloat16: + self.cos_cached = self.cos_cached.bfloat16() + self.sin_cached = self.sin_cached.bfloat16() + return ( + self.cos_cached[:seq_len], + self.sin_cached[:seq_len], + scale, + ) + + +def rotate_every_two(x): + x1 = x[:, :, ::2] + x2 = x[:, :, 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\ + + +def duplicate_interleave(m): + """ + A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy. + """ + dim0 = m.shape[0] + m = m.view(-1, 1) # flatten the matrix + m = m.repeat(1, 2) # repeat all elements into the 2nd dimension + m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy + return m.unsqueeze(1) + + +def _apply_xpos_emb(x, cos, sin, scale): + # x is assumed to be (seq_len, batch_size, dim) here. + cos = duplicate_interleave(cos * scale) + sin = duplicate_interleave(sin * scale) + # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) + return (x * cos) + (rotate_every_two(x) * sin) + + +@torch.jit.script +def apply_xpos_emb(q, k, cos, sin, scale, offset: int = 0): + # q/k are assumed to be (seq_len, batch_size, dim) here. + cos = cos[offset:q.shape[0] + offset] + sin = sin[offset:q.shape[0] + offset] + scale = scale[offset:q.shape[0] + offset] + return ( + _apply_xpos_emb(q, cos, sin, scale), + _apply_xpos_emb(k, cos, sin, 1.0 / scale), + ) + + +def apply_xpos_emb_torch(q, k, cos, sin, scale, offset: int = 0): + # q/k are assumed to be (seq_len, batch_size, dim) here. + cos = cos[offset:q.shape[0] + offset] + sin = sin[offset:q.shape[0] + offset] + scale = scale[offset:q.shape[0] + offset] + return ( + _apply_xpos_emb(q, cos, sin, scale), + _apply_xpos_emb(k, cos, sin, 1.0 / scale), + ) \ No newline at end of file diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 4e81b70b6..82817d5c2 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -32,6 +32,9 @@ apply_rotary_pos_emb_torch, apply_rotary_pos_emb, AliBi, + XPosEmbedding, + apply_xpos_emb_torch, + apply_xpos_emb ) from megatron.model.fused_bias_dropout import ( get_bias_dropout_add, @@ -332,6 +335,11 @@ def __init__( else: self.rotary_emb = None + ## xpos + if neox_args.pos_emb == "xpos": + self.xpos_emb = XPosEmbedding(self.hidden_size_per_attention_head, precision=neox_args.params_dtype) + else: + self.xpos_emb = None self.attention_type = neox_args.attention_config[layer_number] self.use_flash_attention = self.attention_type == "flash" self.sparse = self.attention_type not in ("global", "flash") @@ -570,7 +578,6 @@ def flash_attention(self, query_layer, key_layer, value_layer): ) # [b, sq, np, hn] -> [b, np, sq, hn] matmul_result = matmul_result.transpose(1, 2) - else: # [sq, b, np, hn] -> [b, sq, np, hn] sq = query_layer.size(0) @@ -656,6 +663,7 @@ def forward(self, hidden_states, attention_mask, layer_past=None): if exists(layer_past) and layer_past.numel() > 0: offset = layer_past[0].shape[0] seq_len += offset + cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) query_layer, key_layer = apply_rotary_fn( query_rot, key_rot, cos, sin, offset=offset @@ -664,7 +672,53 @@ def forward(self, hidden_states, attention_mask, layer_past=None): if exists(self.rotary_ndims): query_layer = torch.cat((query_layer, query_pass), dim=-1) key_layer = torch.cat((key_layer, key_pass), dim=-1) + + # print('query_layer', query_layer.size()) #torch.Size([64, 8, 10, 64]) + # print('key_layer', key_layer.size()) #torch.Size([64, 8, 10, 64]) + # print('value_layer', value_layer.size()) #torch.Size([64, 8, 10, 64]) + + ## xpos + if exists(self.xpos_emb): + # =================================== + # Raw attention scores. [b, np, s, s] + # =================================== + _sq, _b, _np, _hn = query_layer.size() + + # [b, np, sq, sk] + output_size = (query_layer.size(1), + query_layer.size(2), + query_layer.size(0), + key_layer.size(0)) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view(output_size[2], + output_size[0] * output_size[1], -1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view(output_size[3], + output_size[0] * output_size[1], -1) + seq_len = key_layer.shape[0] + offset = 0 + if exists(layer_past): + past_key, past_value = layer_past + key_layer = torch.cat((past_key.type_as(key_layer), + key_layer), dim=0) + value_layer = torch.cat((past_value.type_as(value_layer), + value_layer), dim=0) + + if exists(layer_past) and layer_past.numel() > 0: + offset = layer_past[0].shape[0] + seq_len += offset + apply_xpos_fn = apply_xpos_emb_torch if self.bf16 else apply_xpos_emb + cos, sin, scale = self.xpos_emb(value_layer, seq_len=seq_len) + query_layer, key_layer = apply_xpos_fn( + query_layer, key_layer, cos, sin, scale, offset=offset) + + ## [sq, b * np, hn] -> [sq, b, np, hn] + query_layer = query_layer.view(_sq, _b, _np, _hn) + ## [sq, b * np, hn] -> [sk, b, np, hn] + key_layer = key_layer.view(_sq, _b, _np, _hn) + # ================================== # Cache key and value for inference # ================================== diff --git a/megatron/mpu/layers.py b/megatron/mpu/layers.py index 92edbd6eb..f335df98c 100644 --- a/megatron/mpu/layers.py +++ b/megatron/mpu/layers.py @@ -742,6 +742,7 @@ def forward(self, input_): else: input_parallel = scatter_to_model_parallel_region(input_) # Matrix multiply. + output_parallel = F.linear(input_parallel, self.weight) # All-reduce across all the partitions. if not self.parallel_output: diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index e1ea16a16..98707e267 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -141,7 +141,7 @@ class NeoXArgsModel(NeoXArgsTemplate): """ pos_emb: Literal[ - "learned", "rotary", "sinusoidal", "rpe", "alibi", "none" + "learned", "rotary", "sinusoidal", "rpe", "alibi", "none", "xpos" ] = "learned" """ Type of positional embedding to use - choose from 'learned', 'rotary', 'sinusoidal', 'rpe', 'none' @@ -229,10 +229,10 @@ class NeoXArgsModel(NeoXArgsTemplate): """ activation: Literal[ - "gelu", "geglu", "relu", "softsign", "swish", "mish", "silu" + "gelu", "geglu", "relu", "softsign", "swish", "mish", "silu", "swiglu" ] = "gelu" """ - Activation function to use - choose from ["gelu", "geglu", "relu", "softsign", "swish", "mish", "silu"] + Activation function to use - choose from ["gelu", "geglu", "relu", "softsign", "swish", "mish", "silu", "swiglu"] """ scaled_upper_triang_masked_softmax_fusion: bool = False diff --git a/megatron/training.py b/megatron/training.py index 96a94a1d0..1f2231acd 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -57,6 +57,10 @@ from megatron.model.gpt2_model import cross_entropy from eval_tasks import run_eval_harness +import logging +from deepspeed.utils import logger as ds_logger +ds_logger.setLevel(logging.WARNING) + def mup_weights_reinit(neox_args, model): def has_method(o, name): @@ -722,7 +726,7 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) optimizer=optimizer, model=model, loss=loss, - ) + ) timers("backward").stop() # Update parameters. timers("optimizer").start() @@ -734,7 +738,7 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) reduced_loss = { "lm_loss": reduce_losses(losses).mean() } # reduces losses across machines for logging - + if neox_args.precision == "fp16" and model.optimizer.overflow: skipped_iter = 1 else: @@ -747,7 +751,7 @@ def train_step_pipe(neox_args, timers, model, data_iterator): """Single training step with DeepSpeed's pipeline parallel engine.""" assert neox_args.deepspeed - loss = model.train_batch(data_iter=data_iterator) + loss = model.train_batch(data_iter=data_iterator) loss_dict = {"lm_loss": loss} # Don't break Megatron's timers because we changed code paths. for t in [ @@ -790,7 +794,7 @@ def train( # to monitor if we've skipped many iterations in a row and trigger an early exit overflow_monitor = OverflowMonitor(optimizer) - while iteration < neox_args.train_iters: + while iteration < neox_args.train_iters: loss_dict, skipped_iter = train_step( neox_args=neox_args, timers=timers, @@ -804,8 +808,7 @@ def train( if neox_args.precision == "fp16": overflow_monitor.check(skipped_iter) # check for repeated overflow if neox_args.log_gradient_noise_scale: # log noise scale if applicable - noise_scale_logger.update() - + noise_scale_logger.update() # get learning rate (if present) - if doing soft prompt tuning + pipe parallel, you # may have no tunable parameters on a specific rank if optimizer.param_groups: @@ -827,8 +830,7 @@ def train( model=model, optimizer=optimizer, noise_scale_logger=noise_scale_logger, - ) - + ) # Checkpointing if neox_args.save and iteration in neox_args.save_iters: save_checkpoint( diff --git a/megatron/utils.py b/megatron/utils.py index 0071ef872..ab51f0667 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -301,6 +301,7 @@ def log(self, names, normalizer=1.0, reset=True): for name in names: elapsed_time = self.timers[name].elapsed(reset=reset) * 1000.0 / normalizer string += " | {}: {:.2f}".format(name, elapsed_time) + print("-"*10, string) if torch.distributed.is_initialized(): if torch.distributed.get_rank() == 0: print(string, flush=True) diff --git a/prepare_data_ja.sh b/prepare_data_ja.sh new file mode 100644 index 000000000..cb15ffb82 --- /dev/null +++ b/prepare_data_ja.sh @@ -0,0 +1,5 @@ +#!/bin/sh +python prepare_data.py -d ./data \ +-t SPMTokenizer \ +--vocab-file ./novelAI/tokenizer.model \ +wiki_ja_en diff --git a/preprocess_ja.sh b/preprocess_ja.sh new file mode 100644 index 000000000..7de047c6d --- /dev/null +++ b/preprocess_ja.sh @@ -0,0 +1,8 @@ +#!/bin/sh +python tools/preprocess_data.py \ + --input ./data/mydataset.jsonl.zst \ + --output-prefix ./data/wiki_ja_en \ + --vocab-file ./novelAI/tokenizer.model \ + --dataset-impl mmap \ + --tokenizer-type SPMTokenizer \ + --append-eod diff --git a/tools/convert_module_to_hf.py b/tools/convert_module_to_hf.py index 905bdfa16..138bef1ac 100644 --- a/tools/convert_module_to_hf.py +++ b/tools/convert_module_to_hf.py @@ -23,6 +23,7 @@ import torch from transformers import GPTNeoXConfig, GPTNeoXForCausalLM +from hf_gptneox import GPTNeoX2ForCausalLM sys.path.append( os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) @@ -140,12 +141,14 @@ def convert(input_checkpoint_path, loaded_config, output_checkpoint_path): should perform model-parallel merging correctly but only supports features allowed by HF GPT-NeoX implementation (e.g. rotary embeddings) """ - + print('debug: ', loaded_config) hf_config = GPTNeoXConfig() hf_config = create_config(loaded_config) - hf_model = GPTNeoXForCausalLM(hf_config) + # hf_model = GPTNeoXForCausalLM(hf_config) + ## for swiglu + hf_model = GPTNeoX2ForCausalLM(hf_config) # save model in fp16/bf16 if Deepspeed fp16 or bf16 mixed precision was used in config, else 32 bit weights fp16 = get_key(loaded_config, "fp16") @@ -183,12 +186,18 @@ def convert(input_checkpoint_path, loaded_config, output_checkpoint_path): # get layer from hf model hf_layer = hf_model.gpt_neox.layers[layer_i] + for v in hf_layer.state_dict(): + print('debug state_dict: ', v) + print('-'*200) # + 2 bc of embed layer and a dummy _pre_transformer_block loaded_tp_ranks = load_partitions( input_checkpoint_path, mp_partitions, layer_i + 2 ) + for t in loaded_tp_ranks: + print('debug loaded_tp_ranks: ', t.keys()) + state_dict = {} for key in [ "attention.dense.weight", @@ -197,12 +206,14 @@ def convert(input_checkpoint_path, loaded_config, output_checkpoint_path): state_dict[key] = torch.cat([t[key] for t in loaded_tp_ranks], dim=1) # average layernorm stats over mp ranks - for key in [ + keysForOriginGPTNeoX=[ "input_layernorm.weight", "input_layernorm.bias", "post_attention_layernorm.weight", "post_attention_layernorm.bias", - ]: + ] + keysForSwiglu = [] + for key in keysForSwiglu: state_dict[key] = (sum([t[key] for t in loaded_tp_ranks])) / len( loaded_tp_ranks ) @@ -224,13 +235,21 @@ def convert(input_checkpoint_path, loaded_config, output_checkpoint_path): state_dict[key] = sum([t[key] for t in loaded_tp_ranks]) # Just take one - state_dict["attention.rotary_emb.inv_freq"] = loaded_tp_ranks[0][ - "attention.rotary_emb.inv_freq" - ] - state_dict["attention.bias"] = hf_layer.state_dict()["attention.bias"] - state_dict["attention.masked_bias"] = hf_layer.state_dict()[ - "attention.masked_bias" - ] + if loaded_config['pos_emb'] == 'rotary': + state_dict["attention.rotary_emb.inv_freq"] = loaded_tp_ranks[0][ + "attention.rotary_emb.inv_freq" + ] + + + state_dict["attention.dense.bias"] = hf_layer.state_dict()["attention.dense.bias"] + + if "attention.bias" in hf_layer.state_dict(): + state_dict["attention.bias"] = hf_layer.state_dict()["attention.bias"] + + if "attention.masked_bias" in hf_layer.state_dict(): + state_dict["attention.masked_bias"] = hf_layer.state_dict()[ + "attention.masked_bias" + ] # load state_dict into layer hf_layer.load_state_dict(state_dict) diff --git a/tools/convert_module_to_hf_gptneox2.py b/tools/convert_module_to_hf_gptneox2.py new file mode 100644 index 000000000..8b099f04a --- /dev/null +++ b/tools/convert_module_to_hf_gptneox2.py @@ -0,0 +1,359 @@ +# Copyright (c) 2023, EleutherAI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +import yaml +import argparse +from tqdm import tqdm +from typing import List + +import torch +from transformers import GPTNeoXConfig, GPTNeoXForCausalLM, AutoModelForCausalLM + +from hf_gptneox import GPTNeoX2ForCausalLM + +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)) +) +from megatron.tokenizer import build_tokenizer + + +""" +A script for converting saved NeoX Checkpoints to Huggingface (HF) compatible GPT-NeoX type models. + +Note that this script does not support all NeoX features. +Please investigate carefully whether your model is compatible with all architectures supported by the GPTNeoXForCausalLM class in HF. + +(e.g. position embeddings such as AliBi may not be supported by Huggingface's GPT-NeoX architecture. +""" + + +def load_partitions( + input_checkpoint_path, mp_partitions, layer_idx +) -> List[torch.Tensor]: + """Returns a list containing all weights in a given layer from a model (across MP partitions)""" + + loaded_tp_ranks = [ + torch.load( + os.path.join( + input_checkpoint_path, + f"layer_{layer_idx:02}-model_{i:02}-model_states.pt", + ), + map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"), + ) + for i in range(mp_partitions) + ] + + return loaded_tp_ranks + + +def get_key(loaded_config, key, default=None): + """ + Search for a given key in a NeoX yaml. normalizes underscores -> hyphens + """ + key = key.replace("_", "-") + try: + return loaded_config[key] + except KeyError: + key = key.replace("-", "_") + try: + return loaded_config[key] + except KeyError: + return default + + +def create_config(neox_config): + """take in a loaded yaml from NeoX and assign relevant values to HF config. + Returns: GPTNeoXConfig() object + """ + + class TokenizerArgs: + # kinda hacky. + # this is to get something with the same interface as is used in build_tokenizer() + # without diving into loading a neox_args object or using argparse etc. + def __init__(self, neox_config): + self.make_vocab_size_divisible_by = get_key( + neox_config, "make-vocab-size-divisible-by", default=128 + ) + self.model_parallel_size = get_key(neox_config, "model-parallel-size") + self.vocab_file = get_key(neox_config, "vocab-file") + self.merge_file = get_key(neox_config, "merge-file") + self.tokenizer_type = get_key(neox_config, "tokenizer-type") + + self.rank = 0 + + args = TokenizerArgs(neox_config) + tokenizer = build_tokenizer(args) + try: # GPT2TokenizerFast raises NotImplementedError + pad_token = tokenizer.pad + except: + pad_token = ( + 1 # pad defaulting to 1. follows convention from GPT-NeoX-20b tokenizer + ) + + # TODO: change the default value here based on discussion regarding `gpt_j_tied` config parameter's default + use_tied_lns = get_key(neox_config, "gpt-j-tied", False) + + if use_tied_lns: + raise NotImplementedError( + """ERROR: Huggingface Transformers does not yet support a single shared layernorm + per transformer block for GPT-NeoX models trained w/ GPT-J parallel residuals. + See https://github.com/EleutherAI/gpt-neox/pull/481 for further details.""" + ) + + # set all config values. + hf_config = GPTNeoXConfig( + vocab_size=args.padded_vocab_size, + hidden_size=get_key(neox_config, "hidden-size"), + num_hidden_layers=get_key(neox_config, "num-layers"), + num_attention_heads=get_key(neox_config, "num-attention-heads"), + intermediate_size=(get_key(neox_config, "hidden-size") * 4), + hidden_act=get_key(neox_config, "activation", default="gelu"), + rotary_pct=get_key(neox_config, "rotary-pct", default=1.0), + rotary_emb_base=get_key(neox_config, "rotary-emb-base", default=10000), + max_position_embeddings=get_key(neox_config, "max-position-embeddings"), + initializer_range=get_key(neox_config, "init-method-std", 0.02), + layer_norm_eps=get_key(neox_config, "layernorm-epsilon", 1e-5), + use_cache=True, + bos_token_id=tokenizer.eod, + eos_token_id=tokenizer.eod, + tie_word_embeddings=(not get_key(neox_config, "no-weight-tying", False)), + use_parallel_residual=get_key(neox_config, "gpt-j-residual", False), + ) + return hf_config + + +def convert(input_checkpoint_path, loaded_config, output_checkpoint_path): + """convert a NeoX checkpoint to a HF model format. + should perform model-parallel merging correctly + but only supports features allowed by HF GPT-NeoX implementation (e.g. rotary embeddings) + """ + print('debug: ', loaded_config) + hf_config = GPTNeoXConfig() + + hf_config = create_config(loaded_config) + + # hf_model = GPTNeoXForCausalLM(hf_config) + ## for swiglu + hf_model = GPTNeoX2ForCausalLM(hf_config) + + # save model in fp16/bf16 if Deepspeed fp16 or bf16 mixed precision was used in config, else 32 bit weights + fp16 = get_key(loaded_config, "fp16") + if fp16: + try: + # this conditional is quite messy because there were a number of ways to specify bf16 or fp16 training + # in DeeperSpeed v1.0 . + if (fp16.get("fp16", None) or fp16["enabled"]) and not (fp16.get("type", None) == "bfloat16"): + hf_model.half() + print("Saving weights in fp16 precision...") + elif fp16.get("type", None) == "bfloat16": + hf_model.to(dtype=torch.bfloat16) + print("Saving weights in bf16 precision...") + except: + print("Model not trained in fp16 / bf16 mixed precision, saving weights in fp32...") + + mp_partitions = get_key(loaded_config, "model-parallel-size") + + ### Embedding layer ### + loaded_tp_ranks = load_partitions(input_checkpoint_path, mp_partitions, 0) + hf_model.gpt_neox.embed_in.load_state_dict( + { + "weight": torch.cat( + [t["word_embeddings.weight"] for t in loaded_tp_ranks], dim=0 + ) + } + ) + + assert ( + hf_config.vocab_size == hf_model.gpt_neox.embed_in.weight.shape[0] + ), f"ERROR: calculated vocab size {hf_config.vocab_size} != embed param size {hf_model.gpt_neox.embed_in.shape[0]}" + ### End Embedding Layer ### + + for layer_i in tqdm(range(get_key(loaded_config, "num-layers"))): + + # get layer from hf model + hf_layer = hf_model.gpt_neox.layers[layer_i] + for v in hf_layer.state_dict(): + print('debug state_dict: ', v) + print('-'*200) + + # + 2 bc of embed layer and a dummy _pre_transformer_block + loaded_tp_ranks = load_partitions( + input_checkpoint_path, mp_partitions, layer_i + 2 + ) + + for t in loaded_tp_ranks: + print('debug loaded_tp_ranks: ', t.keys()) + + state_dict = {} + + + + for key in [ + "attention.dense.weight", + "mlp.dense_4h_to_h.weight", + ]: + state_dict[key] = torch.cat([t[key] for t in loaded_tp_ranks], dim=1) + + # average layernorm stats over mp ranks + keysForOriginGPTNeoX=[ + "input_layernorm.weight", + "input_layernorm.bias", + "post_attention_layernorm.weight", + "post_attention_layernorm.bias", + ] + keysForSwiglu = [ + 'input_layernorm.scale', + 'post_attention_layernorm.scale' + ] + for key in keysForSwiglu: + state_dict[key] = (sum([t[key] for t in loaded_tp_ranks])) / len( + loaded_tp_ranks + ) + + # LinearWithTPMerge + for key in [ + "mlp.dense_h_to_4h.weight", + "mlp.dense_h_to_4h.bias", + "attention.query_key_value.weight", + "attention.query_key_value.bias", + ]: + state_dict[key] = torch.cat([t[key] for t in loaded_tp_ranks], dim=0) + + # LinearWithTPSplitBias + for key in [ + "mlp.dense_4h_to_h.bias", + # "attention.dense.bias", + ]: + state_dict[key] = sum([t[key] for t in loaded_tp_ranks]) + + # Just take one + if loaded_config['pos_emb'] == 'rotary': + state_dict["attention.rotary_emb.inv_freq"] = loaded_tp_ranks[0][ + "attention.rotary_emb.inv_freq" + ] + + + # state_dict["attention.dense.bias"] = hf_layer.state_dict()["attention.dense.bias"] + + if "attention.bias" in hf_layer.state_dict(): + state_dict["attention.bias"] = hf_layer.state_dict()["attention.bias"] + + if "attention.masked_bias" in hf_layer.state_dict(): + state_dict["attention.masked_bias"] = hf_layer.state_dict()[ + "attention.masked_bias" + ] + + # load state_dict into layer + hf_layer.load_state_dict(state_dict) + + # Load final layer norm + loaded_tp_ranks = load_partitions( + input_checkpoint_path, mp_partitions, get_key(loaded_config, "num-layers") + 3 + ) + + hf_model.gpt_neox.final_layer_norm.load_state_dict( + { + "weight": (sum([t["norm.weight"] for t in loaded_tp_ranks])) + / len(loaded_tp_ranks), + "bias": (sum([t["norm.bias"] for t in loaded_tp_ranks])) + / len(loaded_tp_ranks), + } + ) + del loaded_tp_ranks + + # Load output embedding + loaded_tp_ranks = load_partitions( + input_checkpoint_path, mp_partitions, get_key(loaded_config, "num-layers") + 4 + ) + + hf_model.embed_out.load_state_dict( + { + "weight": torch.cat( + [t["final_linear.weight"] for t in loaded_tp_ranks], dim=0 + ), + } + ) + + del loaded_tp_ranks + + return hf_model + + +if __name__ == "__main__": + + # before running script: + # `pip install --upgrade transformers` + # `huggingface-cli login` + # + from huggingface_hub import create_repo, HfApi + + parser = argparse.ArgumentParser( + description="Merge MP partitions and convert to HF Model." + ) + parser.add_argument( + "--input_dir", + type=str, + help="Path to NeoX checkpoint, e.g. /path/to/model/global_step143000", + ) + parser.add_argument( + "--config_file", + type=str, + help="Path to config file for the input NeoX checkpoint.", + ) + parser.add_argument( + "--output_dir", + type=str, + help="Output dir, where to save the HF Model, tokenizer, and configs", + ) + parser.add_argument( + "--upload", + action="store_true", + help="Set to true in order to upload to the HF Hub directly.", + ) + args = parser.parse_args() + + with open(args.config_file) as f: + loaded_config = yaml.full_load(f) + + hf_model = convert(args.input_dir, loaded_config, args.output_dir) + + hf_model.save_pretrained(args.output_dir) + + # save tokenizer to directory as well, for easy loading of model as a HF model + tokenizer_type = get_key(loaded_config, "tokenizer-type") + + if tokenizer_type == "HFTokenizer": + print(f"saving tokenizer from file {get_key(loaded_config, 'vocab-file')}") + from transformers import PreTrainedTokenizerFast + + tokenizer = PreTrainedTokenizerFast( + tokenizer_file=get_key(loaded_config, "vocab-file") + ) + print("loaded tokenizer: ", tokenizer) + tokenizer.save_pretrained(args.output_dir) + print("tokenizer saved!") + + if args.upload: + repo_name = input("Provide a repository name for the HF Hub: ") + create_repo(repo_name, repo_type="model", private=False, use_auth_token=True) + + api = HfApi() + api.upload_folder( + folder_path=args.output_dir, + repo_id=repo_name, + repo_type="model", + ) diff --git a/tools/corpora.py b/tools/corpora.py index b9e846454..daeae07ed 100644 --- a/tools/corpora.py +++ b/tools/corpora.py @@ -293,6 +293,122 @@ class Enwik8(DataDownloader): urls = ["https://data.deepai.org/enwik8.zip"] +class WikiEn(DataDownloader): + name = "wiki_en" + urls = [ + "https://dumps.wikimedia.org/other/cirrussearch/20230807/enwiki-20230807-cirrussearch-content.json.gz" + ] + + +class WikiJa(DataDownloader): + name = "wiki_ja" + urls = [ + "https://dumps.wikimedia.org/other/cirrussearch/20230807/jawiki-20230807-cirrussearch-content.json.gz", + ] + +class DataDownloaderWithHF(DataDownloader): + def __init__(self, hf_repo_ids = [], *args, **kwargs): + super().__init__(*args, **kwargs) + self.hf_repo_ids = hf_repo_ids + + def download(self): + super().download() + from huggingface_hub import snapshot_download + save_dir = os.path.join(self.base_dir, self.name) + for repo_id in self.hf_repo_ids: + snapshot_download(repo_id=repo_id, revision="main", allow_patterns="*.jsonl", local_dir=save_dir, repo_type='dataset') + +class WikiOSCARJa(DataDownloaderWithHF): + name = "wiki_oscar_ja" + urls = [ + "https://dumps.wikimedia.org/other/cirrussearch/20230807/jawiki-20230807-cirrussearch-content.json.gz", + ] + hf_repo_ids = [ + 'if001/oscar_2023_filtered' + ] + +class HFSnapshotDownloader(DataDownloader): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + @abstractmethod + def hf_repo_ids(self): + pass + + def download(self): + from huggingface_hub import snapshot_download + save_dir = os.path.join(self.base_dir, self.name) + for repo_id in self.hf_repo_ids: + print('save to', save_dir) + allow_patterns = None + if 'if001/oscar_2023_filtered' == repo_id: + allow_patterns="*.jsonl.zst" + if 'if001/aozorabunko-clean-sin' == repo_id: + allow_patterns="*.jsonl.gz" + snapshot_download(repo_id=repo_id, allow_patterns=allow_patterns, local_dir=save_dir, repo_type="dataset") + + +class OSCARJa(HFSnapshotDownloader): + name = "oscar_ja" + urls = [""] + hf_repo_ids = ['if001/oscar_2023_filtered'] + +class AozoraJa(HFSnapshotDownloader): + name = "aozora_ja" + urls = [""] + hf_repo_ids = ['if001/aozorabunko-clean-sin'] + +class HFDataDownloader(DataDownloader): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @property + @abstractmethod + def hf_repo_ids(self): + pass + + def download(self): + from datasets import load_dataset + save_dir = os.path.join(self.base_dir, self.name) + for repo_id in self.hf_repo_ids: + ds = load_dataset(repo_id) + name = repo_id.split('/')[0] + save_path = f'{save_dir}/{name}.jsonl' + print('save to', save_path) + ds['train'].to_json(save_path, force_ascii=False) + + +class IzumiFullDataset(HFDataDownloader): + name = "izumi_full_dataset" + urls = [""] + hf_repo_ids = [ + "izumi-lab/wikipedia-ja-20230720", + "izumi-lab/wikipedia-en-20230720", + "izumi-lab/wikinews-ja-20230728" + ] + +class IzumiWikiJaDataset(HFDataDownloader): + name = "izumi_wiki_ja_dataset" + urls = [""] + hf_repo_ids = [ + "izumi-lab/wikipedia-ja-20230720", + ] + +class IzumiWikiEnDataset(HFDataDownloader): + name = "izumi_wiki_en_dataset" + urls = [""] + hf_repo_ids = [ + "izumi-lab/wikipedia-en-20230720", + ] + +class IzumiWikiNewsJaDataset(HFDataDownloader): + name = "izumi_wiki_news_dataset" + urls = [""] + hf_repo_ids = [ + "izumi-lab/wikinews-ja-20230728" + ] + def maybe_download_gpt2_tokenizer_data(tokenizer_type, data_dir): if tokenizer_type is None or tokenizer_type == "GPT2BPETokenizer": GPT2_VOCAB_FP = f"{data_dir}//gpt2-vocab.json" @@ -324,6 +440,15 @@ def maybe_download_gpt2_tokenizer_data(tokenizer_type, data_dir): "c4": C4, "c4_openwebtext": C4OpenWebText, "enwik8": Enwik8, + 'wiki_en': WikiEn, + 'wiki_ja': WikiJa, + 'oscar_ja': OSCARJa, + 'wiki_oscar_ja': WikiOSCARJa, + 'aozora_ja': AozoraJa, + 'izumi_dataset': IzumiFullDataset, + 'izumi_wiki_ja_dataset': IzumiWikiJaDataset, + 'izumi_wiki_en_dataset': IzumiWikiEnDataset, + 'izumi_wiki_news_dataset': IzumiWikiNewsJaDataset } diff --git a/tools/hf_gptneox.py b/tools/hf_gptneox.py new file mode 100644 index 000000000..5d6767207 --- /dev/null +++ b/tools/hf_gptneox.py @@ -0,0 +1,416 @@ +from transformers.models.gpt_neox import GPTNeoXPreTrainedModel, GPTNeoXModel, GPTNeoXLayer +from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXMLP, GPTNeoXAttention +from transformers.activations import ClassInstantier, ACT2CLS +from torch import Tensor, nn +import torch + +from typing import Callable, Optional, Tuple +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + +class SwiGLU(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: Tensor) -> Tensor: + return F.silu(x) * x + +# ACT2CLS['swiglu'] = SwiGLUFFN +ACT2CLS['swiglu'] = SwiGLU +ACT2FN = ClassInstantier(ACT2CLS) + +class GPTNeoX2MLP(GPTNeoXMLP): + def __init__(self, config): + _copy_hidden_act = config.hidden_act + config.hidden_act = "gelu" + super().__init__(config) + + config.hidden_act = _copy_hidden_act + self.act = ACT2FN[config.hidden_act] + + +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, cos_k=None, sin_k=None): + """ + q, k: [bs, num_heads, seq_len, rot_dim] + cos, sin: [seq_len, rot_dim / 2] + position_ids: [bs, seq_len] + """ + # print(f"q: {q.shape}, k: {k.shape}, cos: {cos.shape}, sin: {sin.shape}, position_ids: {position_ids.shape}") + import einops + cos = einops.repeat(cos, 's r -> s (2 r)') + sin = einops.repeat(sin, 's r -> s (2 r)') + cos_k = einops.repeat(cos_k, 's r -> s (2 r)') + sin_k = einops.repeat(sin_k, 's r -> s (2 r)') + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim] + cos_k = cos_k[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim] + sin_k = sin_k[position_ids].unsqueeze(1) # [bs, 1, seq_len, rot_dim] + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos_k) + (rotate_half(k) * sin_k) + return q_embed, k_embed + +class RotaryEmbedding(torch.nn.Module): + """Based on Tri Dao's XPos: https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/layers/rotary.py""" + def __init__( + self, + dim: int, + max_position_embeddings: int, + base: int = 10_000, + scale_base: int = 512, + device: str = None + ): + super().__init__() + self.dim = dim + self.seq_len_cached = max_position_embeddings + + # Set up `inv_freq` term + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Set up `scale` term + self.scale_base = scale_base + scale = ( + (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) + if scale_base is not None else None + ) + self.register_buffer("scale", scale) + + # Seet up `cos..` and `sin...` cache terms + t = torch.arange(self.seq_len_cached, device=device, dtype=torch.float32) + freqs = torch.outer(t, self.inv_freq) + # freqs = torch.cat((freqs, freqs), dim=-1) + seq_range = torch.arange(self.seq_len_cached, dtype=self.scale.dtype, device=self.scale.device) + power = (seq_range - self.seq_len_cached // 2) / self.scale_base + scale_cached = self.scale.to(device=power.device) ** power.unsqueeze(-1) + # scale_cached = torch.cat((scale_cached, scale_cached), dim=-1) + self.register_buffer("cos_cached", torch.cos(freqs) * scale_cached, persistent=False) + self.register_buffer("sin_cached", torch.sin(freqs) * scale_cached, persistent=False) + self.register_buffer("cos_k_cached", torch.cos(freqs) / scale_cached, persistent=False) + self.register_buffer("sin_k_cached", torch.sin(freqs) / scale_cached, persistent=False) + + def forward(self, x, seq_len=None): + if seq_len > self.seq_len_cached: + self.seq_len_cached = seq_len + t = torch.arange(seq_len, device=x.device, dtype=torch.float32) + freqs = torch.outer(t, self.inv_freq) + freqs = torch.cat((freqs, freqs), dim=-1) + seq_range = torch.arange(self.seq_len_cached, dtype=self.scale.dtype, device=self.scale.device) + power = (seq_range - self.seq_len_cached // 2) / self.scale_base + scale_cached = self.scale.to(device=power.device) ** power.unsqueeze(-1) + scale_cached = torch.cat((scale_cached, scale_cached), dim=-1) + self.register_buffer("cos_cached", torch.cos(freqs) * scale_cached, persistent=False) + self.register_buffer("sin_cached", torch.sin(freqs) * scale_cached, persistent=False) + self.register_buffer("cos_k_cached", torch.cos(freqs) / scale_cached, persistent=False) + self.register_buffer("sin_k_cached", torch.sin(freqs) / scale_cached, persistent=False) + return ( + self.cos_cached[:seq_len, ...], + self.sin_cached[:seq_len, ...], + self.cos_k_cached[:seq_len, ...], + self.sin_k_cached[:seq_len, ...], + ) + +class GPTNeoX2Attention(nn.Module): + def __init__(self, config): + super().__init__() + self.num_attention_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + if self.hidden_size % self.num_attention_heads != 0: + raise ValueError( + "The hidden size is not divisble by the number of attention heads! Make sure to update them" + ) + self.head_size = self.hidden_size // self.num_attention_heads + + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + persistent=False, + ) + self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False) + + self.rotary_ndims = int(self.head_size * config.rotary_pct) + + self.rotary_emb = RotaryEmbedding( + self.rotary_ndims, + max_position_embeddings=config.max_position_embeddings, + base=config.rotary_emb_base + ) + + self.register_buffer( + "norm_factor", + torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()), + persistent=False, + ) + + self.query_key_value = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False) + self.dense = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + position_ids: torch.LongTensor, + head_mask: Optional[torch.FloatTensor] = None, + layer_past: Optional[Tuple[torch.Tensor]] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ): + has_layer_past = layer_past is not None + + # Compute QKV + # Attention heads [batch, seq_len, hidden_size] + # --> [batch, seq_len, (np * 3 * head_size)] + qkv = self.query_key_value(hidden_states) + + # [batch, seq_len, (num_heads * 3 * head_size)] + # --> [batch, seq_len, num_heads, 3 * head_size] + new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) + qkv = qkv.view(*new_qkv_shape) + + # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] + query = qkv[..., : self.head_size].permute(0, 2, 1, 3) + key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) + value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) + + # Compute rotary embeddings on rotary_ndims + query_rot = query[..., : self.rotary_ndims] + query_pass = query[..., self.rotary_ndims :] + key_rot = key[..., : self.rotary_ndims] + key_pass = key[..., self.rotary_ndims :] + + # Compute token offset for rotary embeddings (when decoding) + kv_seq_len = key.shape[-2] + if has_layer_past: + kv_seq_len += layer_past[0].shape[-2] + + # Add rotary embeddings to query and key + # TODO: Check if using xpos + cos, sin, cos_k, sin_k = self.rotary_emb(value, seq_len=kv_seq_len) + query, key = apply_rotary_pos_emb( + query_rot, key_rot, cos, sin, position_ids, cos_k=cos_k, sin_k=sin_k) + + query = torch.cat((query, query_pass), dim=-1) + key = torch.cat((key, key_pass), dim=-1) + + # Cache QKV values + if has_layer_past: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + present = (key, value) if use_cache else None + + # Compute attention + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + # Merge attn_head_size dim and num_attn_heads dim into hidden dim + # [bs, seq_len, num_attention_heads, attn_head_size] + attn_output = attn_output.permute(0, 2, 1, 3).contiguous() + attn_output = attn_output.view(attn_output.size(0), attn_output.size(1), self.num_attention_heads * self.head_size) + + attn_output = self.dense(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + # q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] + # compute causal mask from causal mask buffer + + batch_size, num_attention_heads, query_length, attn_head_size = query.size() + key_length = key.size(-2) + + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + + query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) + key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) + attn_scores = torch.zeros( + batch_size * num_attention_heads, + query_length, + key_length, + dtype=query.dtype, + device=key.device, + ) + attn_scores = torch.baddbmm( + attn_scores, + query, + key.transpose(1, 2), + beta=1.0, + alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor), + ) + attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) + + mask_value = torch.finfo(attn_scores.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype, device=attn_scores.device) + attn_scores = torch.where(causal_mask, attn_scores, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_scores = attn_scores + attention_mask + + # NOTE: Upcast to float32 + attn_weights = nn.functional.softmax(attn_scores, dim=-1, dtype=torch.float32).type_as(value) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + return attn_output, attn_weights + + +def attention_mask_func(attention_scores, ltor_mask): + attention_scores.masked_fill_(~ltor_mask, torch.finfo(attention_scores.dtype).min) + return attention_scores + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim, p=-1.0, eps=1e-8, bias=False): + """ + Root Mean Square Layer Normalization + :param dim: model size + :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled) + :param eps: epsilon value, default 1e-8 + :param bias: whether use bias term for RMSNorm, disabled by + default because RMSNorm doesn't enforce re-centering invariance. + """ + super(RMSNorm, self).__init__() + + self.eps = eps + self.d = dim + self.p = p + self.bias = bias + + self.scale = torch.nn.Parameter(torch.ones(dim)) + self.register_parameter("scale", self.scale) + + if self.bias: + self.offset = torch.nn.Parameter(torch.zeros(dim)) + self.register_parameter("offset", self.offset) + + def forward(self, x): + if self.p < 0.0 or self.p > 1.0: + norm_x = x.norm(2, dim=-1, keepdim=True) + d_x = self.d + else: + partial_size = int(self.d * self.p) + partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1) + + norm_x = partial_x.norm(2, dim=-1, keepdim=True) + d_x = partial_size + + rms_x = norm_x * d_x ** (-1.0 / 2) + x_normed = x / (rms_x + self.eps) + + if self.bias: + return self.scale * x_normed + self.offset + + return self.scale * x_normed + +class GPTNeoX2Layer(GPTNeoXLayer): + def __init__(self, config): + _copy_hidden_act = config.hidden_act + config.hidden_act = "gelu" + super().__init__(config) + + config.hidden_act = _copy_hidden_act + self.use_parallel_residual = config.use_parallel_residual + # self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) + + # self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + # self.attention = GPTNeoXAttention(config) + self.attention = GPTNeoX2Attention(config) + self.mlp = GPTNeoX2MLP(config) + + def forward( + self, + hidden_states: Optional[torch.FloatTensor], + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + layer_past: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + ): + attention_layer_outputs = self.attention( + self.input_layernorm(hidden_states), + attention_mask=attention_mask, + position_ids=position_ids, + layer_past=layer_past, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights) + outputs = attention_layer_outputs[1:] + + if self.use_parallel_residual: + # pseudocode: + # x = x + attn(ln1(x)) + mlp(ln2(x)) + mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) + hidden_states = mlp_output + attn_output + hidden_states + else: + # pseudocode: + # x = x + attn(ln1(x)) + # x = x + mlp(ln2(x)) + attn_output = attn_output + hidden_states + mlp_output = self.mlp(self.post_attention_layernorm(attn_output)) + hidden_states = mlp_output + attn_output + + if use_cache: + outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights) + else: + outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights) + + return outputs + +class GPTNeoX2Model(GPTNeoXModel): + def __init__(self, config): + _copy_hidden_act = config.hidden_act + config.hidden_act = "gelu" + super().__init__(config) + + config.hidden_act = _copy_hidden_act + self.layers = nn.ModuleList([GPTNeoX2Layer(config) for _ in range(config.num_hidden_layers)]) + +class GPTNeoX2ForCausalLM(GPTNeoXPreTrainedModel): + _tied_weights_keys = ["embed_out.weight"] + + def __init__(self, config): + super().__init__(config) + self.gpt_neox = GPTNeoX2Model(config) \ No newline at end of file diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index 862620eb8..f0c8822e3 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -32,6 +32,7 @@ import tqdm import torch import ftfy +import json from megatron.tokenizer import build_tokenizer from megatron.data import indexed_dataset @@ -49,6 +50,7 @@ def initializer(self): def encode(self, text): if self.args.ftfy: text = ftfy.fix_text(text) + # print('text,', text) ids = {} for key in self.args.jsonl_keys: doc_ids = [] @@ -56,7 +58,11 @@ def encode(self, text): if len(text_ids) > 0: doc_ids.append(text_ids) if self.args.append_eod: - doc_ids[-1].append(Encoder.tokenizer.eod) + try: + doc_ids[-1].append(Encoder.tokenizer.eod) + except Exception as e: + print('text', text) + print('doc_ids', doc_ids) ids[key] = doc_ids return ids, len(text) @@ -157,16 +163,33 @@ def yield_from_files(fnames: list, semaphore): :param fnames: list of filenames """ - def yielder(fname, semaphore): + def yielder(fname, semaphore): for f in filter(lambda x: x, lmd.Reader(fname).stream_data()): semaphore.acquire() yield f - for fname in fnames: - semaphore.acquire() + def wiki_yielder(fname, semaphore): + stream = filter(lambda x: x, lmd.Reader(fname).stream_data()) + for f in filter(lambda x: 'text' in x and len(x['text']) != 0, stream): + semaphore.acquire() + yield f['text'] - yield from yielder(fname, semaphore) + def aozora_yielder(fname, semaphore): + for f in filter(lambda x: x, lmd.Reader(fname).stream_data()): + semaphore.acquire() + yield json.loads(f)['text'] + for fname in fnames: + semaphore.acquire() + print('fname', fname) + if 'izumi' in fname: + yield from yielder(fname, semaphore) + elif 'wiki' in fname: + yield from wiki_yielder(fname, semaphore) + elif 'aozora' in fname: + yield from aozora_yielder(fname, semaphore) + else: + yield from yielder(fname, semaphore) def main(): args = get_args() @@ -181,13 +204,14 @@ def main(): # use multiprocessing to iterate over input documents fin = yield_from_files(args.input.split(","), semaphore) - if args.workers > 1: pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer) + # encoded_docs = pool.imap(encoder.encode, fin, chunksize=25) encoded_docs = pool.imap(encoder.encode, fin, chunksize=25) else: - encoder.initializer() + encoder.initializer() encoded_docs = (encoder.encode(doc) for doc in fin) + # encoded_docs = (encoder.encode(doc) for doc in fin) # make a dataset builder for each key in args.jsonl_keys # each key will output to a different file beginning with args.output_prefix