Skip to content

Commit

Permalink
allow for automatic derivation of N-th degree adjacent neighbors and …
Browse files Browse the repository at this point in the history
…optionally to embed the adjacency as part of the edge
  • Loading branch information
lucidrains committed Mar 28, 2021
1 parent 1228d12 commit 3fd51cc
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 8 deletions.
11 changes: 5 additions & 6 deletions denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ def cycle(loader, len_thres = 500):
heads = 4,
depth = 4,
norm_rel_coors = True,
num_nearest_neighbors = 0,
only_sparse_neighbors = True
only_sparse_neighbors = True,
num_adj_degrees = 3,
adj_dim = 4
)

data = scn.load(
Expand All @@ -52,10 +53,9 @@ def cycle(loader, len_thres = 500):
masks = masks.cuda().bool()

l = seqs.shape[1]
coords = rearrange(coords, 'b (l s) c -> b l s c', s = 14)

# keeping only the backbone coordinates
coords = rearrange(coords, 'b (l s) c -> b l s c', s=14)

# Keeping only the backbone coordinates
coords = coords[:, :, 0:3, :]
coords = rearrange(coords, 'b l s c -> b (l s) c')

Expand All @@ -66,7 +66,6 @@ def cycle(loader, len_thres = 500):

i = torch.arange(seq.shape[-1]).to(seq)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))
adj_mat = (adj_mat.float() @ adj_mat.float()) > 0 # get second degree neighbors as well

feats, denoised_coords = transformer(seq, noised_coords, mask = masks, adj_mat = adj_mat)

Expand Down
30 changes: 29 additions & 1 deletion en_transformer/en_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,17 +327,25 @@ def __init__(
fourier_features = 4,
num_nearest_neighbors = 0,
only_sparse_neighbors = False,
num_adj_degrees = None,
adj_dim = 0,
valid_neighbor_radius = float('inf'),
norm_rel_coors = False,
init_eps = 1e-3
):
super().__init__()
assert not (exists(num_adj_degrees) and num_adj_degrees < 1), 'make sure adjacent degrees is greater than 1'

self.token_emb = nn.Embedding(num_tokens, dim) if exists(num_tokens) else None

self.num_adj_degrees = num_adj_degrees
self.adj_emb = nn.Embedding(num_adj_degrees + 1, adj_dim) if exists(num_adj_degrees) and adj_dim > 0 else None
adj_dim = adj_dim if exists(num_adj_degrees) else 0

self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Residual(PreNorm(dim, EquivariantAttention(dim = dim, dim_head = dim_head, heads = heads, m_dim = m_dim, edge_dim = edge_dim, fourier_features = fourier_features, norm_rel_coors = norm_rel_coors, num_nearest_neighbors = num_nearest_neighbors, only_sparse_neighbors = only_sparse_neighbors, valid_neighbor_radius = valid_neighbor_radius, init_eps = init_eps))),
Residual(PreNorm(dim, EquivariantAttention(dim = dim, dim_head = dim_head, heads = heads, m_dim = m_dim, edge_dim = (edge_dim + adj_dim), fourier_features = fourier_features, norm_rel_coors = norm_rel_coors, num_nearest_neighbors = num_nearest_neighbors, only_sparse_neighbors = only_sparse_neighbors, valid_neighbor_radius = valid_neighbor_radius, init_eps = init_eps))),
Residual(PreNorm(dim, FeedForward(dim = dim)))
]))

Expand All @@ -351,9 +359,29 @@ def forward(
mask = None,
adj_mat = None
):
b = feats.shape[0]

if exists(self.token_emb):
feats = self.token_emb(feats)

if exists(self.num_adj_degrees):
if len(adj_mat.shape) == 2:
adj_mat = repeat(adj_mat.clone(), 'i j -> b i j', b = b)

adj_indices = adj_mat.clone().long()

for ind in range(self.num_adj_degrees - 1):
degree = ind + 2

next_degree_adj_mat = (adj_mat.float() @ adj_mat.float()) > 0
next_degree_mask = (next_degree_adj_mat.float() - adj_mat.float()).bool()
adj_indices.masked_fill_(next_degree_mask, degree)
adj_mat = next_degree_adj_mat.clone()

if exists(self.adj_emb):
adj_emb = self.adj_emb(adj_indices)
edges = torch.cat((edges, adj_emb), dim = -1) if exists(edges) else adj_emb

# main network

for attn, ff in self.layers:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'En-transformer',
packages = find_packages(),
version = '0.1.2',
version = '0.1.4',
license='MIT',
description = 'E(n)-Equivariant Transformer',
author = 'Phil Wang',
Expand Down

0 comments on commit 3fd51cc

Please sign in to comment.