Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FCIT #315

Merged
merged 59 commits into from
May 13, 2022
Merged

FCIT #315

Changes from 1 commit
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
992a2b4
Create smoothCF.py
MatthewZhao26 Oct 21, 2021
c4b8b6c
Added mean_embedding
MatthewZhao26 Oct 21, 2021
54bed36
updated inheritance from IndependenceTest
MatthewZhao26 Oct 28, 2021
6dd8f4d
Updated SmoothCF and ME tests to match author code output
MatthewZhao26 Nov 3, 2021
4b25431
Updated SmoothCF and ME inheritance, started editing tutorial
MatthewZhao26 Nov 11, 2021
beb4ecb
Updated documentation
MatthewZhao26 Nov 18, 2021
1f2ad30
Enabled numba, added docstrings
MatthewZhao26 Nov 30, 2021
ca8bdd1
Minor change
MatthewZhao26 Nov 30, 2021
1927935
Merge branch 'main' into fast_tstest
PSSF23 Dec 2, 2021
a789fa8
Response to initial PR comments
MatthewZhao26 Dec 10, 2021
0d57d2a
Merge branch 'fast_tstest' of https://github.com/MatthewZhao26/hyppo …
MatthewZhao26 Dec 10, 2021
37b859e
Added random state to SmoothCF
MatthewZhao26 Dec 12, 2021
5e076d6
Merge branch 'dev' into fast_tstest
sampan501 Dec 12, 2021
e5d9cbf
ran black
MatthewZhao26 Dec 13, 2021
f35f7b3
ran black again
MatthewZhao26 Dec 13, 2021
84f75c6
remove blank
MatthewZhao26 Dec 13, 2021
1594f6d
quotes
MatthewZhao26 Dec 13, 2021
bbfa96a
Changed example
MatthewZhao26 Dec 13, 2021
398a408
black
MatthewZhao26 Dec 13, 2021
e0659a4
deleted
MatthewZhao26 Dec 13, 2021
002deef
black + more changes
MatthewZhao26 Dec 13, 2021
fa4ad44
random_state added to ME
MatthewZhao26 Dec 13, 2021
58352a3
coverage
MatthewZhao26 Dec 13, 2021
2f021c7
tests
MatthewZhao26 Dec 13, 2021
daf006a
black
MatthewZhao26 Dec 13, 2021
27638af
more tests
MatthewZhao26 Dec 13, 2021
e7df8ae
tests
MatthewZhao26 Dec 13, 2021
70e1b24
reformatting
MatthewZhao26 Dec 13, 2021
c685597
reformat
MatthewZhao26 Dec 13, 2021
c2348c0
reformat
MatthewZhao26 Dec 13, 2021
63a67a1
Merge branch 'dev' into fast_tstest
sampan501 Dec 13, 2021
797a62f
added journal to citation (Fast 2-sample)
MatthewZhao26 Dec 13, 2021
1127f71
Merge branch 'fast_tstest' of https://github.com/MatthewZhao26/hyppo …
MatthewZhao26 Dec 13, 2021
bf0d7d3
corrected docstring math format
MatthewZhao26 Dec 13, 2021
1ec4696
formatting update, helper docstring update
MatthewZhao26 Dec 15, 2021
e7c56d0
Merge branch 'dev' into fast_tstest
sampan501 Dec 17, 2021
8c0c98d
Formatting changes, renaming of helpers, moved random_state
MatthewZhao26 Dec 17, 2021
5d52c84
Merge branch 'fast_tstest' of https://github.com/MatthewZhao26/hyppo …
MatthewZhao26 Dec 17, 2021
3adbf35
fix tabbing issue in ksample
MatthewZhao26 Dec 17, 2021
ea2a1f6
added __init__ variable explanantions to tutorial
MatthewZhao26 Dec 18, 2021
db653fe
minor formatting change to tutorial
MatthewZhao26 Dec 18, 2021
cfb6dcb
fix tutorial rendering issue
sampan501 Dec 20, 2021
dd9587c
Create FCIT.py
MatthewZhao26 Apr 7, 2022
b850800
Creating new conditional independence module
MatthewZhao26 Apr 11, 2022
7b8ceb7
Update cond'l ind module + bib
MatthewZhao26 Apr 12, 2022
71001f6
updated index.rst
MatthewZhao26 Apr 13, 2022
358a459
updated config.yml
MatthewZhao26 Apr 13, 2022
6acfd56
added tests, updated hyppo init
MatthewZhao26 Apr 18, 2022
19174a0
updated tests FCIT
MatthewZhao26 Apr 18, 2022
6c57668
Merge branch 'dev' into fast_tstest
MatthewZhao26 Apr 21, 2022
6b5e095
Update FCIT.py
MatthewZhao26 Apr 21, 2022
2ef7143
added FCIT documentation
MatthewZhao26 May 5, 2022
e8b719f
Merge branch 'dev' into pr/315
sampan501 May 5, 2022
a6f4d6b
black update, documentation update FCIT
MatthewZhao26 May 11, 2022
bff8d53
Merge branch 'fast_tstest' of https://github.com/MatthewZhao26/hyppo …
MatthewZhao26 May 11, 2022
f92697d
added tutorial
MatthewZhao26 May 13, 2022
d04742c
Update conditional.py
MatthewZhao26 May 13, 2022
dff6ceb
rst edits
MatthewZhao26 May 13, 2022
0dd9df6
minor doc changes
MatthewZhao26 May 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Create FCIT.py
  • Loading branch information
