diff --git a/bin/notears_nonlinear b/bin/notears_nonlinear index 9ba035f..2a00e81 100644 --- a/bin/notears_nonlinear +++ b/bin/notears_nonlinear @@ -1,10 +1,14 @@ #!/usr/bin/env python3 from notears import nonlinear, utils +import torch import numpy as np import argparse def main(args): + torch.set_default_dtype(torch.double) + np.set_printoptions(precision=3) + X = np.loadtxt(args.X_path, delimiter=',') n, d = X.shape model = nonlinear.NotearsMLP(dims=[d, args.hidden, 1], bias=True) diff --git a/notears/__init__.py b/notears/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/lbfgsb_scipy.py b/notears/lbfgsb_scipy.py similarity index 95% rename from src/lbfgsb_scipy.py rename to notears/lbfgsb_scipy.py index cc43d5f..eb3d9b4 100644 --- a/src/lbfgsb_scipy.py +++ b/notears/lbfgsb_scipy.py @@ -3,7 +3,11 @@ class LBFGSBScipy(torch.optim.Optimizer): - """Wrap L-BFGS-B algorithm, using scipy routines.""" + """Wrap L-BFGS-B algorithm, using scipy routines. + + Courtesy: Arthur Mensch's gist + https://gist.github.com/arthurmensch/c55ac413868550f89225a0b9212aa4cd + """ def __init__(self, params): defaults = dict() @@ -81,7 +85,7 @@ def wrapped_closure(flat_params): bounds = self._gather_flat_bounds() - # How can it work without getting the final solution..? + # Magic sol = sopt.minimize(wrapped_closure, initial_params, method='L-BFGS-B', diff --git a/src/linear.py b/notears/linear.py similarity index 99% rename from src/linear.py rename to notears/linear.py index e1a38f3..6ccaea7 100644 --- a/src/linear.py +++ b/notears/linear.py @@ -84,7 +84,7 @@ def _func(w): if __name__ == '__main__': - import src.utils as ut + import notears.utils as ut ut.set_random_seed(1) n, d, s0, graph_type, sem_type = 100, 20, 20, 'ER', 'gauss' diff --git a/src/locally_connected.py b/notears/locally_connected.py similarity index 100% rename from src/locally_connected.py rename to notears/locally_connected.py diff --git a/src/nonlinear.py b/notears/nonlinear.py similarity index 98% rename from src/nonlinear.py rename to notears/nonlinear.py index 5fcbf06..188b439 100644 --- a/src/nonlinear.py +++ b/notears/nonlinear.py @@ -1,5 +1,5 @@ -from src.locally_connected import LocallyConnected -from src.lbfgsb_scipy import LBFGSBScipy +from notears.locally_connected import LocallyConnected +from notears.lbfgsb_scipy import LBFGSBScipy import torch import torch.nn as nn import numpy as np @@ -211,9 +211,8 @@ def main(): torch.set_default_dtype(torch.double) np.set_printoptions(precision=3) - import src.utils as ut + import notears.utils as ut ut.set_random_seed(123) - torch.manual_seed(123) n, d, s0, graph_type, sem_type = 200, 5, 9, 'ER', 'mim' B_true = ut.simulate_dag(d, s0, graph_type) diff --git a/src/utils.py b/notears/utils.py similarity index 100% rename from src/utils.py rename to notears/utils.py diff --git a/setup.py b/setup.py index 9909ce0..aef3628 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,6 @@ scripts=['bin/notears_linear', 'bin/notears_nonlinear'], packages=['notears'], - package_dir={'notears': 'src'}, + package_dir={'notears': 'notears'}, install_requires=['numpy', 'scipy', 'python-igraph'], )