Self-Explaining Neural Networks
- python (>3.0)
- pytorch (=4.0)
- numpy
- matplotlib
- nltk (needed only for text applications)
- torchtext (needed only for text applications)
- shapely
- squarify
It's highly recommended that the following steps be done inside a virtual environment (e.g., via virtualenv
or anaconda
).
First install pytorch.
Then install remaining dependencies
pip3 install -r requirements.txt
Finally, install this package
git clone [email protected]:dmelis/SENN.git
cd SENN
pip3 install ./
To train models from scratch:
python scripts/main_mnist.py --train
To use pretrained models:
python scripts/main_mnist.py
- aggregators.py - defines the Aggregation functions
- conceptizers.py - defines the functions that encode inputs into concepts (h(x))
- parametrizers.oy - defines the functions that generate parameters from inputs (theta(x))
- trainers.py - objectives, losses and training utilities