-
Notifications
You must be signed in to change notification settings - Fork 63
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
26 changed files
with
994 additions
and
349 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
import torch.nn as nn | ||
from wama_modules.Encoder import VGGEncoder | ||
from wama_modules.Head import ClassificationHead | ||
from wama_modules.BaseModule import GlobalMaxPool | ||
|
||
|
||
class Model(nn.Module): | ||
def __init__(self, in_channel, label_category_dict, dim=2): | ||
super().__init__() | ||
# encoder | ||
f_channel_list = [64, 128, 256, 512] | ||
self.encoder = VGGEncoder( | ||
in_channel, | ||
stage_output_channels=f_channel_list, | ||
blocks=[1, 2, 3, 4], | ||
downsample_ration=[0.5, 0.5, 0.5, 0.5], | ||
dim=dim) | ||
# cls head | ||
self.cls_head = ClassificationHead(label_category_dict, f_channel_list[-1]) | ||
self.pooling = GlobalMaxPool() | ||
|
||
def forward(self, x): | ||
f = self.encoder(x) | ||
logits = self.cls_head(self.pooling(f[-1])) | ||
return logits | ||
|
||
|
||
if __name__ == '__main__': | ||
x = torch.ones([2, 1, 64, 64, 64]) | ||
label_category_dict = dict(is_malignant=4) | ||
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3) | ||
logits = model(x) | ||
print('single-label predicted logits') | ||
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()] | ||
|
||
# output 👇 | ||
# single-label predicted logits | ||
# logits of is_malignant : torch.Size([2, 4]) |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import torch | ||
import torch.nn as nn | ||
from wama_modules.Encoder import ResNetEncoder | ||
from wama_modules.Head import ClassificationHead | ||
from wama_modules.BaseModule import GlobalMaxPool | ||
|
||
|
||
class Model(nn.Module): | ||
def __init__(self, in_channel, label_category_dict, dim=2): | ||
super().__init__() | ||
# encoder | ||
f_channel_list = [64, 128, 256, 512] | ||
self.encoder = ResNetEncoder( | ||
in_channel, | ||
stage_output_channels=f_channel_list, | ||
stage_middle_channels=f_channel_list, | ||
blocks=[1, 2, 3, 4], | ||
type='131', | ||
downsample_ration=[0.5, 0.5, 0.5, 0.5], | ||
dim=dim) | ||
# cls head | ||
self.cls_head = ClassificationHead(label_category_dict, f_channel_list[-1]) | ||
self.pooling = GlobalMaxPool() | ||
|
||
def forward(self, x): | ||
f = self.encoder(x) | ||
logits = self.cls_head(self.pooling(f[-1])) | ||
return logits | ||
|
||
|
||
if __name__ == '__main__': | ||
x = torch.ones([2, 1, 64, 64, 64]) | ||
label_category_dict = dict(is_malignant=4) | ||
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3) | ||
logits = model(x) | ||
print('single-label predicted logits') | ||
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()] | ||
|
||
# output 👇 | ||
# single-label predicted logits | ||
# logits of is_malignant : torch.Size([2, 4]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch | ||
import torch.nn as nn | ||
from wama_modules.Encoder import ResNetEncoder | ||
from wama_modules.Head import ClassificationHead | ||
from wama_modules.BaseModule import GlobalMaxPool | ||
|
||
|
||
class Model(nn.Module): | ||
def __init__(self, in_channel, label_category_dict, dim=2): | ||
super().__init__() | ||
# encoder | ||
f_channel_list = [64, 128, 256, 512] | ||
self.encoder = ResNetEncoder( | ||
in_channel, | ||
stage_output_channels=f_channel_list, | ||
stage_middle_channels=f_channel_list, | ||
blocks=[1, 2, 3, 4], | ||
type='131', | ||
downsample_ration=[0.5, 0.5, 0.5, 0.5], | ||
dim=dim) | ||
# cls head | ||
self.cls_head = ClassificationHead(label_category_dict, f_channel_list[-1]) | ||
|
||
self.pooling = GlobalMaxPool() | ||
|
||
def forward(self, x): | ||
f = self.encoder(x) | ||
logits = self.cls_head(self.pooling(f[-1])) | ||
return logits | ||
|
||
|
||
if __name__ == '__main__': | ||
x = torch.ones([2, 1, 64, 64, 64]) | ||
label_category_dict = dict(shape=4, color=3, other=13) | ||
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3) | ||
logits = model(x) | ||
print('multi-label predicted logits') | ||
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()] | ||
|
||
# out | ||
# multi-label predicted logits | ||
# logits of shape : torch.Size([2, 4]) | ||
# logits of color : torch.Size([2, 3]) | ||
# logits of other : torch.Size([2, 13]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import torch | ||
import torch.nn as nn | ||
from wama_modules.Encoder import ResNetEncoder | ||
from wama_modules.Decoder import UNet_decoder | ||
from wama_modules.Head import SegmentationHead | ||
from wama_modules.utils import resizeTensor | ||
|
||
|
||
class Model(nn.Module): | ||
def __init__(self, in_channel, label_category_dict, dim=2): | ||
super().__init__() | ||
# encoder | ||
Encoder_f_channel_list = [64, 128, 256, 512] | ||
self.encoder = ResNetEncoder( | ||
in_channel, | ||
stage_output_channels=Encoder_f_channel_list, | ||
stage_middle_channels=Encoder_f_channel_list, | ||
blocks=[1, 2, 3, 4], | ||
type='131', | ||
downsample_ration=[0.5, 0.5, 0.5, 0.5], | ||
dim=dim) | ||
# decoder | ||
Decoder_f_channel_list = [32, 64, 128] | ||
self.decoder = UNet_decoder( | ||
in_channels_list=Encoder_f_channel_list, | ||
skip_connection=[False, True, True], | ||
out_channels_list=Decoder_f_channel_list, | ||
dim=dim) | ||
# seg head | ||
self.seg_head = SegmentationHead( | ||
label_category_dict, | ||
Decoder_f_channel_list[0], | ||
dim=dim) | ||
|
||
def forward(self, x): | ||
multi_scale_f1 = self.encoder(x) | ||
multi_scale_f2 = self.decoder(multi_scale_f1) | ||
f_for_seg = resizeTensor(multi_scale_f2[0], size=x.shape[2:]) | ||
logits = self.seg_head(f_for_seg) | ||
return logits | ||
|
||
|
||
if __name__ == '__main__': | ||
x = torch.ones([2, 1, 128, 128, 128]) | ||
label_category_dict = dict(organ=3) | ||
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3) | ||
logits = model(x) | ||
print('multi-label predicted logits') | ||
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()] | ||
|
||
# out | ||
# multi-label predicted logits | ||
# logits of organ : torch.Size([2, 3, 128, 128, 128]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import torch | ||
import torch.nn as nn | ||
from wama_modules.Encoder import ResNetEncoder | ||
from wama_modules.Decoder import UNet_decoder | ||
from wama_modules.Head import SegmentationHead | ||
from wama_modules.utils import resizeTensor | ||
|
||
|
||
class Model(nn.Module): | ||
def __init__(self, in_channel, label_category_dict, dim=2): | ||
super().__init__() | ||
# encoder | ||
Encoder_f_channel_list = [64, 128, 256, 512] | ||
self.encoder = ResNetEncoder( | ||
in_channel, | ||
stage_output_channels=Encoder_f_channel_list, | ||
stage_middle_channels=Encoder_f_channel_list, | ||
blocks=[1, 2, 3, 4], | ||
type='131', | ||
downsample_ration=[0.5, 0.5, 0.5, 0.5], | ||
dim=dim) | ||
# decoder | ||
Decoder_f_channel_list = [32, 64, 128] | ||
self.decoder = UNet_decoder( | ||
in_channels_list=Encoder_f_channel_list, | ||
skip_connection=[False, True, True], | ||
out_channels_list=Decoder_f_channel_list, | ||
dim=dim) | ||
# seg head | ||
self.seg_head = SegmentationHead( | ||
label_category_dict, | ||
Decoder_f_channel_list[0], | ||
dim=dim) | ||
|
||
def forward(self, x): | ||
multi_scale_f1 = self.encoder(x) | ||
multi_scale_f2 = self.decoder(multi_scale_f1) | ||
f_for_seg = resizeTensor(multi_scale_f2[0], size=x.shape[2:]) | ||
logits = self.seg_head(f_for_seg) | ||
return logits | ||
|
||
|
||
if __name__ == '__main__': | ||
x = torch.ones([2, 1, 128, 128, 128]) | ||
label_category_dict = dict(organ=3, tumor=4) | ||
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3) | ||
logits = model(x) | ||
print('multi-label predicted logits') | ||
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()] | ||
|
||
# out | ||
# multi-label predicted logits | ||
# logits of organ : torch.Size([2, 3, 128, 128, 128]) | ||
# logits of tumor : torch.Size([2, 4, 128, 128, 128]) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
import torch | ||
import torch.nn as nn | ||
from wama_modules.Encoder import ResNetEncoder | ||
from wama_modules.Decoder import UNet_decoder | ||
from wama_modules.Head import SegmentationHead, ClassificationHead | ||
from wama_modules.utils import resizeTensor | ||
from wama_modules.BaseModule import GlobalMaxPool | ||
|
||
|
||
class Model(nn.Module): | ||
def __init__(self, | ||
in_channel, | ||
seg_label_category_dict, | ||
cls_label_category_dict, | ||
dim=2): | ||
super().__init__() | ||
# encoder | ||
Encoder_f_channel_list = [64, 128, 256, 512] | ||
self.encoder = ResNetEncoder( | ||
in_channel, | ||
stage_output_channels=Encoder_f_channel_list, | ||
stage_middle_channels=Encoder_f_channel_list, | ||
blocks=[1, 2, 3, 4], | ||
type='131', | ||
downsample_ration=[0.5, 0.5, 0.5, 0.5], | ||
dim=dim) | ||
# decoder | ||
Decoder_f_channel_list = [32, 64, 128] | ||
self.decoder = UNet_decoder( | ||
in_channels_list=Encoder_f_channel_list, | ||
skip_connection=[False, True, True], | ||
out_channels_list=Decoder_f_channel_list, | ||
dim=dim) | ||
# seg head | ||
self.seg_head = SegmentationHead( | ||
seg_label_category_dict, | ||
Decoder_f_channel_list[0], | ||
dim=dim) | ||
# cls head | ||
self.cls_head = ClassificationHead(cls_label_category_dict, Encoder_f_channel_list[-1]) | ||
|
||
# pooling | ||
self.pooling = GlobalMaxPool() | ||
|
||
def forward(self, x): | ||
# get encoder features | ||
multi_scale_encoder = self.encoder(x) | ||
# get decoder features | ||
multi_scale_decoder = self.decoder(multi_scale_encoder) | ||
# perform segmentation | ||
f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:]) | ||
seg_logits = self.seg_head(f_for_seg) | ||
# perform classification | ||
cls_logits = self.cls_head(self.pooling(multi_scale_encoder[-1])) | ||
return seg_logits, cls_logits | ||
|
||
if __name__ == '__main__': | ||
x = torch.ones([2, 1, 128, 128, 128]) | ||
seg_label_category_dict = dict(organ=3, tumor=2) | ||
cls_label_category_dict = dict(shape=4, color=3, other=13) | ||
model = Model( | ||
in_channel=1, | ||
cls_label_category_dict=cls_label_category_dict, | ||
seg_label_category_dict=seg_label_category_dict, | ||
dim=3) | ||
seg_logits, cls_logits = model(x) | ||
print('multi-label predicted logits') | ||
_ = [print('seg logits of ', key, ':', seg_logits[key].shape) for key in seg_logits.keys()] | ||
print('-'*30) | ||
_ = [print('cls logits of ', key, ':', cls_logits[key].shape) for key in cls_logits.keys()] | ||
|
||
# out | ||
# multi-label predicted logits | ||
# seg logits of organ : torch.Size([2, 3, 128, 128, 128]) | ||
# seg logits of tumor : torch.Size([2, 2, 128, 128, 128]) | ||
# ------------------------------ | ||
# cls logits of shape : torch.Size([2, 4]) | ||
# cls logits of color : torch.Size([2, 3]) | ||
# cls logits of other : torch.Size([2, 13]) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import torch | ||
import torch.nn as nn | ||
from wama_modules.Encoder import ResNetEncoder | ||
from wama_modules.Decoder import UNet_decoder | ||
from wama_modules.Head import SegmentationHead | ||
from wama_modules.utils import resizeTensor | ||
from wama_modules.Neck import FPN | ||
|
||
|
||
class Model(nn.Module): | ||
def __init__(self, in_channel, label_category_dict, dim=2): | ||
super().__init__() | ||
# encoder | ||
Encoder_f_channel_list = [64, 128, 256, 512] | ||
self.encoder = ResNetEncoder( | ||
in_channel, | ||
stage_output_channels=Encoder_f_channel_list, | ||
stage_middle_channels=Encoder_f_channel_list, | ||
blocks=[1, 2, 3, 4], | ||
type='131', | ||
downsample_ration=[0.5, 0.5, 0.5, 0.5], | ||
dim=dim) | ||
|
||
# neck | ||
FPN_output_channel = 256 | ||
FPN_channels = [FPN_output_channel]*len(Encoder_f_channel_list) | ||
self.neck = FPN(in_channels_list=Encoder_f_channel_list, | ||
c1=FPN_output_channel//2, | ||
c2=FPN_output_channel, | ||
mode='AddSmall2Big', | ||
dim=dim,) | ||
|
||
# decoder | ||
Decoder_f_channel_list = [32, 64, 128] | ||
self.decoder = UNet_decoder( | ||
in_channels_list=FPN_channels, | ||
skip_connection=[True, True, True], | ||
out_channels_list=Decoder_f_channel_list, | ||
dim=dim) | ||
# seg head | ||
self.seg_head = SegmentationHead( | ||
label_category_dict, | ||
Decoder_f_channel_list[0], | ||
dim=dim) | ||
|
||
def forward(self, x): | ||
multi_scale_encoder = self.encoder(x) | ||
multi_scale_neck = self.neck(multi_scale_encoder) | ||
multi_scale_decoder = self.decoder(multi_scale_neck) | ||
f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:]) | ||
logits = self.seg_head(f_for_seg) | ||
return logits | ||
|
||
|
||
if __name__ == '__main__': | ||
x = torch.ones([2, 1, 128, 128, 128]) | ||
label_category_dict = dict(organ=3, tumor=4) | ||
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3) | ||
logits = model(x) | ||
print('multi-label predicted logits') | ||
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()] | ||
|
||
# out | ||
# multi-label predicted logits | ||
# logits of organ : torch.Size([2, 3, 128, 128, 128]) | ||
# logits of tumor : torch.Size([2, 4, 128, 128, 128]) | ||
|
||
|
Empty file.
Empty file.
Empty file.
Oops, something went wrong.