Skip to content

Commit

Permalink
Make broadcasting from one replica to all more memory efficient
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 646526020
  • Loading branch information
maxtext authors committed Jun 25, 2024
1 parent e7c1f01 commit 9482bf1
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from typing import Optional, Union
from etils import epath

from orbax.checkpoint.checkpoint_manager import CheckpointManager, CheckpointManagerOptions
from orbax.checkpoint.logging import abstract_logger, cloud_logger, standard_logger, composite_logger
import jax
Expand Down Expand Up @@ -134,21 +135,29 @@ def map_to_pspec(data):
pspec = data.sharding.spec
mesh = data.sharding.mesh
if not enable_single_replica_ckpt_restoring:
return orbax.checkpoint.type_handlers.ArrayRestoreArgs(mesh=mesh, mesh_axes=pspec)
orbax.checkpoint.type_handlers.register_type_handler(
jax.Array, orbax.checkpoint.type_handlers.SingleReplicaArrayHandler(), override=True
)
orbax.checkpoint.type_handlers.register_type_handler(
jax.Array, orbax.checkpoint.type_handlers.SingleReplicaArrayHandler(), override=True
)
replica_axis_index = 0 # for maxtext data is the first dimension
return orbax.checkpoint.type_handlers.ArrayRestoreArgs(
mesh=mesh, mesh_axes=pspec)
replica_axis_index = 0
replica_devices = _replica_devices(mesh.devices, replica_axis_index)
replica_mesh = jax.sharding.Mesh(replica_devices, mesh.axis_names)
single_replica_sharding = jax.sharding.NamedSharding(replica_mesh, pspec)
single_replica_sharding = jax.sharding.NamedSharding(
replica_mesh, pspec)

array_handler = (
orbax.checkpoint.type_handlers.SingleReplicaArrayHandler(
replica_axis_index=0,
broadcast_memory_limit_bytes=1024 * 1024 * 1000 # 1000 MB limit
)
)
orbax.checkpoint.type_handlers.register_type_handler(
jax.Array,
array_handler,
override=True
)

return orbax.checkpoint.type_handlers.SingleReplicaArrayRestoreArgs(
sharding=jax.sharding.NamedSharding(mesh, pspec),
single_replica_sharding=single_replica_sharding,
replica_axis_index=replica_axis_index,
global_shape=data.shape,
dtype=data.dtype,
)
Expand Down

0 comments on commit 9482bf1

Please sign in to comment.