This is an implementation of the following paper:
[1] Bello K., Aragam B., Ravikumar P. (2022). DAGMA: Learning DAGs via M-matrices and a Log-Determinant Acyclicity Characterization. NeurIPS'22.
If you find this code useful, please consider citing:
@inproceedings{bello2022dagma,
author = {Bello, Kevin and Aragam, Bryon and Ravikumar, Pradeep},
booktitle = {Advances in Neural Information Processing Systems},
title = {{DAGMA: Learning DAGs via M-matrices and a Log-Determinant Acyclicity Characterization}},
year = {2022}
}
We propose a new acyclicity characterization of DAGs via a log-det function for learning DAGs from observational data. Similar to previously proposed acyclicity functions (e.g. NOTEARS), our characterization is also exact and differentiable. However, when compared to existing characterizations, our log-det function: (1) Is better at detecting large cycles; (2) Has better-behaved gradients; and (3) Its runtime is in practice about an order of magnitude faster. These advantages of our log-det formulation, together with a path-following scheme, lead to significant improvements in structure accuracy (e.g. SHD).
Let
where
Given the exact differentiable characterization of a DAG, we are interested in solving the following optimization problem:
where
where
Let us give an illustration of how DAGMA works in a two-node graph (see Figure 1 in [1] for more details). Here
Below we have 4 plots, where each illustrates the solution to an unconstrained problem for different values of
- Python 3.6+
numpy
scipy
python-igraph
torch
: Only used for nonlinear models.
dagma_linear.py
- implementation of DAGMA for linear models with l1 regularization (supports L2 and Logistic losses).dagma_nonlinear.py
- implementation of DAGMA for nonlinear models using MLPlocally_connected.py
- special layer structure used for MLPutils.py
- graph simulation, data simulation, and accuracy evaluation
Use requirements.txt
to install the dependencies (recommended to use virtualenv or conda).
The simplest way to try out DAGMA is to run a simple example:
$ git clone https://github.com/kevinsbello/dagma.git
$ cd dagma/
$ pip3 install -r requirements.txt
$ python3 dagma_linear.py
The above runs the L1-regularized DAGMA on a randomly generated 20-node Erdos-Renyi graph with 500 samples. Within a few seconds, you should see an output like this:
{'fdr': 0.0, 'tpr': 1.0, 'fpr': 0.0, 'shd': 0, 'nnz': 20}
The data, ground truth graph, and the estimate will be stored in X.csv
, W_true.csv
, and W_est.csv
.
We thank the authors of the NOTEARS repo for making their code available. Part of our code is based on their implementation, specially the utils.py
file and some code from their implementation of nonlinear models.