Skip to content

Commit

Permalink
Merge branch 'main' into log-grads
Browse files Browse the repository at this point in the history
  • Loading branch information
sdtblck committed Apr 25, 2021
2 parents adb1b3f + f90e083 commit 357aed4
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 63 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,8 @@ dmypy.json
wandb/

# data files
data/
data/

# ckpt files
*.pt
*.ckpt
22 changes: 2 additions & 20 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,6 @@
from megatron import print_rank_0
from megatron.utils import natural_sort

_CHECKPOINT_VERSION = None


def set_checkpoint_version(value):
global _CHECKPOINT_VERSION
assert _CHECKPOINT_VERSION is None, \
"checkpoint version already set"
_CHECKPOINT_VERSION = value


def get_checkpoint_version():
global _CHECKPOINT_VERSION
return _CHECKPOINT_VERSION


def check_checkpoint_args(checkpoint_args):
"""Ensure fixed arguments for a model are the same for the input
arguments and the one retreived frm checkpoint."""
Expand Down Expand Up @@ -148,7 +133,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
if args.deepspeed:
save_ds_checkpoint(iteration, model, args)
else:
raise ValueError('Must be using DeepSpeed')
raise ValueError('Must be using deepspeed to use neox')

# Wait so everyone is done (necessary)
torch.distributed.barrier()
Expand Down Expand Up @@ -214,10 +199,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
print("Unable to load checkpoint.")
return iteration
else:
raise ValueError('Must be using DeepSpeed')

# set checkpoint version
set_checkpoint_version(state_dict.get('checkpoint_version', 0))
raise ValueError('Must be using deepspeed to use neox')

# Set iteration.
if args.finetune or release:
Expand Down
40 changes: 0 additions & 40 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule
from megatron.checkpointing import get_checkpoint_version
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import openai_gelu, erf_gelu, exists
Expand Down Expand Up @@ -256,36 +255,6 @@ def __init__(self, attention_mask_func, init_method,
get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
checkpoint = deepspeed.checkpointing.checkpoint

def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first):
input_shape = mixed_layer.size()
if num_splits_first:
"""[s, b, num_splits * np * hn]
-->(view) [s, b, num_splits, np, hn]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """

intermediate_shape = input_shape[:-1] + \
(num_splits, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)

mixed_layer = mixed_layer.view(*intermediate_shape)
mixed_layer = mixed_layer.transpose(-2, -3).contiguous()
else:
"""[s, b, np * hn * num_splits]
-->(view) [s, b, np, hn, num_splits]
-->(tranpose) [s, b, np, num_splits, hn]
-->(view) [s, b, np * num_splits * hn] """

intermediate_shape = input_shape[:-1] + \
(self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head, num_splits)

mixed_layer = mixed_layer.view(*intermediate_shape)
mixed_layer = mixed_layer.transpose(-1, -2).contiguous()
mixed_layer = mixed_layer.view(*input_shape)

return mixed_layer

def forward(self, hidden_states, attention_mask, layer_past=None):

# hidden_states: [sq, b, h]
Expand All @@ -297,15 +266,6 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
mixed_x_layer, _ = self.query_key_value(hidden_states)

checkpoint_version = get_checkpoint_version()
if checkpoint_version is not None:
if checkpoint_version == 0:
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True)
elif checkpoint_version == 1.0:
# [s, b, (np * hn * 3)] --> [s, b, (np * 3 * hn)]
mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, False)

# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + \
(self.num_attention_heads_per_partition,
Expand Down
12 changes: 10 additions & 2 deletions megatron/mpu/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,16 @@ def get_data_parallel_src_rank():
"""Calculate the global rank corresponding to a local rank zero
in the data parallel group."""
global_rank = torch.distributed.get_rank()
local_world_size = get_data_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
topo = get_topology()
if topo is None:
# we are just using model parallel
return global_rank % get_model_parallel_world_size()
else:
# We are using pipeline parallel
d = topo.get_axis_comm_lists('data')
for l in d:
if global_rank in l:
return l[0]


def get_data_parallel_world_size():
Expand Down
Loading

0 comments on commit 357aed4

Please sign in to comment.