-
Notifications
You must be signed in to change notification settings - Fork 63
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
33 changed files
with
2,236 additions
and
54 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
|
||
import numpy as np | ||
import torch | ||
from torch import nn | ||
from torch.autograd import Variable | ||
|
||
import torch.fft as afft | ||
|
||
|
||
class CompactBilinearPooling(nn.Module): | ||
""" | ||
from https://github.com/DeepInsight-PCALab/CompactBilinearPooling-Pytorch | ||
Compute compact bilinear pooling over two bottom inputs. | ||
Args: | ||
output_dim: output dimension for compact bilinear pooling. | ||
sum_pool: (Optional) If True, sum the output along height and width | ||
dimensions and return output shape [batch_size, output_dim]. | ||
Otherwise return [batch_size, height, width, output_dim]. | ||
Default: True. | ||
rand_h_1: (Optional) an 1D numpy array containing indices in interval | ||
`[0, output_dim)`. Automatically generated from `seed_h_1` | ||
if is None. | ||
rand_s_1: (Optional) an 1D numpy array of 1 and -1, having the same shape | ||
as `rand_h_1`. Automatically generated from `seed_s_1` if is | ||
None. | ||
rand_h_2: (Optional) an 1D numpy array containing indices in interval | ||
`[0, output_dim)`. Automatically generated from `seed_h_2` | ||
if is None. | ||
rand_s_2: (Optional) an 1D numpy array of 1 and -1, having the same shape | ||
as `rand_h_2`. Automatically generated from `seed_s_2` if is | ||
None. | ||
""" | ||
|
||
def __init__(self, input_dim1, input_dim2, output_dim, | ||
sum_pool=True, cuda=True, | ||
rand_h_1=None, rand_s_1=None, rand_h_2=None, rand_s_2=None): | ||
super(CompactBilinearPooling, self).__init__() | ||
self.input_dim1 = input_dim1 | ||
self.input_dim2 = input_dim2 | ||
self.output_dim = output_dim | ||
self.sum_pool = sum_pool | ||
|
||
if rand_h_1 is None: | ||
np.random.seed(1) | ||
rand_h_1 = np.random.randint(output_dim, size=self.input_dim1) | ||
if rand_s_1 is None: | ||
np.random.seed(3) | ||
rand_s_1 = 2 * np.random.randint(2, size=self.input_dim1) - 1 | ||
|
||
self.sparse_sketch_matrix1 = Variable(self.generate_sketch_matrix( | ||
rand_h_1, rand_s_1, self.output_dim)) | ||
|
||
if rand_h_2 is None: | ||
np.random.seed(5) | ||
rand_h_2 = np.random.randint(output_dim, size=self.input_dim2) | ||
if rand_s_2 is None: | ||
np.random.seed(7) | ||
rand_s_2 = 2 * np.random.randint(2, size=self.input_dim2) - 1 | ||
|
||
self.sparse_sketch_matrix2 = Variable(self.generate_sketch_matrix( | ||
rand_h_2, rand_s_2, self.output_dim)) | ||
|
||
if cuda: | ||
self.sparse_sketch_matrix1 = self.sparse_sketch_matrix1.cuda() | ||
self.sparse_sketch_matrix2 = self.sparse_sketch_matrix2.cuda() | ||
|
||
def forward(self, bottom1, bottom2): | ||
""" | ||
bottom1: 1st input, 4D Tensor of shape [batch_size, input_dim1, height, width]. | ||
bottom2: 2nd input, 4D Tensor of shape [batch_size, input_dim2, height, width]. | ||
""" | ||
assert bottom1.size(1) == self.input_dim1 and \ | ||
bottom2.size(1) == self.input_dim2 | ||
|
||
batch_size, _, height, width = bottom1.size() | ||
|
||
bottom1_flat = bottom1.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim1) | ||
bottom2_flat = bottom2.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim2) | ||
|
||
sketch_1 = bottom1_flat.mm(self.sparse_sketch_matrix1) | ||
sketch_2 = bottom2_flat.mm(self.sparse_sketch_matrix2) | ||
|
||
fft1 = afft.fft(sketch_1) | ||
fft2 = afft.fft(sketch_2) | ||
|
||
fft_product = fft1 * fft2 | ||
|
||
cbp_flat = afft.ifft(fft_product).real | ||
|
||
cbp = cbp_flat.view(batch_size, height, width, self.output_dim) | ||
|
||
if self.sum_pool: | ||
cbp = cbp.sum(dim=1).sum(dim=1) | ||
|
||
return cbp | ||
|
||
@staticmethod | ||
def generate_sketch_matrix(rand_h, rand_s, output_dim): | ||
""" | ||
Return a sparse matrix used for tensor sketch operation in compact bilinear | ||
pooling | ||
Args: | ||
rand_h: an 1D numpy array containing indices in interval `[0, output_dim)`. | ||
rand_s: an 1D numpy array of 1 and -1, having the same shape as `rand_h`. | ||
output_dim: the output dimensions of compact bilinear pooling. | ||
Returns: | ||
a sparse matrix of shape [input_dim, output_dim] for tensor sketch. | ||
""" | ||
|
||
# Generate a sparse matrix for tensor count sketch | ||
rand_h = rand_h.astype(np.int64) | ||
rand_s = rand_s.astype(np.float32) | ||
assert(rand_h.ndim == 1 and rand_s.ndim == | ||
1 and len(rand_h) == len(rand_s)) | ||
assert(np.all(rand_h >= 0) and np.all(rand_h < output_dim)) | ||
|
||
input_dim = len(rand_h) | ||
indices = np.concatenate((np.arange(input_dim)[..., np.newaxis], | ||
rand_h[..., np.newaxis]), axis=1) | ||
indices = torch.from_numpy(indices) | ||
rand_s = torch.from_numpy(rand_s) | ||
sparse_sketch_matrix = torch.sparse.FloatTensor( | ||
indices.t(), rand_s, torch.Size([input_dim, output_dim])) | ||
return sparse_sketch_matrix.to_dense() | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
bottom1 = Variable(torch.randn(3, 512, 14, 14)) | ||
bottom2 = Variable(torch.randn(3, 128, 14, 14)) | ||
|
||
layer = CompactBilinearPooling(512, 128, 512, cuda=False) | ||
layer.train() | ||
|
||
out = layer(bottom1, bottom2) | ||
print(out.shape) | ||
|
||
|
||
|
||
|
File renamed without changes.
File renamed without changes.
File renamed without changes.
Oops, something went wrong.