Skip to content

Commit

Permalink
Variational inference with PyMC (#1306)
Browse files Browse the repository at this point in the history
* variational inference fit

* remove variational from sample

* make pymc object accessible

* save as McmcPtResult

* tests added

* add warning in write_result()
  • Loading branch information
arrjon committed May 21, 2024
1 parent 7fe40ba commit fefefd5
Show file tree
Hide file tree
Showing 6 changed files with 420 additions and 0 deletions.
6 changes: 6 additions & 0 deletions pypesto/store/save_to_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,9 @@ def write_result(
if sample:
pypesto_sample_writer = SamplingResultHDF5Writer(filename)
pypesto_sample_writer.write(result, overwrite=overwrite)

if hasattr(result, "variational_result"):
logger.warning(
"Results from variational inference are not saved in the hdf5 file. "
"You have to save them manually."
)
9 changes: 9 additions & 0 deletions pypesto/variational/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
Variational inference
======
Find the best variational approximation in a given family to a distribution from which we can sample.
"""

from .pymc import PymcVariational
from .variational_inference import variational_fit
196 changes: 196 additions & 0 deletions pypesto/variational/pymc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
"""Pymc v4 Sampler for Variational Inference."""

import logging
from typing import Optional

import numpy as np
import pytensor.tensor as pt
from scipy import stats

from ..objective import FD
from ..result import McmcPtResult
from ..sample.pymc import PymcObjectiveOp, PymcSampler
from ..sample.sampler import SamplerImportError

logger = logging.getLogger(__name__)


# implementation based on the pymc sampler code in pypesto and:
# https://www.pymc.io/projects/examples/en/latest/variational_inference/variational_api_quickstart.html


class PymcVariational(PymcSampler):
"""Wrapper around Pymc v4 variational inference.
Parameters
----------
step_function:
A pymc step function, e.g. NUTS, Slice. If not specified, pymc
determines one automatically (preferable).
**kwargs:
Options are directly passed on to `pymc.fit`.
"""

def fit(
self,
n_iterations: int,
method: str = "advi",
random_seed: Optional[int] = None,
start_sigma: Optional = None,
inf_kwargs: Optional = None,
beta: float = 1.0,
**kwargs,
):
"""
Sample the problem.
Parameters
----------
n_iterations:
Number of iterations.
method: str or :class:`Inference` of pymc
string name is case-insensitive in:
- 'advi' for ADVI
- 'fullrank_advi' for FullRankADVI
- 'svgd' for Stein Variational Gradient Descent
- 'asvgd' for Amortized Stein Variational Gradient Descent
random_seed: int
random seed for reproducibility
start_sigma: `dict[str, np.ndarray]`
starting standard deviation for inference, only available for method 'advi'
inf_kwargs: dict
additional kwargs passed to pymc.Inference
beta:
Inverse temperature (e.g. in parallel tempering).
"""
try:
import pymc
except ImportError:
raise SamplerImportError("pymc") from None

problem = self.problem
if not problem.objective.has_grad:
logger.info(
"The objective function does not provide gradients. "
"Finite differences will be used."
)
problem.objective = FD(obj=problem.objective)
log_post = PymcObjectiveOp.create_instance(problem.objective, beta)

x0 = None
x_names_free = problem.get_reduced_vector(problem.x_names)
if self.x0 is not None:
x0 = {
x_name: val
for x_name, val in zip(problem.x_names, self.x0)
if x_name in x_names_free
}

# create model context
with pymc.Model():
# parameter bounds as uniform prior
_k = [
pymc.Uniform(x_name, lower=lb, upper=ub)
for x_name, lb, ub in zip(
x_names_free,
problem.lb,
problem.ub,
)
]

# convert parameters to PyTensor tensor variable
theta = pt.as_tensor_variable(_k)

# define distribution with log-posterior as density
pymc.Potential("potential", log_post(theta))

# record function values
pymc.Deterministic("loggyposty", log_post(theta))

# perform the actual sampling
data = pymc.fit(
n=int(n_iterations),
method=method,
random_seed=random_seed,
start=x0,
start_sigma=start_sigma,
inf_kwargs=inf_kwargs,
**kwargs,
)

self.data = data

def sample(self, n_samples: int, beta: float = 1.0) -> McmcPtResult:
"""
Sample from the variational approximation and return McmcPtResult object.
Parameters
----------
n_samples:
Number of samples to be computed.
"""
# get InferenceData object
pymc_data = self.data.sample(n_samples)
x_names_free = self.problem.get_reduced_vector(self.problem.x_names)
post_samples = np.concatenate(
[pymc_data.posterior[name].values for name in x_names_free]
).T
return McmcPtResult(
trace_x=post_samples[np.newaxis, :],
trace_neglogpost=pymc_data.posterior.loggyposty.values,
trace_neglogprior=np.full(
pymc_data.posterior.loggyposty.values.shape, np.nan
),
betas=np.array([1.0] * post_samples.shape[0]),
burn_in=0,
auto_correlation=0,
effective_sample_size=n_samples,
message="variational inference results",
)

def get_variational_parameters(self) -> (list, list):
"""Get the internal pymc variational parameters."""
return (
[param.name for param in self.data.params],
[param.eval() for param in self.data.params],
)

def set_variational_parameters(self, param_list: list):
"""
Set the internal pymc variational parameters.
Parameters
----------
param_list:
List of tuples of the form (param_name, param_value).
"""
if len(param_list) != len(self.data.params):
raise ValueError(
"The number of parameters does not match the number of variational parameters."
)
for i, param in enumerate(param_list):
self.data.params[i].set_value(param)

def eval_variational_log_density(self, x: np.ndarray) -> np.ndarray:
"""
Evaluate the log density of the variational approximation at x_points.
Parameters
----------
x:
The points at which to evaluate the log density.
"""
# TODO: add support for other methods
logger.warning(
"currently only supports the methods `advi` and `fullrank_advi`"
)

if x.ndim == 1:
x = x.reshape(1, -1)
log_density_at_points = np.zeros_like(x)
for i, point in enumerate(x):
log_density_at_points[i] = stats.multivariate_normal.logpdf(
point, mean=self.data.mean.eval(), cov=self.data.cov.eval()
)
vi_log_density = np.sum(log_density_at_points, axis=-1)
return vi_log_density
136 changes: 136 additions & 0 deletions pypesto/variational/variational_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""Functions for variational inference accessible to the user. Currently only pymc is supported."""

import logging
from time import process_time
from typing import Callable, List, Optional, Union

import numpy as np

from ..problem import Problem
from ..result import Result
from ..sample.util import bound_n_samples_from_env
from ..store import autosave
from .pymc import PymcVariational

logger = logging.getLogger(__name__)


def variational_fit(
problem: Problem,
n_iterations: int,
method: str = "advi",
n_samples: Optional[int] = None,
random_seed: Optional[int] = None,
start_sigma: Optional[dict[str, np.ndarray]] = None,
x0: Union[np.ndarray, List[np.ndarray]] = None,
result: Result = None,
filename: Union[str, Callable, None] = None,
overwrite: bool = False,
**kwargs,
) -> Result:
"""
Call to do parameter sampling.
Parameters
----------
problem:
The problem to be solved. If None is provided, a
:class:`pypesto.AdaptiveMetropolisSampler` is used.
n_iterations:
Number of iterations for the optimization.
method: str or :class:`Inference` of pymc (only interface currently supported)
string name is case-insensitive in:
- 'advi' for ADVI
- 'fullrank_advi' for FullRankADVI
- 'svgd' for Stein Variational Gradient Descent
- 'asvgd' for Amortized Stein Variational Gradient Descent
n_samples:
Number of samples to generate after optimization.
random_seed: int
random seed for reproducibility
start_sigma: `dict[str, np.ndarray]`
starting standard deviation for inference, only available for method 'advi'
x0:
Initial parameter for the variational optimization. If None, the best parameter
found in optimization is used.
result:
A result to write to. If None provided, one is created from the
problem.
filename:
Name of the hdf5 file, where the result will be saved. Default is
None, which deactivates automatic saving. If set to
"Auto" it will automatically generate a file named
`year_month_day_profiling_result.hdf5`.
Optionally a method, see docs for `pypesto.store.auto.autosave`.
overwrite:
Whether to overwrite `result/sampling` in the autosave file
if it already exists.
Returns
-------
result:
A result with filled in sample_options part.
"""
# prepare result object
if result is None:
result = Result(problem)

# number of samples
if n_iterations is not None:
n_iterations = bound_n_samples_from_env(n_iterations)

# try to find initial parameters
if x0 is None:
result.optimize_result.sort()
if len(result.optimize_result.list) > 0:
x0 = problem.get_reduced_vector(
result.optimize_result.list[0]["x"]
)

# set variational inference
# currently we only support pymc
variational = PymcVariational()

# initialize sampler to problem
variational.initialize(problem=problem, x0=x0)

# perform the sampling and track time
t_start = process_time()
variational.fit(
n_iterations=n_iterations,
method=method,
random_seed=random_seed,
start_sigma=start_sigma,
**kwargs,
)
t_elapsed = process_time() - t_start
logger.info("Elapsed time: " + str(t_elapsed))

# extract results and save samples to pypesto result
if n_samples is None or n_samples == 0:
# constructing a McmcPtResult object with nearly empty trace_x
n_samples = 1

result.sample_result = variational.sample(n_samples)
result.sample_result.time = t_elapsed

autosave(
filename=filename,
result=result,
store_type="sample",
overwrite=overwrite,
)

# make pymc object available in result
# TODO: if needed, we can add a result object for variational inference methods
result.variational_result = variational
(
result.sample_result.variational_parameters_names,
result.sample_result.variational_parameters,
) = variational.get_variational_parameters()
if filename is not None:
logger.warning(
"Variational parameters are not saved in the hdf5 file. You have to save them manually."
)

return result
1 change: 1 addition & 0 deletions test/variational/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Variational inference tests."""
Loading

0 comments on commit fefefd5

Please sign in to comment.