Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix GPU support for JAX #1179

Merged
merged 2 commits into from
Jul 20, 2022
Merged

Fix GPU support for JAX #1179

merged 2 commits into from
Jul 20, 2022

Conversation

rosbo
Copy link
Contributor

@rosbo rosbo commented Jul 19, 2022

Fixes #1178

Added also a test to prevent regression.

http:https://b/239603020

Added also a test to prevent regression.

http:https://b/239603020
@rosbo rosbo requested a review from Philmod July 19, 2022 22:49
@rosbo rosbo merged commit f82a0ad into main Jul 20, 2022
@rosbo rosbo deleted the fix-gpu-JAX branch July 20, 2022 02:40
@darien-schettler
Copy link

darien-schettler commented Aug 2, 2022

The original problem is fixed but support for GPU is still not "fixed". It throws warnings for every single operation.

2022-08-02 14:12:00.008264: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:111] *** WARNING *** You are using ptxas 11.0.221, which is older than 11.1. ptxas before 11.1 is known to miscompile XLA code, leading to incorrect results or invalid-address errors.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.
2022-08-02 14:12:00.039652: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:111] *** WARNING *** You are using ptxas 11.0.221, which is older than 11.1. ptxas before 11.1 is known to miscompile XLA code, leading to incorrect results or invalid-address errors.

You may not need to update to CUDA 11.1; cherry-picking the ptxas binary is often sufficient.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Jax is not compiled for GPU
3 participants