-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
551 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#!/usr/bin/env python3 | ||
from notears import nonlinear, utils | ||
import numpy as np | ||
import argparse | ||
|
||
|
||
def main(args): | ||
X = np.loadtxt(args.X_path, delimiter=',') | ||
n, d = X.shape | ||
model = nonlinear.NotearsMLP(dims=[d, args.hidden, 1], bias=True) | ||
W_est = nonlinear.notears_nonlinear(model, X, lambda1=args.lambda1, lambda2=args.lambda2) | ||
assert utils.is_dag(W_est) | ||
np.savetxt(args.W_path, W_est, delimiter=',') | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='Run NOTEARS algorithm') | ||
parser.add_argument('X_path', type=str, help='n by p data matrix in csv format') | ||
parser.add_argument('--hidden', type=int, default=10, help='Number of hidden units') | ||
parser.add_argument('--lambda1', type=float, default=0.01, help='L1 regularization parameter') | ||
parser.add_argument('--lambda2', type=float, default=0.01, help='L2 regularization parameter') | ||
parser.add_argument('--W_path', type=str, default='W_est.csv', help='p by p weighted adjacency matrix of estimated DAG in csv format') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
main(args) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,12 +3,12 @@ | |
|
||
setup( | ||
name='notears', | ||
version='2.1', | ||
version='3.0', | ||
description='Implementation of the NOTEARS algorithm', | ||
author='Xun Zheng', | ||
author_email='[email protected]', | ||
url='https://github.com/xunzheng/notears', | ||
download_url='https://github.com/xunzheng/notears/archive/v2.1.zip', | ||
download_url='https://github.com/xunzheng/notears/archive/v3.0.zip', | ||
license='Apache License 2.0', | ||
keywords='notears causal discovery bayesian network structure learning', | ||
classifiers=[ | ||
|
@@ -17,7 +17,8 @@ | |
'License :: OSI Approved :: Apache Software License', | ||
'Topic :: Scientific/Engineering :: Artificial Intelligence', | ||
], | ||
scripts=['bin/notears_linear_l1'], | ||
scripts=['bin/notears_linear', | ||
'bin/notears_nonlinear'], | ||
packages=['notears'], | ||
package_dir={'notears': 'src'}, | ||
install_requires=['numpy', 'scipy', 'python-igraph'], | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import torch | ||
import scipy.optimize as sopt | ||
|
||
|
||
class LBFGSBScipy(torch.optim.Optimizer): | ||
"""Wrap L-BFGS-B algorithm, using scipy routines.""" | ||
|
||
def __init__(self, params): | ||
defaults = dict() | ||
super(LBFGSBScipy, self).__init__(params, defaults) | ||
|
||
if len(self.param_groups) != 1: | ||
raise ValueError("LBFGSBScipy doesn't support per-parameter options" | ||
" (parameter groups)") | ||
|
||
self._params = self.param_groups[0]['params'] | ||
self._numel = sum([p.numel() for p in self._params]) | ||
|
||
def _gather_flat_grad(self): | ||
views = [] | ||
for p in self._params: | ||
if p.grad is None: | ||
view = p.data.new(p.data.numel()).zero_() | ||
elif p.grad.data.is_sparse: | ||
view = p.grad.data.to_dense().view(-1) | ||
else: | ||
view = p.grad.data.view(-1) | ||
views.append(view) | ||
return torch.cat(views, 0) | ||
|
||
def _gather_flat_bounds(self): | ||
bounds = [] | ||
for p in self._params: | ||
if hasattr(p, 'bounds'): | ||
b = p.bounds | ||
else: | ||
b = [(None, None)] * p.numel() | ||
bounds += b | ||
return bounds | ||
|
||
def _gather_flat_params(self): | ||
views = [] | ||
for p in self._params: | ||
if p.data.is_sparse: | ||
view = p.data.to_dense().view(-1) | ||
else: | ||
view = p.data.view(-1) | ||
views.append(view) | ||
return torch.cat(views, 0) | ||
|
||
def _distribute_flat_params(self, params): | ||
offset = 0 | ||
for p in self._params: | ||
numel = p.numel() | ||
# view as to avoid deprecated pointwise semantics | ||
p.data = params[offset:offset + numel].view_as(p.data) | ||
offset += numel | ||
assert offset == self._numel | ||
|
||
def step(self, closure): | ||
"""Performs a single optimization step. | ||
Arguments: | ||
closure (callable): A closure that reevaluates the model | ||
and returns the loss. | ||
""" | ||
assert len(self.param_groups) == 1 | ||
|
||
def wrapped_closure(flat_params): | ||
"""closure must call zero_grad() and backward()""" | ||
flat_params = torch.from_numpy(flat_params) | ||
flat_params = flat_params.to(torch.get_default_dtype()) | ||
self._distribute_flat_params(flat_params) | ||
loss = closure() | ||
loss = loss.item() | ||
flat_grad = self._gather_flat_grad().cpu().detach().numpy() | ||
return loss, flat_grad.astype('float64') | ||
|
||
initial_params = self._gather_flat_params() | ||
initial_params = initial_params.cpu().detach().numpy() | ||
|
||
bounds = self._gather_flat_bounds() | ||
|
||
# How can it work without getting the final solution..? | ||
sol = sopt.minimize(wrapped_closure, | ||
initial_params, | ||
method='L-BFGS-B', | ||
jac=True, | ||
bounds=bounds) | ||
|
||
final_params = torch.from_numpy(sol.x) | ||
final_params = final_params.to(torch.get_default_dtype()) | ||
self._distribute_flat_params(final_params) | ||
|
||
|
||
def main(): | ||
import torch.nn as nn | ||
# torch.set_default_dtype(torch.double) | ||
|
||
n, d, out, j = 10000, 3000, 10, 0 | ||
input = torch.randn(n, d) | ||
w_true = torch.rand(d, out) | ||
w_true[j, :] = 0 | ||
target = torch.matmul(input, w_true) | ||
linear = nn.Linear(d, out) | ||
linear.weight.bounds = [(0, None)] * d * out # hack | ||
for m in range(out): | ||
linear.weight.bounds[m * d + j] = (0, 0) | ||
criterion = nn.MSELoss() | ||
optimizer = LBFGSBScipy(linear.parameters()) | ||
print(list(linear.parameters())) | ||
|
||
def closure(): | ||
optimizer.zero_grad() | ||
output = linear(input) | ||
loss = criterion(output, target) | ||
print('loss:', loss.item()) | ||
loss.backward() | ||
return loss | ||
optimizer.step(closure) | ||
print(list(linear.parameters())) | ||
print(w_true.t()) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import torch | ||
import torch.nn as nn | ||
import math | ||
|
||
|
||
class LocallyConnected(nn.Module): | ||
"""Local linear layer, i.e. Conv1dLocal() with filter size 1. | ||
Args: | ||
num_linear: num of local linear layers, i.e. | ||
in_features: m1 | ||
out_features: m2 | ||
bias: whether to include bias or not | ||
Shape: | ||
- Input: [n, d, m1] | ||
- Output: [n, d, m2] | ||
Attributes: | ||
weight: [d, m1, m2] | ||
bias: [d, m2] | ||
""" | ||
|
||
def __init__(self, num_linear, input_features, output_features, bias=True): | ||
super(LocallyConnected, self).__init__() | ||
self.num_linear = num_linear | ||
self.input_features = input_features | ||
self.output_features = output_features | ||
|
||
self.weight = nn.Parameter(torch.Tensor(num_linear, | ||
input_features, | ||
output_features)) | ||
if bias: | ||
self.bias = nn.Parameter(torch.Tensor(num_linear, output_features)) | ||
else: | ||
# You should always register all possible parameters, but the | ||
# optional ones can be None if you want. | ||
self.register_parameter('bias', None) | ||
|
||
self.reset_parameters() | ||
|
||
@torch.no_grad() | ||
def reset_parameters(self): | ||
k = 1.0 / self.input_features | ||
bound = math.sqrt(k) | ||
nn.init.uniform_(self.weight, -bound, bound) | ||
if self.bias is not None: | ||
nn.init.uniform_(self.bias, -bound, bound) | ||
|
||
def forward(self, input: torch.Tensor): | ||
# [n, d, 1, m2] = [n, d, 1, m1] @ [1, d, m1, m2] | ||
out = torch.matmul(input.unsqueeze(dim=2), self.weight.unsqueeze(dim=0)) | ||
out = out.squeeze(dim=2) | ||
if self.bias is not None: | ||
# [n, d, m2] += [d, m2] | ||
out += self.bias | ||
return out | ||
|
||
def extra_repr(self): | ||
# (Optional)Set the extra information about this module. You can test | ||
# it by printing an object of this class. | ||
return 'num_linear={}, in_features={}, out_features={}, bias={}'.format( | ||
self.num_linear, self.in_features, self.out_features, | ||
self.bias is not None | ||
) | ||
|
||
|
||
def main(): | ||
n, d, m1, m2 = 2, 3, 5, 7 | ||
|
||
# numpy | ||
import numpy as np | ||
input_numpy = np.random.randn(n, d, m1) | ||
weight = np.random.randn(d, m1, m2) | ||
output_numpy = np.zeros([n, d, m2]) | ||
for j in range(d): | ||
# [n, m2] = [n, m1] @ [m1, m2] | ||
output_numpy[:, j, :] = input_numpy[:, j, :] @ weight[j, :, :] | ||
|
||
# torch | ||
torch.set_default_dtype(torch.double) | ||
input_torch = torch.from_numpy(input_numpy) | ||
locally_connected = LocallyConnected(d, m1, m2, bias=False) | ||
locally_connected.weight.data[:] = torch.from_numpy(weight) | ||
output_torch = locally_connected(input_torch) | ||
|
||
# compare | ||
print(torch.allclose(output_torch, torch.from_numpy(output_numpy))) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.