Skip to content

Commit

Permalink
Stop writing msgpack file for new checkpoints and update empty nodes …
Browse files Browse the repository at this point in the history
…handling so that it no longer depends on this file.

PiperOrigin-RevId: 649179323
  • Loading branch information
cpgaffney1 authored and t5-copybara committed Jul 3, 2024
1 parent f975a01 commit 9c0cb7a
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2154,7 +2154,7 @@ def _construct_orbax_restoration_transforms(
)
assert state_subdir.is_dir(), state_subdir
use_orbax_format = state_subdir.stem == _STATE_KEY # Standard Orbax format
structure = state_handler._handler_impl._read_aggregate_file( # pylint: disable=protected-access
structure, _ = state_handler._get_internal_metadata( # pylint: disable=protected-access
state_subdir
)
# Note: Ideally we would use Orbax's `transform_fn` to do this logic, but
Expand Down Expand Up @@ -2187,20 +2187,16 @@ def _transform_fn(
del structure_, param_infos_

def _make_orbax_internal_metadata(value: Any, args: ocp.RestoreArgs):
if ocp.utils.leaf_is_placeholder(value):
if isinstance(value, ocp.metadata.tree.ValueMetadataEntry):
if value.value_type == 'scalar':
return ocp.metadata.tree.ValueMetadataEntry(value_type='scalar')
if isinstance(args, ocp.ArrayRestoreArgs):
restore_type = 'jax.Array'
value_type = 'jax.Array'
else:
restore_type = 'np.ndarray'
return ocp.pytree_checkpoint_handler._InternalValueMetadata( # pylint: disable=protected-access
restore_type=restore_type
)
value_type = 'np.ndarray'
return ocp.metadata.tree.ValueMetadataEntry(value_type=value_type)
else:
return ocp.pytree_checkpoint_handler._InternalValueMetadata( # pylint: disable=protected-access
restore_type=None,
skip_deserialize=True,
aggregate_value=value,
)
return value

directory_ = ocp.utils.get_save_directory(
step, directory, name=_STATE_KEY, step_prefix=get_checkpoint_prefix()
Expand All @@ -2223,6 +2219,7 @@ def _modify_orbax_param_info(info, value):
item_,
None,
None,
None,
)
param_infos_ = jax.tree_util.tree_map(
_modify_orbax_param_info, param_infos_, state_dict_to_restore
Expand Down

0 comments on commit 9c0cb7a

Please sign in to comment.