Skip to content

Commit

Permalink
Only create task heads on last pipeline stage.
Browse files Browse the repository at this point in the history
  • Loading branch information
jaredcasper committed Jan 5, 2021
1 parent 6fa3684 commit f772fbc
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 31 deletions.
35 changes: 19 additions & 16 deletions megatron/model/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import torch

from megatron import get_args, print_rank_0
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
Expand Down Expand Up @@ -45,11 +45,12 @@ def __init__(self, num_classes, num_tokentypes=2):
args.num_layers))

# Multi-choice head.
self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
self.classification_head = get_linear_layer(args.hidden_size,
self.num_classes,
init_method)
self._classification_head_key = 'classification_head'
if mpu.is_pipeline_last_stage():
self.classification_dropout = torch.nn.Dropout(args.hidden_dropout)
self.classification_head = get_linear_layer(args.hidden_size,
self.num_classes,
init_method)
self._classification_head_key = 'classification_head'

def forward(self, model_input, attention_mask, tokentype_ids=None):

Expand Down Expand Up @@ -85,23 +86,25 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._classification_head_key] \
= self.classification_head.state_dict(
destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage():
state_dict_[self._classification_head_key] \
= self.classification_head.state_dict(
destination, prefix, keep_vars)
return state_dict_

def load_state_dict(self, state_dict, strict=True):
"""Customized load."""

self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if self._classification_head_key in state_dict:
self.classification_head.load_state_dict(
state_dict[self._classification_head_key], strict=strict)
else:
print_rank_0('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format(
self._classification_head_key))
if mpu.is_pipeline_last_stage():
if self._classification_head_key in state_dict:
self.classification_head.load_state_dict(
state_dict[self._classification_head_key], strict=strict)
else:
print_rank_last('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format(
self._classification_head_key))


class Classification(ClassificationBase):
Expand Down
33 changes: 18 additions & 15 deletions megatron/model/multiple_choice.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import torch

from megatron import get_args, print_rank_0
from megatron import get_args, print_rank_last
from megatron import mpu
from megatron.model.bert_model import bert_attention_mask_func, bert_extended_attention_mask, bert_position_ids
from megatron.model.language_model import get_language_model
Expand All @@ -44,10 +44,11 @@ def __init__(self, num_tokentypes=2):
args.num_layers))

# Multi-choice head.
self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout)
self.multichoice_head = get_linear_layer(args.hidden_size, 1,
init_method)
self._multichoice_head_key = 'multichoice_head'
if mpu.is_pipeline_last_stage():
self.multichoice_dropout = torch.nn.Dropout(args.hidden_dropout)
self.multichoice_head = get_linear_layer(args.hidden_size, 1,
init_method)
self._multichoice_head_key = 'multichoice_head'

def forward(self, model_input, attention_mask, tokentype_ids=None):

Expand Down Expand Up @@ -97,23 +98,25 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='',
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
state_dict_[self._multichoice_head_key] \
= self.multichoice_head.state_dict(
destination, prefix, keep_vars)
if mpu.is_pipeline_last_stage():
state_dict_[self._multichoice_head_key] \
= self.multichoice_head.state_dict(
destination, prefix, keep_vars)
return state_dict_

def load_state_dict(self, state_dict, strict=True):
"""Customized load."""

self.language_model.load_state_dict(
state_dict[self._language_model_key], strict=strict)
if self._multichoice_head_key in state_dict:
self.multichoice_head.load_state_dict(
state_dict[self._multichoice_head_key], strict=strict)
else:
print_rank_0('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format(
self._multichoice_head_key))
if mpu.is_pipeline_last_stage():
if self._multichoice_head_key in state_dict:
self.multichoice_head.load_state_dict(
state_dict[self._multichoice_head_key], strict=strict)
else:
print_rank_last('***WARNING*** could not find {} in the checkpoint, '
'initializing to random'.format(
self._multichoice_head_key))

class MultipleChoice(MultipleChoiceBase):

Expand Down

0 comments on commit f772fbc

Please sign in to comment.