Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
xunzheng committed Mar 15, 2020
1 parent 3997ab4 commit d6df4c5
Show file tree
Hide file tree
Showing 8 changed files with 15 additions and 8 deletions.
4 changes: 4 additions & 0 deletions bin/notears_nonlinear
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
#!/usr/bin/env python3
from notears import nonlinear, utils
import torch
import numpy as np
import argparse


def main(args):
torch.set_default_dtype(torch.double)
np.set_printoptions(precision=3)

X = np.loadtxt(args.X_path, delimiter=',')
n, d = X.shape
model = nonlinear.NotearsMLP(dims=[d, args.hidden, 1], bias=True)
Expand Down
Empty file added notears/__init__.py
Empty file.
8 changes: 6 additions & 2 deletions src/lbfgsb_scipy.py → notears/lbfgsb_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@


class LBFGSBScipy(torch.optim.Optimizer):
"""Wrap L-BFGS-B algorithm, using scipy routines."""
"""Wrap L-BFGS-B algorithm, using scipy routines.
Courtesy: Arthur Mensch's gist
https://gist.github.com/arthurmensch/c55ac413868550f89225a0b9212aa4cd
"""

def __init__(self, params):
defaults = dict()
Expand Down Expand Up @@ -81,7 +85,7 @@ def wrapped_closure(flat_params):

bounds = self._gather_flat_bounds()

# How can it work without getting the final solution..?
# Magic
sol = sopt.minimize(wrapped_closure,
initial_params,
method='L-BFGS-B',
Expand Down
2 changes: 1 addition & 1 deletion src/linear.py → notears/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _func(w):


if __name__ == '__main__':
import src.utils as ut
import notears.utils as ut
ut.set_random_seed(1)

n, d, s0, graph_type, sem_type = 100, 20, 20, 'ER', 'gauss'
Expand Down
File renamed without changes.
7 changes: 3 additions & 4 deletions src/nonlinear.py → notears/nonlinear.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from src.locally_connected import LocallyConnected
from src.lbfgsb_scipy import LBFGSBScipy
from notears.locally_connected import LocallyConnected
from notears.lbfgsb_scipy import LBFGSBScipy
import torch
import torch.nn as nn
import numpy as np
Expand Down Expand Up @@ -211,9 +211,8 @@ def main():
torch.set_default_dtype(torch.double)
np.set_printoptions(precision=3)

import src.utils as ut
import notears.utils as ut
ut.set_random_seed(123)
torch.manual_seed(123)

n, d, s0, graph_type, sem_type = 200, 5, 9, 'ER', 'mim'
B_true = ut.simulate_dag(d, s0, graph_type)
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@
scripts=['bin/notears_linear',
'bin/notears_nonlinear'],
packages=['notears'],
package_dir={'notears': 'src'},
package_dir={'notears': 'notears'},
install_requires=['numpy', 'scipy', 'python-igraph'],
)

0 comments on commit d6df4c5

Please sign in to comment.