Skip to content

Commit

Permalink
more specialized routes for log-prob
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes committed Aug 27, 2022
1 parent 36185dd commit a8963d3
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 6 deletions.
5 changes: 5 additions & 0 deletions src/approxcdf.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ void truncate_bvn_2by2block(const double mu1, const double mu2,
const double t1, const double t2,
double &restrict mu1_out, double &restrict mu2_out,
double &restrict v1_out, double &restrict v2_out, double &restrict cv_out);
void truncate_logbvn_2by2block(const double mu1, const double mu2,
const double v1, const double v2, const double cv,
const double t1, const double t2,
double &restrict mu1_out, double &restrict mu2_out,
double &restrict v1_out, double &restrict v2_out, double &restrict cv_out);
APPROXCDF_EXPORTED
double norm_cdf_tvbs
(
Expand Down
62 changes: 61 additions & 1 deletion src/bhat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,66 @@ void bv_trunc_std4d_loweronly(const double rho[6], const double tp[2],
#endif
}

static inline
void logbv_trunc_std4d_loweronly(const double rho[6], const double tp[2],
double *restrict mu_out, double *restrict Omega_out)
{
double detV11 = std::fma(-rho[0], rho[0], 1.);
double invV11v = 1. / detV11;
double invV11d = -rho[0] / detV11;
#ifdef REGULARIZE_BHAT
double reg = 1e-5;
double d = 1.;
while (detV11 <= 1e-2) {
d += reg;
detV11 = std::fma(-rho[0], rho[0], d*d);
invV11v = d / detV11;
invV11d = -rho[0] / detV11;
reg *= 1.5;
}
#endif

double Omega11[3];
double mu_half[2];
truncate_logbvn_2by2block(0., 0., 1., 1., rho[0], tp[0], tp[1],
mu_half[0], mu_half[1],
Omega11[0], Omega11[1], Omega11[2]);

mu_out[0] = (invV11v*rho[1] + invV11d*rho[3]) * (mu_half[0]) +
(invV11d*rho[1] + invV11v*rho[3]) * (mu_half[1]);
mu_out[1] = (invV11v*rho[2] + invV11d*rho[4]) * (mu_half[0]) +
(invV11d*rho[2] + invV11v*rho[4]) * (mu_half[1]);

double Omega11_invV11[] = {
Omega11[0]*invV11v + Omega11[2]*invV11d, Omega11[0]*invV11d + Omega11[2]*invV11v,
Omega11[2]*invV11v + Omega11[1]*invV11d, Omega11[2]*invV11d + Omega11[1]*invV11v
};
/* O12 */
double O12[] = {
Omega11_invV11[0]*rho[1] + Omega11_invV11[1]*rho[3], Omega11_invV11[0]*rho[2] + Omega11_invV11[1]*rho[4],
Omega11_invV11[2]*rho[1] + Omega11_invV11[3]*rho[3], Omega11_invV11[2]*rho[2] + Omega11_invV11[3]*rho[4]
};
/* V12 - O12 */
double temp1[] = {
rho[1] - O12[0], rho[2] - O12[1],
rho[3] - O12[2], rho[4] - O12[3]
};
/* iV11 * (V12 - O12) */
double temp2[] = {
invV11v*temp1[0] + invV11d*temp1[2], invV11v*temp1[1] + invV11d*temp1[3],
invV11d*temp1[0] + invV11v*temp1[2], invV11d*temp1[1] + invV11v*temp1[3]
};
/* V22 - V21 * (iV11 * (V12 - O12)) */
Omega_out[0] = 1. - rho[1]*temp2[0] - rho[3]*temp2[2];
Omega_out[1] = 1. - rho[2]*temp2[1] - rho[4]*temp2[3];
Omega_out[2] = rho[5] - rho[1]*temp2[1] - rho[3]*temp2[3];

#ifdef REGULARIZE_BHAT
Omega_out[0] = std::fmax(Omega_out[0], 0.005);
Omega_out[1] = std::fmax(Omega_out[1], 0.005);
#endif
}

/* This is an adaptation of 'cdfqvn'
Note: "Bhat" is the author's surname.
Expand Down Expand Up @@ -417,7 +477,7 @@ double norm_logcdf_4d_internal(const double x[4], const double rho[6])
{
double mu[2];
double Sigma[3];
bv_trunc_std4d_loweronly(rho, x, mu, Sigma);
logbv_trunc_std4d_loweronly(rho, x, mu, Sigma);

double p1 = norm_logcdf_1d((x[2] - mu[0]) / std::sqrt(Sigma[0]));
double p2 = norm_logcdf_2d(
Expand Down
101 changes: 96 additions & 5 deletions src/tvbs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,89 @@ void truncate_bvn_2by2block(const double mu1, const double mu2,
cv_out = s1 * s2 * orho;
}

void truncate_logbvn_2by2block(const double mu1, const double mu2,
const double v1, const double v2, const double cv,
const double t1, const double t2,
double &restrict mu1_out, double &restrict mu2_out,
double &restrict v1_out, double &restrict v2_out, double &restrict cv_out)
{
double s1 = std::sqrt(v1);
double s2 = std::sqrt(v2);
s1 = std::fmax(s1, 1e-8);
s2 = std::fmax(s2, 1e-8);
double ntp1 = (t1 - mu1) / s1;
double ntp2 = (t2 - mu2) / s2;
double rho = cv / (s1 * s2);

double logp = norm_logcdf_2d(ntp1, ntp2, rho);
double rhotilde = std::sqrt(std::fma(-rho, rho, 1.));
rhotilde = std::fmax(rhotilde, 1e-16);
double tr1 = std::fma(-rho, ntp2, ntp1) / rhotilde;
double tr2 = std::fma(-rho, ntp1, ntp2) / rhotilde;
double logpd1 = norm_logpdf_1d(ntp1);
double logpd2 = norm_logpdf_1d(ntp2);
double logcd1 = norm_logcdf_1d(tr1);
double logcd2 = norm_logcdf_1d(tr2);

double log_pd1_cd2 = logpd1 + logcd2;
double log_pd2_cd1 = logpd2 + logcd1;
double log_pdf_tr1 = norm_logpdf_1d(tr1);
double log_pdf_tr2 = norm_logpdf_1d(tr2);

double log_rho = std::log(std::fabs(rho));
double sign_rho = (rho >= 0.)? 1. : -1.;
double log_rho_pd2_cd1 = log_rho + log_pd2_cd1;
double log_rho_pd1_cd2 = log_rho + log_pd1_cd2;

double temp1 = sign_rho * std::exp(log_rho_pd2_cd1 - log_pd1_cd2);
double temp2 = sign_rho * std::exp(log_rho_pd1_cd2 - log_pd2_cd1);
double log_m1, log_m2;
if (temp1 > -1.) {
log_m1 = log_pd1_cd2 + std::log1p(temp1) - logp;
}
else {
log_m1 = log_pd1_cd2 - logp;
}
if (temp2 > -1.) {
log_m2 = log_pd2_cd1 + std::log1p(temp2) - logp;
}
else {
log_m2 = log_pd2_cd1 - logp;
}

double sign_ntp1 = (ntp1 >= 0.)? 1. : -1.;
double sign_ntp2 = (ntp2 >= 0.)? 1. : -1.;
double log_ntp1 = std::log(std::fabs(ntp1));
double log_ntp2 = std::log(std::fabs(ntp2));
double log_rhotilde = std::log(rhotilde);

double os1 = 1. - (
sign_ntp1 * std::exp(log_ntp1 + log_pd1_cd2 - logp)
+ sign_ntp2 * std::exp(log_ntp2 + 2. * log_rho + log_pd2_cd1 - logp)
- sign_rho * std::exp(log_rhotilde + log_rho + logpd2 + log_pdf_tr1 - logp)
) - std::exp(2. * log_m1);
double os2 = 1. - (
sign_ntp2 * std::exp(log_ntp2 + log_pd2_cd1 - logp)
+ sign_ntp1 * std::exp(log_ntp1 + 2. * log_rho + log_pd1_cd2 - logp)
- sign_rho * std::exp(log_rhotilde + log_rho + logpd1 + log_pdf_tr2 - logp)
) - std::exp(2. * log_m2);
double orho = rho * (
1.
- sign_ntp1 * std::exp(log_ntp1 + log_pd1_cd2 - logp)
- sign_ntp2 * std::exp(log_ntp2 + log_pd2_cd1 - logp)
) + std::exp(log_rhotilde + logpd1 + log_pdf_tr2 - logp)
- std::exp(log_m1 + log_m2);

mu1_out = std::fma(-std::exp(log_m1), s1, mu1);
mu2_out = std::fma(-std::exp(log_m2), s2, mu2);
v1_out = v1 * os1;
v2_out = v2 * os2;
cv_out = s1 * s2 * orho;

v1_out = std::fmax(v1_out, std::numeric_limits<double>::min());
v2_out = std::fmax(v1_out, std::numeric_limits<double>::min());
}

double norm_cdf_nd_tvbs_internal
(
double *restrict x_reordered,
Expand Down Expand Up @@ -232,11 +315,19 @@ double norm_cdf_nd_tvbs_internal
}
}

truncate_bvn_2by2block(mu_trunc[2*step], mu_trunc[2*step + 1],
D[2*step*(n+1)], D[(2*step+1)*(n+1)], D[2*step*(n+1) + 1],
x_reordered[2*step], x_reordered[2*step + 1],
bvn_trunc_mu[0], bvn_trunc_mu[1],
bvn_trunc_cv[0], bvn_trunc_cv[1], bvn_trunc_cv[2]);
if (likely(!logp)) {
truncate_bvn_2by2block(mu_trunc[2*step], mu_trunc[2*step + 1],
D[2*step*(n+1)], D[(2*step+1)*(n+1)], D[2*step*(n+1) + 1],
x_reordered[2*step], x_reordered[2*step + 1],
bvn_trunc_mu[0], bvn_trunc_mu[1],
bvn_trunc_cv[0], bvn_trunc_cv[1], bvn_trunc_cv[2]);
} else {
truncate_logbvn_2by2block(mu_trunc[2*step], mu_trunc[2*step + 1],
D[2*step*(n+1)], D[(2*step+1)*(n+1)], D[2*step*(n+1) + 1],
x_reordered[2*step], x_reordered[2*step + 1],
bvn_trunc_mu[0], bvn_trunc_mu[1],
bvn_trunc_cv[0], bvn_trunc_cv[1], bvn_trunc_cv[2]);
}

bvn_trunc_mu[0] -= mu_trunc[2*step];
bvn_trunc_mu[1] -= mu_trunc[2*step + 1];
Expand Down

0 comments on commit a8963d3

Please sign in to comment.