Skip to content

Commit

Permalink
re-organize noisy_env_eos_separated
Browse files Browse the repository at this point in the history
  • Loading branch information
wedddy0707 committed Feb 6, 2021
1 parent 5026d71 commit bc7b275
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 258 deletions.
29 changes: 0 additions & 29 deletions noisy_env_eos_separated/channel.py

This file was deleted.

75 changes: 17 additions & 58 deletions noisy_env_eos_separated/reinforce_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from egg.core.baselines import MeanBaseline

from rnn import RnnEncoder # Note that this 'rnn' is not the file in EGG
from util import find_lengths # Note that this 'utils' is not the file in EGG


Expand Down Expand Up @@ -67,19 +66,14 @@ def __init__(
def reset_parameters(self):
nn.init.normal_(self.sos_embedding, 0.0, 0.01)

def add_noise(self, x):
e = torch.randn_like(x).to(x.device)
return x + self.noise_loc + e * self.noise_scale

def forward(self, x):
prev_hidden = [self.agent(x)]
prev_hidden.extend([torch.zeros_like(prev_hidden[0])
for _ in range(self.num_layers - 1)])

prev_h = [self.agent(x)]
prev_h.extend([
torch.zeros_like(prev_h[0]) for _ in range(self.num_layers - 1)
])
prev_c = [
torch.zeros_like(
prev_hidden[0]) for _ in range(
self.num_layers)] # only used for LSTM
torch.zeros_like(prev_h[0]) for _ in range(self.num_layers)
] # only used for LSTM

input = torch.stack([self.sos_embedding] * x.size(0))

Expand All @@ -92,18 +86,18 @@ def forward(self, x):

for step in range(self.max_len):
for i, layer in enumerate(self.cells):
if self.training:
if isinstance(layer, nn.LSTMCell):
prev_c[i] = self.add_noise(prev_c[i])
else:
prev_hidden[i] = self.add_noise(prev_hidden[i])

e_t = float(self.training) * (
self.noise_loc +
self.noise_scale * torch.randn_like(prev_h[0]).to(prev_h[0])
)
if isinstance(layer, nn.LSTMCell):
h_t, c_t = layer(input, (prev_hidden[i], prev_c[i]))
h_t, c_t = layer(input, (prev_h[i], prev_c[i]))
c_t = c_t + e_t
prev_c[i] = c_t
else:
h_t = layer(input, prev_hidden[i])
prev_hidden[i] = h_t
h_t = layer(input, prev_h[i])
h_t = h_t + e_t
prev_h[i] = h_t
input = h_t

symb_probs = F.softmax(self.output_symbol(h_t), dim=1)
Expand All @@ -130,44 +124,10 @@ def forward(self, x):
symb_entropy = torch.stack(symb_entropy).permute(1, 0)
stop_entropy = torch.stack(stop_entropy).permute(1, 0)

sequence = (symb_seq, stop_seq)
logits = (symb_logits, stop_logits)
entropy = (symb_entropy, stop_entropy)

return sequence, logits, entropy


class RnnReceiverDeterministic(nn.Module):
def __init__(self,
agent,
vocab_size,
embed_dim,
hidden_size,
cell='rnn',
num_layers=1,
noise_loc=0.0,
noise_scale=0.0,
):
super(RnnReceiverDeterministic, self).__init__()
self.agent = agent
self.encoder = RnnEncoder(
vocab_size,
embed_dim,
hidden_size,
cell,
num_layers,
noise_loc=noise_loc,
noise_scale=noise_scale,
)

def forward(self, message, input=None, lengths=None):
encoded = self.encoder(message)
agent_output = self.agent(encoded, input)

logits = torch.zeros(agent_output.size(0)).to(agent_output.device)
entropy = logits

return agent_output, logits, entropy
return symb_seq, stop_seq, logits, entropy


class SenderReceiverRnnReinforce(nn.Module):
Expand Down Expand Up @@ -203,8 +163,7 @@ def forward(self, sender_input, labels, receiver_input=None):
######################################
# Forward Propagation through Sender #
######################################
seq_s, logprob_s, entropy_s = self.sender(sender_input)
symb_seq_s, stop_seq_s = seq_s
symb_seq_s, stop_seq_s, logprob_s, entropy_s = self.sender(sender_input)
symb_logprob_s, stop_logprob_s = logprob_s
symb_entropy_s, stop_entropy_s = entropy_s

Expand Down
153 changes: 0 additions & 153 deletions noisy_env_eos_separated/rnn.py

This file was deleted.

46 changes: 28 additions & 18 deletions noisy_env_eos_separated/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@
from egg.zoo.channel.archs import Sender, Receiver
from egg.zoo.channel.train import loss

from channel import Channel
from util import find_lengths
from reinforce_wrappers import RnnSenderReinforce
from reinforce_wrappers import RnnReceiverDeterministic
from reinforce_wrappers import SenderReceiverRnnReinforce
from util import find_lengths

import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))

from common import Channel # noqa: E402
from common import RnnReceiverDeterministic # noqa: E402


def get_params(params):
Expand Down Expand Up @@ -156,23 +161,26 @@ def suffix_test(game, n_features, device):

with torch.no_grad():
inputs = torch.eye(n_features).to(device)
seq, _, _ = game.sender(inputs)
messages, stop_seq = seq
lengths = find_lengths(stop_seq)
sender_output = game.sender(inputs)
messages = sender_output[0]
stop_seq = sender_output[1]

for i, m, ln in zip(inputs, messages, lengths):
i_symbol = i.argmax().item()
for m_idx in range(ln.item()):
p = torch.stack([m[0:m_idx + 1]])
o, _, _ = game.receiver(p, None, torch.tensor([m_idx + 1]))
o_symbol = o.argmax().item()
for i, m, stop in zip(inputs, messages, stop_seq):
for m_idx in range(m.size(0)):
prefix = m[0:m_idx + 1]
o = game.receiver(torch.stack([prefix]), lengths=torch.tensor([m_idx + 1]))
o = o[0]

dump_message = (
f'input: {i_symbol} -> '
f'message: {",".join([str(p[0,i].item()) for i in range(m_idx + 1)])} -> '
f'output: {o_symbol}')
f'input: {i.argmax().item()} -> '
f'prefix: {",".join([str(prefix[i].item()) for i in range(prefix.size(0))])} -> '
f'output: {o.argmax().item()}'
)
print(dump_message, flush=True)

if stop[m_idx].item():
break

game.train(mode=train_state)


Expand All @@ -182,8 +190,9 @@ def dump(game, n_features, device):
train_state = game.training # persist so we restore it back
game.eval()

seq, _, _ = game.sender(inputs)
messages, stop_seq = seq
sender_output = game.sender(inputs)
messages = sender_output[0]
stop_seq = sender_output[1]
lengths = find_lengths(stop_seq)
outputs, _, _ = game.receiver(messages, None, lengths)

Expand All @@ -205,7 +214,8 @@ def dump(game, n_features, device):
dump_message = (
f'input: {i_symbol.item()} -> '
f'message: {",".join([str(m[i].item()) for i in range(ln.item())])} -> '
f'output: {o_symbol.item()}')
f'output: {o_symbol.item()}'
)
print(dump_message, flush=True)

uniform_acc /= n_features
Expand Down

0 comments on commit bc7b275

Please sign in to comment.