-
Notifications
You must be signed in to change notification settings - Fork 1k
/
checkpointing.py
247 lines (207 loc) · 8.97 KB
/
checkpointing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
# coding=utf-8
# Copyright (c) 2021, EleutherAI contributors
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Input/output checkpointing."""
import os
import re
import shutil
import random
import sys
import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP
from glob import glob
from megatron import mpu, get_args
from megatron import get_args
from megatron import print_rank_0
from megatron.utils import natural_sort
def check_checkpoint_args(checkpoint_args):
"""Ensure fixed arguments for a model are the same for the input
arguments and the one retreived frm checkpoint."""
args = get_args()
def _compare(arg_name):
checkpoint_value = getattr(checkpoint_args, arg_name)
args_value = getattr(args, arg_name)
error_message = '{} value from checkpoint ({}) is not equal to the ' \
'input argument value ({}).'.format(
arg_name, checkpoint_value, args_value)
assert checkpoint_value == args_value, error_message
_compare('num_layers')
_compare('hidden_size')
_compare('num_attention_heads')
_compare('max_position_embeddings')
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
_compare('tokenizer_type')
_compare('model_parallel_size')
def ensure_directory_exists(filename):
"""Build filename's path if it does not already exists."""
dirname = os.path.dirname(filename)
if not os.path.exists(dirname):
os.makedirs(dirname)
def get_checkpoint_name(checkpoints_path, iteration,
release=False, mp_rank=None):
"""A unified checkpoint name."""
if release:
directory = 'release'
else:
directory = 'iter_{:07d}'.format(iteration)
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}'.format(
mpu.get_model_parallel_rank() if mp_rank is None
else mp_rank),
'model_optim_rng.pt')
def get_checkpoint_tracker_filename(checkpoints_path):
"""Tracker file rescords the latest chckpoint during
training to restart from."""
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
def delete_old_checkpoints(save_dir, n_to_keep):
if torch.distributed.get_rank() == 0:
ckpt_dir_regex = r'global_step[\d]*'
if save_dir.endswith('/'):
save_dir = save_dir.strip('/')
all_ckpts = natural_sort([i for i in glob(f'{save_dir}/*') if os.path.isdir(i)
and re.search(ckpt_dir_regex, i)])
n_to_delete = len(all_ckpts) - n_to_keep
if n_to_delete > 0:
to_delete = all_ckpts[:n_to_delete]
print(f"WARNING: Deleting old checkpoints: \n\t{', '.join(to_delete)}")
for ckpt in to_delete:
try:
shutil.rmtree(ckpt)
except FileNotFoundError:
pass
def save_ds_checkpoint(iteration, model, args):
"""Save a model checkpoint."""
sd = {}
sd['iteration'] = iteration
# rng states.
if not args.no_save_rng:
sd['random_rng_state'] = random.getstate()
sd['np_rng_state'] = np.random.get_state()
sd['torch_rng_state'] = torch.get_rng_state()
sd['cuda_rng_state'] = torch.cuda.get_rng_state()
sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()
if args.pipe_parallel_size == 0:
# megatron model uses state_dict_for_save_checkpointing instead of the standard state_dict
# state_dict is used by deepspeed for module saving so it needs to point to the right function
model.module.state_dict = model.module.state_dict_for_save_checkpoint
else:
# Pipeline parallelism manages its own state_dict.
pass
model.save_checkpoint(args.save, client_state=sd)
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
"""Save a model checkpoint."""
args = get_args()
if args.deepspeed:
save_ds_checkpoint(iteration, model, args)
else:
raise ValueError('Must be using deepspeed to use neox')
# Wait so everyone is done (necessary)
torch.distributed.barrier()
# And update the latest iteration
if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))
# Wait so everyone is done (necessary)
torch.distributed.barrier()
if args.keep_last_n_checkpoints is not None:
delete_old_checkpoints(args.save, args.keep_last_n_checkpoints)
# Wait so everyone is done (not necessary)
torch.distributed.barrier()
def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load'):
"""Load a model checkpoint and return the iteration."""
args = get_args()
load_dir = getattr(args, load_arg)
if isinstance(model, torchDDP):
model = model.module
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir)
# If no tracker file, return iretation zero.
if not os.path.isfile(tracker_filename):
print_rank_0('WARNING: could not find the metadata file {} '.format(
tracker_filename))
print_rank_0(' will not load any checkpoints and will start from '
'random')
return 0
# Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration = 0
release = False
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == 'release'
if not release:
print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
tracker_filename))
sys.exit()
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)
if args.deepspeed:
load_optim_and_scheduler = not args.no_load_optim # TODO: These should be configured by separate args
checkpoint_name, state_dict = model.load_checkpoint(load_dir,
load_optimizer_states=load_optim_and_scheduler,
load_lr_scheduler_states=load_optim_and_scheduler)
if checkpoint_name is None:
if mpu.get_data_parallel_rank() == 0:
print("Unable to load checkpoint.")
return iteration
else:
raise ValueError('Must be using deepspeed to use neox')
# Set iteration.
if args.finetune or release:
iteration = 0
else:
try:
iteration = state_dict['iteration']
except KeyError:
try: # Backward compatible with older checkpoints
iteration = state_dict['total_iters']
except KeyError:
print_rank_0('A metadata file exists but unable to load '
'iteration from checkpoint {}, exiting'.format(
checkpoint_name))
sys.exit()
# Check arguments.
if 'args' in state_dict:
checkpoint_args = state_dict['args']
check_checkpoint_args(checkpoint_args)
else:
print_rank_0('could not find arguments in the checkpoint ...')
# rng states.
if not release and not args.finetune and not args.no_load_rng:
try:
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'.format(checkpoint_name))
sys.exit()
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return iteration