-
Notifications
You must be signed in to change notification settings - Fork 6
/
Batch.py
73 lines (63 loc) · 2.82 KB
/
Batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
from torchtext import data
import numpy as np
from torch.autograd import Variable
def nopeak_mask(size, opt):
np_mask = np.triu(np.ones((1, size, size)), k=1).astype('uint8')
if opt.use_cond2dec == True:
cond_mask = np.zeros((1, opt.cond_dim, opt.cond_dim))
cond_mask_upperright = np.ones((1, opt.cond_dim, size))
cond_mask_upperright[:, :, 0] = 0
cond_mask_lowerleft = np.zeros((1, size, opt.cond_dim))
upper_mask = np.concatenate([cond_mask, cond_mask_upperright], axis=2)
lower_mask = np.concatenate([cond_mask_lowerleft, np_mask], axis=2)
np_mask = np.concatenate([upper_mask, lower_mask], axis=1)
np_mask = Variable(torch.from_numpy(np_mask) == 0)
if opt.device == 0:
np_mask = np_mask.cuda()
return np_mask
def create_masks(src, trg, cond, opt):
torch.set_printoptions(profile="full")
src_mask = (src != opt.src_pad).unsqueeze(-2)
cond_mask = torch.unsqueeze(cond, -2)
cond_mask = torch.ones_like(cond_mask, dtype=bool)
src_mask = torch.cat([cond_mask, src_mask], dim=2)
if trg is not None:
trg_mask = (trg != opt.trg_pad).unsqueeze(-2)
if opt.use_cond2dec == True:
trg_mask = torch.cat([cond_mask, trg_mask], dim=2)
np_mask = nopeak_mask(trg.size(1), opt)
if trg.is_cuda:
np_mask.cuda()
trg_mask = trg_mask & np_mask
else:
trg_mask = None
return src_mask, trg_mask
# patch on Torchtext's batching process that makes it more efficient
# from http:https://nlp.seas.harvard.edu/2018/04/03/attention.html#position-wise-feed-forward-networks
class MyIterator(data.Iterator):
def create_batches(self):
if self.train:
def pool(d, random_shuffler):
for p in data.batch(d, self.batch_size * 100):
p_batch = data.batch(sorted(p, key=self.sort_key), self.batch_size, self.batch_size_fn)
for b in random_shuffler(list(p_batch)):
yield b
self.batches = pool(self.data(), self.random_shuffler)
else:
self.batches = []
for b in data.batch(self.data(), self.batch_size,
self.batch_size_fn):
self.batches.append(sorted(b, key=self.sort_key))
global max_src_in_batch, max_tgt_in_batch
def batch_size_fn(new, count, sofar):
"Keep augmenting batch and calculate total number of tokens + padding."
global max_src_in_batch, max_tgt_in_batch
if count == 1:
max_src_in_batch = 0
max_tgt_in_batch = 0
max_src_in_batch = max(max_src_in_batch, len(new.src))
max_tgt_in_batch = max(max_tgt_in_batch, len(new.trg) + 2)
src_elements = count * max_src_in_batch
tgt_elements = count * max_tgt_in_batch
return max(src_elements, tgt_elements)