-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
387 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
|
||
import torch | ||
from torch import nn | ||
from math import sqrt | ||
|
||
|
||
class LinearEmbedding(nn.Module): | ||
def __init__(self, input_size: int, embed_size: int): | ||
super(LinearEmbedding, self).__init__() | ||
self.embedding = nn.Linear(input_size, embed_size) | ||
|
||
def forward(self, inputs): | ||
return self.embedding(inputs) | ||
|
||
|
||
# Glimpse using Dot-product attention | ||
class Glimpse(nn.Module): | ||
def __init__(self, | ||
input_size, | ||
hidden_size, | ||
n_head): | ||
super(Glimpse, self).__init__() | ||
self.input_size = input_size | ||
self.hidden_size = hidden_size | ||
self.n_head = n_head | ||
self.single_dim = hidden_size // n_head | ||
self.c_div = 1.0 / sqrt(self.single_dim) | ||
self.W_q = nn.Linear(self.input_size, self.hidden_size) | ||
self.W_k = nn.Linear(self.input_size, self.hidden_size) | ||
self.W_v = nn.Linear(self.input_size, self.hidden_size) | ||
self.W_out = nn.Linear(self.hidden_size, self.input_size) | ||
|
||
def forward(self, query, target, mask=None): | ||
""" | ||
Parameters | ||
---------- | ||
- query : FloatTensor with shape [batch_size x input_size] | ||
- target : FloatTensor with shape [batch_size x seq_len x input_size] | ||
- mask : BoolTensor with shape [batch_size x input_size] | ||
""" | ||
batch_size, seq_len, _ = target.shape | ||
|
||
q_c = self.W_q(query).reshape(batch_size, self.n_head, self.single_dim) | ||
k = self.W_k(target).reshape(batch_size, seq_len, self.n_head, | ||
self.single_dim).permute(0, 2, 1, 3).contiguous() | ||
v = self.W_v(target).reshape(batch_size, seq_len, self.n_head, | ||
self.single_dim).permute(0, 2, 1, 3).contiguous() | ||
qk = torch.einsum("ijl,ijkl->ijk", [q_c, k]) * self.c_div | ||
|
||
if mask is not None: | ||
_mask = mask.unsqueeze(1).repeat(1, self.n_head, 1) | ||
qk[_mask] = -100000.0 | ||
|
||
alpha = torch.softmax(qk, -1) | ||
#print(alpha.shape, v.shape) | ||
h = torch.einsum("ijk,ijkl->ijl", alpha, v) | ||
|
||
if self.n_head == 1: | ||
ret = h.reshape(batch_size, -1) | ||
return alpha.squeeze(1), ret | ||
else: | ||
ret = self.W_out(h.reshape(batch_size, -1)) | ||
return alpha, ret | ||
|
||
|
||
# Pointer using Dot-product attention | ||
class Pointer(nn.Module): | ||
def __init__( | ||
self, | ||
input_size: int, | ||
hidden_size: int, | ||
C: float = 10 | ||
): | ||
|
||
super(Pointer, self).__init__() | ||
self.input_size = input_size | ||
self.hidden_size = hidden_size | ||
self.C = C | ||
|
||
self.W_q = nn.Linear(self.input_size, self.hidden_size) | ||
self.W_k = nn.Linear(self.input_size, self.hidden_size) | ||
self.W_v = nn.Linear(self.input_size, self.hidden_size) | ||
|
||
def forward(self, query, target, mask=None): | ||
""" | ||
Parameters | ||
---------- | ||
query : FloatTensor [batch_size x input_size] | ||
target : FloatTensor [batch_size x seq_len x input_size] | ||
mask : BoolTensor [batch_size x seq_len] | ||
""" | ||
batch_size, seq_len, _ = target.shape | ||
|
||
q_c = self.W_q(query) # batch_size x hidden_size | ||
k = self.W_k(target) # batch_size x seq_len x hidden_size | ||
v = self.W_v(target) # batch_size x seq_len x hidden_size | ||
qk = torch.einsum("ik,ijk->ij", [q_c, k]) # batch_size x seq_len | ||
qk = self.C * torch.tanh(qk) | ||
|
||
if mask is not None: | ||
_mask = mask.clone() | ||
qk[_mask] = -100000.0 | ||
|
||
alpha = torch.softmax(qk, dim=-1) | ||
ret = torch.einsum("ij,ijk->ij", [alpha, v]) | ||
|
||
return alpha, ret | ||
|
||
|
||
class Attention(nn.Module): | ||
""" Bahanadu Attention | ||
""" | ||
|
||
def __init__(self, hidden_size: int, C: float = 10): | ||
super(Attention, self).__init__() | ||
self.C = C | ||
self.W_q = nn.Linear(hidden_size, hidden_size) | ||
self.W_k = nn.Linear(hidden_size, hidden_size) | ||
self.W_v = nn.Linear(hidden_size, 1) | ||
|
||
def forward(self, query, target): | ||
""" | ||
Args: | ||
- query: [batch_size x hidden_size] | ||
- target: [batch_size x seq_len x hidden_size] | ||
""" | ||
|
||
batch_size, seq_len, _ = target.shape | ||
query = self.W_q(query) | ||
# [batch_size x seq_len x hidden_size] | ||
query = query.unsqueeze(1).repeat(1, seq_len, 1) | ||
target = self.W_k(target) # [batch_size x seq_len x hidden_size] | ||
logits = self.W_v(torch.tanh(query + target)).squeeze(-1) | ||
logits = self.C * torch.tanh(logits) | ||
|
||
return target, logits |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import torch | ||
from torch import nn | ||
from torch.autograd import Variable | ||
|
||
|
||
class TSPSolver(nn.Module): | ||
def __init__(self, actor: nn.Module, device: str = 'cpu'): | ||
super(TSPSolver, self).__init__() | ||
self.actor = actor | ||
self.device = device | ||
|
||
def compute_reward(self, sample_solution): | ||
# sample_solution seq_len of (batch_size) | ||
# torch.LongTensor (batch_size x seq_len x 2) | ||
|
||
batch_size, seq_len, _ = sample_solution.size() | ||
tour_length = Variable(torch.zeros([batch_size])).to(self.device) | ||
|
||
# tour (0 -> 1 -> ... -> n-1 -> n) reward | ||
for i in range(seq_len - 1): | ||
tour_length += torch.norm( | ||
sample_solution[:, i, :] - sample_solution[:, i + 1, :], dim=-1 | ||
) | ||
# last tour (n -> 0) reward | ||
tour_length += torch.norm( | ||
sample_solution[:, seq_len - 1, :] - sample_solution[:, 0, :], dim=-1 | ||
) | ||
|
||
return tour_length | ||
|
||
def forward(self, inputs: torch.tensor): | ||
# inputs shape: (batch_size, input_size, seq_len) | ||
probs, actions = self.actor(inputs) | ||
R = self.compute_reward( | ||
inputs.gather(1, actions.unsqueeze(2).repeat(1, 1, 2)) | ||
) | ||
|
||
return R, probs, actions |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torch.distributions import Categorical | ||
|
||
from models.layers import Glimpse, LinearEmbedding, Pointer | ||
|
||
|
||
class AttentionBlock(nn.Module): | ||
def __init__( | ||
self, | ||
embed_dim: int, | ||
n_heads: int, | ||
feed_forward_hidden: int = 512, | ||
bn: bool = False | ||
): | ||
super(AttentionBlock, self).__init__() | ||
self.mha = torch.nn.MultiheadAttention(embed_dim, n_heads) | ||
self.embed = nn.Sequential( | ||
nn.Linear(embed_dim, feed_forward_hidden), | ||
nn.ReLU(), | ||
nn.Linear(feed_forward_hidden, embed_dim) | ||
) | ||
|
||
def forward(self, x): | ||
# Multiheadattention in pytorch starts with (target_seq_length, batch_size, embedding_size). | ||
# thus we permute order first. https://pytorch.org/docs/stable/nn.html#multiheadattention | ||
x = x.permute(1, 0, 2) | ||
x = x + self.mha(x, x, x)[0] | ||
x = x.permute(1, 0, 2) | ||
x = x + self.embed(x) | ||
|
||
return x | ||
|
||
|
||
class AttentionModule(nn.Sequential): | ||
def __init__( | ||
self, | ||
embed_dim: int, | ||
n_heads: int, | ||
feed_forward_hidden: int = 512, | ||
n_self_attentions:int = 2, | ||
bn: bool = False | ||
): | ||
super(AttentionModule, self).__init__( | ||
*(AttentionBlock(embed_dim, n_heads, feed_forward_hidden, bn) for _ in range(n_self_attentions)) | ||
) | ||
|
||
|
||
class AttentionTSP(nn.Module): | ||
def __init__( | ||
self, | ||
pos_size: int, | ||
embed_size: int, | ||
hidden_size: int, | ||
seq_len: int, | ||
n_head: int = 4, | ||
C: float = 10 | ||
): | ||
super(AttentionTSP, self).__init__() | ||
|
||
self.embedding_size = embed_size | ||
self.hidden_size = hidden_size | ||
self.seq_len = seq_len | ||
self.n_head = n_head | ||
self.C = C | ||
|
||
self.embedding = LinearEmbedding(pos_size, embed_size) | ||
self.mha = AttentionModule(embed_size, n_head) | ||
|
||
self.init_w = nn.Parameter(torch.Tensor(pos_size * self.embedding_size)) | ||
self.init_w.data.uniform_(-1, 1) | ||
self.glimpse = Glimpse(self.embedding_size, | ||
self.hidden_size, self.n_head) | ||
self.pointer = Pointer(self.embedding_size, | ||
self.hidden_size, 1, self.C) | ||
|
||
self.h_context_embed = nn.Linear( | ||
self.embedding_size, self.embedding_size) | ||
self.v_weight_embed = nn.Linear( | ||
self.embedding_size * pos_size, self.embedding_size) | ||
|
||
def forward(self, inputs): | ||
# inputs: [batch_size x seq_len x 2] | ||
|
||
batch_size = inputs.shape[0] | ||
seq_len = inputs.shape[1] | ||
|
||
embedded = self.embedding(inputs) | ||
h = self.mha(embedded) | ||
h_mean = h.mean(dim=1) | ||
h_bar = self.h_context_embed(h_mean) | ||
h_rest = self.v_weight_embed(self.init_w) | ||
query = h_bar + h_rest | ||
|
||
# init query | ||
prev_chosen_indices = [] | ||
prev_chosen_logprobs = [] | ||
first_chosen_hs = None | ||
mask = torch.zeros(batch_size, self.seq_len, dtype=torch.bool) | ||
|
||
for index in range(self.seq_len): | ||
_, n_query = self.glimpse(query, h, mask) | ||
prob, _ = self.pointer(n_query, h, mask) | ||
cat = Categorical(prob) | ||
chosen = cat.sample() | ||
logprobs = cat.log_prob(chosen) | ||
prev_chosen_indices.append(chosen) | ||
prev_chosen_logprobs.append(logprobs) | ||
|
||
mask[[i for i in range(batch_size)], chosen] = True | ||
cc = chosen.unsqueeze(1).unsqueeze(2).repeat( | ||
1, 1, self.embedding_size | ||
) | ||
if first_chosen_hs is None: | ||
first_chosen_hs = h.gather(1, cc).squeeze(1) | ||
chosen_hs = h.gather(1, cc).squeeze(1) | ||
h_rest = self.v_weight_embed( | ||
torch.cat([first_chosen_hs, chosen_hs], dim=-1)) | ||
query = h_bar + h_rest | ||
|
||
return torch.stack(prev_chosen_logprobs, 1), torch.stack(prev_chosen_indices, 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import math | ||
import torch | ||
import numpy as np | ||
from torch import nn | ||
from torch.distributions import Categorical | ||
|
||
from models.layers import Attention, LinearEmbedding | ||
|
||
|
||
class RNNTSP(nn.Module): | ||
def __init__( | ||
self, | ||
pos_size: int, | ||
embed_size: int, | ||
hidden_size: int, | ||
seq_len: int, | ||
n_glimpses: int, | ||
tanh_exploration: float | ||
): | ||
super(RNNTSP, self).__init__() | ||
|
||
self.embed_size = embed_size | ||
self.hidden_size = hidden_size | ||
self.n_glimpses = n_glimpses | ||
self.seq_len = seq_len | ||
|
||
self.embedding = LinearEmbedding(pos_size, embed_size) | ||
self.encoder = nn.LSTM(embed_size, hidden_size, batch_first=True) | ||
self.decoder = nn.LSTM(embed_size, hidden_size, batch_first=True) | ||
self.pointer = Attention(hidden_size, C=tanh_exploration) | ||
self.glimpse = Attention(hidden_size) | ||
|
||
self.dec_sos = nn.Parameter(torch.FloatTensor(embed_size)) | ||
self.dec_sos.data.uniform_(-(1. / math.sqrt(embed_size)), | ||
1. / math.sqrt(embed_size)) | ||
|
||
def apply_mask_to_logits(self, logits, mask, idxs): | ||
batch_size = logits.size(0) | ||
clone_mask = mask.clone() | ||
if idxs is not None: | ||
clone_mask[[i for i in range(batch_size)], idxs.data] = 1 | ||
logits[clone_mask] = -np.inf | ||
|
||
return logits, clone_mask | ||
|
||
def forward(self, inputs): | ||
# inputs: (batch_size x seq_len x 2) | ||
|
||
batch_size = inputs.shape[0] | ||
seq_len = inputs.shape[1] | ||
|
||
embedded = self.embedding(inputs) | ||
encoder_outputs, (hidden, context) = self.encoder(embedded) | ||
|
||
prev_chosen_logprobs = [] | ||
preb_chosen_indices = [] | ||
mask = torch.zeros(batch_size, self.seq_len, dtype=torch.bool) | ||
|
||
decoder_input = self.dec_sos | ||
decoder_input = decoder_input.unsqueeze(0).repeat(batch_size, 1) | ||
for index in range(seq_len): | ||
_, (hidden, context) = self.decoder( | ||
decoder_input.unsqueeze(1), (hidden, context) | ||
) | ||
|
||
query = hidden.squeeze(0) | ||
for _ in range(self.n_glimpses): | ||
ref, logits = self.glimpse(query, encoder_outputs) | ||
_mask = mask.clone() | ||
# logits[_mask] = -100000.0 | ||
logits[_mask] = -np.inf | ||
query = torch.matmul( | ||
ref.transpose(-1, -2), | ||
torch.softmax(logits, dim=-1).unsqueeze(-1) | ||
).squeeze(-1) | ||
|
||
_, logits = self.pointer(query, encoder_outputs) | ||
|
||
_mask = mask.clone() | ||
logits[_mask] = -np.inf | ||
probs = torch.softmax(logits, dim=-1) | ||
cat = Categorical(probs) | ||
chosen = cat.sample() | ||
mask[[i for i in range(batch_size)], chosen] = True | ||
log_probs = cat.log_prob(chosen) | ||
decoder_input = embedded.gather( | ||
1, chosen[:, None, None].repeat(1, 1, self.hidden_size) | ||
).squeeze(1) | ||
prev_chosen_logprobs.append(log_probs) | ||
preb_chosen_indices.append(chosen) | ||
|
||
return torch.stack(prev_chosen_logprobs, 1), torch.stack(preb_chosen_indices, 1) |