Skip to content

Commit

Permalink
Load correct libtpu for pytorch & jax.
Browse files Browse the repository at this point in the history
It turns out jax & pytorch are incompatible (require different libtpu
versions). In order to support importing EITHER of them (but not both)
we will swap in the correct libtpu during import (by monkey-patching the
import code for both).

http:https://b/213335159
  • Loading branch information
djherbis committed May 26, 2022
1 parent 11f01ce commit a54fc96
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
15 changes: 11 additions & 4 deletions tpu/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
ARG BASE_IMAGE_TAG
ARG LIBTPU_IMAGE_TAG
ARG TENSORFLOW_VERSION

FROM gcr.io/cloud-tpu-v2-images/libtpu:${LIBTPU_IMAGE_TAG} as libtpu
FROM gcr.io/kaggle-images/python-tpu-tensorflow-whl:python-${BASE_IMAGE_TAG}-${TENSORFLOW_VERSION} AS tensorflow_whl
FROM gcr.io/kaggle-images/python:${BASE_IMAGE_TAG}

Expand All @@ -12,20 +10,29 @@ ARG TORCH_VERSION

ENV ISTPUVM=1

COPY --from=libtpu /libtpu.so /lib

COPY --from=tensorflow_whl /tmp/tensorflow_pkg/tensorflow*.whl /tmp/tensorflow_pkg/
RUN pip install /tmp/tensorflow_pkg/tensorflow*.whl && \
rm -rf /tmp/tensorflow_pkg && \
/tmp/clean-layer.sh

# LIBTPU installed here:
ENV DEFAULT_LIBTPU=/opt/conda/lib/python3.7/site-packages/libtpu/libtpu.so
ENV PYTORCH_LIBTPU=/opt/conda/lib/python3.7/site-packages/libtpu/torch-libtpu.so
ENV JAX_LIBTPU=/opt/conda/lib/python3.7/site-packages/libtpu/jax-libtpu.so

# https://cloud.google.com/tpu/docs/pytorch-xla-ug-tpu-vm#changing_pytorch_version
RUN pip uninstall -y torch && \
pip install torch==${TORCH_VERSION} && \
# The URL doesn't include patch version. i.e. must use 1.11 instead of 1.11.0
pip install torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-${TORCH_VERSION%.*}-cp37-cp37m-linux_x86_64.whl && \
cp $DEFAULT_LIBTPU $PYTORCH_LIBTPU && \
/tmp/clean-layer.sh

# https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm#install_jax_on_your_cloud_tpu_vm
RUN pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \
cp $DEFAULT_LIBTPU $JAX_LIBTPU && \
/tmp/clean-layer.sh

# Monkey-patch JAX & PYTORCH to load the correct libtpu.so when they are imported:
RUN sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${PYTORCH_LIBTPU}'|" /opt/conda/lib/python3.7/site-packages/torch_xla/__init__.py && \
sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${JAX_LIBTPU}'|" /opt/conda/lib/python3.7/site-packages/jax/_src/cloud_tpu_init.py
3 changes: 1 addition & 2 deletions tpu/config.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# TODO(b/213335159): Use ci-pretest for BASE_IMAGE_TAG once stable.
BASE_IMAGE_TAG=v108
LIBTPU_IMAGE_TAG=libtpu_1.1.0_RC00
BASE_IMAGE_TAG=v115
TENSORFLOW_VERSION=2.8.0
TORCH_VERSION=1.11.0

0 comments on commit a54fc96

Please sign in to comment.