-
Notifications
You must be signed in to change notification settings - Fork 10
/
RCIT.R
161 lines (121 loc) · 4.33 KB
/
RCIT.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#' RCIT and RCoT - tests whether x and y are conditionally independent given z. Calls RIT if z is empty.
#' @param x Random variable x.
#' @param y Random variable y.
#' @param z Random variable z.
#' @param approx Method for approximating the null distribution. Options include:
#' "lpd4," the Lindsay-Pilla-Basak method (default),
#' "gamma" for the Satterthwaite-Welch method,
#' "hbe" for the Hall-Buckley-Eagleson method,
#' "chi2" for a normalized chi-squared statistic,
#' "perm" for permutation testing (warning: this one is slow but recommended for small samples generally <500 )
#' @param num_f Number of features for conditioning set. Default is 25.
#' @param num_f2 Number of features for non-conditioning sets. Default is 5.
#' @param seed The seed for controlling random number generation. Use if you want to replicate results exactly. Default is NULL.
#' @return A list containing the p-value \code{p} and statistic \code{Sta}
#' @export
#' @examples
#' RCIT(rnorm(1000),rnorm(1000),rnorm(1000));
#'
#' x=rnorm(10000);
#' y=(x+rnorm(10000))^2;
#' z=rnorm(10000);
#' RCIT(x,y,z,seed=2);
RCIT <- function(x,y,z=NULL,approx="lpd4",num_f=100,num_f2=5, seed=NULL){
if (length(z)==0){
out=RIT(x,y,approx=approx,seed=seed);
return(out)
}
else{
x=matrix2(x);
y=matrix2(y);
z=matrix2(z);
z=z[,apply(z,2,sd)>0];
z=matrix2(z);
d=ncol(z);
if (length(z)==0){
out=RIT(x,y,approx=approx, seed=seed);
return(out);
} else if (sd(x)==0 | sd(y)==0){
out=list(p=1,Sta=0);
return(out)
}
r=nrow(x);
if (r>500){
r1=500
} else {r1=r;}
x=normalize(x);
y=normalize(y);
z=normalize(z);
#for (t in seq_len(ncol(x))){
# x[,t] = pnorm(ecdf(x[,t])(x[,t]));
#}
#for (t in seq_len(ncol(y))){
# y[,t] = pnorm(ecdf(y[,t])(y[,t]));
#}
#for (t in seq_len(d)){
# z[,t] = pnorm(ecdf(z[,t])(z[,t]));
#}
y = cbind(y,z)
four_z = random_fourier_features(z[,1:d],num_f=num_f,sigma=median(c(t(dist(z[1:r1,])))), seed = seed );
four_x = random_fourier_features(x,num_f=num_f2,sigma=median(c(t(dist(x[1:r1,])))), seed = seed );
four_y = random_fourier_features(y,num_f=num_f2,sigma=median(c(t(dist(y[1:r1,])))), seed = seed );
f_x=normalize(four_x$feat);
f_y=normalize(four_y$feat);
f_z=normalize(four_z$feat);
Cxy=cov(f_x,f_y);
Czz = cov(f_z);
i_Czz = chol2inv(chol( Czz + diag(num_f) * 1e-10))
# i_Czz = ginv(Czz+diag(num_f)*1E-10); #requires library(MASS)
Cxz=cov(f_x,f_z);
Czy=cov(f_z,f_y);
z_i_Czz=f_z%*%i_Czz;
e_x_z = z_i_Czz%*%t(Cxz);
e_y_z = z_i_Czz%*%Czy;
#approximate null distributions
res_x = f_x-e_x_z;
res_y = f_y-e_y_z;
if (num_f2==1){
approx="hbe"
}
if (approx == "perm"){
Cxy_z = cov(res_x, res_y);
Sta = r*sum(Cxy_z^2);
nperm =1000;
Stas = c();
for (ps in 1:nperm){
perm = sample(1:r,r);
Sta_p = Sta_perm(res_x[perm,],res_y,r)
Stas = c(Stas, Sta_p);
}
p = 1-(sum(Sta >= Stas)/length(Stas));
} else {
Cxy_z=Cxy-Cxz%*%i_Czz%*%Czy; #less accurate for permutation testing
Sta = r*sum(Cxy_z^2);
d =expand.grid(1:ncol(f_x),1:ncol(f_y));
res = res_x[,d[,1]]*res_y[,d[,2]];
Cov = 1/r * (t(res)%*%res);
if (approx == "chi2"){
i_Cov = ginv(Cov)
Sta = r * (c(Cxy_z)%*% i_Cov %*% c(Cxy_z) );
p = 1-pchisq(Sta, length(c(Cxy_z)));
} else{
eig_d = eigen(Cov,symmetric=TRUE);
eig_d$values=eig_d$values[eig_d$values>0];
if (approx == "gamma"){
p=1-sw(eig_d$values,Sta);
} else if (approx == "hbe") {
p=1-hbe(eig_d$values,Sta);
} else if (approx == "lpd4"){
eig_d_values=eig_d$values;
p=try(1-lpb4(eig_d_values,Sta),silent=TRUE);
if (!is.numeric(p) | is.nan(p)){
p=1-hbe(eig_d$values,Sta);
}
}
}
}
if (p<0) p=0;
out=list(p=p,Sta=Sta);
return(out)
}
}