diff --git a/setup.sh b/setup.sh index 1b41f9e53..f653c40b8 100644 --- a/setup.sh +++ b/setup.sh @@ -148,8 +148,8 @@ elif [[ $MODE == "nightly" ]]; then exit 1 fi echo "Installing jax-head, jaxlib-nightly" - # Install jax from GitHub head - pip3 install git+https://github.com/google/jax + # Install jax-nightly + pip3 install --pre -U jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html # Install jaxlib-nightly pip3 install --pre -U jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html