Skip to content

Commit

Permalink
Fix GPU support for JAX (Kaggle#1179)
Browse files Browse the repository at this point in the history
* Fix GPU support for JAX

Added also a test to prevent regression.

http:https://b/239603020

* Rename var in test
  • Loading branch information
rosbo committed Jul 20, 2022
1 parent 029dea1 commit f82a0ad
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Dockerfile.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ RUN pip install lightgbm==$LIGHTGBM_VERSION && \

# Install JAX
{{ if eq .Accelerator "gpu" }}
RUN pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html && \
RUN pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \
/tmp/clean-layer.sh
{{ else }}
RUN pip install jax[cpu] && \
Expand Down
3 changes: 2 additions & 1 deletion tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import time

import jax
import jax.numpy as np

from common import gpu_test
Expand All @@ -21,4 +22,4 @@ def test_grad(self):

def test_backend(self):
expected_backend = 'cpu' if len(os.environ.get('CUDA_VERSION', '')) == 0 else 'gpu'

self.assertEqual(expected_backend, jax.default_backend())

0 comments on commit f82a0ad

Please sign in to comment.