-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add discriminability (one sample and two sample) (#27)
* Adding one sample test for discriminability * Adding discriminability doc * Adding discriminability * adding unit test * Bug fixing unit test * Bug fixing unit test * Bug fixing unit test * adding correction for the one sample test * Fixed issues with docs, unit tests and addressing PR review comments * Update discrimTwoSample.py * merging master * adding correction for coding formats * adding correction for final pr * adding correction for final pr * adding correction for final pr * Adding correction for code format using black * Adding correction for coding format * Adding correction for coding format * Adding correction for coding format * added numba for two sample test * correcting doc formats * correcting doc formats * correcting doc formats * correcting doc formats
- Loading branch information
Showing
14 changed files
with
999 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
Discriminability | ||
**************** | ||
|
||
.. currentmodule:: mgc.discrim | ||
|
||
Discriminability one sample test | ||
-------------------------------- | ||
.. autoclass:: DiscrimOneSample | ||
|
||
Discriminability two sample test | ||
-------------------------------- | ||
.. autoclass:: DiscrimTwoSample |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,8 +4,9 @@ Reference | |
********* | ||
|
||
.. toctree:: | ||
:maxdepth: 2 | ||
:maxdepth: 3 | ||
|
||
discrim | ||
independence | ||
ksample | ||
sims |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,5 +2,6 @@ | |
import mgc.ksample | ||
import mgc.time_series | ||
import mgc.sims | ||
import mgc.discrim | ||
|
||
__version__ = "0.0.1" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .discrim_one_samp import DiscrimOneSample | ||
from .discrim_two_samp import DiscrimTwoSample | ||
|
||
__all__ = ["DiscrimOneSample", "DiscrimTwoSample"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import warnings | ||
import numpy as np | ||
from .._utils import ( | ||
contains_nan, | ||
check_ndarray_xy, | ||
convert_xy_float64, | ||
check_reps, | ||
euclidean, | ||
) | ||
|
||
|
||
class _CheckInputs: | ||
"""Checks inputs for discriminability tests""" | ||
|
||
def __init__(self, x, y, reps=None, is_dist=False, remove_isolates=True): | ||
self.x = x | ||
self.y = y | ||
self.reps = reps | ||
self.is_distance = is_dist | ||
self.remove_isolates = remove_isolates | ||
|
||
def __call__(self): | ||
if len(self.x) > 1: | ||
if self.x[0].shape[0] != self.x[1].shape[0]: | ||
msg = "The input matrices do not have the same number of rows." | ||
raise ValueError(msg) | ||
|
||
tmp_ = [] | ||
for x1 in self.x: | ||
check_ndarray_xy(x1, self.y) | ||
contains_nan(x1) | ||
contains_nan(self.y) | ||
check_min_samples(x1) | ||
x1, self.y = convert_xy_float64(x1, self.y) | ||
tmp_.append(self._condition_input(x1)) | ||
|
||
self.x = tmp_ | ||
|
||
if self.reps: | ||
check_reps(self.reps) | ||
|
||
return self.x, self.y | ||
|
||
def _condition_input(self, x1): | ||
"""Checks whether there is only one subject and removes | ||
isolates and calculate distance.""" | ||
uniques, counts = np.unique(self.y, return_counts=True) | ||
|
||
if (counts != 1).sum() <= 1: | ||
msg = "You have passed a vector containing only a single unique sample id." | ||
raise ValueError(msg) | ||
|
||
if self.remove_isolates: | ||
idx = np.isin(self.y, uniques[counts != 1]) | ||
self.y = self.y[idx] | ||
|
||
x1 = np.asarray(x1) | ||
if not self.is_distance: | ||
x1 = x1[idx] | ||
else: | ||
x1 = x1[np.ix_(idx, idx)] | ||
|
||
if not self.is_distance: | ||
x1 = euclidean(x1) | ||
|
||
return x1 | ||
|
||
|
||
def check_min_samples(x1): | ||
"""Check if the number of samples is at least 3""" | ||
nx = x1.shape[0] | ||
|
||
if nx <= 10: | ||
raise ValueError("Number of samples is too low") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from abc import ABC, abstractmethod | ||
import numpy as np | ||
from .._utils import euclidean | ||
|
||
|
||
class DiscriminabilityTest(ABC): | ||
r""" | ||
A base class for a Discriminability test. | ||
""" | ||
|
||
def __init__(self): | ||
self.pvalue_ = None | ||
super().__init__() | ||
|
||
# @abstractmethod | ||
def _statistic(self, x, y): | ||
r""" | ||
Calulates the independence test statistic. | ||
Parameters | ||
---------- | ||
x, y : ndarray | ||
Input data matrices. | ||
""" | ||
|
||
rdfs = self._discr_rdf(x, y) | ||
stat = np.nanmean(rdfs) | ||
|
||
return stat | ||
|
||
def _discr_rdf(self, dissimilarities, labels): | ||
# calculates test statistics distribution | ||
rdfs = [] | ||
|
||
for i, label in enumerate(labels): | ||
di = dissimilarities[i] | ||
|
||
# All other samples except its own label | ||
idx = labels == label | ||
Dij = di[~idx] | ||
|
||
# All samples except itself | ||
idx[i] = False | ||
Dii = di[idx] | ||
|
||
rdf = [ | ||
1 - ((Dij < d).sum() + 0.5 * (Dij == d).sum()) / Dij.size for d in Dii | ||
] | ||
rdfs.append(rdf) | ||
|
||
out = np.full((len(rdfs), max(map(len, rdfs))), np.nan) | ||
for i, rdf in enumerate(rdfs): | ||
out[i, : len(rdf)] = rdf | ||
|
||
return out | ||
|
||
@abstractmethod | ||
def _perm_stat(self, index): | ||
r""" | ||
Helper function that is used to calculate parallel permuted test | ||
statistics. | ||
Parameters | ||
---------- | ||
index : int | ||
Iterator used for parallel statistic calculation | ||
Returns | ||
------- | ||
perm_stat : float | ||
Test statistic for each value in the null distribution. | ||
""" | ||
|
||
@abstractmethod | ||
def test(self): | ||
r""" | ||
Calculates the test statistic and p-value for Discriminability one sample | ||
and two sample test. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
from ._utils import _CheckInputs | ||
import numpy as np | ||
import random | ||
from .base import DiscriminabilityTest | ||
from scipy._lib._util import MapWrapper | ||
|
||
|
||
class DiscrimOneSample(DiscriminabilityTest): | ||
r""" | ||
A class that performs a one-sample test for discriminability. | ||
Discriminability index is a measure of whether a data acquisition and | ||
preprocessing pipeline is more discriminable among different subjects. | ||
The key insight is that each measurement of the same item should be more | ||
similar to other measurements of that item, as compared to measurements | ||
of any other item. One sample test measures whether the discriminability | ||
for a dataset differs from random chance. More details can be described | ||
in [#1Dscr]_. | ||
Parameters | ||
---------- | ||
is_dist : bool, optional (default: False) | ||
whether `x` is a distance matrix or not. | ||
remove_isolates : bool, optional (default: True) | ||
whether to remove the measurements with single instance or not. | ||
See Also | ||
-------- | ||
DiscrimTwoSample : Two sample test for comparing the discriminability of two data | ||
Notes | ||
----- | ||
With :math:`D_X` as the sample discriminability of :math:`X`, | ||
one sample test verifies whether | ||
.. math:: | ||
H_0: D_X = D_0 | ||
and | ||
.. math:: | ||
H_A: D_X > D_0 | ||
where :math:`D_0` is the discriminability that would be observed by random chance. | ||
References | ||
---------- | ||
.. [#1Dscr] Eric W. Bridgeford, et al. "Optimal Decisions for Reference | ||
Pipelines and Datasets: Applications in Connectomics." Bioarxiv (2019). | ||
""" | ||
|
||
def __init__(self, is_dist=False, remove_isolates=True): | ||
# set is_distance to true if compute_distance is None | ||
self.is_distance = is_dist | ||
self.remove_isolates = remove_isolates | ||
DiscriminabilityTest.__init__(self) | ||
|
||
def _statistic(self, x, y): | ||
""" | ||
Helper function that calculates the discriminability test statistics. | ||
""" | ||
stat = super(DiscrimOneSample, self)._statistic(x, y) | ||
|
||
return stat | ||
|
||
def test(self, x, y, reps=1000, workers=-1): | ||
r""" | ||
Calculates the test statistic and p-value for Discriminability one sample test. | ||
Parameters | ||
---------- | ||
x: ndarray | ||
An `(n, d)` data matrix with `n` samples in `d` dimensions, | ||
if flag is_dist = Flase and an `(n, n)` distance matrix, | ||
if flag is_dist = True | ||
y : ndarray | ||
a vector containing the sample ids for our :math:`n` samples. | ||
reps : int, optional (default: 1000) | ||
The number of replications used to estimate the null distribution | ||
when using the permutation test used to calculate the p-value. | ||
workers : int, optional (default: -1) | ||
The number of cores to parallelize the p-value computation over. | ||
Supply -1 to use all cores available to the Process. | ||
Returns | ||
------- | ||
stat : float | ||
The computed discriminability statistic. | ||
pvalue : float | ||
The computed one sample test p-value. | ||
Examples | ||
-------- | ||
>>> import numpy as np | ||
>>> from mgc.discrim import DiscrimOneSample | ||
>>> x = np.concatenate((np.zeros((50,2)) ,np.ones((50,2))), axis=0) | ||
>>> y = np.concatenate((np.zeros(50),np.ones(50)), axis= 0) | ||
>>> stat, p = DiscrimOneSample().test(x,y) | ||
>>> '%.1f, %.2f' % (stat, p) | ||
'1.0, 0.00' | ||
""" | ||
|
||
check_input = _CheckInputs( | ||
[x], | ||
y, | ||
reps=reps, | ||
is_dist=self.is_distance, | ||
remove_isolates=self.remove_isolates, | ||
) | ||
x, y = check_input() | ||
|
||
self.x = np.asarray(x[0]) | ||
self.y = y | ||
|
||
stat = self._statistic(self.x, self.y) | ||
self.stat = stat | ||
|
||
# use all cores to create function that parallelizes over number of reps | ||
mapwrapper = MapWrapper(workers) | ||
null_dist = np.array(list(mapwrapper(self._perm_stat, range(reps)))) | ||
self.null_dist = null_dist | ||
|
||
# calculate p-value and significant permutation map through list | ||
pvalue = ((null_dist >= stat).sum()) / reps | ||
|
||
# correct for a p-value of 0. This is because, with bootstrapping | ||
# permutations, a p-value of 0 is incorrect | ||
if pvalue == 0: | ||
pvalue = 1 / reps | ||
|
||
self.pvalue_ = pvalue | ||
|
||
return stat, pvalue | ||
|
||
def _perm_stat(self, index): | ||
permy = np.random.permutation(self.y) | ||
|
||
perm_stat = self._statistic(self.x, permy) | ||
|
||
return perm_stat |
Oops, something went wrong.