Skip to content

Cambridge-ICCS/GPJax

 
 

Repository files navigation

GPJax's logo

codecov CodeFactor Netlify Status PyPI version DOI Downloads Slack Invite

Quickstart | Install guide | Documentation | Slack Community

GPJax aims to provide a low-level interface to Gaussian process (GP) models in Jax, structured to give researchers maximum flexibility in extending the code to suit their own needs. The idea is that the code should be as close as possible to the maths we write on paper when working with GP models.

Package support

GPJax was founded by Thomas Pinder. Today, the maintenance of GPJax is undertaken by Thomas Pinder and Daniel Dodd.

We would be delighted to receive contributions from interested individuals and groups. To learn how you can get involved, please read our guide for contributing. If you have any questions, we encourage you to open an issue. For broader conversations, such as best GP fitting practices or questions about the mathematics of GPs, we invite you to open a discussion.

Feel free to join our Slack Channel, where we can discuss the development of GPJax and broader support for Gaussian process modelling.

Supported methods and interfaces

Notebook examples

Guides for customisation

Conversion between .ipynb and .py

Above examples are stored in examples directory in the double percent (py:percent) format. Checkout jupytext using-cli for more info.

  • To convert example.py to example.ipynb, run:
jupytext --to notebook example.py
  • To convert example.ipynb to example.py, run:
jupytext --to py:percent example.ipynb

Simple example

Let us import some dependencies and simulate a toy dataset $\mathcal{D}$.

import gpjax as gpx
from jax import grad, jit
import jax.numpy as jnp
import jax.random as jr
import optax as ox

key = jr.PRNGKey(123)

f = lambda x: 10 * jnp.sin(x)

n = 50
x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,1)).sort()
y = f(x) + jr.normal(key, shape=(n,1))
D = gpx.Dataset(X=x, y=y)

# Construct the prior
meanf = gpx.mean_functions.Zero()
kernel = gpx.kernels.RBF()
prior = gpx.Prior(mean_function=meanf, kernel = kernel)

# Define a likelihood
likelihood = gpx.Gaussian(num_datapoints = n)

# Construct the posterior
posterior = prior * likelihood

# Define an optimiser
optimiser = ox.adam(learning_rate=1e-2)

# Define the marginal log-likelihood
negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True))

# Obtain Type 2 MLEs of the hyperparameters
opt_posterior, history = gpx.fit(
    model=posterior,
    objective=negative_mll,
    train_data=D,
    optim=optimiser,
    num_iters=500,
    safe=True,
    key=key,
)

# Infer the predictive posterior distribution
xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)
latent_dist = opt_posterior(xtest, D)
predictive_dist = opt_posterior.likelihood(latent_dist)

# Obtain the predictive mean and standard deviation
pred_mean = predictive_dist.mean()
pred_std = predictive_dist.stddev()

Installation

Stable version

The latest stable version of GPJax can be installed via pip:

pip install gpjax

Note

We recommend you check your installation version:

python -c 'import gpjax; print(gpjax.__version__)'

Development version

Warning

This version is possibly unstable and may contain bugs.

Note

We advise you create virtual environment before installing:

conda create -n gpjax_experimental python=3.10.0
conda activate gpjax_experimental

Clone a copy of the repository to your local machine and run the setup configuration in development mode.

git clone https://github.com/JaxGaussianProcesses/GPJax.git
cd GPJax
poetry install

We recommend you check your installation passes the supplied unit tests:

poetry run pytest

Citing GPJax

If you use GPJax in your research, please cite our JOSS paper.

@article{Pinder2022,
  doi = {10.21105/joss.04455},
  url = {https://doi.org/10.21105/joss.04455},
  year = {2022},
  publisher = {The Open Journal},
  volume = {7},
  number = {75},
  pages = {4455},
  author = {Thomas Pinder and Daniel Dodd},
  title = {GPJax: A Gaussian Process Framework in JAX},
  journal = {Journal of Open Source Software}
}

Releases

No releases published

Packages

No packages published

Languages

  • Python 97.1%
  • TeX 2.7%
  • Makefile 0.2%