-
Notifications
You must be signed in to change notification settings - Fork 236
/
train.py
573 lines (468 loc) · 21.1 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
"""
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# pylint: disable=g-bad-todo, abstract-method, consider-using-with, ungrouped-imports
"""Training loop and Decoding of the model."""
# Calling jax.device_count here prevents a "TPU platform already registered" error.
# See github.com/google/maxtext/issues/20 for more
import datetime
import os
import sys
import functools
from typing import Sequence
from absl import app
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
import grain.python as grain
import jax
import numpy as np
import optax
import orbax.checkpoint
import checkpointing
import max_utils
import maxtext_utils
import max_logging
import optimizers
import pyconfig
# pylint: disable-next=unused-import
import register_jax_proxy_backend
from vertex_tensorboard import VertexTensorboardManager
from input_pipeline.input_pipeline_interface import create_data_iterator_with_tokenizer
from layers import models
import jax.numpy as jnp
from jax import random
from jax.sharding import Mesh
from jax.experimental import checkify
from cloud_tpu_diagnostics import diagnostic
from cloud_tpu_diagnostics.configuration import debug_configuration
from cloud_tpu_diagnostics.configuration import diagnostic_configuration
from cloud_tpu_diagnostics.configuration import stack_trace_configuration
from layers import quantizations
from ml_goodput_measurement import goodput
Transformer = models.Transformer
EPS = 1e-8
def validate_train_config(config):
"""Validates the configuration is set correctly for train.py"""
assert config.run_name, "Erroring out, need a real run_name"
if not config.dataset_path.startswith("gs:https://"):
max_logging.log("WARNING: 'dataset_path' might be pointing your local file system")
if not config.base_output_directory.startswith("gs:https://"):
max_logging.log("WARNING: 'base_output_directory' might be pointing your local file system")
assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive integer."
def get_first_step(state):
with jax.spmd_mode("allow_all"):
return int(state.step)
def load_next_batch(train_iter, example_batch, config):
"""Loads the next batch. Can keep reusing the same batch for performance reasons"""
if config.reuse_example_batch and example_batch is not None:
return example_batch
else:
return next(train_iter)
def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr):
"""Records scalar metrics to be written to tensorboard"""
metrics["scalar"].update({"perf/step_time_seconds": step_time_delta.total_seconds()})
metrics["scalar"].update({"perf/per_device_tflops": per_device_tflops})
metrics["scalar"].update({"perf/per_device_tflops_per_sec": per_device_tflops / step_time_delta.total_seconds()})
metrics["scalar"].update({"learning/current_learning_rate": lr})
_buffered_step = None
_buffered_metrics = None
def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config):
"""Entry point for all metrics writing in Train's Main.
TODO: would be better as a Class in the future (that initialized all state!)
To avoid introducing an unnecessary dependency, we "double buffer" -- we hold
onto the last metrics and step and only publish when we receive a new metrics and step.
The logic is that this ensures that Jax is able to queues train_steps and we
don't block when turning "lazy" Jax arrays into real Python numbers.
"""
global _buffered_step, _buffered_metrics
if _buffered_metrics is not None:
if _buffered_step is None:
raise ValueError(f"When writing metrics, {_buffered_step=} was none")
write_metrics_to_tensorboard(writer, _buffered_metrics, _buffered_step, config)
if config.metrics_file:
max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file)
if config.gcs_metrics and jax.process_index() == 0:
running_gcs_metrics = max_utils.write_metrics_for_gcs(_buffered_metrics, _buffered_step, config, running_gcs_metrics)
_buffered_step = step
_buffered_metrics = metrics
def write_metrics_to_tensorboard(writer, metrics, step, config):
"""Writes metrics to tensorboard"""
with jax.spmd_mode("allow_all"):
if jax.process_index() == 0:
for metric_name in metrics.get("scalar", []):
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
for metric_name in metrics.get("scalars", []):
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)
full_log = step % config.log_period == 0
max_logging.log(
f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, "
f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, "
f"loss: {metrics['scalar']['learning/loss']:.3f}"
)
if full_log and jax.process_index() == 0:
max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'")
writer.flush()
def save_checkpoint(checkpoint_manager, step, state, dataset_type="c4", data_iterator=None):
"""Wrapper for saving checkpoint"""
if dataset_type == "c4-array_record":
return checkpoint_manager.save(
step,
args=orbax.checkpoint.args.Composite(
items=orbax.checkpoint.args.PyTreeSave(item=state),
iter=grain.PyGrainCheckpointSave(data_iterator.local_iterator),
),
)
else:
return checkpoint_manager.save(
step, args=orbax.checkpoint.args.Composite(items=orbax.checkpoint.args.PyTreeSave(item=state))
)
# -----------------------------------------------------------------------------
# Top-level Functions
# -----------------------------------------------------------------------------
def record_activation_metrics(output_metrics, intermediate_outputs, config):
"""Adds the activation metrics to the metrics dict"""
if config.scan_layers:
metrics_dict = intermediate_outputs["intermediates"]["decoder"]["decoder"]
for layer_num in range(config.num_decoder_layers):
output_metrics["scalar"][f"activ_fraction_zero/layer_{layer_num:03d}"] = metrics_dict["activation_fraction_zero"][0][
layer_num
]
output_metrics["scalar"][f"activ_mean/layer_{layer_num:03d}"] = metrics_dict["activation_mean"][0][layer_num]
output_metrics["scalar"][f"activ_stdev/layer_{layer_num:03d}"] = metrics_dict["activation_stdev"][0][layer_num]
else:
for layer_num in range(config.num_decoder_layers):
layer = intermediate_outputs["intermediates"]["decoder"][f"layers_{layer_num}"]
output_metrics["scalar"][f"activ_fraction_zero/layer_{layer_num:03d}"] = layer["activation_fraction_zero"][0]
output_metrics["scalar"][f"activ_mean/layer_{layer_num:03d}"] = layer["activation_mean"][0]
output_metrics["scalar"][f"activ_stdev/layer_{layer_num:03d}"] = layer["activation_stdev"][0]
def loss_fn(model, config, data, dropout_rng, params, is_train=True):
"""loss_fn for both train and eval.
Args:
model: A nn.Module
config: Config of parameters
data: Batch of data to apply to the model
dropout_rng: A key to use to generate rng for dropout
params: Model params
is_train: True for train_step and False for eval_step
Returns:
loss: average loss
aux: a dictionary including intermediate_outputs, total_loss, and total_weights
"""
# inputs, targets, segments, positions = apply_args
rng1, aqt_rng = jax.random.split(dropout_rng)
# decimate proportion of data when per_device_batch_size<1
if is_train:
for k, v in data.items():
data[k] = v[: config.global_batch_size_to_train_on, :]
logits, intermediate_outputs = model.apply(
params,
data["inputs"],
data["inputs_position"],
decoder_segment_ids=data["inputs_segmentation"],
enable_dropout=config.enable_dropout if is_train else False,
rngs={"dropout": rng1, "params": aqt_rng},
mutable="intermediates",
)
one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size)
xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets, 0.0)
xent = nn.with_logical_constraint(xent, ("activation_batch", "activation_length"))
# Mask out paddings at the end of each example.
xent = xent * (data["targets_segmentation"] != 0)
total_loss = jnp.sum(xent)
total_weights = jnp.sum(data["targets_segmentation"] != 0)
loss = total_loss / (total_weights + EPS)
aux = {
"intermediate_outputs": intermediate_outputs,
"total_loss": total_loss,
"total_weights": total_weights,
}
return loss, aux
def train_step(model, config, state, data, dropout_rng):
"""
Args:
model: A nn.Module
state: A pytree of the current state of the model
data: Batch of data to apply to the model
dropout_rng: A key to use to generate rng for dropout
Returns:
new_state: Same format as state.
metrics: Dictionary of model metrics such as loss, training rate, etc.
rng2: A new rng key that can be used in future calls.
"""
train_loss_fn = functools.partial(loss_fn, model, config, data, dropout_rng, is_train=True)
grad_fn = jax.value_and_grad(train_loss_fn, has_aux=True)
(loss, aux), raw_grads = grad_fn(state.params)
intermediate_outputs = aux["intermediate_outputs"]
if config.gradient_clipping_threshold > 0:
grads, _ = optax.clip_by_global_norm(config.gradient_clipping_threshold).update(raw_grads, state, None)
else:
grads = raw_grads
new_state = state.apply_gradients(grads=grads)
metrics = {
"scalar": {
"learning/loss": loss,
"learning/grad_norm": max_utils.l2norm_pytree(grads),
"learning/raw_grad_norm": max_utils.l2norm_pytree(raw_grads),
"learning/param_norm": max_utils.l2norm_pytree(new_state.params),
},
"scalars": {},
}
if config.record_internal_nn_metrics:
record_activation_metrics(metrics, intermediate_outputs, config)
return new_state, metrics
def eval_step(model, config, state, data, dropout_rng):
"""eval_step no backprop and new state compared with train_step."""
eval_loss_fn = functools.partial(loss_fn, model, config, data, dropout_rng, is_train=False)
loss, aux = eval_loss_fn(state.params)
total_loss = aux["total_loss"]
total_weights = aux["total_weights"]
metrics = {
"scalar": {"evaluation/loss": loss, "evaluation/total_loss": total_loss, "evaluation/total_weights": total_weights}
}
return metrics
def create_goodput_recorder(config):
if config.enable_goodput_recording:
logger_name = f"goodput_{config.run_name}"
recorder = goodput.GoodputRecorder(config.run_name, logger_name, jax.process_index() == 0)
return recorder
return None
def record_goodput(recorder, config, step=None, job_start=False, job_end=False):
if recorder and config.enable_goodput_recording:
if job_start and step is None:
recorder.record_job_start_time()
if job_end and step is None:
recorder.record_job_end_time()
if step is not None:
recorder.record_step_start_time(step)
def check_example_batch(config, example_batch):
if config.max_checkify:
jittable_f = checkify.checkify(
lambda x: checkify.check(jnp.any(x > -1), "Batch contains bad synthetic data!")
)
# Check if inputs in batch contains bad synthetic data.
err, _ = jax.jit(jittable_f)(example_batch['inputs'][: config.global_batch_size_to_train_on, :])
err.throw()
def setup_mesh_and_model(config):
"""Set up the mesh and the model for training
Args:
config
Returns:
init_rng: RNG key
writer: Summary writer for tensorboard
checkpoint_manager: Orbax checkpointer
state_mesh_annotations: the mesh annotations for the train state
model:
mesh:
learning_rate_schedule:
tx:
"""
init_rng = random.PRNGKey(config.init_weights_seed)
writer = max_utils.initialize_summary_writer(config)
checkpoint_manager = checkpointing.create_orbax_checkpoint_manager(
config.checkpoint_dir,
config.enable_checkpointing,
config.async_checkpointing,
config.checkpoint_period,
config.dataset_type,
)
# Mesh definition
devices_array = max_utils.create_device_mesh(config)
mesh = Mesh(devices_array, config.mesh_axes)
# Model and Optimizer definition
quant = quantizations.configure_quantization(config)
model = Transformer(config, mesh, quant=quant)
learning_rate_schedule = max_utils.create_learning_rate_schedule(config)
tx = optimizers.get_optimizer(config, learning_rate_schedule)
return init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx
def setup_train_loop(config):
"""Set up prerequisites for the training loop -
checkpoint_manager, PRNG keys, Mesh, Model and optimizer.
Set up data iterator and tokenizer, initialize the model.
Args:
config
Returns:
init_rng:
writer: Summary writer for tensorboard
checkpoint_manager: Orbax checkpointer
state_mesh_annotations: the mesh annotations for the train state
model:
mesh:
learning_rate_schedule:
data_iterator:
state: the initialized train state
"""
init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx = setup_mesh_and_model(config)
data_iterator, eval_data_iterator, _ = create_data_iterator_with_tokenizer(config, mesh)
state, state_mesh_annotations, data_iterator = max_utils.setup_training_state(
model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager
)
maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh)
return (
init_rng,
writer,
checkpoint_manager,
state_mesh_annotations,
model,
mesh,
learning_rate_schedule,
data_iterator,
eval_data_iterator,
state,
)
def train_loop(config, state=None):
"""Main Training loop.
Args:
config:
state:
ckpt_path:
Returns:
"""
# Create a GoodputRecorder to log information
recorder = create_goodput_recorder(config)
record_goodput(recorder, config, job_start=True)
(
init_rng,
writer,
checkpoint_manager,
state_mesh_annotations,
model,
mesh,
learning_rate_schedule,
data_iterator,
eval_data_iterator,
state,
) = setup_train_loop(config)
# pylint: disable=line-too-long
(
functional_train,
in_shard_train,
out_shard_train,
static_argnums_train,
donate_argnums_train,
) = maxtext_utils.get_functional_train_with_signature(train_step, mesh, state_mesh_annotations, model, config)
if eval_data_iterator:
# pylint: disable=line-too-long
(
functional_eval,
in_shard_eval,
out_shard_eval,
static_argnums_eval,
donate_argnums_eval,
) = maxtext_utils.get_functional_eval_with_signature(eval_step, mesh, state_mesh_annotations, model, config)
num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params)
max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion")
per_device_tflops, _, _ = maxtext_utils.calculate_tflops_training_per_device(config)
# Write train config params, num model params, and XLA flags to tensorboard
max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), writer)
max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ["LIBTPU_INIT_ARGS"], writer)
max_utils.add_config_to_summary_writer(config, writer)
# Define the compilation of functional_train, either by loading the compiled version or wrapping a new one in a jit
if config.compiled_trainstep_file != "":
print("Loading the compiled function...", flush=True)
# Need to pass train signature and state to determine i/o shapes of train_state for now.
p_train_step = maxtext_utils.load_compiled(config, functional_train, state)
# TODO: p_eval_step is not yet supported in load_compiled
p_eval_step = None
print("Loaded compiled function!", flush=True)
else:
p_train_step = jax.jit(
functional_train,
in_shardings=in_shard_train,
out_shardings=out_shard_train,
static_argnums=static_argnums_train,
donate_argnums=donate_argnums_train,
)
if eval_data_iterator:
p_eval_step = jax.jit(
functional_eval,
in_shardings=in_shard_eval,
out_shardings=out_shard_eval,
static_argnums=static_argnums_eval,
donate_argnums=donate_argnums_eval,
)
else:
p_eval_step = None
local_metrics_file = open(config.metrics_file, "a", encoding="utf8") if config.metrics_file else None
running_gcs_metrics = [] if config.gcs_metrics else None
start_step = get_first_step(state) # this is the start_step for training
first_profiling_step = start_step + config.skip_first_n_steps_for_profiler
if config.enable_profiler and first_profiling_step >= config.steps:
raise ValueError("Profiling requested but initial profiling step set past training final step")
last_profiling_step = np.clip(first_profiling_step + config.profiler_steps - 1, first_profiling_step, config.steps - 1)
example_batch = None
last_step_completion = datetime.datetime.now()
for step in np.arange(start_step, config.steps):
if step == first_profiling_step:
max_utils.activate_profiler(config)
with jax.profiler.StepTraceAnnotation("train", step_num=step):
example_batch = load_next_batch(data_iterator, example_batch, config)
check_example_batch(config, example_batch=example_batch)
nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
record_goodput(recorder, config, step=step)
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
state, metrics = p_train_step(state, example_batch, nextrng)
new_time = datetime.datetime.now()
record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step))
last_step_completion = new_time
if checkpoint_manager is not None:
if save_checkpoint(checkpoint_manager, step, state, config.dataset_type, data_iterator):
max_logging.log(f"saved a checkpoint at step {step}")
# Upon preemption, exit when and only when all ongoing saves are complete.
if checkpoint_manager.reached_preemption(step):
checkpoint_manager.wait_until_finished()
sys.exit()
write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config)
if config.eval_interval > 0 and step > start_step and step % config.eval_interval == 0:
assert eval_data_iterator
cumulative_eval_metrics = {"total_loss": 0.0, "total_weights": 0.0}
for eval_batch in eval_data_iterator:
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
eval_metrics = p_eval_step(state, eval_batch, nextrng)
cumulative_eval_metrics["total_loss"] += float(eval_metrics["scalar"]["evaluation/total_loss"])
cumulative_eval_metrics["total_weights"] += float(eval_metrics["scalar"]["evaluation/total_weights"])
eval_loss = cumulative_eval_metrics["total_loss"] / (cumulative_eval_metrics["total_weights"] + EPS)
max_logging.log(f"average loss after {step=}: {eval_loss=}, total_weights={cumulative_eval_metrics['total_weights']}")
if eval_loss <= config.target_eval_loss:
max_logging.log(f"Early stop and exit loop after reaching {config.target_eval_loss=}")
max_utils.deactivate_profiler(config)
break
if step == last_profiling_step:
max_utils.deactivate_profiler(config)
if checkpoint_manager is not None:
checkpoint_manager.wait_until_finished()
write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, config.steps - 1, config) # final step metrics
max_utils.close_summary_writer(writer)
record_goodput(recorder, config, job_end=True)
return state
def main(argv: Sequence[str]) -> None:
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
pyconfig.initialize(argv)
config = pyconfig.config
validate_train_config(config)
os.environ["TFDS_DATA_DIR"] = config.dataset_path
vertex_tensorboard_manager = VertexTensorboardManager()
if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"):
vertex_tensorboard_manager.configure_vertex_tensorboard(config)
debug_config = debug_configuration.DebugConfig(
stack_trace_config=stack_trace_configuration.StackTraceConfig(
collect_stack_trace=config.collect_stack_trace,
stack_trace_to_cloud=config.stack_trace_to_cloud,
stack_trace_interval_seconds=config.stack_trace_interval_seconds,
)
)
diagnostic_config = diagnostic_configuration.DiagnosticConfig(debug_config)
with diagnostic.diagnose(diagnostic_config):
train_loop(config)
if __name__ == "__main__":
app.run(main)