diff --git a/README.md b/README.md index 255e58e..70017f5 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,11 @@ and also in [this repository](gatv2_conv_DGL.py). [https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/api_docs/python/gnn/keras/layers/GATv2.md](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/api_docs/python/gnn/keras/layers/GATv2.md) +# DictionaryLookup + +The code for reproducing the DictionaryLookup experiments can be found in the [dictionary_lookup](dictionary_lookup/README.md) directory. + + The rest of the code for reproducing the experiments in the paper will be made publicly available. # Citation diff --git a/dictionary_lookup/.gitignore b/dictionary_lookup/.gitignore new file mode 100644 index 0000000..e17e6b7 --- /dev/null +++ b/dictionary_lookup/.gitignore @@ -0,0 +1,130 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ +.idea diff --git a/dictionary_lookup/README.md b/dictionary_lookup/README.md new file mode 100644 index 0000000..96d0dff --- /dev/null +++ b/dictionary_lookup/README.md @@ -0,0 +1,50 @@ +# DictionaryLookup Benchmark + +This repository can be used to reproduce the experiments of +Section 4.1 in the paper, for the "DictionaryLookup" problem. + + +# The DictionaryLookup problem +![alt text](./images/fig2.png "Figure 2 from the paper") + +## Requirements + +### Dependencies +This project is based on PyTorch 1.7.1 and the [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/) library. +First, install PyTorch from the official website: [https://pytorch.org/](https://pytorch.org/). +PyTorch Geometric requires manual installation, and we thus recommend to use the instructions in [https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html). + +The `requirements.txt` lists the additional requirements. + +Eventually, run the following to verify that all dependencies are satisfied: +```setup +pip install -r requirements.txt +``` + +## Reproducing Experiments + +To run a single experiment from the paper, run: + +``` +python main.py --help +``` +And see the available flags. +For example, to train a GATv2 with size=10 and num_heads=1, run: +``` +python main.py --task DICTIONARY --size 10 --num_heads 1 --type GAT2 --eval_every 10 +``` + +Alternatively, to train a GAT with size=10 and num_heads=8, run: +``` +python main.py --task DICTIONARY --size 10 --num_heads 8 --type GAT --eval_every 10 +``` + +## Experiment with other GNN types +To experiment with other GNN types: +* Add the new GNN type to the `GNN_TYPE` enum [here](common.py#L30), for example: `MY_NEW_TYPE = auto()` +* Add another `elif self is GNN_TYPE.MY_NEW_TYPE:` to instantiate the new GNN type object [here](common.py#L57) +* Use the new type as a flag for the `main.py` file: +``` +python main.py --type MY_NEW_TYPE ... +``` + diff --git a/dictionary_lookup/common.py b/dictionary_lookup/common.py new file mode 100644 index 0000000..8816614 --- /dev/null +++ b/dictionary_lookup/common.py @@ -0,0 +1,123 @@ +from enum import Enum, auto + +from tasks.dictionary_lookup import DictionaryLookupDataset +from gnns.gat2 import GAT2Conv + +from torch import nn +from torch_geometric.nn import GCNConv, GINConv, GATConv + +class Task(Enum): + DICTIONARY = auto() + + @staticmethod + def from_string(s): + try: + return Task[s] + except KeyError: + raise ValueError() + + def get_dataset(self, size, train_fraction, unseen_combs): + if self is Task.DICTIONARY: + dataset = DictionaryLookupDataset(size) + else: + dataset = None + + return dataset.generate_data(train_fraction, unseen_combs) + + +class GNN_TYPE(Enum): + GCN = auto() + GIN = auto() + GAT = auto() + GATv2 = auto() + + @staticmethod + def from_string(s): + try: + return GNN_TYPE[s] + except KeyError: + raise ValueError() + + def get_layer(self, in_dim, out_dim, num_heads): + if self is GNN_TYPE.GCN: + return GCNConv( + in_channels=in_dim, + out_channels=out_dim) + elif self is GNN_TYPE.GIN: + return GINConv(nn.Sequential(nn.Linear(in_dim, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU(), + nn.Linear(out_dim, out_dim), nn.BatchNorm1d(out_dim), nn.ReLU())) + elif self is GNN_TYPE.GAT: + # The output will be the concatenation of the heads, yielding a vector of size out_dim + return GATConv(in_dim, out_dim // num_heads, heads=num_heads, add_self_loops=False) + elif self is GNN_TYPE.GATv2: + return GAT2Conv(in_dim, out_dim // num_heads, heads=num_heads, bias=False, share_weights=True, add_self_loops=False) + + + +class StoppingCriterion(object): + def __init__(self, stop): + self.stop = stop + self.best_train_loss = -float('inf') + self.best_train_node_acc = 0 + self.best_train_graph_acc = 0 + self.best_test_node_acc = 0 + self.best_test_graph_acc = 0 + self.best_epoch = 0 + + self.name = stop.name + + def new_best_str(self): + return f' (new best {self.name})' + + def is_met(self, train_loss, train_node_acc, train_graph_acc, test_node_acc, test_graph_acc, stopping_threshold): + if self.stop is STOP.TRAIN_NODE: + new_value = train_node_acc + old_value = self.best_train_node_acc + elif self.stop is STOP.TRAIN_GRAPH: + new_value = train_graph_acc + old_value = self.best_train_graph_acc + elif self.stop is STOP.TEST_NODE: + new_value = test_node_acc + old_value = self.best_test_node_acc + elif self.stop is STOP.TEST_GRAPH: + new_value = test_graph_acc + old_value = self.best_test_graph_acc + elif self.stop is STOP.TRAIN_LOSS: + new_value = -train_loss + old_value = self.best_train_loss + else: + raise ValueError + + return new_value > (old_value + stopping_threshold), new_value + + def __repr__(self): + return str(self.stop) + + def update_best(self, train_node_acc, train_graph_acc, test_node_acc, test_graph_acc, epoch): + self.best_train_node_acc = train_node_acc + self.best_train_graph_acc = train_graph_acc + self.best_test_node_acc = test_node_acc + self.best_test_graph_acc = test_graph_acc + self.best_epoch = epoch + + def print_best(self): + print(f'Best epoch: {self.best_epoch}') + print(f'Best train node acc: {self.best_train_node_acc}') + print(f'Best train graph acc: {self.best_train_graph_acc}') + print(f'Best test node acc: {self.best_test_node_acc}') + print(f'Best test graph acc: {self.best_test_graph_acc}') + + +class STOP(Enum): + TRAIN_NODE = auto() + TRAIN_GRAPH = auto() + TEST_NODE = auto() + TEST_GRAPH = auto() + TRAIN_LOSS = auto() + + @staticmethod + def from_string(s): + try: + return STOP[s] + except KeyError: + raise ValueError() diff --git a/dictionary_lookup/experiment.py b/dictionary_lookup/experiment.py new file mode 100644 index 0000000..e09924a --- /dev/null +++ b/dictionary_lookup/experiment.py @@ -0,0 +1,280 @@ +import os +import torch +import torch_scatter + +from torch_geometric.data import DataLoader +from torch.optim.lr_scheduler import ReduceLROnPlateau +from pathlib import Path + +import numpy as np +import random +from attrdict import AttrDict +import seaborn as sb +import matplotlib.pyplot as plt +import pandas as pd + +from common import StoppingCriterion +from models.graph_model import GraphModel + +class Experiment(): + def __init__(self, args): + self.task = args.task + gnn_type = args.type + self.size = args.size + num_layers = args.num_layers + self.dim = args.dim + self.unroll = args.unroll + self.train_fraction = args.train_fraction + self.max_epochs = args.max_epochs + self.batch_size = args.batch_size + self.accum_grad = args.accum_grad + self.eval_every = args.eval_every + self.loader_workers = args.loader_workers + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.stopping_criterion = StoppingCriterion(args.stop) + self.patience = args.patience + self.save_path = args.save + self.args = args + + seed = 11 + torch.manual_seed(seed) + np.random.seed(seed) + random.seed(seed) + + self.X_train, self.X_test, dim0, out_dim, self.criterion = \ + self.task.get_dataset(self.size, self.train_fraction, unseen_combs=True) + + self.model = GraphModel(gnn_type=gnn_type, num_layers=num_layers, dim0=dim0, h_dim=self.dim, out_dim=out_dim, + unroll=args.unroll, + layer_norm=args.use_layer_norm, + use_activation=args.use_activation, + use_residual=args.use_residual, + num_heads=args.num_heads, + dropout=args.dropout, + ).to(self.device) + + print(f'Starting experiment') + self.print_args(args) + print(f'Training examples: {len(self.X_train)}, test examples: {len(self.X_test)}') + + def print_args(self, args): + if type(args) is AttrDict: + for key, value in args.items(): + print(f"{key}: {value}") + else: + for arg in vars(args): + print(f"{arg}: {getattr(args, arg)}") + print() + + def run(self): + optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001) + scheduler = ReduceLROnPlateau(optimizer, mode='max', threshold_mode='abs', factor=0.5, patience=10) + print('Starting training') + + epochs_no_improve = 0 + for epoch in range(1, (self.max_epochs // self.eval_every) + 1): + self.model.train() + loader = DataLoader(self.X_train * self.eval_every, batch_size=self.batch_size, shuffle=True, + pin_memory=True, num_workers=self.loader_workers) + + total_loss = 0 + total_num_nodes = 0 + train_per_node_correct = 0 + total_num_graphs = 0 + train_per_graph_correct = 0 + optimizer.zero_grad() + for i, batch in enumerate(loader): + batch = batch.to(self.device) + out, targets_batch = self.model(batch) + loss = self.criterion(input=out, target=batch.y) + total_num_graphs += batch.num_graphs + total_num_nodes += targets_batch.size(0) + total_loss += (loss.item() * targets_batch.size(0)) + _, train_per_node_pred = out.max(dim=1) + per_node_correct = train_per_node_pred.eq(batch.y) + train_per_node_correct += per_node_correct.sum().item() + train_per_graph_correct += torch_scatter.scatter_min( + index=targets_batch, src=per_node_correct.double())[0].sum().item() + + loss = loss / self.accum_grad + loss.backward() + if (i + 1) % self.accum_grad == 0: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + optimizer.step() + optimizer.zero_grad() + + avg_training_loss = total_loss / total_num_nodes + train_per_node_acc = train_per_node_correct / total_num_nodes + train_per_graph_acc = train_per_graph_correct / total_num_graphs + + test_node_acc, test_graph_acc = self.eval() + cur_lr = [g["lr"] for g in optimizer.param_groups] + + stopping_threshold = 0.0001 + should_stop, relevant_value = self.stopping_criterion.is_met(train_loss=avg_training_loss, + train_node_acc=train_per_node_acc, + train_graph_acc=train_per_graph_acc, + test_node_acc=test_node_acc, + test_graph_acc=test_graph_acc, + stopping_threshold=stopping_threshold) + if should_stop: + self.stopping_criterion.update_best(train_node_acc=train_per_node_acc, + train_graph_acc=train_per_graph_acc, + test_node_acc=test_node_acc, + test_graph_acc=test_graph_acc, + epoch=epoch * self.eval_every) + epochs_no_improve = 0 + new_best_str = self.stopping_criterion.new_best_str() + else: + epochs_no_improve += 1 + new_best_str = '' + + scheduler.step(relevant_value) + print( + f'Epoch {epoch * self.eval_every}, LR: {cur_lr}: Train loss: {avg_training_loss:.7f}, ' + f'Train-node acc: {train_per_node_acc:.4f}, ' + f'Train-graph acc: {train_per_graph_acc:.4f}, ' + f'Test-node acc: {test_node_acc:.4f}, ' + f'Test-graph acc: {test_graph_acc:.4f} {new_best_str}') + if relevant_value == 1.0: + break + if epochs_no_improve >= self.patience: + print( + f'{self.patience} * {self.eval_every} epochs without {self.stopping_criterion} improvement, stopping. ') + break + + self.stopping_criterion.print_best() + if self.save_path is not None: + self.save_model() + + return self.stopping_criterion + + def eval(self): + self.model.eval() + with torch.no_grad(): + loader = DataLoader(self.X_test, batch_size=self.batch_size, shuffle=False, + pin_memory=True, num_workers=self.loader_workers) + + total_num_nodes = 0 + total_per_node_correct = 0 + total_num_graphs = 0 + total_per_graph_correct = 0 + + for batch in loader: + batch = batch.to(self.device) + out, targets_batch = self.model(batch) + _, pred = out.max(dim=1) + + total_num_nodes += targets_batch.size(0) + total_num_graphs += batch.num_graphs + + total_per_node_correct += pred.eq(batch.y).sum().item() + total_per_graph_correct += torch_scatter.scatter_min( + index=targets_batch, src=pred.eq(batch.y).double())[0].sum().item() + + per_node_acc = total_per_node_correct / total_num_nodes + per_graph_acc = total_per_graph_correct / total_num_graphs + return per_node_acc, per_graph_acc + + def save_model(self): + print(f'Saving model to: {self.save_path}') + p = Path(self.save_path) + os.makedirs(p.parent, exist_ok=True) + torch.save({'model': self.model.state_dict(), + 'args': self.args}, self.save_path) + print(f'Saved model') + + @staticmethod + def complete_missing_args(args, default_args): + for key, value in default_args.items(): + if key not in args: + print(f"Missing key '{key}' in saved model, setting value: {value}") + setattr(args, key, value) + + return args + + @staticmethod + def load_model(path, default_args=None): + print(f'Loading model from: {path}') + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + checkpoint = torch.load(path, map_location=device) + saved_args = checkpoint['args'] + if default_args is not None: + saved_args = Experiment.complete_missing_args(saved_args, default_args) + exp = Experiment(saved_args) + exp.device = device + exp.model.load_state_dict(checkpoint['model']) + exp.model.to(device) + print(f'Loaded model') + return exp + + def plot_figures(self, example_num): + example = self.X_test[example_num].to(self.device) + example.batch = [0 for _ in range(example.num_nodes)] + + pred, attention_per_edge, edge_index = self.model.attention_per_edge(example) + per_node_acc = pred.eq(example.y).detach().cpu().numpy() + print(f'Node accuracy: {per_node_acc} ({np.mean(per_node_acc):.2f})') + + attention_per_key_query = self.get_attention_per_key_and_query(attention_per_edge, edge_index) + key_labels = [f'$k{i}$' for i in range(self.size)] + + if self.model.layers[0].add_self_loops and self.args.include_self: + key_labels += ['self'] + query_labels = [f'$q{i}$' for i in range(self.size)] + self.plot_heatmap_and_line(attention_per_key_query, + xticklabels=key_labels, + yticklabels=query_labels) + def plot_heatmap_and_line(self, data, xticklabels='auto', yticklabels='auto'): + if data.shape[0] > 1: + data = np.expand_dims(data[1], axis=0) + print('gat = np.array([') + for row in data[0]: + rowstr = ', '.join([f'{x:.2f}' for x in row]) + print(f' [{rowstr}],') + print(']') + + size = 3 + tik_label_size = 15 + fig, axes = plt.subplots(2, 1, figsize=(size+2, (size+1) * 2)) + plt.yticks(rotation=0) + for i in range(data.shape[0]): + cur_ax = axes[0] + if data.shape[0] > 1: + axes[i].set_title(f'Head #{i}') + cur_ax = axes[i] + ax = sb.heatmap(data[i], annot=True, fmt='.2f', cbar=False, + xticklabels=xticklabels, yticklabels=yticklabels, ax=cur_ax) + ax.xaxis.tick_top() + ax.tick_params(labelsize=tik_label_size) + ax.set_yticklabels(ax.get_yticklabels(), rotation=0) + + cur_ax = axes[1] + d = {q: row for q, row in zip(yticklabels, data[i])} + df = pd.DataFrame(d, index=xticklabels) + ax = sb.lineplot(data=df, ax=cur_ax) + ax.tick_params(labelsize=tik_label_size) + + + plt.setp(ax.get_legend().get_texts(), fontsize=tik_label_size) + # plt.subplots_adjust(hspace=0.08, top=0.96, bottom=0.04) + + plt.legend(bbox_to_anchor=(0.99, 0.45), loc='right', prop={'size': tik_label_size, }, labelspacing=0.1, + borderaxespad=0.,) + + # plt.savefig('gatv2.pdf') + plt.show() + + def get_attention_per_key_and_query(self, attention_per_edge, edge_index): + result_y_size = self.size + 1 if (self.model.layers[0].add_self_loops and self.args.include_self) else self.size + if len(attention_per_edge.shape) == 1: + attention_per_edge = attention_per_edge.expand_dims(axis=-1) + result = np.zeros(shape=(attention_per_edge.shape[-1], self.size, result_y_size)) + for head_idx in range(attention_per_edge.shape[-1]): + for src, tgt, score in zip(edge_index[0], edge_index[1], attention_per_edge[:,head_idx]): + if tgt >= self.size: + continue + if src < self.size: + src = self.size * 2 + result[head_idx, tgt, src - self.size] = score + return result diff --git a/dictionary_lookup/gnns/gat2.py b/dictionary_lookup/gnns/gat2.py new file mode 100644 index 0000000..d44eca0 --- /dev/null +++ b/dictionary_lookup/gnns/gat2.py @@ -0,0 +1,165 @@ +from typing import Union, Tuple, Optional +from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType, + OptTensor) + +import torch +from torch import Tensor +import torch.nn.functional as F +from torch.nn import Parameter, Linear +from torch_sparse import SparseTensor, set_diag +from torch_geometric.nn.conv import MessagePassing +from torch_geometric.utils import remove_self_loops, add_self_loops, softmax + +from torch_geometric.nn.inits import glorot, zeros + + +class GAT2Conv(MessagePassing): + r""" + Args: + in_channels (int or tuple): Size of each input sample. A tuple + corresponds to the sizes of source and target dimensionalities. + out_channels (int): Size of each output sample. + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + concat (bool, optional): If set to :obj:`False`, the multi-head + attentions are averaged instead of concatenated. + (default: :obj:`True`) + negative_slope (float, optional): LeakyReLU angle of the negative + slope. (default: :obj:`0.2`) + dropout (float, optional): Dropout probability of the normalized + attention coefficients which exposes each node to a stochastically + sampled neighborhood during training. (default: :obj:`0`) + add_self_loops (bool, optional): If set to :obj:`False`, will not add + self-loops to the input graph. (default: :obj:`True`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + share_weights (bool, optional): If set to :obj:`True`, the same linear layer will be applied to both nodes + of every edge. + (default: :obj:`False`) + **kwargs (optional): Additional arguments of + :class:`torch_geometric.nn.conv.MessagePassing`. + """ + _alpha: OptTensor + + def __init__(self, in_channels: Union[int, Tuple[int, int]], + out_channels: int, heads: int = 1, concat: bool = True, + negative_slope: float = 0.2, dropout: float = 0., + add_self_loops: bool = True, bias: bool = True, + share_weights: bool = False, + **kwargs): + kwargs.setdefault('aggr', 'add') + super(GAT2Conv, self).__init__(node_dim=0, **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.concat = concat + self.negative_slope = negative_slope + self.dropout = dropout + self.add_self_loops = add_self_loops + + if isinstance(in_channels, int): + self.lin_l = Linear(in_channels, heads * out_channels, bias=bias) + if share_weights: + self.lin_r = self.lin_l + else: + self.lin_r = Linear(in_channels, heads * out_channels, bias=bias) + else: + self.lin_l = Linear(in_channels[0], heads * out_channels, bias=bias) + self.lin_r = Linear(in_channels[1], heads * out_channels, bias=bias) + + self.att = Parameter(torch.Tensor(1, heads, out_channels)) + + if bias and concat: + self.bias = Parameter(torch.Tensor(heads * out_channels)) + elif bias and not concat: + self.bias = Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + + self._alpha = None + + self.reset_parameters() + + def reset_parameters(self): + glorot(self.lin_l.weight) + glorot(self.lin_r.weight) + glorot(self.att) + zeros(self.bias) + + def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, + size: Size = None, return_attention_weights=None): + r""" + + Args: + return_attention_weights (bool, optional): If set to :obj:`True`, + will additionally return the tuple + :obj:`(edge_index, attention_weights)`, holding the computed + attention weights for each edge. (default: :obj:`None`) + """ + H, C = self.heads, self.out_channels + + x_l: OptTensor = None + x_r: OptTensor = None + if isinstance(x, Tensor): + assert x.dim() == 2, 'Static graphs not supported in `GATConv`.' + x_l = x_r = self.lin_l(x).view(-1, H, C) + else: + x_l, x_r = x[0], x[1] + assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.' + x_l = self.lin_l(x_l).view(-1, H, C) + if x_r is not None: + x_r = self.lin_r(x_r).view(-1, H, C) + + assert x_l is not None + + if self.add_self_loops: + if isinstance(edge_index, Tensor): + num_nodes = x_l.size(0) + if x_r is not None: + num_nodes = min(num_nodes, x_r.size(0)) + if size is not None: + num_nodes = min(size[0], size[1]) + edge_index, _ = remove_self_loops(edge_index) + edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) + elif isinstance(edge_index, SparseTensor): + edge_index = set_diag(edge_index) + + # propagate_type: (x: OptPairTensor, alpha: OptPairTensor) + out = self.propagate(edge_index, x=(x_l, x_r), size=size) + + alpha = self._alpha + self._alpha = None + + if self.concat: + out = out.view(-1, self.heads * self.out_channels) + else: + out = out.mean(dim=1) + + if self.bias is not None: + out += self.bias + + if isinstance(return_attention_weights, bool): + assert alpha is not None + if isinstance(edge_index, Tensor): + return out, (edge_index, alpha) + elif isinstance(edge_index, SparseTensor): + return out, edge_index.set_value(alpha, layout='coo') + else: + return out + + def message(self, x_j: Tensor, x_i: Tensor, + index: Tensor, ptr: OptTensor, + size_i: Optional[int]) -> Tensor: + x = x_i + x_j + x = F.leaky_relu(x, self.negative_slope) + alpha = (x * self.att).sum(dim=-1) + alpha = softmax(alpha, index, ptr, size_i) + self._alpha = alpha + alpha = F.dropout(alpha, p=self.dropout, training=self.training) + return x_j * alpha.unsqueeze(-1) + + def __repr__(self): + return '{}({}, {}, heads={})'.format(self.__class__.__name__, + self.in_channels, + self.out_channels, self.heads) diff --git a/dictionary_lookup/images/fig2.png b/dictionary_lookup/images/fig2.png new file mode 100644 index 0000000..3935fd9 Binary files /dev/null and b/dictionary_lookup/images/fig2.png differ diff --git a/dictionary_lookup/main.py b/dictionary_lookup/main.py new file mode 100644 index 0000000..8543d28 --- /dev/null +++ b/dictionary_lookup/main.py @@ -0,0 +1,101 @@ +from argparse import ArgumentParser +from attrdict import AttrDict + +from experiment import Experiment +from common import Task, GNN_TYPE, STOP + +def get_fake_args( + task=Task.DICTIONARY, + type=GNN_TYPE.GAT, + dim=128, + num_heads=1, + size=10, + num_layers=1, + dropout=0.0, + train_fraction=0.8, + unseen_combs=True, + max_epochs=50000, + eval_every=100, + batch_size=1024, + accum_grad=1, + patience=20, + stop=STOP.TEST_NODE, + loader_workers=0, + use_layer_norm=False, + use_activation=False, + use_residual=False, + include_self=False, + unroll=False, + save=None, + load=None, +): + return AttrDict({ + 'task': task, + 'type': type, + 'dim': dim, + 'num_heads': num_heads, + 'size': size, + 'num_layers': num_layers, + 'dropout': dropout, + 'train_fraction': train_fraction, + 'unseen_combs': unseen_combs, + 'max_epochs': max_epochs, + 'eval_every': eval_every, + 'batch_size': batch_size, + 'accum_grad': accum_grad, + 'stop': stop, + 'patience': patience, + 'loader_workers': loader_workers, + 'use_layer_norm': use_layer_norm, + 'use_activation': use_activation, + 'use_residual': use_residual, + 'include_self': include_self, + 'unroll': unroll, + 'save': save, + 'load': load + }) + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument("--task", dest="task", default=Task.DICTIONARY, type=Task.from_string, choices=list(Task), + required=False) + parser.add_argument("--type", dest="type", default=GNN_TYPE.GAT, type=GNN_TYPE.from_string, choices=list(GNN_TYPE), + required=False) + parser.add_argument("--dim", dest="dim", default=128, type=int, required=False) + parser.add_argument("--num_heads", dest="num_heads", default=1, type=int, required=False) + parser.add_argument("--size", dest="size", default=10, type=int, required=False) + parser.add_argument("--num_layers", dest="num_layers", default=1, type=int, required=False) + parser.add_argument("--dropout", dest="dropout", default=0.0, type=float, required=False) + parser.add_argument("--train_fraction", dest="train_fraction", default=0.8, type=float, required=False) + # parser.add_argument('--unseen_combs', action='store_true') + parser.add_argument("--max_epochs", dest="max_epochs", default=50000, type=int, required=False) + parser.add_argument("--eval_every", dest="eval_every", default=100, type=int, required=False) + parser.add_argument("--batch_size", dest="batch_size", default=1024, type=int, required=False) + parser.add_argument("--accum_grad", dest="accum_grad", default=1, type=int, required=False) + parser.add_argument("--stop", dest="stop", default=STOP.TEST_NODE, type=STOP.from_string, choices=list(STOP), + required=False) + parser.add_argument("--save", dest="save", type=str, required=False) + parser.add_argument("--load", dest="load", type=str, required=False) + parser.add_argument("--plot", dest="plot", default=None, type=int, required=False, help='plots the attention for a specific example') + parser.add_argument("--patience", dest="patience", default=20, type=int, required=False) + parser.add_argument("--loader_workers", dest="loader_workers", default=0, type=int, required=False) + parser.add_argument('--use_layer_norm', action='store_true') + parser.add_argument('--use_activation', action='store_true') + parser.add_argument('--use_residual', action='store_true') + parser.add_argument('--include_self', action='store_true') + parser.add_argument('--unroll', action='store_true', help='use the same weights across GNN layers') + + args = parser.parse_args() + if args.load is None: + Experiment(args).run() + else: + exp = Experiment.load_model(args.load, get_fake_args()) + test_node_acc, test_graph_acc = exp.eval() + print( + f'Test-node acc: {test_node_acc:.4f}, ' + f'Test-graph acc: {test_graph_acc:.4f}') + if args.plot is not None: + exp.plot_figures(args.plot) + + diff --git a/dictionary_lookup/models/graph_model.py b/dictionary_lookup/models/graph_model.py new file mode 100644 index 0000000..d616b8e --- /dev/null +++ b/dictionary_lookup/models/graph_model.py @@ -0,0 +1,83 @@ +import torch +from torch import nn +from torch.nn import functional as F +import torch_geometric + + +class GraphModel(torch.nn.Module): + def __init__(self, gnn_type, num_layers, dim0, h_dim, out_dim, + unroll, layer_norm, use_activation, use_residual, num_heads, dropout): + super(GraphModel, self).__init__() + self.gnn_type = gnn_type + self.unroll = unroll + self.use_layer_norm = layer_norm + self.use_activation = use_activation + self.use_residual = use_residual + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + self.num_layers = num_layers + self.layer0_keys = nn.Embedding(num_embeddings=dim0, embedding_dim=h_dim) + self.layer0_values = nn.Embedding(num_embeddings=dim0, embedding_dim=h_dim) + + self.layer0_ff = nn.Sequential(nn.ReLU()) + + self.layers = nn.ModuleList() + self.layer_norms = nn.ModuleList() + self.dropout = nn.Dropout(p=dropout) + if unroll: + self.layers.append(gnn_type.get_layer( + in_dim=h_dim, + out_dim=h_dim, num_heads=num_heads)) + else: + for i in range(num_layers): + self.layers.append(gnn_type.get_layer( + in_dim=h_dim, + out_dim=h_dim, num_heads=num_heads)) + if self.use_layer_norm: + for i in range(num_layers): + self.layer_norms.append(nn.LayerNorm(h_dim)) + + self.out_dim = out_dim + self.out_layer = nn.Linear(in_features=h_dim, out_features=out_dim, bias=False) + + def forward(self, data, return_attention_weights=None): + x, edge_index, batch, target_mask = data.x, data.edge_index, data.batch, data.target_mask + + x_key, x_val = x[:, 0], x[:, 1] + x_key_embed = self.layer0_keys(x_key) + x_val_embed = self.layer0_values(x_val) + x = x_key_embed + x_val_embed + x = self.layer0_ff(x) + + for i in range(self.num_layers): + if self.unroll: + layer = self.layers[0] + else: + layer = self.layers[i] + + new_x = layer(x, edge_index, return_attention_weights=return_attention_weights) + if return_attention_weights is True: + new_x, attention_weights = new_x + if self.use_activation: + new_x = F.relu(new_x) + if self.use_residual: + x = x + new_x + else: + x = new_x + if self.use_layer_norm: + x = self.layer_norms[i](x) + x = F.dropout(x) + + target_nodes = x[target_mask] + logits = self.out_layer(target_nodes) + if return_attention_weights is True: + return logits, attention_weights + else: + targets_batch = batch[target_mask] + return logits, targets_batch + + def attention_per_edge(self, example): + logits, (edge_index, alpha) = self.forward(example, return_attention_weights=True) + _, pred = logits.max(dim=1) + + return pred, alpha.cpu().detach().numpy(), edge_index.cpu().detach().numpy() diff --git a/dictionary_lookup/requirements.txt b/dictionary_lookup/requirements.txt new file mode 100644 index 0000000..904247d --- /dev/null +++ b/dictionary_lookup/requirements.txt @@ -0,0 +1,8 @@ +attrdict==2.0.1 +torch>=1.7.1 +torch-geometric>=1.7.0 +torch-scatter>=2.0.4 +torch-sparse>=0.6.0 +torchvision +sklearn +seaborn diff --git a/dictionary_lookup/saved_models/gat_128_s8.pt b/dictionary_lookup/saved_models/gat_128_s8.pt new file mode 100644 index 0000000..daf4ffa Binary files /dev/null and b/dictionary_lookup/saved_models/gat_128_s8.pt differ diff --git a/dictionary_lookup/saved_models/gatv2_128_s8.pt b/dictionary_lookup/saved_models/gatv2_128_s8.pt new file mode 100644 index 0000000..872b910 Binary files /dev/null and b/dictionary_lookup/saved_models/gatv2_128_s8.pt differ diff --git a/dictionary_lookup/tasks/dictionary_lookup.py b/dictionary_lookup/tasks/dictionary_lookup.py new file mode 100644 index 0000000..677f416 --- /dev/null +++ b/dictionary_lookup/tasks/dictionary_lookup.py @@ -0,0 +1,104 @@ +import random + +import numpy as np +import itertools +import math +import torch +import torch_geometric + +from torch.nn import functional as F +from torch_geometric.data import Data +from sklearn.model_selection import train_test_split + +import common + + +class DictionaryLookupDataset(object): + def __init__(self, size): + super().__init__() + self.size = size + self.edges, self.empty_id = self.init_edges() + self.criterion = F.cross_entropy + + + def init_edges(self): + targets = range(0, self.size) + sources = range(self.size, self.size * 2) + next_unused_id = self.size + all_pairs = itertools.product(sources, targets) + edges = [list(i) for i in zip(*all_pairs)] + + return edges, next_unused_id + + def create_empty_graph(self, add_self_loops=False): + edge_index = torch.tensor(self.edges, requires_grad=False, dtype=torch.long) + if add_self_loops: + edge_index, _ = torch_geometric.utils.add_remaining_self_loops(edge_index=edge_index, ) + return edge_index + + def get_combinations(self): + # returns: an iterable of [permutation(size)] + # number of combinations: size! + + max_examples = 32000 # starting to affect from size=8, because 8!==40320 + + if math.factorial(self.size) > max_examples: + permutations = [np.random.permutation(range(self.size)) for _ in range(max_examples)] + else: + permutations = itertools.permutations(range(self.size)) + + return permutations + + def generate_data(self, train_fraction, unseen_combs): + data_list = [] + + for perm in self.get_combinations(): + edge_index = self.create_empty_graph(add_self_loops=False) + nodes = torch.tensor(self.get_nodes_features(perm), dtype=torch.long, requires_grad=False) + target_mask = torch.tensor([True] * (self.size) + [False] * self.size, + dtype=torch.bool, requires_grad=False) + labels = torch.tensor(perm, dtype=torch.long, requires_grad=False) + data_list.append(Data(x=nodes, edge_index=edge_index, target_mask=target_mask, y=labels)) + + dim0, out_dim = self.get_dims() + if unseen_combs: + X_train, X_test = self.unseen_combs_train_test_split(data_list, train_fraction=train_fraction, shuffle=True) + else: + X_train, X_test = train_test_split(data_list, train_size=train_fraction, shuffle=True) + + return X_train, X_test, dim0, out_dim, self.criterion + + def get_nodes_features(self, perm): + # perm: a list of indices + + # The first row contains (key, empty_id) + # The second row contains (key, value) where the order of values is according to perm + nodes = [(key, self.empty_id) for key in range(self.size)] + + for key, val in zip(range(self.size), perm): + nodes.append((key, val)) + + return nodes + + def get_dims(self): + # get input and output dims + in_dim = self.size + 1 + out_dim = self.size + return in_dim, out_dim + + def unseen_combs_train_test_split(self, data_list, train_fraction, shuffle=True): + per_position_fraction = train_fraction ** (1 / self.size) + num_training_pairs = int(per_position_fraction * (self.size ** 2)) + allowed_positions = set(random.sample( + list(itertools.product(range(self.size), range(self.size))), num_training_pairs)) + train = [] + test = [] + for example in data_list: + if all([(i, label.item()) in allowed_positions for i, label in enumerate(example.y)]): + train.append(example) + else: + test.append(example) + if shuffle: + random.shuffle(train) + return train, test +