Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Should non-pod multihost be possible on TPU v2s/v3s? #372

Closed
tensorpro opened this issue Jan 24, 2024 · 3 comments
Closed

Should non-pod multihost be possible on TPU v2s/v3s? #372

tensorpro opened this issue Jan 24, 2024 · 3 comments

Comments

@tensorpro
Copy link

tensorpro commented Jan 24, 2024

I'm trying to use MaxText's multihost_runner to run across multiple TPU v2-8/v3-8 VMs.
In doing so, I'm running into the same issue described here when I try to setup multihost: google/jax#16708

I have changed the jax.distributed.init() to take a host IP, and number of processes as inputs. This works without crashing if I set the environment variable JAX_PLATFORM_NAME=cpu. Each process sees global CPUs from other processes, and has its own jax.process_index().

But when running with the TPU backend, each VM only sees its 8 local devices, and each vm has jax.process_index()==0. This results in the maxtext mulltihost_runner failing, since multiple processes try to write checkpoints etc. Even if I disable writing, the processes aren't actually communicating. They are basically just running their own local copies of the program.

One potential culprit is that I did not used queued resources to start my TPUs. But when I try running

gcloud alpha compute tpus queued-resources create my-queued-resource --accelerator-type=v3-8 --runtime-version=tpu-vm-base --node-count=2 --zone=us-central1-a --project=myproject-123456

I get

(gcloud.alpha.compute.tpus.queued-resources.create) INVALID_ARGUMENT: Cloud TPU was unable to complete the operation. Please try again, or contact support if the problem persists. [EID: 0x9640cfdf11d084b6]

Using queued resources to allocate them one at a time (using --node-id instead of --node-count) works without issue, which leads me to suspect that if --node-count does something to connect the TPUs being spun up, it isn't supported with v2s/v3s?

Is this expected behavior?

@tensorpro tensorpro changed the title Should non-pod multihost be possible on older TPU architectures? Should non-pod multihost be possible on TPU v2s/v3s? Jan 24, 2024
@tensorpro
Copy link
Author

tensorpro commented Jan 24, 2024

I found this multislice announcement which mentions it working on v4s/v5s, with no mention of v2s / v3s. This leads me to believe I'd need to use v4s/v5s to use multihost in this setup.

@rwitten
Copy link
Collaborator

rwitten commented Jan 26, 2024

No -- this indeed isn't supported. (It would additionally be VERY hard to make it performant because the DCN bandwidth per ICI domain is so small.)

@rwitten rwitten closed this as completed Jan 26, 2024
@tensorpro
Copy link
Author

Got it, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants