Skip to content

Commit

Permalink
Merge pull request #635 from google:mattdavidow-numpy-prod
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631470826
  • Loading branch information
maxtext authors committed May 7, 2024
2 parents a28f518 + 3075bbe commit d590328
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_typ
), f"Found unspecified values (-1) for more than one {parallelism_type}\
parallelism axis. At most one axis can be unspecified."

determined_val = target_product / np.product(parallelism_vals) * -1
determined_val = target_product / np.prod(parallelism_vals) * -1

assert (
determined_val >= 1 and determined_val.is_integer
Expand All @@ -301,9 +301,9 @@ def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_typ
target_type = "slices" if parallelism_type == "DCN" else "devices per slice"

assert (
np.product(parallelism_vals) == target_product
np.prod(parallelism_vals) == target_product
), f"Number of {target_type} {target_product} does not match\
the product of the {parallelism_type} parallelism {np.product(parallelism_vals)}"
the product of the {parallelism_type} parallelism {np.prod(parallelism_vals)}"

return parallelism_vals

Expand Down
4 changes: 2 additions & 2 deletions pedagogical_examples/shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ def simple_timeit(f, tries=5, verbose=True):

# Assert that we have correct inputs of sharding that fit the number of chips
assert (
np.product(dcn_parallelism) * np.product(ici_parallelism) == num_devices
np.prod(dcn_parallelism) * np.prod(ici_parallelism) == num_devices
), f"Number of devices {num_devices} \
does not match the product of the parallelism {np.product(dcn_parallelism) * np.product(ici_parallelism)}"
does not match the product of the parallelism {np.prod(dcn_parallelism) * np.prod(ici_parallelism)}"

multi_slice_env = hasattr(jax.devices()[0], "slice_index")
# Create device mesh
Expand Down

0 comments on commit d590328

Please sign in to comment.