Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FCIT #315

Merged
merged 59 commits into from
May 13, 2022
Merged

FCIT #315

Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
992a2b4
Create smoothCF.py
MatthewZhao26 Oct 21, 2021
c4b8b6c
Added mean_embedding
MatthewZhao26 Oct 21, 2021
54bed36
updated inheritance from IndependenceTest
MatthewZhao26 Oct 28, 2021
6dd8f4d
Updated SmoothCF and ME tests to match author code output
MatthewZhao26 Nov 3, 2021
4b25431
Updated SmoothCF and ME inheritance, started editing tutorial
MatthewZhao26 Nov 11, 2021
beb4ecb
Updated documentation
MatthewZhao26 Nov 18, 2021
1f2ad30
Enabled numba, added docstrings
MatthewZhao26 Nov 30, 2021
ca8bdd1
Minor change
MatthewZhao26 Nov 30, 2021
1927935
Merge branch 'main' into fast_tstest
PSSF23 Dec 2, 2021
a789fa8
Response to initial PR comments
MatthewZhao26 Dec 10, 2021
0d57d2a
Merge branch 'fast_tstest' of https://github.com/MatthewZhao26/hyppo …
MatthewZhao26 Dec 10, 2021
37b859e
Added random state to SmoothCF
MatthewZhao26 Dec 12, 2021
5e076d6
Merge branch 'dev' into fast_tstest
sampan501 Dec 12, 2021
e5d9cbf
ran black
MatthewZhao26 Dec 13, 2021
f35f7b3
ran black again
MatthewZhao26 Dec 13, 2021
84f75c6
remove blank
MatthewZhao26 Dec 13, 2021
1594f6d
quotes
MatthewZhao26 Dec 13, 2021
bbfa96a
Changed example
MatthewZhao26 Dec 13, 2021
398a408
black
MatthewZhao26 Dec 13, 2021
e0659a4
deleted
MatthewZhao26 Dec 13, 2021
002deef
black + more changes
MatthewZhao26 Dec 13, 2021
fa4ad44
random_state added to ME
MatthewZhao26 Dec 13, 2021
58352a3
coverage
MatthewZhao26 Dec 13, 2021
2f021c7
tests
MatthewZhao26 Dec 13, 2021
daf006a
black
MatthewZhao26 Dec 13, 2021
27638af
more tests
MatthewZhao26 Dec 13, 2021
e7df8ae
tests
MatthewZhao26 Dec 13, 2021
70e1b24
reformatting
MatthewZhao26 Dec 13, 2021
c685597
reformat
MatthewZhao26 Dec 13, 2021
c2348c0
reformat
MatthewZhao26 Dec 13, 2021
63a67a1
Merge branch 'dev' into fast_tstest
sampan501 Dec 13, 2021
797a62f
added journal to citation (Fast 2-sample)
MatthewZhao26 Dec 13, 2021
1127f71
Merge branch 'fast_tstest' of https://github.com/MatthewZhao26/hyppo …
MatthewZhao26 Dec 13, 2021
bf0d7d3
corrected docstring math format
MatthewZhao26 Dec 13, 2021
1ec4696
formatting update, helper docstring update
MatthewZhao26 Dec 15, 2021
e7c56d0
Merge branch 'dev' into fast_tstest
sampan501 Dec 17, 2021
8c0c98d
Formatting changes, renaming of helpers, moved random_state
MatthewZhao26 Dec 17, 2021
5d52c84
Merge branch 'fast_tstest' of https://github.com/MatthewZhao26/hyppo …
MatthewZhao26 Dec 17, 2021
3adbf35
fix tabbing issue in ksample
MatthewZhao26 Dec 17, 2021
ea2a1f6
added __init__ variable explanantions to tutorial
MatthewZhao26 Dec 18, 2021
db653fe
minor formatting change to tutorial
MatthewZhao26 Dec 18, 2021
cfb6dcb
fix tutorial rendering issue
sampan501 Dec 20, 2021
dd9587c
Create FCIT.py
MatthewZhao26 Apr 7, 2022
b850800
Creating new conditional independence module
MatthewZhao26 Apr 11, 2022
7b8ceb7
Update cond'l ind module + bib
MatthewZhao26 Apr 12, 2022
71001f6
updated index.rst
MatthewZhao26 Apr 13, 2022
358a459
updated config.yml
MatthewZhao26 Apr 13, 2022
6acfd56
added tests, updated hyppo init
MatthewZhao26 Apr 18, 2022
19174a0
updated tests FCIT
MatthewZhao26 Apr 18, 2022
6c57668
Merge branch 'dev' into fast_tstest
MatthewZhao26 Apr 21, 2022
6b5e095
Update FCIT.py
MatthewZhao26 Apr 21, 2022
2ef7143
added FCIT documentation
MatthewZhao26 May 5, 2022
e8b719f
Merge branch 'dev' into pr/315
sampan501 May 5, 2022
a6f4d6b
black update, documentation update FCIT
MatthewZhao26 May 11, 2022
bff8d53
Merge branch 'fast_tstest' of https://github.com/MatthewZhao26/hyppo …
MatthewZhao26 May 11, 2022
f92697d
added tutorial
MatthewZhao26 May 13, 2022
d04742c
Update conditional.py
MatthewZhao26 May 13, 2022
dff6ceb
rst edits
MatthewZhao26 May 13, 2022
0dd9df6
minor doc changes
MatthewZhao26 May 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ jobs:
workflows:
run-tests:
jobs:
- build-and-test
- build-and-test
12 changes: 12 additions & 0 deletions docs/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ Discriminability
DiscrimOneSample
DiscrimTwoSample

