Skip to content

Commit

Permalink
Remove tests for zeros_like and add_any primitives from neural_tangents.
Browse files Browse the repository at this point in the history
These primitives have been removed from JAX at head, and these tests will fail with a current JAX.

PiperOrigin-RevId: 595974695
  • Loading branch information
hawkinsp authored and romanngg committed Feb 1, 2024
1 parent ad47437 commit f96f176
Showing 1 changed file with 0 additions and 6 deletions.
6 changes: 0 additions & 6 deletions tests/rules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,6 @@ def _concat_shapes(max_n_args: int = 4, *shapes):

jax.lax.copy_p: lambda s, _: [{}],

ad.zeros_like_p: lambda s, _: [{}],

lax.neg_p: lambda s, _: [{}],

lax.transpose_p: lambda s, _: [
Expand Down Expand Up @@ -372,10 +370,6 @@ def _concat_shapes(max_n_args: int = 4, *shapes):


_BINARY_PRIMITIVES = {
# TODO(romann): what is the purpose of this primitive?
ad.add_jaxvals_p:
lambda s1, s2: ([{}] if s1 == s2 else []),

lax.mul_p:
lambda s1, s2: ([{}] if _is_broadcastable(s1, s2) else []),

Expand Down

0 comments on commit f96f176

Please sign in to comment.