Skip to content

Commit

Permalink
simplified E(n)-transformers, where relative distance and positions a…
Browse files Browse the repository at this point in the history
…re injected as rotation information
  • Loading branch information
lucidrains committed May 15, 2021
1 parent 3c68fe6 commit a64a3ca
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 140 deletions.
23 changes: 10 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@ from en_transformer import EnTransformer

model = EnTransformer(
dim = 512,
depth = 4, # depth
dim_head = 64, # dimension per head
heads = 8, # number of heads
edge_dim = 4, # dimension of edge features
fourier_features = 2, # num fourier features to append to relative distance which goes into the edge MLP
num_nearest_neighbors = 64 # only do attention between coordinates N nearest neighbors - set to 0 to turn off
depth = 4, # depth
dim_head = 64, # dimension per head
heads = 8, # number of heads
edge_dim = 4, # dimension of edge feature
neighbors = 64 # only do attention between coordinates N nearest neighbors - set to 0 to turn off
)

feats = torch.randn(1, 1024, 512)
Expand All @@ -40,15 +39,15 @@ import torch
from en_transformer import EnTransformer

model = EnTransformer(
num_tokens = 10,
num_edge_tokens = 5,
num_tokens = 10, # number of unique nodes, say atoms
rel_pos_emb = True, # set this to true if your sequence is not an unordered set. it will accelerate convergence
num_edge_tokens = 5, # number of unique edges, say bond types
dim = 128,
edge_dim = 16,
depth = 3,
heads = 4,
dim_head = 32,
fourier_features = 2,
num_nearest_neighbors = 8,
neighbors = 8
)

atoms = torch.randint(0, 10, (1, 16)) # 10 different types of atoms
Expand All @@ -70,8 +69,7 @@ model = EnTransformer(
depth = 1,
heads = 4,
dim_head = 32,
fourier_features = 2,
num_nearest_neighbors = 0,
neighbors = 0,
only_sparse_neighbors = True, # must be set to true
num_adj_degrees = 3, # the number of degrees to derive from 1st degree neighbors passed in
adj_dim = 8 # whether to pass the adjacency degree information as an edge embedding
Expand Down Expand Up @@ -103,7 +101,6 @@ model = EnTransformer(
heads = 4,
dim_head = 32,
edge_dim = 4,
fourier_features = 2,
num_nearest_neighbors = 0,
only_sparse_neighbors = True
)
Expand Down
22 changes: 9 additions & 13 deletions denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
BATCH_SIZE = 1
GRADIENT_ACCUMULATE_EVERY = 16

def cycle(loader, len_thres = 500):
def cycle(loader, len_thres = 200):
while True:
for data in loader:
if data.seqs.shape[1] > len_thres:
Expand All @@ -21,14 +21,12 @@ def cycle(loader, len_thres = 500):

transformer = EnTransformer(
num_tokens = 21,
dim = 16,
dim_head = 32,
dim = 32,
dim_head = 64,
heads = 4,
depth = 4,
norm_rel_coors = True,
only_sparse_neighbors = True,
num_adj_degrees = 3,
adj_dim = 4
rel_pos_emb = True, # there is inherent order in the sequence (backbone atoms of amino acid chain)
neighbors = 16
)

data = scn.load(
Expand All @@ -53,9 +51,10 @@ def cycle(loader, len_thres = 500):
masks = masks.cuda().bool()

l = seqs.shape[1]
coords = rearrange(coords, 'b (l s) c -> b l s c', s=14)
coords = rearrange(coords, 'b (l s) c -> b l s c', s = 14)

# keeping only the backbone coordinates

# Keeping only the backbone coordinates
coords = coords[:, :, 0:3, :]
coords = rearrange(coords, 'b l s c -> b (l s) c')

Expand All @@ -64,10 +63,7 @@ def cycle(loader, len_thres = 500):

noised_coords = coords + torch.randn_like(coords)

i = torch.arange(seq.shape[-1]).to(seq)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))

feats, denoised_coords = transformer(seq, noised_coords, mask = masks, adj_mat = adj_mat)
feats, denoised_coords = transformer(seq, noised_coords, mask = masks)

loss = F.mse_loss(denoised_coords[masks], coords[masks])

Expand Down
Loading

0 comments on commit a64a3ca

Please sign in to comment.