Skip to content

Commit

Permalink
added reload model params for finetuning
Browse files Browse the repository at this point in the history
  • Loading branch information
mohammad committed Jan 5, 2021
1 parent 43529f7 commit 160ba68
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
31 changes: 29 additions & 2 deletions megatron/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def scale_loss(self, loss):
def step(self):
pass

@abstractmethod
def reload_model_params(self):
pass

@abstractmethod
def state_dict(self):
pass
Expand Down Expand Up @@ -243,22 +247,41 @@ def _unscale_master_grads_and_check_for_nan(self):
return found_inf_flag


def _copy_master_params_to_model_params(self):
# Only needed for the fp16 params.
def _get_model_and_master_params_data_fp16(self):
model_data = []
master_data = []
for model_group, master_group in zip(self.fp16_groups,
self.fp32_from_fp16_groups):
for model_param, master_param in zip(model_group, master_group):
model_data.append(model_param.data)
master_data.append(master_param.data)
return model_data, master_data


def _copy_master_params_to_model_params(self):
# Only needed for the fp16 params.
model_data, master_data = self._get_model_and_master_params_data_fp16()
self._dummy_overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(amp_C.multi_tensor_scale,
self._dummy_overflow_buf,
[master_data, model_data],
1.0)

def _copy_model_params_to_master_params(self):
# Only needed for the fp16 params.
model_data, master_data = self._get_model_and_master_params_data_fp16()
self._dummy_overflow_buf.fill_(0)
# Scaling with factor `1.0` is equivalent to copy.
multi_tensor_applier(amp_C.multi_tensor_scale,
self._dummy_overflow_buf,
[model_data, master_data],
1.0)


def reload_model_params(self):
self._copy_model_params_to_master_params()


@torch.no_grad()
def step(self):
Expand Down Expand Up @@ -388,6 +411,10 @@ def step(self):
return True


def reload_model_params(self):
pass


def state_dict(self):
return self.optimizer.state_dict()

Expand Down
3 changes: 1 addition & 2 deletions tasks/finetune_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,7 @@ def finetune(train_valid_datasets_provider, model_provider,
args.load = original_load
# This is critical when only model is loaded. We should make sure
# master parameters are also updated.
if args.fp16:
optimizer._model_params_to_master_params()
optimizer.reload_model_params()
timers('pretrained checkpoint').stop()

# Print setup timing.
Expand Down

0 comments on commit 160ba68

Please sign in to comment.