Skip to content

Commit

Permalink
add parameter that controls how the output layer is parallelized (#417)
Browse files Browse the repository at this point in the history
Co-authored-by: Samuel Weinbach <[email protected]>
  • Loading branch information
sdtblck and sweinbach authored Sep 29, 2021
1 parent 2022712 commit d0e5eb1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
38 changes: 25 additions & 13 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class ParallelMLP(nn.Module):
"""

def __init__(self, neox_args, init_method, output_layer_init_method):
super(ParallelMLP, self).__init__()
super().__init__()

self.activation_func = get_activation(neox_args)
self.activation_type = neox_args.activation
Expand Down Expand Up @@ -120,16 +120,28 @@ class ParallelLinear(nn.Module):
"""

def __init__(self, neox_args, parallel_output=True, init_method=nn.init.xavier_normal_):
super(ParallelLinear, self).__init__()
self.final_linear = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.padded_vocab_size,
bias=False,
input_is_parallel=False,
init_method=init_method,
parallel_output=parallel_output,
skip_bias_add=False)
super().__init__()
parallelism = neox_args.output_layer_parallelism
if parallelism == "column":
self.final_linear = mpu.ColumnParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.padded_vocab_size,
bias=False,
init_method=init_method,
gather_output=not parallel_output,
skip_bias_add=False)
else:
self.final_linear = mpu.RowParallelLinear(
neox_args=neox_args,
input_size=neox_args.hidden_size,
output_size=neox_args.padded_vocab_size,
bias=False,
input_is_parallel=False,
init_method=init_method,
parallel_output=parallel_output,
skip_bias_add=False)


def forward(self, hidden_states):
return self.final_linear(hidden_states)
Expand All @@ -145,7 +157,7 @@ class ParallelSelfAttention(nn.Module):
def __init__(self, neox_args, attention_mask_func, init_method,
output_layer_init_method, layer_number,
rpe=None, rotary=False, get_key_value=False):
super(ParallelSelfAttention, self).__init__()
super().__init__()

self.fp16 = neox_args.precision == "fp16"
self.bf16 = neox_args.precision == "bfloat16"
Expand Down Expand Up @@ -416,7 +428,7 @@ class ParallelTransformerLayer(nn.Module):
def __init__(self, neox_args, attention_mask_func, init_method,
output_layer_init_method, layer_number, rpe=None, rotary=False, get_key_value=False):

super(ParallelTransformerLayer, self).__init__()
super().__init__()
self.layer_number = layer_number

norm, eps = get_norm(neox_args)
Expand Down
7 changes: 7 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,12 @@ class NeoXArgsModel(NeoXArgsTemplate):
'init_range': float = 0.5 # if no init string is provided, initialize the soft prompt with a uniform distribution between -init_range and init_rang
"""

output_layer_parallelism: Literal["row", "column"] = "row"

"""
Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column)
"""


@dataclass
class NeoXArgsOptimizer(NeoXArgsTemplate):
Expand Down Expand Up @@ -481,6 +487,7 @@ class NeoXArgsLogging(NeoXArgsTemplate):
"""



@dataclass
class NeoXArgsOther(NeoXArgsTemplate):
"""
Expand Down

0 comments on commit d0e5eb1

Please sign in to comment.