-
Notifications
You must be signed in to change notification settings - Fork 0
/
crnn.py
254 lines (193 loc) · 8.18 KB
/
crnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import unidecode
import string
import random
import re
import torch
import torch.nn as nn
from tqdm import tqdm
import time, math
import torch.nn.functional as f
import unidecode
import string
import random
import re
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler
all_characters = string.printable
n_characters = len(all_characters)
### Copied from AllenNLP library ==> https://allennlp.org/
class InputVariationalDropout(torch.nn.Dropout):
"""
Apply the dropout technique in Gal and Ghahramani, "Dropout as a Bayesian Approximation:
Representing Model Uncertainty in Deep Learning" (https://arxiv.org/abs/1506.02142) to a
3D tensor.
This module accepts a 3D tensor of shape ``(batch_size, num_timesteps, embedding_dim)``
and samples a single dropout mask of shape ``(batch_size, embedding_dim)`` and applies
it to every time step.
"""
def forward(self, input_tensor):
# pylint: disable=arguments-differ
"""
Apply dropout to input tensor.
Parameters
----------
input_tensor: ``torch.FloatTensor``
A tensor of shape ``(batch_size, num_timesteps, embedding_dim)``
Returns
-------
output: ``torch.FloatTensor``
A tensor of shape ``(batch_size, num_timesteps, embedding_dim)`` with dropout applied.
"""
ones = input_tensor.data.new_ones(input_tensor.shape[0], input_tensor.shape[-1])
dropout_mask = torch.nn.functional.dropout(ones, self.p, self.training, inplace=False)
if self.inplace:
input_tensor *= dropout_mask.unsqueeze(1)
return None
else:
return dropout_mask.unsqueeze(1) * input_tensor
class RNN(nn.Module):
def __init__(self, hidden_size, n_layers=1,dropout=0.5,rnn_cell=nn.RNN):
"""
Create the network
"""
super(RNN, self).__init__()
all_characters = string.printable
n_char = len(all_characters)
self.n_char = n_char
self.hidden_size = hidden_size
self.n_layers = n_layers
self.dropout = InputVariationalDropout(dropout,inplace=False)
# (batch,chunk_len) -> (batch, chunk_len, hidden_size)
self.embed = nn.Embedding(n_char,hidden_size)
# (batch, chunk_len, hidden_size) -> (batch, chunk_len, hidden_size)
self.rnns = nn.ModuleList([rnn_cell(hidden_size,hidden_size,batch_first=True) for _ in range(n_layers)])
#(batch, chunk_len, hidden_size) -> (batch, chunk_len, output_size)
self.predict = nn.Linear(hidden_size,n_char)
def forward(self, input):
"""
batched forward: input is (batch > 1,chunk_len)
"""
output = self.embed(input)
for rnn in self.rnns:
do = self.dropout(output)
output,_ = rnn(do)
output = self.predict(self.dropout(output))
return output
def forward_seq(self, input,hiddens=None):
"""
not batched forward: input is (1,chunk_len)
"""
output = self.embed(input)
output = output.unsqueeze(0)
new_hiddens = []
if hiddens is None:
hiddens = [None for _ in self.rnns]
for rnn,hidden in zip(self.rnns,hiddens):
output,hidden = rnn(output,hidden)
new_hiddens.append(hidden)
output = self.predict(output)
return output,new_hiddens
class CharRNN():
def __init__(self,filename,load=None,device='cpu', hidden_size=512, n_layers=3,rnn_cell=nn.LSTM):
self.model = RNN( hidden_size, n_layers=n_layers,rnn_cell=rnn_cell)
self.device = device
self.model.to(device)
self.file = unidecode.unidecode(open(filename).read()) #clean text => only ascii
self.file_len = len(self.file)
self.tensor_file = self.file2tensor()
self.checkpoint_file = "charnn.chkpt"
print('file_len =', self.file_len)
if load:
self.load(load)
def save(self,path):
torch.save(self.model.state_dict(), path)
def load(self,path):
self.model.load_state_dict(torch.load(path,map_location=self.device))
#### GENERATION #####
def generate(self,prime_str='.', predict_len=100, temperature=0.8):
prime_input = self.char_tensor(prime_str).squeeze(0)
hidden = None
predicted = prime_str+""
# Use priming string to "build up" hidden state
for p in range(len(prime_str)-1):
_,hidden = self.model.forward_seq(prime_input[p].unsqueeze(0),hidden)
#print(hidden.size())
for p in range(predict_len):
output, hidden = self.model.forward_seq(prime_input[-1].unsqueeze(0), hidden)
# Sample from the network as a multinomial distribution
output_dist = output.data.view(-1).div(temperature).exp()
#print(output_dist)
top_i = torch.multinomial(output_dist, 1)[0]
#print(top_i)
# Add predicted character to string and use as next input
predicted_char = all_characters[top_i]
predicted += predicted_char
prime_input = torch.cat([prime_input,self.char_tensor(predicted_char).squeeze(0)])
return predicted
########## DATA ##########
#Maps the file to a tensor of longs
def file2tensor(self):
all_characters = string.printable
return torch.LongTensor([all_characters.index(x) for x in self.file]).to(self.device)#,dtype=torch.Long())
#creates chunks
def training_set_tensor(self,chunk_len):
remains = self.tensor_file.size(0)%(chunk_len+1)
view = self.tensor_file[:-remains].view(-1,chunk_len+1)
return view
# Turn string into list of longs
def char_tensor(self,string):
tensor = torch.zeros(1,len(string),device=self.device).long()
for c in range(len(string)):
tensor[0,c] = all_characters.index(string[c])
return tensor
#### Training ####
def train_one(self,inp, target):
"""
Train sequence for one chunk:
"""
#reset gradients
self.model.train()
self.model_optimizer.zero_grad()
# predict output
output = self.model(inp)
#compute loss
loss = f.cross_entropy(output.view(output.size(0)*output.size(1),-1), target.view(-1))
#compute gradients and backpropagate
loss.backward()
self.model_optimizer.step()
self.model.eval()
return loss.data.item()
def train(self,iterations=1,chunk_len=110,batch_size=16, print_each=100):
self.model_optimizer= torch.optim.Adam(self.model.parameters())
train_file = self.training_set_tensor(chunk_len)
data = DataLoader(train_file, batch_size=batch_size,shuffle=True)
iters = 0
with tqdm(total=iterations,desc=f"training - chunks of len {chunk_len}") as pbar:
while (iters < iterations):
for t in data:
tr,te = t[:,:-1].contiguous(),t[:,1:].contiguous()
loss = self.train_one(tr,te) #train on one chunk
if iters % print_each == 0:
self.save(self.checkpoint_file)
print("-"*25)
print(f"Generated text at iter {iters}")
print("-"*25)
print(self.generate(temperature=0.8))
print("-"*25)
print(f"model-checkpointed in {self.checkpoint_file}")
print("")
iters += 1
pbar.update(1)
pbar.set_postfix({"loss":loss})
if iters > iterations:
break
if __name__ == "__main__":
crnn = CharRNN("input.txt",device="cpu",rnn_cell=nn.LSTM)
#print(crnn.training_set_tensor(100))
crnn.load("charnn.chkpt")
#for chklen in range(64,100):
crnn.train(200,batch_size=256,chunk_len=64) # train for X epochs
#print(crnn.generate())
crnn.save("ss_512")