-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
1,049 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
pip-wheel-metadata/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# DictionaryLookup Benchmark | ||
|
||
This repository can be used to reproduce the experiments of | ||
Section 4.1 in the paper, for the "DictionaryLookup" problem. | ||
|
||
|
||
# The DictionaryLookup problem | ||
![alt text](./images/fig2.png "Figure 2 from the paper") | ||
|
||
## Requirements | ||
|
||
### Dependencies | ||
This project is based on PyTorch 1.7.1 and the [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/) library. | ||
First, install PyTorch from the official website: [https://pytorch.org/](https://pytorch.org/). | ||
PyTorch Geometric requires manual installation, and we thus recommend to use the instructions in [https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html). | ||
|
||
The `requirements.txt` lists the additional requirements. | ||
|
||
Eventually, run the following to verify that all dependencies are satisfied: | ||
```setup | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Reproducing Experiments | ||
|
||
To run a single experiment from the paper, run: | ||
|
||
``` | ||
python main.py --help | ||
``` | ||
And see the available flags. | ||
For example, to train a GATv2 with size=10 and num_heads=1, run: | ||
``` | ||
python main.py --task DICTIONARY --size 10 --num_heads 1 --type GAT2 --eval_every 10 | ||
``` | ||
|
||
Alternatively, to train a GAT with size=10 and num_heads=8, run: | ||
``` | ||
python main.py --task DICTIONARY --size 10 --num_heads 8 --type GAT --eval_every 10 | ||
``` | ||
|
||
## Experiment with other GNN types | ||
To experiment with other GNN types: | ||
* Add the new GNN type to the `GNN_TYPE` enum [here](common.py#L30), for example: `MY_NEW_TYPE = auto()` | ||
* Add another `elif self is GNN_TYPE.MY_NEW_TYPE:` to instantiate the new GNN type object [here](common.py#L57) | ||
* Use the new type as a flag for the `main.py` file: | ||
``` | ||
python main.py --type MY_NEW_TYPE ... | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
from enum import Enum, auto | ||
|
||
from tasks.dictionary_lookup import DictionaryLookupDataset | ||
from gnns.gat2 import GAT2Conv | ||
|
||
from torch import nn | ||
from torch_geometric.nn import GCNConv, GINConv, GATConv | ||
|
||
class Task(Enum): | ||
DICTIONARY = auto() | ||
|
||
@staticmethod | ||
def from_string(s): | ||
try: | ||
return Task[s] | ||
except KeyError: | ||
raise ValueError() | ||
|
||
def get_dataset(self, size, train_fraction, unseen_combs): | ||
if self is Task.DICTIONARY: | ||
dataset = DictionaryLookupDataset(size) | ||
else: | ||
dataset = None | ||
|
||
return dataset.generate_data(train_fraction, unseen_combs) | ||
|
||
|
||
class GNN_TYPE(Enum): | ||
GCN = auto() | ||
GIN = auto() | ||
GAT = auto() | ||
GATv2 = auto() | ||
|
||
@staticmethod | ||
def from_string(s): | ||
try: | ||
return GNN_TYPE[s] | ||
except KeyError: | ||
raise ValueError() | ||
|
||
def get_layer(self, in_dim, out_dim, num_heads): | ||
if self is GNN_TYPE.GCN: | ||
return GCNConv( | ||
in_channels=in_dim, | ||
out_channels=out_dim) | ||
elif self is GNN_TYPE.GIN: | ||
return GINConv(nn.Sequential(nn.Linear(in_dim, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU(), | ||
nn.Linear(out_dim, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU())) | ||
elif self is GNN_TYPE.GAT: | ||
# The output will be the concatenation of the heads, yielding a vector of size out_dim | ||
return GATConv(in_dim, out_dim // num_heads, heads=num_heads, add_self_loops=False) | ||
elif self is GNN_TYPE.GATv2: | ||
return GAT2Conv(in_dim, out_dim // num_heads, heads=num_heads, bias=False, share_weights=True, add_self_loops=False) | ||
|
||
|
||
|
||
class StoppingCriterion(object): | ||
def __init__(self, stop): | ||
self.stop = stop | ||
self.best_train_loss = -float('inf') | ||
self.best_train_node_acc = 0 | ||
self.best_train_graph_acc = 0 | ||
self.best_test_node_acc = 0 | ||
self.best_test_graph_acc = 0 | ||
self.best_epoch = 0 | ||
|
||
self.name = stop.name | ||
|
||
def new_best_str(self): | ||
return f' (new best {self.name})' | ||
|
||
def is_met(self, train_loss, train_node_acc, train_graph_acc, test_node_acc, test_graph_acc, stopping_threshold): | ||
if self.stop is STOP.TRAIN_NODE: | ||
new_value = train_node_acc | ||
old_value = self.best_train_node_acc | ||
elif self.stop is STOP.TRAIN_GRAPH: | ||
new_value = train_graph_acc | ||
old_value = self.best_train_graph_acc | ||
elif self.stop is STOP.TEST_NODE: | ||
new_value = test_node_acc | ||
old_value = self.best_test_node_acc | ||
elif self.stop is STOP.TEST_GRAPH: | ||
new_value = test_graph_acc | ||
old_value = self.best_test_graph_acc | ||
elif self.stop is STOP.TRAIN_LOSS: | ||
new_value = -train_loss | ||
old_value = self.best_train_loss | ||
else: | ||
raise ValueError | ||
|
||
return new_value > (old_value + stopping_threshold), new_value | ||
|
||
def __repr__(self): | ||
return str(self.stop) | ||
|
||
def update_best(self, train_node_acc, train_graph_acc, test_node_acc, test_graph_acc, epoch): | ||
self.best_train_node_acc = train_node_acc | ||
self.best_train_graph_acc = train_graph_acc | ||
self.best_test_node_acc = test_node_acc | ||
self.best_test_graph_acc = test_graph_acc | ||
self.best_epoch = epoch | ||
|
||
def print_best(self): | ||
print(f'Best epoch: {self.best_epoch}') | ||
print(f'Best train node acc: {self.best_train_node_acc}') | ||
print(f'Best train graph acc: {self.best_train_graph_acc}') | ||
print(f'Best test node acc: {self.best_test_node_acc}') | ||
print(f'Best test graph acc: {self.best_test_graph_acc}') | ||
|
||
|
||
class STOP(Enum): | ||
TRAIN_NODE = auto() | ||
TRAIN_GRAPH = auto() | ||
TEST_NODE = auto() | ||
TEST_GRAPH = auto() | ||
TRAIN_LOSS = auto() | ||
|
||
@staticmethod | ||
def from_string(s): | ||
try: | ||
return STOP[s] | ||
except KeyError: | ||
raise ValueError() |
Oops, something went wrong.