Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Planning for handling/tracking of deterministic nodes #168

Closed
damonbayer opened this issue Jun 7, 2024 · 1 comment
Closed

Planning for handling/tracking of deterministic nodes #168

damonbayer opened this issue Jun 7, 2024 · 1 comment

Comments

@damonbayer
Copy link
Collaborator

damonbayer commented Jun 7, 2024

So far, when we have wanted to track a "generated quantity" (a quantity which is not sampled directly, but is depends on quantities that are sampled), we have littered numpyro.Deterministic throughout the model code. This can be confusing because the name supplied in numpyro.Deterministic may not correspond to the variable with that name in other parts of the code. E.g. there may be a numpyro site called Rt but later in the model the variable Rt is padded. This padded Rt will not be present in the posterior samples, even though it would be easier to use in post processing than the unpadded Rt.

Perhapes @dylanhmorris can fill us in if there is a more "correct" way to achieve this, but I propose adding a generated_quantities flag to the model arguments, collecting all of the numpyro.Deterministic random variables at the end of the model, and only calling numpyro.Deterministic if generated_quantities = True.

Ex:

import numpy as np
import jax.numpy as jnp
import jax
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
import matplotlib.pyplot as plt

# Define the model
def linear_regression(X, y=None, gq=False):
    # Priors for unknown parameters
    alpha = numpyro.sample('alpha', dist.Normal(0, 10))
    beta = numpyro.sample('beta', dist.Normal(0, 10))
    sigma = numpyro.sample('sigma', dist.Exponential(1))

    # Linear model
    mean = alpha + beta * X

    # Likelihood (sampling distribution) of observations
    with numpyro.plate('data', X.shape[0]):
        obs = numpyro.sample('obs', dist.Normal(mean, sigma), obs=y)
    if gq:
        numpyro.deterministic('mean', mean)
    return mean

# Generate synthetic data
np.random.seed(0)
N = 100
X = np.random.randn(N)
y = 1.0 + 2.0 * X + np.random.normal(0, 1.0, size=N)

# Define the MCMC model
nuts_kernel = NUTS(linear_regression)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), X, y)
posterior_samples = mcmc.get_samples()
posterior_samples.keys()
# Posterior predictive sampling
predictions = predictive(jax.random.PRNGKey(1), X, gq=True)

This ensures that the MCMC objects remain lean and useful for diagnostics, while allowing us to produce generated_quantites in a centralized location.

@damonbayer damonbayer changed the title Revise throughout to use Predictive Revise throughout to be compatible with Predictive Jun 7, 2024
@damonbayer
Copy link
Collaborator Author

damonbayer commented Jun 10, 2024

After discussion with @dylanhmorris, we think it would be better for now to continue tracking deterministic random variables as we do now (or make them DeterministicVariables), but the models should be revised to actually track them.

@damonbayer damonbayer added this to the L Sprint milestone Jun 10, 2024
@dylanhmorris dylanhmorris changed the title Revise throughout to be compatible with Predictive Planning for handling/tracking of deterministic nodes Jun 11, 2024
@damonbayer damonbayer modified the milestones: 🐺 Lycorhinus, M Sprint Jun 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

When branches are created from issues, their pull requests are automatically linked.

2 participants