Skip to content

Commit

Permalink
dictionary lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
urialon committed Dec 14, 2021
1 parent fc772ad commit 54ffbac
Show file tree
Hide file tree
Showing 13 changed files with 1,049 additions and 0 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ and also in [this repository](gatv2_conv_DGL.py).

[https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/api_docs/python/gnn/keras/layers/GATv2.md](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/api_docs/python/gnn/keras/layers/GATv2.md)

# DictionaryLookup

The code for reproducing the DictionaryLookup experiments can be found in the [dictionary_lookup](dictionary_lookup/README.md) directory.


The rest of the code for reproducing the experiments in the paper will be made publicly available.

# Citation
Expand Down
130 changes: 130 additions & 0 deletions dictionary_lookup/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/
.idea
50 changes: 50 additions & 0 deletions dictionary_lookup/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# DictionaryLookup Benchmark

This repository can be used to reproduce the experiments of
Section 4.1 in the paper, for the "DictionaryLookup" problem.


# The DictionaryLookup problem
![alt text](./images/fig2.png "Figure 2 from the paper")

## Requirements

### Dependencies
This project is based on PyTorch 1.7.1 and the [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/) library.
First, install PyTorch from the official website: [https://pytorch.org/](https://pytorch.org/).
PyTorch Geometric requires manual installation, and we thus recommend to use the instructions in [https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html).

The `requirements.txt` lists the additional requirements.

Eventually, run the following to verify that all dependencies are satisfied:
```setup
pip install -r requirements.txt
```

## Reproducing Experiments

To run a single experiment from the paper, run:

```
python main.py --help
```
And see the available flags.
For example, to train a GATv2 with size=10 and num_heads=1, run:
```
python main.py --task DICTIONARY --size 10 --num_heads 1 --type GAT2 --eval_every 10
```

Alternatively, to train a GAT with size=10 and num_heads=8, run:
```
python main.py --task DICTIONARY --size 10 --num_heads 8 --type GAT --eval_every 10
```

## Experiment with other GNN types
To experiment with other GNN types:
* Add the new GNN type to the `GNN_TYPE` enum [here](common.py#L30), for example: `MY_NEW_TYPE = auto()`
* Add another `elif self is GNN_TYPE.MY_NEW_TYPE:` to instantiate the new GNN type object [here](common.py#L57)
* Use the new type as a flag for the `main.py` file:
```
python main.py --type MY_NEW_TYPE ...
```

123 changes: 123 additions & 0 deletions dictionary_lookup/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from enum import Enum, auto

from tasks.dictionary_lookup import DictionaryLookupDataset
from gnns.gat2 import GAT2Conv

from torch import nn
from torch_geometric.nn import GCNConv, GINConv, GATConv

class Task(Enum):
DICTIONARY = auto()

@staticmethod
def from_string(s):
try:
return Task[s]
except KeyError:
raise ValueError()

def get_dataset(self, size, train_fraction, unseen_combs):
if self is Task.DICTIONARY:
dataset = DictionaryLookupDataset(size)
else:
dataset = None

return dataset.generate_data(train_fraction, unseen_combs)


class GNN_TYPE(Enum):
GCN = auto()
GIN = auto()
GAT = auto()
GATv2 = auto()

@staticmethod
def from_string(s):
try:
return GNN_TYPE[s]
except KeyError:
raise ValueError()

def get_layer(self, in_dim, out_dim, num_heads):
if self is GNN_TYPE.GCN:
return GCNConv(
in_channels=in_dim,
out_channels=out_dim)
elif self is GNN_TYPE.GIN:
return GINConv(nn.Sequential(nn.Linear(in_dim, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU(),
nn.Linear(out_dim, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU()))
elif self is GNN_TYPE.GAT:
# The output will be the concatenation of the heads, yielding a vector of size out_dim
return GATConv(in_dim, out_dim // num_heads, heads=num_heads, add_self_loops=False)
elif self is GNN_TYPE.GATv2:
return GAT2Conv(in_dim, out_dim // num_heads, heads=num_heads, bias=False, share_weights=True, add_self_loops=False)



class StoppingCriterion(object):
def __init__(self, stop):
self.stop = stop
self.best_train_loss = -float('inf')
self.best_train_node_acc = 0
self.best_train_graph_acc = 0
self.best_test_node_acc = 0
self.best_test_graph_acc = 0
self.best_epoch = 0

self.name = stop.name

def new_best_str(self):
return f' (new best {self.name})'

def is_met(self, train_loss, train_node_acc, train_graph_acc, test_node_acc, test_graph_acc, stopping_threshold):
if self.stop is STOP.TRAIN_NODE:
new_value = train_node_acc
old_value = self.best_train_node_acc
elif self.stop is STOP.TRAIN_GRAPH:
new_value = train_graph_acc
old_value = self.best_train_graph_acc
elif self.stop is STOP.TEST_NODE:
new_value = test_node_acc
old_value = self.best_test_node_acc
elif self.stop is STOP.TEST_GRAPH:
new_value = test_graph_acc
old_value = self.best_test_graph_acc
elif self.stop is STOP.TRAIN_LOSS:
new_value = -train_loss
old_value = self.best_train_loss
else:
raise ValueError

return new_value > (old_value + stopping_threshold), new_value

def __repr__(self):
return str(self.stop)

def update_best(self, train_node_acc, train_graph_acc, test_node_acc, test_graph_acc, epoch):
self.best_train_node_acc = train_node_acc
self.best_train_graph_acc = train_graph_acc
self.best_test_node_acc = test_node_acc
self.best_test_graph_acc = test_graph_acc
self.best_epoch = epoch

def print_best(self):
print(f'Best epoch: {self.best_epoch}')
print(f'Best train node acc: {self.best_train_node_acc}')
print(f'Best train graph acc: {self.best_train_graph_acc}')
print(f'Best test node acc: {self.best_test_node_acc}')
print(f'Best test graph acc: {self.best_test_graph_acc}')


class STOP(Enum):
TRAIN_NODE = auto()
TRAIN_GRAPH = auto()
TEST_NODE = auto()
TEST_GRAPH = auto()
TRAIN_LOSS = auto()

@staticmethod
def from_string(s):
try:
return STOP[s]
except KeyError:
raise ValueError()
Loading

0 comments on commit 54ffbac

Please sign in to comment.