Skip to content

Commit

Permalink
ready go
Browse files Browse the repository at this point in the history
  • Loading branch information
WAMAWAMA committed Nov 11, 2022
1 parent 9566e5d commit b311aeb
Show file tree
Hide file tree
Showing 26 changed files with 994 additions and 349 deletions.
479 changes: 445 additions & 34 deletions README.md

Large diffs are not rendered by default.

39 changes: 39 additions & 0 deletions demo/Demo0_VGG_SingleLabelClassification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch
import torch.nn as nn
from wama_modules.Encoder import VGGEncoder
from wama_modules.Head import ClassificationHead
from wama_modules.BaseModule import GlobalMaxPool


class Model(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
f_channel_list = [64, 128, 256, 512]
self.encoder = VGGEncoder(
in_channel,
stage_output_channels=f_channel_list,
blocks=[1, 2, 3, 4],
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# cls head
self.cls_head = ClassificationHead(label_category_dict, f_channel_list[-1])
self.pooling = GlobalMaxPool()

def forward(self, x):
f = self.encoder(x)
logits = self.cls_head(self.pooling(f[-1]))
return logits


if __name__ == '__main__':
x = torch.ones([2, 1, 64, 64, 64])
label_category_dict = dict(is_malignant=4)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('single-label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# output 👇
# single-label predicted logits
# logits of is_malignant : torch.Size([2, 4])
Empty file.
Empty file.
41 changes: 41 additions & 0 deletions demo/Demo1_ResNet_SingleLabelClassification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Head import ClassificationHead
from wama_modules.BaseModule import GlobalMaxPool


class Model(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=f_channel_list,
stage_middle_channels=f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# cls head
self.cls_head = ClassificationHead(label_category_dict, f_channel_list[-1])
self.pooling = GlobalMaxPool()

def forward(self, x):
f = self.encoder(x)
logits = self.cls_head(self.pooling(f[-1]))
return logits


if __name__ == '__main__':
x = torch.ones([2, 1, 64, 64, 64])
label_category_dict = dict(is_malignant=4)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('single-label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# output 👇
# single-label predicted logits
# logits of is_malignant : torch.Size([2, 4])
44 changes: 44 additions & 0 deletions demo/Demo2_ResNet_MultiLabelClassification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Head import ClassificationHead
from wama_modules.BaseModule import GlobalMaxPool


class Model(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=f_channel_list,
stage_middle_channels=f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# cls head
self.cls_head = ClassificationHead(label_category_dict, f_channel_list[-1])

self.pooling = GlobalMaxPool()

def forward(self, x):
f = self.encoder(x)
logits = self.cls_head(self.pooling(f[-1]))
return logits


if __name__ == '__main__':
x = torch.ones([2, 1, 64, 64, 64])
label_category_dict = dict(shape=4, color=3, other=13)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi-label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# out
# multi-label predicted logits
# logits of shape : torch.Size([2, 4])
# logits of color : torch.Size([2, 3])
# logits of other : torch.Size([2, 13])
53 changes: 53 additions & 0 deletions demo/Demo3_ResNetUnet_SingleLabelSegmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Decoder import UNet_decoder
from wama_modules.Head import SegmentationHead
from wama_modules.utils import resizeTensor


class Model(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
Encoder_f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=Encoder_f_channel_list,
stage_middle_channels=Encoder_f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# decoder
Decoder_f_channel_list = [32, 64, 128]
self.decoder = UNet_decoder(
in_channels_list=Encoder_f_channel_list,
skip_connection=[False, True, True],
out_channels_list=Decoder_f_channel_list,
dim=dim)
# seg head
self.seg_head = SegmentationHead(
label_category_dict,
Decoder_f_channel_list[0],
dim=dim)

def forward(self, x):
multi_scale_f1 = self.encoder(x)
multi_scale_f2 = self.decoder(multi_scale_f1)
f_for_seg = resizeTensor(multi_scale_f2[0], size=x.shape[2:])
logits = self.seg_head(f_for_seg)
return logits


if __name__ == '__main__':
x = torch.ones([2, 1, 128, 128, 128])
label_category_dict = dict(organ=3)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi-label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# out
# multi-label predicted logits
# logits of organ : torch.Size([2, 3, 128, 128, 128])
56 changes: 56 additions & 0 deletions demo/Demo4_ResNetUnet_MultiLabelSegmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Decoder import UNet_decoder
from wama_modules.Head import SegmentationHead
from wama_modules.utils import resizeTensor


class Model(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
Encoder_f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=Encoder_f_channel_list,
stage_middle_channels=Encoder_f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# decoder
Decoder_f_channel_list = [32, 64, 128]
self.decoder = UNet_decoder(
in_channels_list=Encoder_f_channel_list,
skip_connection=[False, True, True],
out_channels_list=Decoder_f_channel_list,
dim=dim)
# seg head
self.seg_head = SegmentationHead(
label_category_dict,
Decoder_f_channel_list[0],
dim=dim)

def forward(self, x):
multi_scale_f1 = self.encoder(x)
multi_scale_f2 = self.decoder(multi_scale_f1)
f_for_seg = resizeTensor(multi_scale_f2[0], size=x.shape[2:])
logits = self.seg_head(f_for_seg)
return logits


if __name__ == '__main__':
x = torch.ones([2, 1, 128, 128, 128])
label_category_dict = dict(organ=3, tumor=4)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi-label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# out
# multi-label predicted logits
# logits of organ : torch.Size([2, 3, 128, 128, 128])
# logits of tumor : torch.Size([2, 4, 128, 128, 128])


81 changes: 81 additions & 0 deletions demo/Demo5_MultiTask_SegAndCls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Decoder import UNet_decoder
from wama_modules.Head import SegmentationHead, ClassificationHead
from wama_modules.utils import resizeTensor
from wama_modules.BaseModule import GlobalMaxPool


class Model(nn.Module):
def __init__(self,
in_channel,
seg_label_category_dict,
cls_label_category_dict,
dim=2):
super().__init__()
# encoder
Encoder_f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=Encoder_f_channel_list,
stage_middle_channels=Encoder_f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# decoder
Decoder_f_channel_list = [32, 64, 128]
self.decoder = UNet_decoder(
in_channels_list=Encoder_f_channel_list,
skip_connection=[False, True, True],
out_channels_list=Decoder_f_channel_list,
dim=dim)
# seg head
self.seg_head = SegmentationHead(
seg_label_category_dict,
Decoder_f_channel_list[0],
dim=dim)
# cls head
self.cls_head = ClassificationHead(cls_label_category_dict, Encoder_f_channel_list[-1])

# pooling
self.pooling = GlobalMaxPool()

def forward(self, x):
# get encoder features
multi_scale_encoder = self.encoder(x)
# get decoder features
multi_scale_decoder = self.decoder(multi_scale_encoder)
# perform segmentation
f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:])
seg_logits = self.seg_head(f_for_seg)
# perform classification
cls_logits = self.cls_head(self.pooling(multi_scale_encoder[-1]))
return seg_logits, cls_logits

if __name__ == '__main__':
x = torch.ones([2, 1, 128, 128, 128])
seg_label_category_dict = dict(organ=3, tumor=2)
cls_label_category_dict = dict(shape=4, color=3, other=13)
model = Model(
in_channel=1,
cls_label_category_dict=cls_label_category_dict,
seg_label_category_dict=seg_label_category_dict,
dim=3)
seg_logits, cls_logits = model(x)
print('multi-label predicted logits')
_ = [print('seg logits of ', key, ':', seg_logits[key].shape) for key in seg_logits.keys()]
print('-'*30)
_ = [print('cls logits of ', key, ':', cls_logits[key].shape) for key in cls_logits.keys()]

# out
# multi-label predicted logits
# seg logits of organ : torch.Size([2, 3, 128, 128, 128])
# seg logits of tumor : torch.Size([2, 2, 128, 128, 128])
# ------------------------------
# cls logits of shape : torch.Size([2, 4])
# cls logits of color : torch.Size([2, 3])
# cls logits of other : torch.Size([2, 13])


68 changes: 68 additions & 0 deletions demo/Demo6_UnetwithFPN_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Decoder import UNet_decoder
from wama_modules.Head import SegmentationHead
from wama_modules.utils import resizeTensor
from wama_modules.Neck import FPN


class Model(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
Encoder_f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=Encoder_f_channel_list,
stage_middle_channels=Encoder_f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)

# neck
FPN_output_channel = 256
FPN_channels = [FPN_output_channel]*len(Encoder_f_channel_list)
self.neck = FPN(in_channels_list=Encoder_f_channel_list,
c1=FPN_output_channel//2,
c2=FPN_output_channel,
mode='AddSmall2Big',
dim=dim,)

# decoder
Decoder_f_channel_list = [32, 64, 128]
self.decoder = UNet_decoder(
in_channels_list=FPN_channels,
skip_connection=[True, True, True],
out_channels_list=Decoder_f_channel_list,
dim=dim)
# seg head
self.seg_head = SegmentationHead(
label_category_dict,
Decoder_f_channel_list[0],
dim=dim)

def forward(self, x):
multi_scale_encoder = self.encoder(x)
multi_scale_neck = self.neck(multi_scale_encoder)
multi_scale_decoder = self.decoder(multi_scale_neck)
f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:])
logits = self.seg_head(f_for_seg)
return logits


if __name__ == '__main__':
x = torch.ones([2, 1, 128, 128, 128])
label_category_dict = dict(organ=3, tumor=4)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi-label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# out
# multi-label predicted logits
# logits of organ : torch.Size([2, 3, 128, 128, 128])
# logits of tumor : torch.Size([2, 4, 128, 128, 128])


Empty file.
Empty file.
Empty file.
Loading

0 comments on commit b311aeb

Please sign in to comment.