MatthewZhao26 committed Apr 7, 2022
commit dd9587c3928cb3a68a35512e6fab617bf4cf580a
175 changes: 175 additions & 0 deletions hyppo/independence/FCIT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import time
import joblib

import numpy as np
from scipy.stats import ttest_1samp
from sklearn.metrics import mean_squared_error as mse
from sklearn.preprocessing import StandardScaler

from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import ShuffleSplit

from .base import IndependenceTest, IndependenceTestOutput


class FCIT(IndependenceTest):

def __init__(self, model, cv_grid, num_perm=8, prop_test=0.1, discrete=(False, False)):

self.model = model
self.cv_grid = cv_grid
self.num_perm = num_perm
self.prop_test = prop_test
self.discrete = discrete
IndependenceTest.__init__(self)


def statistic(self, x, y, z):

n_samples = x.shape[0]
n_test = int(n_samples * self.prop_test)

# Set up storage for true data and permuted data MSEs.
d0_stats = np.zeros(self.num_perm)
d1_stats = np.zeros(self.num_perm)
data_permutations = [np.random.permutation(x.shape[0]) for i in range(self.num_perm)]

########################################################
# Compute mses for y = f(x, z), varying train-test splits.
clf = cross_val(x, y, z, self.cv_grid, self.model, prop_test=self.prop_test)
datadict = {
'x': x,
'y': y,
'z': z,
'data_permutation': data_permutations,
'n_test': n_test,
'reshuffle': False,
'clf': clf,
}
d1_stats = np.array(joblib.Parallel(n_jobs=-1, max_nbytes=100e6)(
joblib.delayed(obtain_error)((datadict, i)) for i in range(self.num_perm)))

#####################################################################33
# Compute mses for y = f(x, reshuffle(z)), varying train-test splits.
if z.shape[1] == 0:
x_indep_y = x[np.random.permutation(n_samples)]
else:
x_indep_y = np.empty([x.shape[0], 0])

clf = cross_val(x_indep_y, y, z, self.cv_grid, self.model, prop_test=self.prop_test)

datadict['reshuffle'] = True
datadict['x'] = x_indep_y
d0_stats = np.array(joblib.Parallel(n_jobs=-1, max_nbytes=100e6)(
joblib.delayed(obtain_error)((datadict, i)) for i in range(self.num_perm)))


t, p_value = ttest_1samp(d0_stats / d1_stats, 1)
if t < 0:
p_value = 1 - p_value / 2
else:
p_value = p_value / 2

return t, p_value


def test(self, x, y, z):

###################################################
# Compute test set size.
n_samples = x.shape[0]
#n_test = int(n_samples * self.prop_test)

if z is None:
z = np.empty([n_samples, 0])

if self.discrete[0] and not self.discrete[1]:
# If x xor y is discrete, use the continuous variable as input.
x, y = y, x
elif x.shape[1] < y.shape[1]:
# Otherwise, predict the variable with fewer dimensions.
x, y = y, x

# Normalize y to make the decision tree stopping criterion meaningful.
y = StandardScaler().fit_transform(y)

stat, pvalue = self.statistic(x, y, z)

return IndependenceTestOutput(stat, pvalue)


def cross_val(x, y, z, cv_grid, model, prop_test):
""" Choose the best decision tree hyperparameters by
cross-validation. The hyperparameter to optimize is min_samples_split
(see sklearn's DecisionTreeRegressor).
Args:
x (n_samples, x_dim): Input data array.
y (n_samples, y_dim): Output data array.
z (n_samples, z_dim): Optional auxiliary input data.
cv_grid (list): List of hyperparameter values to try in CV.
regresor (sklearn classifier): Which regression model to use.
prop_test (float): Proportion of validation data to use.
Returns:
DecisionTreeRegressor with the best hyperparameter setting.
"""

splitter = ShuffleSplit(n_splits=3, test_size=prop_test)
cv = GridSearchCV(estimator=model, cv=splitter,
param_grid=cv_grid, n_jobs=-1)
cv.fit(interleave(x, z), y)

return type(model)(**cv.best_params_)


#for concat of x and z
def interleave(x, z, seed=None):
""" Interleave x and z dimension-wise.
Args:
x (n_samples, x_dim) array.
z (n_samples, z_dim) array.
Returns
An array of shape (n_samples, x_dim + z_dim) in which
the columns of x and z are interleaved at random.
"""
state = np.random.get_state()
np.random.seed(seed or int(time.time()))
total_ids = np.random.permutation(x.shape[1]+z.shape[1])
np.random.set_state(state)
out = np.zeros([x.shape[0], x.shape[1] + z.shape[1]])
out[:, total_ids[:x.shape[1]]] = x
out[:, total_ids[x.shape[1]:]] = z
return out


#computes MSE
def obtain_error(data_and_i):
"""
A function used for multithreaded computation of the fcit test statistic.
data['x']: First variable.
data['y']: Second variable.
data['z']: Conditioning variable.
data['data_permutation']: Permuted indices of the data.
data['perm_ids']: Permutation for the bootstrap.
data['n_test']: Number of test points.
data['clf']: Decision tree regressor.
"""
data, i = data_and_i
x = data['x']
y = data['y']
z = data['z']
if data['reshuffle']:
perm_ids = np.random.permutation(x.shape[0])
else:
perm_ids = np.arange(x.shape[0])
data_permutation = data['data_permutation'][i]
n_test = data['n_test']
clf = data['clf']

x_z = interleave(x[perm_ids], z, seed=i)

clf.fit(x_z[data_permutation][n_test:], y[data_permutation][n_test:])
return mse(y[data_permutation][:n_test],
clf.predict(x_z[data_permutation][:n_test]))