-
Notifications
You must be signed in to change notification settings - Fork 247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Should non-pod multihost be possible on TPU v2s/v3s? #372
Comments
I found this multislice announcement which mentions it working on v4s/v5s, with no mention of v2s / v3s. This leads me to believe I'd need to use v4s/v5s to use multihost in this setup. |
No -- this indeed isn't supported. (It would additionally be VERY hard to make it performant because the DCN bandwidth per ICI domain is so small.) |
Got it, thanks! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I'm trying to use MaxText's multihost_runner to run across multiple TPU v2-8/v3-8 VMs.
In doing so, I'm running into the same issue described here when I try to setup multihost: google/jax#16708
I have changed the
jax.distributed.init()
to take a host IP, and number of processes as inputs. This works without crashing if I set the environment variableJAX_PLATFORM_NAME=cpu
. Each process sees global CPUs from other processes, and has its ownjax.process_index()
.But when running with the TPU backend, each VM only sees its 8 local devices, and each vm has
jax.process_index()==0
. This results in the maxtext mulltihost_runner failing, since multiple processes try to write checkpoints etc. Even if I disable writing, the processes aren't actually communicating. They are basically just running their own local copies of the program.One potential culprit is that I did not used queued resources to start my TPUs. But when I try running
I get
Using queued resources to allocate them one at a time (using
--node-id
instead of--node-count
) works without issue, which leads me to suspect that if--node-count
does something to connect the TPUs being spun up, it isn't supported with v2s/v3s?Is this expected behavior?
The text was updated successfully, but these errors were encountered: