Everything you want to know about Google Cloud TPU
- 1. Community
- 2. Introduction to TPU
- 3. Introduction to the TRC Program
- 4. Create a TPU VM Instance
- 5. Environment Setup
- 6. Development Environment Setup
- 7. JAX Basics
- 8. TPU Best Practices
- 9. JAX Best Practices
- 9.1. Import convention
- 9.2. Manage random keys in JAX
- 9.3. Serialize model parameters
- 9.4. Conversion between NumPy arrays and JAX arrays
- 9.5. Conversion between PyTorch tensors and JAX arrays
- 9.6. Type annotation
- 9.7. Check if an array is either a NumPy array or a JAX array
- 9.8. Get the shapes of all parameters in a nested dictionary
- 9.9. The correct way to generate random numbers on CPU
- 9.10. Use optimizers from Optax
- 9.11. Use the cross-entropy loss implementation from Optax
- 10. Working With Pods
- 11. Common Gotchas
- 11.1. External IP of TPU machine changes occasionally
- 11.2. One TPU device can only be used by one process at a time
- 11.3. TCMalloc breaks several programs
- 11.4. There is no TPU counterpart of nvidia-smi
- 11.5. libtpu.so already in used by another process
- 11.6. JAX does not support the multiprocessing fork strategy
This project is inspired by Cloud Run FAQ, a community-maintained knowledge base of another Google Cloud product.
As of 23 Feb 2022, there is no official chat group for Cloud TPUs. You can join the @cloudtpu chat group on Telegram or TPU Podcast on Discord, which are connected with each other. There is also an official TRC Cloud TPU v4 user group in Google Chat.
TL;DR: TPU is to GPU as GPU is to CPU.
TPU is a special hardware designed specifically for machine learning. There is a performance comparison in Hugging Face Transformers:
Moreover, for researchers, the TRC program provides free TPU. As far as I know, if you have ever been concerned about the computing resources for training models, this is the best solution. For more details on the TRC program, see below.
If you want to use PyTorch, TPU may not be suitable for you. TPU is poorly supported by PyTorch. In one of my experiments, one batch took about 14 seconds to run on CPU, but over 4 hours to run on TPU. Twitter user @mauricetpunkt also thinks PyTorch's performance on TPUs is bad.
Another problem is that although a single TPU v3-8 device has 8 cores (16 GiB memory for each core), you need to write extra code to make use of all the 8 cores (see below). Otherwise, only the first core is used.
Unfortunately, in most cases you cannot touch a TPU physically. TPU is only available through cloud services.
You can create TPU instances on Google Cloud Platform. For more information on setting up TPU, see below.
You can also use Google Colab, but I don't recommend this way. Moreover, if you get free access to TPU from the TRC program, you will be using Google Cloud Platform, not Google Colab.
After creating a TPU v3-8 instance on Google Cloud Platform, you will get a Ubuntu 20.04 cloud server with sudo access, 96 cores, 335 GiB memory and one TPU device with 8 cores (128 GiB TPU memory in total).
This is similar to the way we use GPU. In most cases, when you use a GPU, you use a Linux server that connects with a GPU. When you use a TPU, you use a Linux server that connects with a TPU.
Besides its homepage, Shawn has written a wonderful article about the TRC program in google/jax#2108. Anyone who is interested in TPU should read it immediately.
At the first three months, it is completely free because all the fees are covered by Google Cloud free trial. After that, I pay only about HK$13.95 (approx. US$1.78) for one month for the outbound Internet traffic.
You need to loosen the restrictions of the firewall so that Mosh and other programs will not be blocked.
Open the Firewall management page in VPC network.
Click the button to create a new firewall rule.
Set name to 'allow-all', targets to 'All instances in the network', source filter to 0.0.0.0/0, protocols and ports to 'Allow all', and then click 'Create'.
More stringently circumscribed firewall rules may be enforced for users working with confidential datasets or other situations where a high level of security is required.
Open Google Cloud Platform, navigate to the TPU management page.
Click the console button on the top-right corner to activate Cloud Shell.
In Cloud Shell, type the following command to create a Cloud TPU VM v3-8 with TPU software version v2-nightly20210914:
gcloud alpha compute tpus tpu-vm create node-1 --project tpu-develop --zone europe-west4-a --accelerator-type v3-8 --version v2-nightly20210914
If the command fails because there are no more TPUs to allocate, you can re-run the command again.
Besides, It is more convinent to have the gcloud
command installed on your local machine, so that you will not need to open a Cloud Shell to run the command.
To create a TPU Pod, run the following command:
gcloud alpha compute tpus tpu-vm create node-3 --project tpu-advanced-research --zone us-central2-b --accelerator-type v4-16 --version v2-alpha-tpuv4
To SSH into the TPU VM:
gcloud alpha compute tpus tpu-vm ssh node-1 --zone europe-west4-a
To SSH into one of the TPU Pods:
gcloud alpha compute tpus tpu-vm ssh node-3 --zone us-central2-b --worker 0
Save the following script to setup.sh
and run the script.
gcloud alpha compute tpus tpu-vm ssh node-2 --zone us-central2-b --worker all --command '
# Confirm that the script is running on the host
uname -a
# Install common packages
export DEBIAN_FRONTEND=noninteractive
sudo apt-get update -y -qq
sudo apt-get upgrade -y -qq
sudo apt-get install -y -qq golang neofetch zsh mosh byobu aria2
# Install Python 3.10
sudo apt-get install -y -qq software-properties-common
sudo add-apt-repository -y ppa:deadsnakes/ppa
sudo apt-get install -y -qq python3.10-full python3.10-dev
# Install Oh My Zsh
sh -c "$(curl -fsSL https://raw.githubusercontent.com/ohmyzsh/ohmyzsh/master/tools/install.sh)" "" --unattended
sudo chsh $USER -s /usr/bin/zsh
# Change timezone
# timedatectl list-timezones # list timezones
sudo timedatectl set-timezone Asia/Hong_Kong # change to your timezone
# Create venv
python3.10 -m venv $HOME/.venv310
. $HOME/.venv310/bin/activate
# Install JAX with TPU support
pip install -U pip
pip install -U wheel
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
'
The script will create a venv in ~/.venv310
, so you will need to run the . ~/.venv310/bin/activate
command when you activate a shell, or call the Python interpreter with ~/.venv310/bin/python
.
Clone this repository. In the root directory of this repository, run:
pip install -r requirements.txt
If you connect to the server directly with SSH, there is a risk of loss of connection. If this happens, the training script you are running in the foreground will be terminated.
Mosh and Byobu are two programs to solve this problem. Byobu will ensure that the script continues to run on the server even if the connection is lost, while Mosh guarantees that the connection will not be lost.
Install Mosh on your local device, then log in into the server with:
mosh tpu1 -- byobu
You can learn more about Byobu from the video Learn Byobu while listening to Mozart.
Open VSCode. Open the 'Extensions' panel on the left. Search for 'Remote - SSH' and install.
Press F1 to open the command palette. Type 'ssh', then select 'Remote-SSH: Connect to Host...'. Input the server name you would like to connect and press Enter.
Wait for VSCode to be set up on the server. After it is finished, you can develop on the server using VSCode.
Run this command:
~/.venv310/bin/python -c 'import jax; print(jax.devices())' # should print TpuDevice
For TPU Pods, run the following command locally:
gcloud alpha compute tpus tpu-vm ssh node-2 --zone us-central2-b --worker all --command '~/.venv310/bin/python -c "import jax; jax.process_index() == 0 and print(jax.devices())"'
JAX is the next generation of deep learning libraries, with excellent support for TPU. To get started quickly with JAX, you can read the official tutorial.
There are four key points here.
1. params
and opt_state
should be replicated across the devices:
replicated_params = jax.device_put_replicated(params, jax.devices())
2. data
and labels
should be split to the devices:
n_devices = jax.device_count()
batch_size, *data_shapes = data.shape
assert batch_size % n_devices == 0, 'The data cannot be split evenly to the devices'
data = data.reshape(n_devices, batch_size // n_devices, *data_shapes)
3. Decorate the target function with jax.pmap
:
@partial(jax.pmap, axis_name='num_devices')
4. In the loss
function, use jax.lax.pmean
to calculate the mean value across devices:
grads = jax.lax.pmean(grads, axis_name='num_devices') # calculate mean across devices
See 01-basics/test_pmap.py for a complete working example.
See also https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html#example.
key, subkey = (lambda keys: (keys[0], keys[1:]))(rand.split(key, num=9))
Note that you cannot use the regular way to split the keys:
key, *subkey = rand.split(key, num=9)
Because in this way, subkey
is a list rather than an array.
opt_state
should be replicated as well.
Use optax.set_to_zero
together with optax.multi_transform
.
params = {
'a': { 'x1': ..., 'x2': ... },
'b': { 'x1': ..., 'x2': ... },
}
param_labels = {
'a': { 'x1': 'freeze', 'x2': 'train' },
'b': 'train',
}
optimizer_scheme = {
'train': optax.adam(...),
'freeze': optax.set_to_zero(),
}
optimizer = optax.multi_transform(optimizer_scheme, param_labels)
See Freeze Parameters Example for details.
Google Colab only provides TPU v2-8 devices, while on Google Cloud Platform you can select TPU v2-8 and TPU v3-8.
Besides, on Google Colab you can only use TPU through the Jupyter Notebook interface. Even if you log in into the Colab server via SSH, it is a docker image and you don't have root access. On Google Cloud Platform, however, you have full access to the TPU VM.
If you really want to use TPU on Google Colab, you need to run the following script to set up TPU:
import jax
from jax.tools.colab_tpu import setup_tpu
setup_tpu()
devices = jax.devices()
print(devices) # should print TpuDevice
When you are creating a TPU instance, you need to choose between TPU VM and TPU node. Always prefer TPU VM because it is the new architecture in which TPU devices are connected to the host VM directly. This will make it easier to set up the TPU device.
After setting up Remote-SSH, you can work with Jupyter notebook files in VSCode.
Alternatively, you can run a regular Jupyter Notebook server on the TPU VM, forward the port to your PC and connect to it. However, you should prefer VSCode because it is more powerful, offers better integration with other tools and is easier to set up.
TPU VM instances in the same zone are connected with internal IPs, so you can create a shared file system using NFS.
Example: Tensorboard
Although every TPU VM is allocated with a public IP, in most cases you should expose a server to the Internet because it is insecure.
Port forwarding via SSH
ssh -C -N -L 127.0.0.1:6006:127.0.0.1:6006 tpu1
You may see two different kind of import conventions. One is to import jax.numpy as np and import the original numpy as onp. Another one is to import jax.numpy as jnp and leave original numpy as np.
On 16 Jan 2019, Colin Raffel wrote in a blog article that the convention at that time was to import original numpy as onp.
On 5 Nov 2020, Niru Maheswaranathan said in a tweet that he thinks the convention at that time was to import jax as jnp and to leave original numpy as np.
TODO: Conclusion?
The regular way is this:
key, *subkey = rand.split(key, num=4)
print(subkey[0])
print(subkey[1])
print(subkey[2])
Normally, the model parameters are represented by a nested dictionary like this:
{
"embedding": DeviceArray,
"ff1": {
"kernel": DeviceArray,
"bias": DeviceArray
},
"ff2": {
"kernel": DeviceArray,
"bias": DeviceArray
}
}
You can use flax.serialization.msgpack_serialize
to serialize the parameters into bytes, and use flax.serialization.msgpack_restore
to convert them back.
Use np.asarray
and onp.asarray
.
import jax.numpy as np
import numpy as onp
a = np.array([1, 2, 3]) # JAX array
b = onp.asarray(a) # converted to NumPy array
c = onp.array([1, 2, 3]) # NumPy array
d = np.asarray(c) # converted to JAX array
Convert a PyTorch tensor to a JAX array:
import jax.numpy as np
import torch
a = torch.rand(2, 2) # PyTorch tensor
b = np.asarray(a.numpy()) # JAX array
Convert a JAX array to a PyTorch tensor:
import jax.numpy as np
import numpy as onp
import torch
a = np.zeros((2, 2)) # JAX array
b = torch.from_numpy(onp.asarray(a)) # PyTorch tensor
This will result in a warning:
UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:178.)
If you need writable tensors, you can use onp.array
instead of onp.asarray
to make a copy of the original array.
isinstance(a, (np.ndarray, onp.ndarray))
jax.tree_map(lambda x: x.shape, params)
Use the jax.default_device() context manager:
import jax
import jax.random as rand
device_cpu = jax.devices('cpu')[0]
with jax.default_device(device_cpu):
key = rand.PRNGKey(42)
a = rand.poisson(key, 3, shape=(1000,))
print(a.device()) # TFRT_CPU_0
See jax-ml/jax#9691 (comment).
optax.softmax_cross_entropy_with_integer_labels
See also: §8.4.
#!/bin/bash
while read p; do
ssh "$p" "cd $PWD; rm -rf /tmp/libtpu_lockfile /tmp/tpu_logs; . ~/.venv310/bin/activate; $@" &
done < external-ips.txt
rm -rf /tmp/libtpu_lockfile /tmp/tpu_logs; . ~/.venv310/bin/activate; "$@"
wait
As of 17 Jul 2022, the external IP address may change if there is a maintenance event.
Therefore, we should use gcloud
command instead of directly connect to it with SSH. However, if we want to use VSCode, SSH is the only choice.
The system will also be rebooted.
Unlike GPU, you will get an error if you run two processes on TPU at a time:
I0000 00:00:1648534265.148743 625905 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process.
Even if a TPU device has 8 cores and one process only utilizes the first core, the other processes will not be able to utilize the rest of the cores.
TCMalloc is Google's customized memory allocation library. On TPU VM, LD_PRELOAD
is set to use TCMalloc by default:
$ echo LD_PRELOAD
/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
However, using TCMalloc in this manner may break several programs like gsutil:
$ gsutil --help
/snap/google-cloud-sdk/232/platform/bundledpythonunix/bin/python3: /snap/google-cloud-sdk/232/platform/bundledpythonunix/bin/../../../lib/x86_64-linux-gnu/libm.so.6: version `GLIBC_2.29' not found (required by /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4)
The homepage of TCMalloc also indicates that LD_PRELOAD
is tricky and this mode of usage is not recommended.
If you encounter problems related to TCMalloc, you can disable it in the current shell using the command:
unset LD_PRELOAD
See https://twitter.com/ayaka14732/status/1565016471323156481.
See google/jax#9756.
if ! pgrep -a -u $USER python ; then
killall -q -w -s SIGKILL ~/.venv310/bin/python
fi
rm -rf /tmp/libtpu_lockfile /tmp/tpu_logs
See also jax-ml/jax#9220 (comment).
Use the spawn
or forkserver
strategies.