.. automodule:: hyppo.conditional_independence

.. currentmodule:: hyppo.conditional_independence


Conditional Independence
_________________________
sampan501 marked this conversation as resolved.
Show resolved Hide resolved

.. autosummary::
:toctree: generated/

FCIT


.. automodule:: hyppo.kgof
Expand Down
12 changes: 12 additions & 0 deletions docs/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,18 @@ @article{friedmanMultivariateGeneralizationsoftheWaldWolfowitzandSmirnovTwoSampl
annotation = {\_eprint: https://doi.org/10.1214/aos/1176344722},
}

@article{chalupka2018FastConditionalIndependence,
title={Fast Conditional Independence Test for Vector Variables with Large Sample Sizes},
author={Krzysztof Chalupka and Pietro Perona and Frederick Eberhardt},
year={2018},
journal={arXiv:1804.02747 [math, stat]},
print={1804.02747},
eprinttype={arxiv},
abstract={We present and evaluate the Fast (conditional) Independence Test (FIT) -- a nonparametric conditional independence test. The test is based on the idea that when P(X∣Y,Z)=P(X∣Y), Z is not useful as a feature to predict X, as long as Y is also a regressor. On the contrary, if P(X∣Y,Z)≠P(X∣Y), Z might improve prediction results. FIT applies to thousand-dimensional random variables with a hundred thousand samples in a fraction of the time required by alternative methods. We provide an extensive evaluation that compares FIT to six extant nonparametric independence tests. The evaluation shows that FIT has low probability of making both Type I and Type II errors compared to other tests, especially as the number of available samples grows. Our implementation of FIT is publicly available.},
archivePrefix={arXiv},
primaryClass={stat.ML}
}

@inproceedings{hellerMultivariateTestsOfAssociation2016,
title = {Multivariate Tests of Association Based on Univariate Tests},
author = {Heller, Ruth and Heller, Yair},
Expand Down
1 change: 1 addition & 0 deletions hyppo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
import hyppo.kgof
import hyppo.tools
import hyppo.d_variate
import hyppo.conditional_independence
sampan501 marked this conversation as resolved.
Show resolved Hide resolved

__version__ = "0.3.2"
237 changes: 237 additions & 0 deletions hyppo/conditional_independence/FCIT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import time
import joblib

import numpy as np
from scipy.stats import ttest_1samp
from sklearn.metrics import mean_squared_error as mse
from sklearn.preprocessing import StandardScaler

from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import ShuffleSplit
from sklearn.tree import DecisionTreeRegressor

from .base import ConditionalIndependenceTest, ConditionalIndependenceTestOutput


class FCIT(ConditionalIndependenceTest):
r"""
Fast Conditional Independence test statistic and p-value

The Fast Conditional Independence Test is a non-parametric
conditional independence test :footcite:p:`chalupka2018FastConditionalIndependence`.

Parameters
----------
model: Sklearn regressor
Regressor used to predict input data :math: `Y`
cv_grid: dict
Dictionary of parameters to cross-validate over when training regressor.
num_perm: int
Number of data permutations to estimate the p-value from marginal stats.
prop_test: float
Proportion of data to evaluate test stat on.
discrete: tuple of string
Whether :math: `X` or :math: `Y` are discrete
Notes
-----
The motivation for the test rests on the assumption that if :math: `X \not\!\perp\!\!\!\perp Y \mid Z`,
sampan501 marked this conversation as resolved.
Show resolved Hide resolved
then :math: `Y` should be more accurately predicted by using both
:math: `X` and :math: `Z` as covariates as opposed to only using
:math: `Z` as a covariate. Likewise, if :math: `X \perp \!\!\! \perp Y \mid Z`,
then :math: `Y` should be predicted just as accurately by solely
using :math: `X` or soley using :math: `Z`:footcite:p:`chalupka2018FastConditionalIndependence`.
Thus, the test works by using a regressor (the default is decision tree) to
to predict input :math: `Y` using both :math: `X` and :math: `Z` and using
only :math: `Z` :footcite:p:`chalupka2018FastConditionalIndependence`. Then,
accuracy of both predictions are measured via mean-squared error (MSE).
:math: `X \perp \!\!\! \perp Y \mid Z` if and only if MSE of the algorithm
using both :math: `X` and :math: `Z` is not smaller than the MSE of the
algorithm trained using only :math: `Z` :footcite:p:`chalupka2018FastConditionalIndependence`.

References
----------
.. footbibliography::
"""

def __init__(
self,
model=DecisionTreeRegressor(),
cv_grid={"min_samples_split": [2, 8, 64, 512, 1e-2, 0.2, 0.4]},
num_perm=8,
prop_test=0.1,
discrete=(False, False),
):

self.model = model
self.cv_grid = cv_grid
self.num_perm = num_perm
self.prop_test = prop_test
self.discrete = discrete
ConditionalIndependenceTest.__init__(self)

def statistic(self, x, y, z=None):
r"""
Calculates the FCIT test statistic.

Parameters
----------
x,y,z : ndarray of float
Input data matrices.

Returns
-------
stat : float
The computed FCIT test statistic.
"""

n_samples = x.shape[0]
n_test = int(n_samples * self.prop_test)

data_permutations = [
np.random.permutation(x.shape[0]) for i in range(self.num_perm)
]

clf = cross_val(x, y, z, self.cv_grid, self.model, prop_test=self.prop_test)
datadict = {
"x": x,
"y": y,
"z": z,
"data_permutation": data_permutations,
"n_test": n_test,
"reshuffle": False,
"clf": clf,
}
d1_stats = np.array(
joblib.Parallel(n_jobs=-1, max_nbytes=100e6)(
joblib.delayed(obtain_error)((datadict, i))
for i in range(self.num_perm)
)
)

if z.shape[1] == 0:
x_indep_y = x[np.random.permutation(n_samples)]
else:
x_indep_y = np.empty([x.shape[0], 0])

clf = cross_val(
x_indep_y, y, z, self.cv_grid, self.model, prop_test=self.prop_test
)

datadict["reshuffle"] = True
datadict["x"] = x_indep_y
d0_stats = np.array(
joblib.Parallel(n_jobs=-1, max_nbytes=100e6)(
joblib.delayed(obtain_error)((datadict, i))
for i in range(self.num_perm)
)
)

t, p_value = ttest_1samp(d0_stats / d1_stats, 1)
if t < 0:
p_value = 1 - p_value / 2
else:
p_value = p_value / 2

return t, p_value

def test(self, x, y, z=None):
r"""
Calculates the FCIT test statistic and p-value.

Parameters
----------
x,y,z : ndarray of float
Input data matrices.

Returns
-------
stat : float
The computed FCIT statistic.
pvalue : float
The computed FCIT p-value.

Examples
--------
>>> import numpy as np
>>> from hyppo.conditional_independence import FCIT
>>> from sklearn.tree import DecisionTreeRegressor
>>> np.random.seed(1234)
>>> dim = 2
>>> n = 100000
>>> z1 = np.random.multivariate_normal(mean=np.zeros(dim), cov=np.eye(dim), size=(n))
>>> A1 = np.random.normal(loc=0, scale=1, size=dim * dim).reshape(dim, dim)
>>> B1 = np.random.normal(loc=0, scale=1, size=dim * dim).reshape(dim, dim)
>>> x1 = (A1 @ z1.T + np.random.multivariate_normal(mean=np.zeros(dim), cov=np.eye(dim), size=(n)).T)
>>> y1 = (B1 @ z1.T + np.random.multivariate_normal(mean=np.zeros(dim), cov=np.eye(dim), size=(n)).T)
>>> model = DecisionTreeRegressor()
>>> cv_grid = {"min_samples_split": [2, 8, 64, 512, 1e-2, 0.2, 0.4]}
>>> stat, pvalue = FCIT(model=model, cv_grid=cv_grid).test(x1.T, y1.T, z1)
>>> '%.2f, %.3f' % (stat, pvalue)
'-3.59, 0.995'
"""

n_samples = x.shape[0]

if z is None:
z = np.empty([n_samples, 0])

if self.discrete[0] and not self.discrete[1]:
x, y = y, x
elif x.shape[1] < y.shape[1]:
x, y = y, x

y = StandardScaler().fit_transform(y)

stat, pvalue = self.statistic(x, y, z)

return ConditionalIndependenceTestOutput(stat, pvalue)


def cross_val(x, y, z, cv_grid, model, prop_test):
"""
Choose the regression hyperparameters by
cross-validation.
"""

splitter = ShuffleSplit(n_splits=3, test_size=prop_test)
cv = GridSearchCV(estimator=model, cv=splitter, param_grid=cv_grid, n_jobs=-1)
cv.fit(interleave(x, z), y)

return type(model)(**cv.best_params_)


def interleave(x, z, seed=None):
"""Interleave x and z dimension-wise."""
state = np.random.get_state()
np.random.seed(seed or int(time.time()))
total_ids = np.random.permutation(x.shape[1] + z.shape[1])
np.random.set_state(state)
out = np.zeros([x.shape[0], x.shape[1] + z.shape[1]])
out[:, total_ids[: x.shape[1]]] = x
out[:, total_ids[x.shape[1] :]] = z
return out


def obtain_error(data_and_i):
"""
A function used for multithreaded computation of the fcit test statistic.
Calculates MSE error for both trained regressors.
"""
data, i = data_and_i
x = data["x"]
y = data["y"]
z = data["z"]
if data["reshuffle"]:
perm_ids = np.random.permutation(x.shape[0])
else:
perm_ids = np.arange(x.shape[0])
data_permutation = data["data_permutation"][i]
n_test = data["n_test"]
clf = data["clf"]

x_z = interleave(x[perm_ids], z, seed=i)

clf.fit(x_z[data_permutation][n_test:], y[data_permutation][n_test:])
return mse(
y[data_permutation][:n_test], clf.predict(x_z[data_permutation][:n_test])
)
6 changes: 6 additions & 0 deletions hyppo/conditional_independence/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .FCIT import FCIT


__all__ = [s for s in dir()] # add imported tests to __all__

COND_INDEP_TESTS = {"fcit": FCIT}
51 changes: 51 additions & 0 deletions hyppo/conditional_independence/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from abc import ABC, abstractmethod
from typing import NamedTuple


class ConditionalIndependenceTestOutput(NamedTuple):
stat: float
pvalue: float


class ConditionalIndependenceTest(ABC):
"""
A base class for a conditional independence test.

"""

def __init__(self):
super().__init__()

@abstractmethod
def statistic(self, x, y, z):
r"""
Calculates the conditional independence test statistic.

Parameters
----------
x,y,z : ndarray of float
Input data matrices.

Returns
-------
stat : float
The computed conditional independence test statistic.
"""

@abstractmethod
def test(self, x, y, z):
r"""
Calculates the conditional independence test statistic and p-value.

Parameters
----------
x,y,z : ndarray of float
Input data matrices.

Returns
-------
stat : float
The computed conditional independence test statistic.
pvalue : float
The computed conditional independence test p-value.
"""
Loading