Skip to content

Commit

Permalink
No-op refactoring: use the modern jnp/np aliases convention for `…
Browse files Browse the repository at this point in the history
…jax.numpy`/`numpy`. Standardize hanging indentation in many-arguments functions. Fix minor typos / linter issues.

PiperOrigin-RevId: 561063090
  • Loading branch information
romanngg committed Aug 29, 2023
1 parent ae888ca commit 3c3dc9f
Show file tree
Hide file tree
Showing 50 changed files with 1,878 additions and 1,710 deletions.
34 changes: 17 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ key = random.PRNGKey(1)
x = random.normal(key, (10, 100))
_, params = init_fn(key, input_shape=x.shape)

y = apply_fn(params, x) # (10, 1) np.ndarray outputs of the neural network
y = apply_fn(params, x) # (10, 1) jnp.ndarray outputs of the neural network
```

Neural Tangents is designed to serve as a drop-in replacement for `stax`, extending the `(init_fn, apply_fn)` tuple to a triple `(init_fn, apply_fn, kernel_fn)`, where `kernel_fn` is the kernel function of the infinite network (GP) of the given architecture. Below is an example of computing the covariances of the GP between two batches of inputs `x1` and `x2`.
Expand All @@ -137,8 +137,8 @@ Note that `kernel_fn` can compute _two_ covariance matrices corresponding to the

```python
# Get kernel of a single type
nngp = kernel_fn(x1, x2, 'nngp') # (10, 20) np.ndarray
ntk = kernel_fn(x1, x2, 'ntk') # (10, 20) np.ndarray
nngp = kernel_fn(x1, x2, 'nngp') # (10, 20) jnp.ndarray
ntk = kernel_fn(x1, x2, 'ntk') # (10, 20) jnp.ndarray

# Get kernels as a namedtuple
both = kernel_fn(x1, x2, ('nngp', 'ntk'))
Expand Down Expand Up @@ -169,10 +169,10 @@ predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train,
y_train)

y_test_nngp = predict_fn(x_test=x_test, get='nngp')
# (20, 1) np.ndarray test predictions of an infinite Bayesian network
# (20, 1) jnp.ndarray test predictions of an infinite Bayesian network

y_test_ntk = predict_fn(x_test=x_test, get='ntk')
# (20, 1) np.ndarray test predictions of an infinite continuous
# (20, 1) jnp.ndarray test predictions of an infinite continuous
# gradient descent trained network at convergence (t = inf)

# Get predictions as a namedtuple
Expand Down Expand Up @@ -288,22 +288,22 @@ post-activations which are substantially more nonlinear.
#### Example:

```python
import jax.numpy as np
import jax.numpy as jnp
import neural_tangents as nt

def apply_fn(params, x):
W, b = params
return np.dot(x, W) + b
return jnp.dot(x, W) + b

W_0 = np.array([[1., 0.], [0., 1.]])
b_0 = np.zeros((2,))
W_0 = jnp.array([[1., 0.], [0., 1.]])
b_0 = jnp.zeros((2,))

apply_fn_lin = nt.linearize(apply_fn, (W_0, b_0))
W = np.array([[1.5, 0.2], [0.1, 0.9]])
W = jnp.array([[1.5, 0.2], [0.1, 0.9]])
b = b_0 + 0.2

x = np.array([[0.3, 0.2], [0.4, 0.5], [1.2, 0.2]])
logits = apply_fn_lin((W, b), x) # (3, 2) np.ndarray
x = jnp.array([[0.3, 0.2], [0.4, 0.5], [1.2, 0.2]])
logits = apply_fn_lin((W, b), x) # (3, 2) jnp.ndarray
```

