Skip to content

Commit

Permalink
Update jax.tree_map to jax.tree_util.tree_map
Browse files Browse the repository at this point in the history
  • Loading branch information
RissyRan committed Apr 25, 2024
1 parent f6060b0 commit 967941b
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 32 deletions.
4 changes: 2 additions & 2 deletions MaxText/convert_gemma_chkpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,15 @@ def main(raw_args=None) -> None:

layer_weight["self_attention"] = copy.deepcopy(self_attention)
jax_weights["decoder"]["layers"] = copy.deepcopy(layer_weight)
jax_weights = jax.tree_map(jnp.array, jax_weights)
jax_weights = jax.tree_util.tree_map(jnp.array, jax_weights)

def astype_fn(x):
if isinstance(x, jnp.ndarray):
return x.astype(jnp.bfloat16)
else:
return x

jax_weights = jax.tree_map(astype_fn, jax_weights)
jax_weights = jax.tree_util.tree_map(astype_fn, jax_weights)

enable_checkpointing = True
async_checkpointing = False
Expand Down
10 changes: 5 additions & 5 deletions MaxText/generate_param_only_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ def _possibly_unroll_params(config, training_state, training_state_annotations,
def new_pspec(x):
return jax.sharding.PartitionSpec(*x[0 : config.param_scan_axis] + x[config.param_scan_axis + 1 :])

new_per_layer_state_annotation = jax.tree_map(new_pspec, training_state_annotations_layers)
new_per_layer_state_sharding = jax.tree_map(lambda x: jax.sharding.NamedSharding(mesh, x), new_per_layer_state_annotation)
new_per_layer_state_annotation = jax.tree_util.tree_map(new_pspec, training_state_annotations_layers)
new_per_layer_state_sharding = jax.tree_util.tree_map(lambda x: jax.sharding.NamedSharding(mesh, x), new_per_layer_state_annotation)

for i in range(config.num_decoder_layers):

def slice_ith(input_layers):
return jax.tree_map(lambda x: jax.numpy.take(x, i, axis=config.param_scan_axis), input_layers)
return jax.tree_util.tree_map(lambda x: jax.numpy.take(x, i, axis=config.param_scan_axis), input_layers)

new_layer = jax.jit(slice_ith, out_shardings=new_per_layer_state_sharding)(training_state_layers)

Expand All @@ -70,7 +70,7 @@ def slice_ith(input_layers):
del training_state.params["params"]["decoder"]["layers"]
del training_state_annotations.params["params"]["decoder"]["layers"]

jax.tree_map(lambda x: x.delete(), training_state_layers)
jax.tree_util.tree_map(lambda x: x.delete(), training_state_layers)


def _read_train_checkpoint(config, checkpoint_manager, mesh):
Expand All @@ -90,7 +90,7 @@ def _read_train_checkpoint(config, checkpoint_manager, mesh):
def _save_decode_checkpoint(config, state, checkpoint_manager):
"""Generate checkpoint for decode from the training_state."""
with jax.spmd_mode("allow_all"):
decode_state = max_utils.init_decode_state(None, jax.tree_map(lambda x: x.astype(jax.numpy.bfloat16), state.params))
decode_state = max_utils.init_decode_state(None, jax.tree_util.tree_map(lambda x: x.astype(jax.numpy.bfloat16), state.params))
if checkpoint_manager is not None:
if save_checkpoint(checkpoint_manager, 0, decode_state):
max_logging.log(f"saved an decode checkpoint at {config.checkpoint_dir}")
Expand Down
2 changes: 1 addition & 1 deletion MaxText/input_pipeline/input_pipeline_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self, config, mesh):
self.mesh = mesh
self.config = config
data_pspec = P(*config.data_sharding)
data_pspec_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
data_pspec_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
self.data_generator = jax.jit(
SyntheticDataIterator.raw_generate_synthetic_data, out_shardings=data_pspec_shardings, static_argnums=0
)
Expand Down
2 changes: 1 addition & 1 deletion MaxText/llama_or_mistral_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def checkpoint_device_put(arr):
return jax.device_put(arr, device=s3)

# convert all weights to jax.numpy with sharding if applicable
jax_weights = jax.tree_map(checkpoint_device_put, jax_weights)
jax_weights = jax.tree_util.tree_map(checkpoint_device_put, jax_weights)

# dummy configs for the checkpoint_manager
step_number_to_save_new_ckpt = 0
Expand Down
4 changes: 2 additions & 2 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def find_nans_and_infs(pytree):
def finder(x):
return jnp.any(jnp.isinf(x) | jnp.isnan(x))

bad_pytree = jax.tree_map(finder, pytree)
bad_pytree = jax.tree_util.tree_map(finder, pytree)
return jax.tree_util.tree_flatten(bad_pytree)


Expand Down Expand Up @@ -660,7 +660,7 @@ def delete_leaf(leaf):
leaf.delete()
del leaf

jax.tree_map(delete_leaf, p)
jax.tree_util.tree_map(delete_leaf, p)


def summarize_pytree_data(params, name="Params", raw=False):
Expand Down
12 changes: 6 additions & 6 deletions MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ def load_params(self, *args, **kwargs) -> Params:
"""Load Parameters, typically from GCS"""
# pylint: disable=unused-argument
state, self.state_mesh_annotations = max_utils.setup_decode_state(self.model, self.config, self.rng, self._mesh, None)
self.abstract_params = jax.tree_map(
self.abstract_params = jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), state.params
)
self.kv_cache_annotations = max_utils.get_kv_cache_annotations(self.model, self.config, self.rng, self._mesh)
self.kv_cache_shardings = jax.tree_map(lambda x: jax.sharding.NamedSharding(self._mesh, x), self.kv_cache_annotations)
self.kv_cache_shardings = jax.tree_util.tree_map(lambda x: jax.sharding.NamedSharding(self._mesh, x), self.kv_cache_annotations)

