Skip to content

Implementation of E(n)-Transformer, which incorporates attention mechanisms into Welling's E(n)-Equivariant Graph Neural Network

License

Notifications You must be signed in to change notification settings

lucidrains/En-transformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

41 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

E(n)-Equivariant Transformer

Implementation of E(n)-Equivariant Transformer, which extends the ideas from Welling's E(n)-Equivariant Graph Neural Network with attention.

Install

$ pip install En-transformer

Usage

import torch
from en_transformer import EnTransformer

model = EnTransformer(
    dim = 512,
    depth = 4,                         # depth
    dim_head = 64,                     # dimension per head
    heads = 8,                         # number of heads
    edge_dim = 4,                      # dimension of edge features
    fourier_features = 2,              # num fourier features to append to relative distance which goes into the edge MLP
    num_nearest_neighbors = 64         # only do attention between coordinates N nearest neighbors - set to 0 to turn off
)

feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
edges = torch.randn(1, 1024, 1024, 4)

mask = torch.ones(1, 1024).bool()

feats, coors = model(feats, coors, edges, mask = mask)  # (1, 16, 512), (1, 16, 3)

If you would like to only attend to sparse neighbors, as defined by an adjacency matrix (say for atoms), you have to set one more flag and then pass in the N x N adjacency matrix.

import torch
from en_transformer import EnTransformer

model = EnTransformer(
    dim = 512,
    depth = 1,
    heads = 4,
    dim_head = 32,
    fourier_features = 2,
    num_nearest_neighbors = 0,
    only_sparse_neighbors = True
)

feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)

# naively assume a single chain of atoms
i = torch.arange(feats.shape[1])
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))

# adjacency matrix must be passed in
feats_out, coors_out = model(feats, coors, adj_mat = adj_mat) # (1, 16, 512), (1, 16, 3)

Example

To run a protein backbone coordinate denoising toy task, first install sidechainnet

$ pip install sidechainnet

Then

$ python denoise.py

Citations

@misc{satorras2021en,
    title 	= {E(n) Equivariant Graph Neural Networks}, 
    author 	= {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
    year 	= {2021},
    eprint 	= {2102.09844},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}

About

Implementation of E(n)-Transformer, which incorporates attention mechanisms into Welling's E(n)-Equivariant Graph Neural Network

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages