Skip to content

Commit

Permalink
add checkpointing, to allow for greatly increased number of neighbors…
Browse files Browse the repository at this point in the history
… attended to
  • Loading branch information
lucidrains committed Aug 25, 2021
1 parent 9e76767 commit 8dbf7da
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 9 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ model = EnTransformer(
edge_dim = 4, # dimension of edge feature
neighbors = 64, # only do attention between coordinates N nearest neighbors - set to 0 to turn off
talking_heads = True, # use Shazeer's talking heads https://arxiv.org/abs/2003.02436
checkpoint = True, # use checkpointing so one can increase depth at little memory cost (and increase neighbors attended to)
use_cross_product = True # use cross product vectors (idea by @MattMcPartlon)
)

Expand Down
41 changes: 33 additions & 8 deletions en_transformer/en_transformer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.utils.checkpoint import checkpoint_sequential

from en_transformer.rotary import SinusoidalEmbeddings, apply_rotary_pos_emb

Expand Down Expand Up @@ -239,7 +240,7 @@ def forward(
nbhd_ranking = nbhd_ranking.masked_fill(self_mask, -1.)
nbhd_ranking = nbhd_ranking.masked_fill(adj_mat, 0.)

if num_nn > 0:
if 0 < num_nn < n:
# make sure padding does not end up becoming neighbors
if exists(mask):
ranking_mask = mask[:, :, None] * mask[:, None, :]
Expand Down Expand Up @@ -402,6 +403,18 @@ def forward(

# transformer

class Block(nn.Module):
def __init__(self, attn, ff):
super().__init__()
self.attn = attn
self.ff = ff

def forward(self, inp, coor_changes = None):
feats, coors, mask, edges, adj_mat = inp
feats, coors = self.attn(feats, coors, edges = edges, mask = mask, adj_mat = adj_mat)
feats, coors = self.ff(feats, coors)
return (feats, coors, mask, edges, adj_mat)

class EnTransformer(nn.Module):
def __init__(
self,
Expand All @@ -424,7 +437,8 @@ def __init__(
norm_rel_coors = True,
norm_coors_scale_init = 1.,
use_cross_product = False,
talking_heads = False
talking_heads = False,
checkpoint = False
):
super().__init__()
assert dim_head >= 32, 'your dimension per head should be greater than 32 for rotary embeddings to work well'
Expand All @@ -440,13 +454,14 @@ def __init__(
self.adj_emb = nn.Embedding(num_adj_degrees + 1, adj_dim) if exists(num_adj_degrees) and adj_dim > 0 else None
adj_dim = adj_dim if exists(num_adj_degrees) else 0

self.checkpoint = checkpoint
self.layers = nn.ModuleList([])

for ind in range(depth):
self.layers.append(nn.ModuleList([
self.layers.append(Block(
Residual(PreNorm(dim, EquivariantAttention(dim = dim, dim_head = dim_head, heads = heads, coors_hidden_dim = coors_hidden_dim, edge_dim = (edge_dim + adj_dim), neighbors = neighbors, only_sparse_neighbors = only_sparse_neighbors, valid_neighbor_radius = valid_neighbor_radius, init_eps = init_eps, rel_pos_emb = rel_pos_emb, norm_rel_coors = norm_rel_coors, norm_coors_scale_init = norm_coors_scale_init, use_cross_product = use_cross_product, talking_heads = talking_heads))),
Residual(PreNorm(dim, FeedForward(dim = dim)))
]))
))

def forward(
self,
Expand Down Expand Up @@ -489,15 +504,25 @@ def forward(
adj_emb = self.adj_emb(adj_indices)
edges = torch.cat((edges, adj_emb), dim = -1) if exists(edges) else adj_emb

assert not (return_coor_changes and self.training), 'you must be eval mode in order to return coordinates'

# go through layers

coor_changes = [coors]
inp = (feats, coors, mask, edges, adj_mat)

# if in training mode and checkpointing is designated, use checkpointing across blocks to save memory
if self.training and self.checkpoint:
inp = checkpoint_sequential(self.layers, 1, inp)
else:
# iterate through blocks
for layer in self.layers:
inp = layer(inp)
coor_changes.append(inp[1]) # append coordinates for visualization

for attn, ff in self.layers:
feats, coors = attn(feats, coors, edges = edges, mask = mask, adj_mat = adj_mat)
coor_changes.append(coors)
# return

feats, coors = ff(feats, coors)
feats, coors, *_ = inp

if return_coor_changes:
return feats, coors, coor_changes
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'En-transformer',
packages = find_packages(),
version = '0.3.9',
version = '0.4.0',
license='MIT',
description = 'E(n)-Equivariant Transformer',
author = 'Phil Wang',
Expand Down

0 comments on commit 8dbf7da

Please sign in to comment.