if not self.model.quant:
self.abstract_params = jax.tree_map(
self.abstract_params = jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), state.params
)
return state.params
Expand All @@ -113,7 +113,7 @@ def model_apply(_p, _rng):
# Remove param values which have corresponding qtensors in aqt to save memory.
params["params"] = quantizations.remove_quantized_params(state.params["params"], new_vars["aqt"])

self.abstract_params = jax.tree_map(
self.abstract_params = jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(shape=x.shape, dtype=x.dtype, sharding=x.sharding), params
)

Expand Down Expand Up @@ -342,13 +342,13 @@ def init(abstract_params):
with self._mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
mesh_annotations = nn.logical_to_mesh(logical_annotations)

shardings = jax.tree_map(
shardings = jax.tree_util.tree_map(
lambda mesh_annotation: jax.sharding.NamedSharding(self._mesh, mesh_annotation), mesh_annotations
)

@functools.partial(jax.jit, out_shardings=shardings)
def initialize():
return jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), abstract_outputs)
return jax.tree_util.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), abstract_outputs)

cache = initialize()["cache"]

Expand Down
8 changes: 4 additions & 4 deletions MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def get_functional_train_with_signature(train_step, mesh, state_mesh_annotations
functional_train = get_functional_train_step(train_step, model, config)
functional_train.__name__ = "train_step"
data_pspec = P(*config.data_sharding)
state_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations)
data_sharding = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
state_mesh_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations)
data_sharding = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng
out_shardings = (state_mesh_shardings, None) # State, metrics
static_argnums = () # We partial out the static argnums of model and config
Expand All @@ -51,8 +51,8 @@ def get_functional_eval_with_signature(eval_step, mesh, state_mesh_annotations,
functional_eval = get_functional_eval_step(eval_step, model, config)
functional_eval.__name__ = "eval_step"
data_pspec = P(*config.data_sharding)
state_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations)
data_sharding = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
state_mesh_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations)
data_sharding = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng
out_shardings = None # metrics
static_argnums = () # We partial out the static argnums of model, config
Expand Down
12 changes: 6 additions & 6 deletions MaxText/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,19 @@ def _update_momentum(update, mu, nu):
nu = (1.0 - beta2_decay) * (update**2) + beta2_decay * nu
return _slot_opt_state(mu=mu, nu=nu)

updated_moments = jax.tree_map(_update_momentum, updates, state.mu, state.nu)
updated_moments = jax.tree_util.tree_map(_update_momentum, updates, state.mu, state.nu)

mu = jax.tree_map(lambda x: x.mu, updated_moments)
nu = jax.tree_map(lambda x: x.nu, updated_moments)
mu = jax.tree_util.tree_map(lambda x: x.mu, updated_moments)
nu = jax.tree_util.tree_map(lambda x: x.nu, updated_moments)

updates = jax.tree_map(lambda mu, nu: mu / (jnp.sqrt(nu + epsilon_root) + epsilon), mu, nu)
updates = jax.tree_util.tree_map(lambda mu, nu: mu / (jnp.sqrt(nu + epsilon_root) + epsilon), mu, nu)

if weight_decay > 0:
updates = jax.tree_map(lambda x, v: x + weight_decay * v, updates, params)
updates = jax.tree_util.tree_map(lambda x, v: x + weight_decay * v, updates, params)

step_size = -1.0 * learning_rate_fn(count)
# Finally, fold in step size.
updates = jax.tree_map(lambda x: step_size * x, updates)
updates = jax.tree_util.tree_map(lambda x: step_size * x, updates)

updated_states = optax.ScaleByAdamState(count=count + 1, mu=mu, nu=nu)
return updates, updated_states
Expand Down
10 changes: 5 additions & 5 deletions pedagogical_examples/shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,28 +207,28 @@ def multiply_layers_with_loss(in_act, in_layers):

def training_step(in_act, in_layers):
_, grad_layers = multiply_layers_and_grad(in_act, in_layers)
out_layers = jax.tree_map(lambda param, grad: param - 1e-4 * grad, in_layers, grad_layers[0])
out_layers = jax.tree_util.tree_map(lambda param, grad: param - 1e-4 * grad, in_layers, grad_layers[0])
return out_layers

print("finished includes ", flush=True)

replicated_sharding = jax.sharding.NamedSharding(mesh, data_sharding)

parameter_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding)
parameter_mesh_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding)

data_pspec_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding)
data_pspec_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding)

jit_func = jax.jit(
training_step,
in_shardings=(replicated_sharding, parameter_mesh_shardings),
out_shardings=data_pspec_shardings,
)

data_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_sharding)
data_mesh_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_sharding)

jit_gen_data = jax.jit(gen_data, in_shardings=None, out_shardings=data_mesh_shardings)

parameter_mesh_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding)
parameter_mesh_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding)

jit_gen_layers = jax.jit(gen_layers, in_shardings=None, out_shardings=parameter_mesh_shardings)

Expand Down

0 comments on commit 967941b

Please sign in to comment.