Skip to content

Commit

Permalink
[LSC] change uses of jax.random.KeyArray and jax.random.PRNGKeyArray …
Browse files Browse the repository at this point in the history
…to jax.Array

This change replaces uses of jax.random.KeyArray and jax.random.PRNGKeyArray in the context of type annotations with jax.Array, which is the correct annotation for JAX PRNG keys moving forward.

The purpose of this change is to remove references to KeyArray and PRNGKeyArray, which are deprecated (google/jax#17594) and will soon be removed from JAX. The design and thought process behind this is described in https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html.

Note that KeyArray and PRNGKeyArray have always been aliased to Any, so the new type annotation is far more specific than the old one.

PiperOrigin-RevId: 574254218
  • Loading branch information
Jake VanderPlas authored and romanngg committed Nov 21, 2023
1 parent 136338d commit d816c8f
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 7 deletions.
7 changes: 4 additions & 3 deletions neural_tangents/_src/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from .batching import batch
from .empirical import empirical_kernel_fn, NtkImplementation, DEFAULT_NTK_IMPLEMENTATION, _DEFAULT_NTK_FWD, _DEFAULT_NTK_S_RULES, _DEFAULT_NTK_J_RULES
import jax
from jax import random
import jax.numpy as jnp
from jax.tree_util import tree_map
Expand All @@ -54,7 +55,7 @@ def _sample_once_kernel_fn(
def kernel_fn_sample_once(
x1: NTTree[jnp.ndarray],
x2: Optional[NTTree[jnp.ndarray]],
key: random.KeyArray,
key: jax.Array,
get: Get,
**apply_fn_kwargs):
init_key, dropout_key = random.split(key, 2)
Expand All @@ -66,7 +67,7 @@ def kernel_fn_sample_once(

def _sample_many_kernel_fn(
kernel_fn_sample_once,
key: random.KeyArray,
key: jax.Array,
n_samples: set[int],
get_generator: bool):
def normalize(sample: PyTree, n: int) -> PyTree:
Expand Down Expand Up @@ -117,7 +118,7 @@ def get_sampled_kernel(
def monte_carlo_kernel_fn(
init_fn: InitFn,
apply_fn: ApplyFn,
key: random.KeyArray,
key: jax.Array,
n_samples: Union[int, Iterable[int]],
batch_size: int = 0,
device_count: int = -1,
Expand Down
3 changes: 2 additions & 1 deletion neural_tangents/_src/stax/combinators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import warnings

import frozendict
import jax
from jax import random, lax
import jax.example_libraries.stax as ostax
from .requirements import Diagonal, get_req, layer, requires
Expand Down Expand Up @@ -174,7 +175,7 @@ def parallel(*layers: Layer) -> InternalLayer:
init_fns, apply_fns, kernel_fns = zip(*layers)
init_fn_stax, apply_fn_stax = ostax.parallel(*zip(init_fns, apply_fns))

def init_fn(rng: random.KeyArray, input_shape: Shapes):
def init_fn(rng: jax.Array, input_shape: Shapes):
return type(input_shape)(init_fn_stax(rng, input_shape))

def apply_fn(params, inputs, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions neural_tangents/_src/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from typing import Any, Generator, Optional, Sequence, TYPE_CHECKING, TypeVar, Union, Protocol

from jax import random
import jax
import jax.numpy as jnp

from .kernel import Kernel
Expand Down Expand Up @@ -77,7 +77,7 @@ class InitFn(Protocol):

def __call__(
self,
rng: random.KeyArray,
rng: jax.Array,
input_shape: Shapes,
**kwargs
) -> tuple[Shapes, PyTree]:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def mask(
x: jnp.ndarray,
mask_constant: Optional[float],
mask_axis: Sequence[int],
key: jax.random.KeyArray,
key: jax.Array,
p: float
) -> jnp.ndarray:
if mask_constant is not None:
Expand Down

0 comments on commit d816c8f

Please sign in to comment.