Skip to content

martiningram/jax_advi

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

20 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

JAX ADVI

What is this?

JAX ADVI is a small library designed to do black-box variational inference using automatic differentiation variational inference (ADVI) using JAX. Specifically, it implements the variant proposed by Giordano et al. which allows the use of (approximate) second order methods, which can be more robust. You can read about the methodology on my blog.

How do I use it?

Installation

To use this library, you'll need JAX. If you want to run the examples, you'll also need:

  • Pandas
  • Scikit-learn
  • Matplotlib
  • Pystan

To install, clone this repository and run python setup.py develop.

Usage

The key function is optimize_advi_mean_field. Here's a simple example of a logistic regression:

Model:

log_reg_model

Code:

from jax_advi.advi import optimize_advi_mean_field
from jax import jit
from jax.scipy.stats import norm
from functools import partial
from jax.nn import log_sigmoid

# Define parameter shapes:
theta_shapes = {
    'beta': (K),
    'gamma': ()
}

# Define a function to calculate the log likelihood
def calculate_likelihood(theta, X, y):
    
    logit_prob = X @ theta['beta'] + theta['gamma']
    
    prob_pres = log_sigmoid(logit_prob)
    prob_abs = log_sigmoid(-logit_prob)
    
    return jnp.sum(y * prob_pres + (1 - y) * prob_abs)

# Define a function to calculate the log prior
def calculate_prior(theta):
    
    beta_prior = jnp.sum(norm.logpdf(theta['beta']))
    gamma_prior = jnp.sum(norm.logpdf(theta['gamma']))
    
    return beta_prior + gamma_prior
	
# The partial application basically conditions on the data (not defined in this
# little snippet)
log_lik_fun = jit(partial(calculate_likelihood, X=X, y=y))
log_prior_fun = jit(calculate_prior)

# Call the optimisation function
result = optimize_advi_mean_field(theta_shapes, log_prior_fun, log_lik_fun, n_draws=None)

JAX ADVI should typically get the means right, but the variances may be off. For a full example comparing against Stan's HMC, see the example notebook. Here is a comparison for some toy data, taken from that example:

comparison As promised, good means, less reliable standard deviations.

Example 2

A more complex example is the following hierarchical model:

bradley_terry

You can read a bit more about it in the blog post if you like. Here's what the code looks like:

Define shapes:

theta_shapes = {
    'player_skills': (n_p),
    'skill_prior_sd': ()
}

Constrain the prior standard deviation to be positive:

from jax_advi.constraints import constrain_positive

theta_constraints = {
    'skill_prior_sd': constrain_positive
}

Define the log likelihood and prior:

from jax.scipy.stats import norm
from jax import jit
from jax.nn import log_sigmoid
import jax.numpy as jnp

@jit
def log_prior_fun(theta):
    
    # Prior
    skill_prior = jnp.sum(norm.logpdf(theta['player_skills'], 0., theta['skill_prior_sd']))
    
    # hyperpriors
    hyper_sd = norm.logpdf(theta['skill_prior_sd'])
    
    return skill_prior + hyper_sd

def log_lik_fun(theta, winner_ids, loser_ids):
    
    logit_probs = theta['player_skills'][winner_ids] - theta['player_skills'][loser_ids]
    
    return jnp.sum(log_sigmoid(logit_probs))

from functools import partial

curried_lik = jit(partial(log_lik_fun, winner_ids=winner_ids, loser_ids=loser_ids))

Finally, optimize:

result = optimize_advi_mean_field(
	theta_shapes, log_prior_fun, curried_lik, 
	constrain_fun_dict=theta_constraints, verbose=True, M=100)

You can run the example and compare it against Stan using the example notebook. Again, the means are good and the variances are a bit off:

tennis_comparison The fit took 10s on my laptop with a GTX 2070, compared to 100 minutes in Stan, so you do get quite a speedup. It's also more accurate than Stan's ADVI, which gives the following means:

stan_advi I hope you find this library useful. Please raise issues if anything doesn't work. Please note that this is still a new package and I wouldn't trusting it blindly yet. Its mean estimates seem reliable so far, but I recommend checking them against Stan's (or another MCMC framework) to be sure. If you find examples that break it, I'd be very interested to see them.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages