Skip to content

Commit

Permalink
rnn and attention tsp model
Browse files Browse the repository at this point in the history
  • Loading branch information
mskimS2 committed Jan 13, 2023
1 parent cb06d17 commit df830e7
Show file tree
Hide file tree
Showing 4 changed files with 387 additions and 0 deletions.
136 changes: 136 additions & 0 deletions models/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@

import torch
from torch import nn
from math import sqrt


class LinearEmbedding(nn.Module):
def __init__(self, input_size: int, embed_size: int):
super(LinearEmbedding, self).__init__()
self.embedding = nn.Linear(input_size, embed_size)

def forward(self, inputs):
return self.embedding(inputs)


# Glimpse using Dot-product attention
class Glimpse(nn.Module):
def __init__(self,
input_size,
hidden_size,
n_head):
super(Glimpse, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.n_head = n_head
self.single_dim = hidden_size // n_head
self.c_div = 1.0 / sqrt(self.single_dim)
self.W_q = nn.Linear(self.input_size, self.hidden_size)
self.W_k = nn.Linear(self.input_size, self.hidden_size)
self.W_v = nn.Linear(self.input_size, self.hidden_size)
self.W_out = nn.Linear(self.hidden_size, self.input_size)

def forward(self, query, target, mask=None):
"""
Parameters
----------
- query : FloatTensor with shape [batch_size x input_size]
- target : FloatTensor with shape [batch_size x seq_len x input_size]
- mask : BoolTensor with shape [batch_size x input_size]
"""
batch_size, seq_len, _ = target.shape

q_c = self.W_q(query).reshape(batch_size, self.n_head, self.single_dim)
k = self.W_k(target).reshape(batch_size, seq_len, self.n_head,
self.single_dim).permute(0, 2, 1, 3).contiguous()
v = self.W_v(target).reshape(batch_size, seq_len, self.n_head,
self.single_dim).permute(0, 2, 1, 3).contiguous()
qk = torch.einsum("ijl,ijkl->ijk", [q_c, k]) * self.c_div

if mask is not None:
_mask = mask.unsqueeze(1).repeat(1, self.n_head, 1)
qk[_mask] = -100000.0

alpha = torch.softmax(qk, -1)
#print(alpha.shape, v.shape)
h = torch.einsum("ijk,ijkl->ijl", alpha, v)

if self.n_head == 1:
ret = h.reshape(batch_size, -1)
return alpha.squeeze(1), ret
else:
ret = self.W_out(h.reshape(batch_size, -1))
return alpha, ret


# Pointer using Dot-product attention
class Pointer(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
C: float = 10
):

super(Pointer, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.C = C

self.W_q = nn.Linear(self.input_size, self.hidden_size)
self.W_k = nn.Linear(self.input_size, self.hidden_size)
self.W_v = nn.Linear(self.input_size, self.hidden_size)

def forward(self, query, target, mask=None):
"""
Parameters
----------
query : FloatTensor [batch_size x input_size]
target : FloatTensor [batch_size x seq_len x input_size]
mask : BoolTensor [batch_size x seq_len]
"""
batch_size, seq_len, _ = target.shape

q_c = self.W_q(query) # batch_size x hidden_size
k = self.W_k(target) # batch_size x seq_len x hidden_size
v = self.W_v(target) # batch_size x seq_len x hidden_size
qk = torch.einsum("ik,ijk->ij", [q_c, k]) # batch_size x seq_len
qk = self.C * torch.tanh(qk)

if mask is not None:
_mask = mask.clone()
qk[_mask] = -100000.0

alpha = torch.softmax(qk, dim=-1)
ret = torch.einsum("ij,ijk->ij", [alpha, v])

return alpha, ret


class Attention(nn.Module):
""" Bahanadu Attention
"""

def __init__(self, hidden_size: int, C: float = 10):
super(Attention, self).__init__()
self.C = C
self.W_q = nn.Linear(hidden_size, hidden_size)
self.W_k = nn.Linear(hidden_size, hidden_size)
self.W_v = nn.Linear(hidden_size, 1)

def forward(self, query, target):
"""
Args:
- query: [batch_size x hidden_size]
- target: [batch_size x seq_len x hidden_size]
"""

batch_size, seq_len, _ = target.shape
query = self.W_q(query)
# [batch_size x seq_len x hidden_size]
query = query.unsqueeze(1).repeat(1, seq_len, 1)
target = self.W_k(target) # [batch_size x seq_len x hidden_size]
logits = self.W_v(torch.tanh(query + target)).squeeze(-1)
logits = self.C * torch.tanh(logits)

return target, logits
38 changes: 38 additions & 0 deletions models/solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
from torch import nn
from torch.autograd import Variable


class TSPSolver(nn.Module):
def __init__(self, actor: nn.Module, device: str = 'cpu'):
super(TSPSolver, self).__init__()
self.actor = actor
self.device = device

def compute_reward(self, sample_solution):
# sample_solution seq_len of (batch_size)
# torch.LongTensor (batch_size x seq_len x 2)

batch_size, seq_len, _ = sample_solution.size()
tour_length = Variable(torch.zeros([batch_size])).to(self.device)

# tour (0 -> 1 -> ... -> n-1 -> n) reward
for i in range(seq_len - 1):
tour_length += torch.norm(
sample_solution[:, i, :] - sample_solution[:, i + 1, :], dim=-1
)
# last tour (n -> 0) reward
tour_length += torch.norm(
sample_solution[:, seq_len - 1, :] - sample_solution[:, 0, :], dim=-1
)

return tour_length

def forward(self, inputs: torch.tensor):
# inputs shape: (batch_size, input_size, seq_len)
probs, actions = self.actor(inputs)
R = self.compute_reward(
inputs.gather(1, actions.unsqueeze(2).repeat(1, 1, 2))
)

return R, probs, actions
121 changes: 121 additions & 0 deletions models/tsp_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import torch
import torch.nn as nn
from torch.distributions import Categorical

from models.layers import Glimpse, LinearEmbedding, Pointer


class AttentionBlock(nn.Module):
def __init__(
self,
embed_dim: int,
n_heads: int,
feed_forward_hidden: int = 512,
bn: bool = False
):
super(AttentionBlock, self).__init__()
self.mha = torch.nn.MultiheadAttention(embed_dim, n_heads)
self.embed = nn.Sequential(
nn.Linear(embed_dim, feed_forward_hidden),
nn.ReLU(),
nn.Linear(feed_forward_hidden, embed_dim)
)

def forward(self, x):
# Multiheadattention in pytorch starts with (target_seq_length, batch_size, embedding_size).
# thus we permute order first. https://pytorch.org/docs/stable/nn.html#multiheadattention
x = x.permute(1, 0, 2)
x = x + self.mha(x, x, x)[0]
x = x.permute(1, 0, 2)
x = x + self.embed(x)

return x


class AttentionModule(nn.Sequential):
def __init__(
self,
embed_dim: int,
n_heads: int,
feed_forward_hidden: int = 512,
n_self_attentions:int = 2,
bn: bool = False
):
super(AttentionModule, self).__init__(
*(AttentionBlock(embed_dim, n_heads, feed_forward_hidden, bn) for _ in range(n_self_attentions))
)


class AttentionTSP(nn.Module):
def __init__(
self,
pos_size: int,
embed_size: int,
hidden_size: int,
seq_len: int,
n_head: int = 4,
C: float = 10
):
super(AttentionTSP, self).__init__()

self.embedding_size = embed_size
self.hidden_size = hidden_size
self.seq_len = seq_len
self.n_head = n_head
self.C = C

self.embedding = LinearEmbedding(pos_size, embed_size)
self.mha = AttentionModule(embed_size, n_head)

self.init_w = nn.Parameter(torch.Tensor(pos_size * self.embedding_size))
self.init_w.data.uniform_(-1, 1)
self.glimpse = Glimpse(self.embedding_size,
self.hidden_size, self.n_head)
self.pointer = Pointer(self.embedding_size,
self.hidden_size, 1, self.C)

self.h_context_embed = nn.Linear(
self.embedding_size, self.embedding_size)
self.v_weight_embed = nn.Linear(
self.embedding_size * pos_size, self.embedding_size)

def forward(self, inputs):
# inputs: [batch_size x seq_len x 2]

batch_size = inputs.shape[0]
seq_len = inputs.shape[1]

embedded = self.embedding(inputs)
h = self.mha(embedded)
h_mean = h.mean(dim=1)
h_bar = self.h_context_embed(h_mean)
h_rest = self.v_weight_embed(self.init_w)
query = h_bar + h_rest

# init query
prev_chosen_indices = []
prev_chosen_logprobs = []
first_chosen_hs = None
mask = torch.zeros(batch_size, self.seq_len, dtype=torch.bool)

for index in range(self.seq_len):
_, n_query = self.glimpse(query, h, mask)
prob, _ = self.pointer(n_query, h, mask)
cat = Categorical(prob)
chosen = cat.sample()
logprobs = cat.log_prob(chosen)
prev_chosen_indices.append(chosen)
prev_chosen_logprobs.append(logprobs)

mask[[i for i in range(batch_size)], chosen] = True
cc = chosen.unsqueeze(1).unsqueeze(2).repeat(
1, 1, self.embedding_size
)
if first_chosen_hs is None:
first_chosen_hs = h.gather(1, cc).squeeze(1)
chosen_hs = h.gather(1, cc).squeeze(1)
h_rest = self.v_weight_embed(
torch.cat([first_chosen_hs, chosen_hs], dim=-1))
query = h_bar + h_rest

return torch.stack(prev_chosen_logprobs, 1), torch.stack(prev_chosen_indices, 1)
92 changes: 92 additions & 0 deletions models/tsp_rnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import math
import torch
import numpy as np
from torch import nn
from torch.distributions import Categorical

from models.layers import Attention, LinearEmbedding


class RNNTSP(nn.Module):
def __init__(
self,
pos_size: int,
embed_size: int,
hidden_size: int,
seq_len: int,
n_glimpses: int,
tanh_exploration: float
):
super(RNNTSP, self).__init__()

self.embed_size = embed_size
self.hidden_size = hidden_size
self.n_glimpses = n_glimpses
self.seq_len = seq_len

self.embedding = LinearEmbedding(pos_size, embed_size)
self.encoder = nn.LSTM(embed_size, hidden_size, batch_first=True)
self.decoder = nn.LSTM(embed_size, hidden_size, batch_first=True)
self.pointer = Attention(hidden_size, C=tanh_exploration)
self.glimpse = Attention(hidden_size)

self.dec_sos = nn.Parameter(torch.FloatTensor(embed_size))
self.dec_sos.data.uniform_(-(1. / math.sqrt(embed_size)),
1. / math.sqrt(embed_size))

def apply_mask_to_logits(self, logits, mask, idxs):
batch_size = logits.size(0)
clone_mask = mask.clone()
if idxs is not None:
clone_mask[[i for i in range(batch_size)], idxs.data] = 1
logits[clone_mask] = -np.inf

return logits, clone_mask

def forward(self, inputs):
# inputs: (batch_size x seq_len x 2)

batch_size = inputs.shape[0]
seq_len = inputs.shape[1]

embedded = self.embedding(inputs)
encoder_outputs, (hidden, context) = self.encoder(embedded)

prev_chosen_logprobs = []
preb_chosen_indices = []
mask = torch.zeros(batch_size, self.seq_len, dtype=torch.bool)

decoder_input = self.dec_sos
decoder_input = decoder_input.unsqueeze(0).repeat(batch_size, 1)
for index in range(seq_len):
_, (hidden, context) = self.decoder(
decoder_input.unsqueeze(1), (hidden, context)
)

query = hidden.squeeze(0)
for _ in range(self.n_glimpses):
ref, logits = self.glimpse(query, encoder_outputs)
_mask = mask.clone()
# logits[_mask] = -100000.0
logits[_mask] = -np.inf
query = torch.matmul(
ref.transpose(-1, -2),
torch.softmax(logits, dim=-1).unsqueeze(-1)
).squeeze(-1)

_, logits = self.pointer(query, encoder_outputs)

_mask = mask.clone()
logits[_mask] = -np.inf
probs = torch.softmax(logits, dim=-1)
cat = Categorical(probs)
chosen = cat.sample()
mask[[i for i in range(batch_size)], chosen] = True
log_probs = cat.log_prob(chosen)
decoder_input = embedded.gather(
1, chosen[:, None, None].repeat(1, 1, self.hidden_size)
).squeeze(1)
prev_chosen_logprobs.append(log_probs)
preb_chosen_indices.append(chosen)

return torch.stack(prev_chosen_logprobs, 1), torch.stack(preb_chosen_indices, 1)

0 comments on commit df830e7

Please sign in to comment.