### Function space:
Expand All @@ -314,17 +314,17 @@ Outputs of a linearized model [evolve identically to those of an infinite one](h

```python
import jax.random as random
import jax.numpy as np
import jax.numpy as jnp
import neural_tangents as nt


def apply_fn(params, x):
W, b = params
return np.dot(x, W) + b
return jnp.dot(x, W) + b


W_0 = np.array([[1., 0.], [0., 1.]])
b_0 = np.zeros((2,))
W_0 = jnp.array([[1., 0.], [0., 1.]])
b_0 = jnp.zeros((2,))
params = (W_0, b_0)

key1, key2 = random.split(random.PRNGKey(1), 2)
Expand All @@ -341,7 +341,7 @@ t = 5.
y_train_0 = apply_fn(params, x_train)
y_test_0 = apply_fn(params, x_test)
y_train_t, y_test_t = mse_predictor(t, y_train_0, y_test_0, ntk_test_train)
# (3, 2) and (4, 2) np.ndarray train and test outputs after `t` units of time
# (3, 2) and (4, 2) jnp.ndarray train and test outputs after `t` units of time
# training with continuous gradient descent
```

Expand Down
42 changes: 28 additions & 14 deletions examples/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ def _one_hot(x, k, dtype=np.float32):
return np.array(x[:, None] == np.arange(k), dtype)


def get_dataset(name,
n_train=None,
n_test=None,
permute_train=False,
do_flatten_and_normalize=True,
data_dir=None,
input_key='image'):
def get_dataset(
name,
n_train=None,
n_test=None,
permute_train=False,
do_flatten_and_normalize=True,
data_dir=None,
input_key='image'
):
"""Download, parse and process a dataset to unit scale and one-hot labels."""
# Need this following http:https://cl/378185881 to prevent GPU test breakages.
tf.config.set_visible_devices([], 'GPU')
Expand Down Expand Up @@ -112,10 +114,17 @@ def embed_glove(xs, glove_path, max_sentence_length=1000, mask_constant=1000.):
Adapted from https://keras.io/examples/pretrained_word_embeddings/.
Args:
xs: list of string numpy arrays to embed.
glove_path: path to the GloVe embedding file.
max_sentence_length: pad/truncate embeddings to this length.
mask_constant: mask padding with this constant.
xs:
list of string numpy arrays to embed.
glove_path:
path to the GloVe embedding file.
max_sentence_length:
pad/truncate embeddings to this length.
mask_constant:
mask padding with this constant.
Returns:
xs with words replaced by word embeddings, padded/truncated to a fixed
Expand Down Expand Up @@ -157,9 +166,14 @@ def _get_glove_embedding_layer(tokenizer, glove_path, max_sentence_length):
Adapted from https://keras.io/examples/pretrained_word_embeddings/.
Args:
tokenizer: the `keras.preprocessing.text.Tokenizer` used to tokenize inputs.
glove_path: path to the GloVe embedding file.
max_sentence_length: pad/truncate embeddings to this length.
tokenizer:
the `keras.preprocessing.text.Tokenizer` used to tokenize inputs.
glove_path:
path to the GloVe embedding file.
max_sentence_length:
pad/truncate embeddings to this length.
Returns:
Keras embedding layer for a given GloVe embeddings.
Expand Down
8 changes: 4 additions & 4 deletions examples/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"""

from absl import app
from jax import numpy as np
from jax import numpy as jnp
from jax import random
from neural_tangents import stax

Expand All @@ -28,8 +28,8 @@ def main(unused_argv):
# Consider the normalized exponential kernel from
# https://arxiv.org/abs/2003.02237 (page 6).
def nngp_fn(cov12, var1, var2):
prod = np.sqrt(var1 * var2)
return prod * np.exp(cov12 / prod - 1)
prod = jnp.sqrt(var1 * var2)
return prod * jnp.exp(cov12 / prod - 1)

# This kernel has no known corresponding elementwise nonlinearity.
# `stax.Elementwise` derives the NTK kernel automatically under the hood using
Expand All @@ -50,7 +50,7 @@ def nngp_fn(cov12, var1, var2):
k_manual = kernel_fn_manual(x1, x2, 'ntk')

# The two kernels match!
assert np.max(np.abs(k_manual - k_auto)) < 1e-6
assert jnp.max(jnp.abs(k_manual - k_auto)) < 1e-6
print('NTK derived via autodiff matches the hand-derived NTK!')


Expand Down
7 changes: 4 additions & 3 deletions examples/elementwise_numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""

from absl import app
from jax import numpy as np
from jax import numpy as jnp
from jax import random
import jax.nn
from neural_tangents import stax
Expand Down Expand Up @@ -50,8 +50,9 @@ def main(unused_argv):
kernel_numerical = kernel_fn_numerical(x1, x2)

# The two kernels are close!
assert np.max(np.abs(kernel_closed_form.nngp - kernel_numerical.nngp)) < 1e-3
assert np.max(np.abs(kernel_closed_form.ntk - kernel_numerical.ntk)) < 1e-3
assert jnp.max(jnp.abs(kernel_closed_form.nngp -
kernel_numerical.nngp)) < 1e-3
assert jnp.max(jnp.abs(kernel_closed_form.ntk - kernel_numerical.ntk)) < 1e-3
print('Gaussian quadrature approximation of the kernel is accurate!')


Expand Down
6 changes: 3 additions & 3 deletions examples/empirical_ntk.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from absl import app
import jax
from jax import numpy as np
from jax import numpy as jnp
from jax import random
import neural_tangents as nt
from neural_tangents import stax
Expand Down Expand Up @@ -57,7 +57,7 @@ def main(unused_argv):
**kwargs,
implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION)

# (6, 3, 10, 10) full `np.ndarray` test-train NTK
# (6, 3, 10, 10) full `jnp.ndarray` test-train NTK
ntk_jc = jacobian_contraction(x2, x1, params)

# NTK-vector products-based implementation.
Expand All @@ -84,7 +84,7 @@ def main(unused_argv):
# Check that implementations match
for ntk1 in [ntk_jc, ntk_vp, ntk_sd, ntk_auto]:
for ntk2 in [ntk_jc, ntk_vp, ntk_sd, ntk_auto]:
diff = np.max(np.abs(ntk1 - ntk2))
diff = jnp.max(jnp.abs(ntk1 - ntk2))
print(f'NTK implementation diff {diff}.')
assert diff < (1e-4 if jax.default_backend() != 'tpu' else 0.1), diff

Expand Down
2 changes: 1 addition & 1 deletion examples/experimental/empirical_ntk_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _get_ntks(f, x1, x2, params, vmap_axes):
jacobian_contraction = nt.experimental.empirical_ntk_fn_tf(
**kwargs,
implementation=nt.NtkImplementation.JACOBIAN_CONTRACTION)
# (6, 3, 10, 10) full `np.ndarray` test-train NTK
# (6, 3, 10, 10) full `jnp.ndarray` test-train NTK
ntk_jc = jacobian_contraction(x2, x1, params)

# NTK-vector products-based implementation.
Expand Down
4 changes: 2 additions & 2 deletions examples/function_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from jax import jit
from jax import random
from jax.example_libraries import optimizers
import jax.numpy as np
import jax.numpy as jnp
import neural_tangents as nt
from neural_tangents import stax
from examples import datasets
Expand Down Expand Up @@ -58,7 +58,7 @@ def main(unused_argv):
state = opt_init(params)

# Create an mse loss function and a gradient function.
loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
loss = lambda fx, y_hat: 0.5 * jnp.mean((fx - y_hat) ** 2)
grad_loss = jit(grad(lambda params, x, y: loss(apply_fn(params, x), y)))

# Create an MSE predictor to solve the NTK equation in function space.
Expand Down
29 changes: 15 additions & 14 deletions examples/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from absl import app
from jax import random
import jax.numpy as np
import jax.numpy as jnp
import neural_tangents as nt
from neural_tangents import stax
from examples import datasets
Expand Down Expand Up @@ -93,35 +93,36 @@ def main(*args, use_dummy_data: bool = False, **kwargs) -> None:
print(f'Kernel construction and inference done in {duration} seconds.')

# Print out accuracy and loss for infinite network predictions.
loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
loss = lambda fx, y_hat: 0.5 * jnp.mean((fx - y_hat) ** 2)
util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)
util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)


def _get_dummy_data(mask_constant: float
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
def _get_dummy_data(
mask_constant: float
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Return dummy data for when downloading embeddings is not feasible."""
n_train, n_test = 6, 6

def get_x(shape, key):
key_x, key_mask = random.split(key)
x = random.normal(key_x, shape)
mask = random.bernoulli(key_mask, 0.6, shape)
x = np.where(mask, mask_constant, x)
x = jnp.where(mask, mask_constant, x)
return x

def get_y(x):
x = np.where(x == mask_constant, 0., x)
x = jnp.where(x == mask_constant, 0., x)

def weighted_sum(x, start, end):
return np.sum(x[..., start:end] *
np.arange(x.shape[1])[None, ..., None],
axis=(1, 2))

y_label = np.stack([weighted_sum(x, 0, x.shape[-1] // 2),
weighted_sum(x, x.shape[-1] // 2, x.shape[-1])],
axis=-1) > 0
y = np.where(y_label, 0.5, -0.5)
return jnp.sum(x[..., start:end] *
jnp.arange(x.shape[1])[None, ..., None],
axis=(1, 2))

y_label = jnp.stack([weighted_sum(x, 0, x.shape[-1] // 2),
weighted_sum(x, x.shape[-1] // 2, x.shape[-1])],
axis=-1) > 0
y = jnp.where(y_label, 0.5, -0.5)
return y

rng_train, rng_test = random.split(random.PRNGKey(1), 2)
Expand Down
4 changes: 2 additions & 2 deletions examples/infinite_fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import time
from absl import app
import jax.numpy as np
import jax.numpy as jnp
import neural_tangents as nt
from neural_tangents import stax
from examples import datasets
Expand Down Expand Up @@ -62,7 +62,7 @@ def main(unused_argv):
print('Kernel construction and inference done in %s seconds.' % duration)

# Print out accuracy and loss for infinite network predictions.
loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat) ** 2)
loss = lambda fx, y_hat: 0.5 * jnp.mean((fx - y_hat) ** 2)
util.print_summary('NNGP test', y_test, fx_test_nngp, None, loss)
util.print_summary('NTK test', y_test, fx_test_ntk, None, loss)

Expand Down
6 changes: 3 additions & 3 deletions examples/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
"""


import jax.numpy as np
import jax.numpy as jnp


def _accuracy(y, y_hat):
"""Compute the accuracy of the predictions with respect to one-hot labels."""
return np.mean(np.argmax(y, axis=1) == np.argmax(y_hat, axis=1))
return jnp.mean(jnp.argmax(y, axis=1) == jnp.argmax(y_hat, axis=1))


def print_summary(name, labels, net_p, lin_p, loss):
Expand All @@ -34,5 +34,5 @@ def print_summary(name, labels, net_p, lin_p, loss):
print('Linearization Accuracy = {}'.format(_accuracy(lin_p, labels)))
print('Linearization Loss = {}'.format(loss(lin_p, labels)))
print('RMSE of predictions: {}'.format(
np.sqrt(np.mean((net_p - lin_p) ** 2))))
jnp.sqrt(jnp.mean((net_p - lin_p) ** 2))))
print('---------------------------------------')
4 changes: 2 additions & 2 deletions examples/weight_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from jax import random
from jax.example_libraries import optimizers
from jax.nn import log_softmax
import jax.numpy as np
import jax.numpy as jnp
import neural_tangents as nt
from neural_tangents import stax
from examples import datasets
Expand Down Expand Up @@ -66,7 +66,7 @@ def main(unused_argv):
state_lin = opt_init(params)

# Create a cross-entropy loss function.
loss = lambda fx, y_hat: -np.mean(log_softmax(fx) * y_hat)
loss = lambda fx, y_hat: -jnp.mean(log_softmax(fx) * y_hat)

# Specialize the loss function to compute gradients for both linearized and
# full networks.
Expand Down
2 changes: 1 addition & 1 deletion neural_tangents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""Public Neural Tangents modules and functions."""


__version__ = '0.6.4'
__version__ = '0.6.5'

from . import experimental
from . import predict
Expand Down
Loading

0 comments on commit 3c3dc9f

Please sign in to comment.