Skip to content

Commit

Permalink
add multi-label demos
Browse files Browse the repository at this point in the history
  • Loading branch information
WAMAWAMA committed Jan 23, 2023
1 parent b84b432 commit 16c8701
Show file tree
Hide file tree
Showing 33 changed files with 2,236 additions and 54 deletions.
5 changes: 5 additions & 0 deletions .idea/codeStyles/codeStyleConfig.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

48 changes: 34 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ pip install git+https://github.com/rwightman/pytorch-image-models.git

## 2. Update list
- 2022/11/11: The birthday of this code, version `v0.0.1`
- 2023/01/23: Add [demo](demo/multi_label) code of 6 multi-label network structures (Happy Chinese New Year 🎇)
- ...


Expand Down Expand Up @@ -305,11 +306,11 @@ if __name__ == '__main__':
label_category_dict = dict(shape=4, color=3, other=13)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi-label predicted logits')
print('multi_label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# out
# multi-label predicted logits
# multi_label predicted logits
# logits of shape : torch.Size([2, 4])
# logits of color : torch.Size([2, 3])
# logits of other : torch.Size([2, 13])
Expand Down Expand Up @@ -370,11 +371,11 @@ if __name__ == '__main__':
label_category_dict = dict(organ=3)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi-label predicted logits')
print('multi_label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# out
# multi-label predicted logits
# multi_label predicted logits
# logits of organ : torch.Size([2, 3, 128, 128, 128])
```
</details>
Expand Down Expand Up @@ -433,11 +434,11 @@ if __name__ == '__main__':
label_category_dict = dict(organ=3, tumor=4)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi-label predicted logits')
print('multi_label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# out
# multi-label predicted logits
# multi_label predicted logits
# logits of organ : torch.Size([2, 3, 128, 128, 128])
# logits of tumor : torch.Size([2, 4, 128, 128, 128])
```
Expand Down Expand Up @@ -516,13 +517,13 @@ if __name__ == '__main__':
seg_label_category_dict=seg_label_category_dict,
dim=3)
seg_logits, cls_logits = model(x)
print('multi-label predicted logits')
print('multi_label predicted logits')
_ = [print('seg logits of ', key, ':', seg_logits[key].shape) for key in seg_logits.keys()]
print('-'*30)
_ = [print('cls logits of ', key, ':', cls_logits[key].shape) for key in cls_logits.keys()]

# out
# multi-label predicted logits
# multi_label predicted logits
# seg logits of organ : torch.Size([2, 3, 128, 128, 128])
# seg logits of tumor : torch.Size([2, 2, 128, 128, 128])
# ------------------------------
Expand Down Expand Up @@ -599,11 +600,11 @@ if __name__ == '__main__':
label_category_dict = dict(organ=3, tumor=4)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi-label predicted logits')
print('multi_label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# out
# multi-label predicted logits
# multi_label predicted logits
# logits of organ : torch.Size([2, 3, 128, 128, 128])
# logits of tumor : torch.Size([2, 4, 128, 128, 128])
```
Expand Down Expand Up @@ -717,11 +718,11 @@ if __name__ == '__main__':
model = TransUNet(in_channel=1, label_category_dict=label_category_dict, dim=2)
with torch.no_grad():
logits = model(x)
print('multi-label predicted logits')
print('multi_label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# out
# multi-label predicted logits
# multi_label predicted logits
# logits of organ : torch.Size([2, 3, 256, 256])
# logits of tumor : torch.Size([2, 4, 256, 256])
```
Expand Down Expand Up @@ -836,11 +837,11 @@ if __name__ == '__main__':
model = TransUnet(in_channel=1, label_category_dict=label_category_dict, dim=3)
with torch.no_grad():
logits = model(x)
print('multi-label predicted logits')
print('multi_label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# out
# multi-label predicted logits
# multi_label predicted logits
# logits of organ : torch.Size([2, 3, 128, 128, 96])
# logits of tumor : torch.Size([2, 4, 128, 128, 96])
```
Expand All @@ -849,6 +850,25 @@ if __name__ == '__main__':





<details>
<summary> Demo: Multi-label network structure 🟢 </summary>

6 different novel multi-label network structures

|Network| Publication | Demo code | Paper link|
|---|---|---|---|
|CNNRNN|CVPR2016|[code](demo/multi_label/Demo_CVPR2016_MultiLabel_CNNRNN.py)|[link](http:https://openaccess.thecvf.com/content_cvpr_2016/html/Wang_CNN-RNN_A_Unified_CVPR_2016_paper.html)|
|ML-GCN|CVPR2019|[code](demo/multi_label/Demo_CVPR2019_MultiLabel_ML_GCN.py)|[link](https://arxiv.org/abs/1904.03582)|
|SSGRL|ICCV2019|[code](demo/multi_label/Demo_ICCV2019_MultiLabel_SSGRL.py)|[link](https://arxiv.org/abs/1908.07325)|
|C-tran|CVPR2021|[code](demo/multi_label/Demo_CVPR2021_MultiLabel_C_tran.py)|[link](https://arxiv.org/abs/2011.14027)|
|ML-decoder|arxiv2021|[code](demo/multi_label/Demo_Arxiv2021_MultiLabel_ML_decoder.py)|[link](http:https://arxiv.org/abs/2111.12933)|
|Q2L|arxiv2021|[code](demo/multi_label/Demo_ArXiv2021_MultiLabel_Query2Label.py)|[link](https://arxiv.org/abs/2107.10834)|

</details>


*Todo-demo list ( 🚧 under preparation and coming soon...) ↓


Expand Down
3 changes: 2 additions & 1 deletion demo/Demo0_VGG_SingleLabelClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def forward(self, x):

if __name__ == '__main__':
x = torch.ones([2, 1, 64, 64, 64])
label_category_dict = dict(is_malignant=4)
category_num = 1
label_category_dict = dict(is_malignant=category_num)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('single-label predicted logits')
Expand Down
4 changes: 2 additions & 2 deletions demo/Demo2_ResNet_MultiLabelClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ def forward(self, x):
label_category_dict = dict(shape=4, color=3, other=13)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi-label predicted logits')
print('multi_label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# out
# multi-label predicted logits
# multi_label predicted logits
# logits of shape : torch.Size([2, 4])
# logits of color : torch.Size([2, 3])
# logits of other : torch.Size([2, 13])
4 changes: 2 additions & 2 deletions demo/Demo3_ResNetUnet_SingleLabelSegmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def forward(self, x):
label_category_dict = dict(organ=3)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi-label predicted logits')
print('multi_label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# out
# multi-label predicted logits
# multi_label predicted logits
# logits of organ : torch.Size([2, 3, 128, 128, 128])
4 changes: 2 additions & 2 deletions demo/Demo4_ResNetUnet_MultiLabelSegmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def forward(self, x):
label_category_dict = dict(organ=3, tumor=4)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi-label predicted logits')
print('multi_label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# out
# multi-label predicted logits
# multi_label predicted logits
# logits of organ : torch.Size([2, 3, 128, 128, 128])
# logits of tumor : torch.Size([2, 4, 128, 128, 128])

Expand Down
4 changes: 2 additions & 2 deletions demo/Demo5_MultiTask_SegAndCls.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ def forward(self, x):
seg_label_category_dict=seg_label_category_dict,
dim=3)
seg_logits, cls_logits = model(x)
print('multi-label predicted logits')
print('multi_label predicted logits')
_ = [print('seg logits of ', key, ':', seg_logits[key].shape) for key in seg_logits.keys()]
print('-'*30)
_ = [print('cls logits of ', key, ':', cls_logits[key].shape) for key in cls_logits.keys()]

# out
# multi-label predicted logits
# multi_label predicted logits
# seg logits of organ : torch.Size([2, 3, 128, 128, 128])
# seg logits of tumor : torch.Size([2, 2, 128, 128, 128])
# ------------------------------
Expand Down
4 changes: 2 additions & 2 deletions demo/Demo6_UnetwithFPN_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ def forward(self, x):
label_category_dict = dict(organ=3, tumor=4)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi-label predicted logits')
print('multi_label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# out
# multi-label predicted logits
# multi_label predicted logits
# logits of organ : torch.Size([2, 3, 128, 128, 128])
# logits of tumor : torch.Size([2, 4, 128, 128, 128])

Expand Down
4 changes: 2 additions & 2 deletions demo/Demo7_2D_TransUnet_Segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ def forward(self, x):
model = TransUNet(in_channel=1, label_category_dict=label_category_dict, dim=2)
with torch.no_grad():
logits = model(x)
print('multi-label predicted logits')
print('multi_label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# out
# multi-label predicted logits
# multi_label predicted logits
# logits of organ : torch.Size([2, 3, 256, 256])
# logits of tumor : torch.Size([2, 4, 256, 256])
4 changes: 2 additions & 2 deletions demo/Demo8_3D_TransUnet_Segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def forward(self, x):
model = TransUnet(in_channel=1, label_category_dict=label_category_dict, dim=3)
with torch.no_grad():
logits = model(x)
print('multi-label predicted logits')
print('multi_label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# out
# multi-label predicted logits
# multi_label predicted logits
# logits of organ : torch.Size([2, 3, 128, 128, 96])
# logits of tumor : torch.Size([2, 4, 128, 128, 96])

141 changes: 141 additions & 0 deletions demo/Demo_BilinearPooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@

import numpy as np
import torch
from torch import nn
from torch.autograd import Variable

import torch.fft as afft


class CompactBilinearPooling(nn.Module):
"""
from https://github.com/DeepInsight-PCALab/CompactBilinearPooling-Pytorch
Compute compact bilinear pooling over two bottom inputs.
Args:
output_dim: output dimension for compact bilinear pooling.
sum_pool: (Optional) If True, sum the output along height and width
dimensions and return output shape [batch_size, output_dim].
Otherwise return [batch_size, height, width, output_dim].
Default: True.
rand_h_1: (Optional) an 1D numpy array containing indices in interval
`[0, output_dim)`. Automatically generated from `seed_h_1`
if is None.
rand_s_1: (Optional) an 1D numpy array of 1 and -1, having the same shape
as `rand_h_1`. Automatically generated from `seed_s_1` if is
None.
rand_h_2: (Optional) an 1D numpy array containing indices in interval
`[0, output_dim)`. Automatically generated from `seed_h_2`
if is None.
rand_s_2: (Optional) an 1D numpy array of 1 and -1, having the same shape
as `rand_h_2`. Automatically generated from `seed_s_2` if is
None.
"""

def __init__(self, input_dim1, input_dim2, output_dim,
sum_pool=True, cuda=True,
rand_h_1=None, rand_s_1=None, rand_h_2=None, rand_s_2=None):
super(CompactBilinearPooling, self).__init__()
self.input_dim1 = input_dim1
self.input_dim2 = input_dim2
self.output_dim = output_dim
self.sum_pool = sum_pool

if rand_h_1 is None:
np.random.seed(1)
rand_h_1 = np.random.randint(output_dim, size=self.input_dim1)
if rand_s_1 is None:
np.random.seed(3)
rand_s_1 = 2 * np.random.randint(2, size=self.input_dim1) - 1

self.sparse_sketch_matrix1 = Variable(self.generate_sketch_matrix(
rand_h_1, rand_s_1, self.output_dim))

if rand_h_2 is None:
np.random.seed(5)
rand_h_2 = np.random.randint(output_dim, size=self.input_dim2)
if rand_s_2 is None:
np.random.seed(7)
rand_s_2 = 2 * np.random.randint(2, size=self.input_dim2) - 1

self.sparse_sketch_matrix2 = Variable(self.generate_sketch_matrix(
rand_h_2, rand_s_2, self.output_dim))

if cuda:
self.sparse_sketch_matrix1 = self.sparse_sketch_matrix1.cuda()
self.sparse_sketch_matrix2 = self.sparse_sketch_matrix2.cuda()

def forward(self, bottom1, bottom2):
"""
bottom1: 1st input, 4D Tensor of shape [batch_size, input_dim1, height, width].
bottom2: 2nd input, 4D Tensor of shape [batch_size, input_dim2, height, width].
"""
assert bottom1.size(1) == self.input_dim1 and \
bottom2.size(1) == self.input_dim2

batch_size, _, height, width = bottom1.size()

bottom1_flat = bottom1.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim1)
bottom2_flat = bottom2.permute(0, 2, 3, 1).contiguous().view(-1, self.input_dim2)

sketch_1 = bottom1_flat.mm(self.sparse_sketch_matrix1)
sketch_2 = bottom2_flat.mm(self.sparse_sketch_matrix2)

fft1 = afft.fft(sketch_1)
fft2 = afft.fft(sketch_2)

fft_product = fft1 * fft2

cbp_flat = afft.ifft(fft_product).real

cbp = cbp_flat.view(batch_size, height, width, self.output_dim)

if self.sum_pool:
cbp = cbp.sum(dim=1).sum(dim=1)

return cbp

@staticmethod
def generate_sketch_matrix(rand_h, rand_s, output_dim):
"""
Return a sparse matrix used for tensor sketch operation in compact bilinear
pooling
Args:
rand_h: an 1D numpy array containing indices in interval `[0, output_dim)`.
rand_s: an 1D numpy array of 1 and -1, having the same shape as `rand_h`.
output_dim: the output dimensions of compact bilinear pooling.
Returns:
a sparse matrix of shape [input_dim, output_dim] for tensor sketch.
"""

# Generate a sparse matrix for tensor count sketch
rand_h = rand_h.astype(np.int64)
rand_s = rand_s.astype(np.float32)
assert(rand_h.ndim == 1 and rand_s.ndim ==
1 and len(rand_h) == len(rand_s))
assert(np.all(rand_h >= 0) and np.all(rand_h < output_dim))

input_dim = len(rand_h)
indices = np.concatenate((np.arange(input_dim)[..., np.newaxis],
rand_h[..., np.newaxis]), axis=1)
indices = torch.from_numpy(indices)
rand_s = torch.from_numpy(rand_s)
sparse_sketch_matrix = torch.sparse.FloatTensor(
indices.t(), rand_s, torch.Size([input_dim, output_dim]))
return sparse_sketch_matrix.to_dense()


if __name__ == '__main__':

bottom1 = Variable(torch.randn(3, 512, 14, 14))
bottom2 = Variable(torch.randn(3, 128, 14, 14))

layer = CompactBilinearPooling(512, 128, 512, cuda=False)
layer.train()

out = layer(bottom1, bottom2)
print(out.shape)




File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 16c8701

Please sign in to comment.