Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Mesh setup for multiprocess CPUs. #723

Merged
merged 1 commit into from
Jun 24, 2024
Merged

Conversation

RoshaniN
Copy link
Collaborator

  1. The condition for JAX distributed initialize needed to be reordered to allow triggering GPU and CPU initializations.
    The following condition will evaluate to true due to base.yaml defaults, unless overridden from command line.
if (
      raw_keys["enable_checkpointing"]
      and raw_keys["async_checkpointing"]
      and raw_keys["compile_topology_num_slices"] == -1
      and not raw_keys["enable_single_controller"]
  )
  1. get_num_slices(raw_keys): has some new logic that computes number of slices and number of devices in a slice. Adapting this logic to suit CPUs ( slices have no meaning for CPUs, because they do not support hierarchical network. Hence setting num_slices to 1 and allowing existing ICI parallelism logic in max_utils).

Testing on multiprocess CPUs -
I tested standalone_checkpointer.py end-to-end (on 2 nodepools, with 2 hosts each) to verify this change.

@copybara-service copybara-service bot merged commit 5db6d73 into main Jun 24, 2024
14 checks passed
@copybara-service copybara-service bot deleted the fix_jax_coordinator branch June 24, 2024 17:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants