Skip to content

Commit

Permalink
Avoid use of deprecated device_buffer attriutes of jax.Array
Browse files Browse the repository at this point in the history
These have been deprecated as of JAX v0.4.22

PiperOrigin-RevId: 589254584
  • Loading branch information
Jake VanderPlas authored and romanngg committed Dec 11, 2023
1 parent 429dc13 commit ad3d524
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
5 changes: 4 additions & 1 deletion neural_tangents/_src/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,7 +1347,10 @@ def _is_on_cpu(x: PyTree) -> bool:
def _arr_is_on_cpu(x: jnp.ndarray) -> bool:
# TODO(romann): revisit when https://github.com/google/jax/issues/1431 and
# https://github.com/google/jax/issues/1432 are fixed.
if hasattr(x, 'device_buffer'):
if hasattr(x, 'addressable_shards'):
# device_buffer is deprecated, so try addressable_shards first.
return 'cpu' in str(x.addressable_shards[0].device).lower()
elif hasattr(x, 'device_buffer'):
return 'cpu' in str(x.device_buffer.device()).lower()

if isinstance(x, (np.ndarray, jnp.ndarray)):
Expand Down
8 changes: 5 additions & 3 deletions tests/predict_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,9 +857,11 @@ def testPredictOnCPU(self):
def is_on_cpu(x):
return jax.tree_util.tree_all(
jax.tree_map(
lambda x: 'cpu' in str(x.device_buffer.device()
).lower(),
x))
lambda x: 'cpu'
in str(x.addressable_shards[0].device).lower(),
x,
)
)

self.assertEqual(on_cpu, is_on_cpu(predict_inf))
self.assertEqual(on_cpu, is_on_cpu(predict_none))
Expand Down

0 comments on commit ad3d524

Please sign in to comment.