Skip to content

Commit

Permalink
add demo of 2D/3D transUnet
Browse files Browse the repository at this point in the history
  • Loading branch information
WAMAWAMA committed Nov 11, 2022
1 parent 27810d3 commit a0a0d54
Show file tree
Hide file tree
Showing 10 changed files with 437 additions and 9 deletions.
233 changes: 230 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -611,22 +611,249 @@ if __name__ == '__main__':



<details>
<summary> Demo7: Build a 2D TransUnet for Segmentation </summary>

From paper : *TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation*

Proposed by [Jieneng Chen](https://github.com/Beckschen)

[[paper]](https://arxiv.org/pdf/2102.04306.pdf)
[[official code]](https://github.com/Beckschen/TransUNet)
[Structure of TransUnet] 👇

![transunet](images/transUnet.png)

Demo code 👇
```python
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Decoder import UNet_decoder
from wama_modules.Head import SegmentationHead
from wama_modules.utils import resizeTensor
from transformers import ViTModel
from wama_modules.utils import load_weights, tmp_class


class TransUNet(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()

# encoder
Encoder_f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=Encoder_f_channel_list,
stage_middle_channels=Encoder_f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)

# neck
neck_out_channel = 768
transformer = ViTModel.from_pretrained('google/vit-base-patch32-224-in21k')
configuration = transformer.config
self.trans_downsample_size = configuration.image_size = [8, 8]
configuration.patch_size = [1, 1]
configuration.num_channels = Encoder_f_channel_list[-1]
configuration.encoder_stride = 1 # just for MAE decoder, otherwise this paramater is not used
self.neck = ViTModel(configuration, add_pooling_layer=False)

pretrained_weights = transformer.state_dict()
pretrained_weights['embeddings.position_embeddings'] = self.neck.state_dict()[
'embeddings.position_embeddings']
pretrained_weights['embeddings.patch_embeddings.projection.weight'] = self.neck.state_dict()[
'embeddings.patch_embeddings.projection.weight']
pretrained_weights['embeddings.patch_embeddings.projection.bias'] = self.neck.state_dict()[
'embeddings.patch_embeddings.projection.bias']
self.neck = load_weights(self.neck, pretrained_weights) # reload pretrained weights

# decoder
Decoder_f_channel_list = [32, 64, 128]
self.decoder = UNet_decoder(
in_channels_list=Encoder_f_channel_list[:-1]+[neck_out_channel],
skip_connection=[True, True, True],
out_channels_list=Decoder_f_channel_list,
dim=dim)

# seg head
self.seg_head = SegmentationHead(
label_category_dict,
Decoder_f_channel_list[0],
dim=dim)

def forward(self, x):
# encoder forward
multi_scale_encoder = self.encoder(x)

# neck forward
f_neck = self.neck(resizeTensor(multi_scale_encoder[-1], size=self.trans_downsample_size))
f_neck = f_neck.last_hidden_state
f_neck = f_neck[:, 1:] # remove class token
f_neck = f_neck.permute(0, 2, 1)
f_neck = f_neck.reshape(
f_neck.shape[0],
f_neck.shape[1],
self.trans_downsample_size[0],
self.trans_downsample_size[1]
) # reshape
f_neck = resizeTensor(f_neck, size=multi_scale_encoder[-1].shape[2:])
multi_scale_encoder[-1] = f_neck

# decoder forward
multi_scale_decoder = self.decoder(multi_scale_encoder)
f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:])

# seg_head forward
logits = self.seg_head(f_for_seg)
return logits


if __name__ == '__main__':
x = torch.ones([2, 1, 256, 256])
label_category_dict = dict(organ=3, tumor=4)
model = TransUNet(in_channel=1, label_category_dict=label_category_dict, dim=2)
with torch.no_grad():
logits = model(x)
print('multi-label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# out
# multi-label predicted logits
# logits of organ : torch.Size([2, 3, 256, 256])
# logits of tumor : torch.Size([2, 4, 256, 256])
```
</details>

*Todo-demo list ( 🚧 under preparation and coming soon...) ↓



<details>
<summary> Demo: Build a TransUnet </summary>

<summary> Demo8: Build a 3D TransUnet for Segmentation </summary>

*Original TransUnet only recieves 2D input.
So if we want to build a 3D TransUnet, with the `tensor.reshape` operation in torch,
we can temporarily convert 3D featuremap to 2D featuremap in the middle process,
and then convert it back to 3D featuremap. You can find this process in the code of `neck forward`

```python
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Decoder import UNet_decoder
from wama_modules.Head import SegmentationHead
from wama_modules.utils import resizeTensor
from transformers import ViTModel
from wama_modules.utils import load_weights, tmp_class


class TransUnet(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()

# encoder
Encoder_f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=Encoder_f_channel_list,
stage_middle_channels=Encoder_f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)

# neck
neck_out_channel = 768
transformer = ViTModel.from_pretrained('google/vit-base-patch32-224-in21k')
configuration = transformer.config
self.trans_size_3D = [8, 8, 4]
self.trans_size = configuration.image_size = [
self.trans_size_3D[0], self.trans_size_3D[1]*self.trans_size_3D[2]
]
configuration.patch_size = [1, 1]
configuration.num_channels = Encoder_f_channel_list[-1]
configuration.encoder_stride = 1 # just for MAE decoder, otherwise this paramater is not used
self.neck = ViTModel(configuration, add_pooling_layer=False)

pretrained_weights = transformer.state_dict()
pretrained_weights['embeddings.position_embeddings'] = self.neck.state_dict()[
'embeddings.position_embeddings']
pretrained_weights['embeddings.patch_embeddings.projection.weight'] = self.neck.state_dict()[
'embeddings.patch_embeddings.projection.weight']
pretrained_weights['embeddings.patch_embeddings.projection.bias'] = self.neck.state_dict()[
'embeddings.patch_embeddings.projection.bias']
self.neck = load_weights(self.neck, pretrained_weights) # reload pretrained weights

# decoder
Decoder_f_channel_list = [32, 64, 128]
self.decoder = UNet_decoder(
in_channels_list=Encoder_f_channel_list[:-1]+[neck_out_channel],
skip_connection=[True, True, True],
out_channels_list=Decoder_f_channel_list,
dim=dim)

# seg head
self.seg_head = SegmentationHead(
label_category_dict,
Decoder_f_channel_list[0],
dim=dim)

def forward(self, x):
# encoder forward
multi_scale_encoder = self.encoder(x)

# neck forward
neck_input = resizeTensor(multi_scale_encoder[-1], size=self.trans_size_3D)
neck_input = neck_input.reshape(neck_input.shape[0], neck_input.shape[1], *self.trans_size) # 3D to 2D
f_neck = self.neck(neck_input)
f_neck = f_neck.last_hidden_state
f_neck = f_neck[:, 1:] # remove class token
f_neck = f_neck.permute(0, 2, 1)
f_neck = f_neck.reshape(
f_neck.shape[0],
f_neck.shape[1],
self.trans_size[0],
self.trans_size[1]
) # reshape
f_neck = f_neck.reshape(f_neck.shape[0], f_neck.shape[1], *self.trans_size_3D) # 2D to 3D
f_neck = resizeTensor(f_neck, size=multi_scale_encoder[-1].shape[2:])
multi_scale_encoder[-1] = f_neck

# decoder forward
multi_scale_decoder = self.decoder(multi_scale_encoder)
f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:])

# seg_head forward
logits = self.seg_head(f_for_seg)
return logits


if __name__ == '__main__':
x = torch.ones([2, 1, 128, 128, 96])
label_category_dict = dict(organ=3, tumor=4)
model = TransUnet(in_channel=1, label_category_dict=label_category_dict, dim=3)
with torch.no_grad():
logits = model(x)
print('multi-label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# out
# multi-label predicted logits
# logits of organ : torch.Size([2, 3, 128, 128, 96])
# logits of tumor : torch.Size([2, 4, 128, 128, 96])
```

</details>



*Todo-demo list ( 🚧 under preparation and coming soon...) ↓




<details>
<summary> Demo: Build a C-tran model for multi-label classification</summary>

Expand Down
98 changes: 98 additions & 0 deletions demo/Demo7_2D_TransUnet_Segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Decoder import UNet_decoder
from wama_modules.Head import SegmentationHead
from wama_modules.utils import resizeTensor
from transformers import ViTModel
from wama_modules.utils import load_weights, tmp_class


class TransUNet(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()

# encoder
Encoder_f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=Encoder_f_channel_list,
stage_middle_channels=Encoder_f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)

# neck
neck_out_channel = 768
transformer = ViTModel.from_pretrained('google/vit-base-patch32-224-in21k')
configuration = transformer.config
self.trans_downsample_size = configuration.image_size = [8, 8]
configuration.patch_size = [1, 1]
configuration.num_channels = Encoder_f_channel_list[-1]
configuration.encoder_stride = 1 # just for MAE decoder, otherwise this paramater is not used
self.neck = ViTModel(configuration, add_pooling_layer=False)

pretrained_weights = transformer.state_dict()
pretrained_weights['embeddings.position_embeddings'] = self.neck.state_dict()[
'embeddings.position_embeddings']
pretrained_weights['embeddings.patch_embeddings.projection.weight'] = self.neck.state_dict()[
'embeddings.patch_embeddings.projection.weight']
pretrained_weights['embeddings.patch_embeddings.projection.bias'] = self.neck.state_dict()[
'embeddings.patch_embeddings.projection.bias']
self.neck = load_weights(self.neck, pretrained_weights) # reload pretrained weights

# decoder
Decoder_f_channel_list = [32, 64, 128]
self.decoder = UNet_decoder(
in_channels_list=Encoder_f_channel_list[:-1]+[neck_out_channel],
skip_connection=[True, True, True],
out_channels_list=Decoder_f_channel_list,
dim=dim)

# seg head
self.seg_head = SegmentationHead(
label_category_dict,
Decoder_f_channel_list[0],
dim=dim)

def forward(self, x):
# encoder forward
multi_scale_encoder = self.encoder(x)

# neck forward
f_neck = self.neck(resizeTensor(multi_scale_encoder[-1], size=self.trans_downsample_size))
f_neck = f_neck.last_hidden_state
f_neck = f_neck[:, 1:] # remove class token
f_neck = f_neck.permute(0, 2, 1)
f_neck = f_neck.reshape(
f_neck.shape[0],
f_neck.shape[1],
self.trans_downsample_size[0],
self.trans_downsample_size[1]
) # reshape
f_neck = resizeTensor(f_neck, size=multi_scale_encoder[-1].shape[2:])
multi_scale_encoder[-1] = f_neck

# decoder forward
multi_scale_decoder = self.decoder(multi_scale_encoder)
f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:])

# seg_head forward
logits = self.seg_head(f_for_seg)
return logits


if __name__ == '__main__':
x = torch.ones([2, 1, 256, 256])
label_category_dict = dict(organ=3, tumor=4)
model = TransUNet(in_channel=1, label_category_dict=label_category_dict, dim=2)
with torch.no_grad():
logits = model(x)
print('multi-label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]

# out
# multi-label predicted logits
# logits of organ : torch.Size([2, 3, 256, 256])
# logits of tumor : torch.Size([2, 4, 256, 256])
Loading

0 comments on commit a0a0d54

Please sign in to comment.