Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
WAMAWAMA committed Feb 5, 2023
1 parent 3c2d1f6 commit e90d479
Show file tree
Hide file tree
Showing 19 changed files with 177 additions and 10 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,24 @@ Pretrained weights:



Quick start with demo codes of 6 different novel multi-label network structures

*all model codes are re-constructed in a very simple way*

|Network| Publication | Demo code | Paper link| Support multi-class per label|
|---|---|---|---|---|
|CNNRNN|CVPR2016|[code](demo/multi_label/Demo_CVPR2016_MultiLabel_CNNRNN.py)|[link](http:https://openaccess.thecvf.com/content_cvpr_2016/html/Wang_CNN-RNN_A_Unified_CVPR_2016_paper.html)|×|
|ML-GCN|CVPR2019|[code](demo/multi_label/Demo_CVPR2019_MultiLabel_ML_GCN.py)|[link](https://arxiv.org/abs/1904.03582)|×|
|SSGRL|ICCV2019|[code](demo/multi_label/Demo_ICCV2019_MultiLabel_SSGRL.py)|[link](https://arxiv.org/abs/1908.07325)||
|C-tran|CVPR2021|[code](demo/multi_label/Demo_CVPR2021_MultiLabel_C_tran.py)|[link](https://arxiv.org/abs/2011.14027)||
|ML-decoder|arxiv2021|[code](demo/multi_label/Demo_Arxiv2021_MultiLabel_ML_decoder.py)|[link](http:https://arxiv.org/abs/2111.12933)||
|Q2L|arxiv2021|[code](demo/multi_label/Demo_ArXiv2021_MultiLabel_Query2Label.py)|[link](https://arxiv.org/abs/2107.10834)||






## 1. Installation
🔥 [wama_modules](https://github.com/WAMAWAMA/wama_modules)
`Basic` `1D/2D/3D`
Expand Down
169 changes: 159 additions & 10 deletions demo/Demo_RelationNet4LesionSeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
from wama_modules.BaseModule import MakeNorm, GlobalMaxPool
from wama_modules.Decoder import UNet_decoder
from wama_modules.Head import SegmentationHead
from wama_modules.utils import load_weights, resizeTensor, tmp_class
from wama_modules.utils import load_weights, resizeTensor, tmp_class, tensor2array
from wama_modules.thirdparty_lib.MedicalNet_Tencent.model import generate_model
from wama_modules.PositionEmbedding import PositionalEncoding_1D_learnable
import matplotlib.pyplot as plt
import numpy as np

def show2D(img):
plt.imshow(img)
plt.show()


class TransformerEncoder(nn.Module):
Expand Down Expand Up @@ -201,6 +207,7 @@ def forward(self, x):

# achieve lesion token
lesion_token = self.pool(f_encoder[-1])
print(lesion_token.shape)

# relation module forward
if self.attention_type == 'cross':
Expand Down Expand Up @@ -228,17 +235,142 @@ def forward(self, x):
lesion_seg_map = self.seg_head_lesion(f_for_seg)


return lesion_seg_map, organ_seg_map
return lesion_seg_map, organ_seg_map, organ_seg_map_tensor


class RalationNet_v2(nn.Module):
def __init__(self,
organ_num=16, # actually, it should be organ_num+1(background)
encoder_weights_pth=None, # encoder_weights_pth=r"D:\pretrainedweights\MedicalNet_Tencent\MedicalNet_weights\resnet_18_23dataset.pth"
attention_type='cross', # cross or self
relation_layer = 2,
relation_head = 8,
add_organ_embeddings = False,
dim=3):
super().__init__()
# self = tmp_class() # for debug
self.organ_num = organ_num
self.attention_type = attention_type
self.add_organ_embeddings = add_organ_embeddings
self.dim = dim

# encoder from thirdparty_lib.MedicalNet_Tencent
Encoder_f_channel_list = [64, 64, 128, 256, 512]
self.encoder = generate_model(18)
if encoder_weights_pth is not None:
pretrain_weights = torch.load(encoder_weights_pth, map_location='cpu')['state_dict']
self.encoder = load_weights(self.encoder, pretrain_weights, drop_modelDOT=True, silence=True)

# decoder for lesion
Decoder_f_channel_list = [32, 64, 128, 256]
self.decoder = UNet_decoder(
in_channels_list=Encoder_f_channel_list,
skip_connection=[True, True, True, True],
out_channels_list=Decoder_f_channel_list,
dim=dim)

# seg head for lesion
self.seg_head_lesion = SegmentationHead(
label_category_dict=dict(lesion=1),
in_channel=Decoder_f_channel_list[0],
dim=dim)

# seg head for organ
self.seg_head_organ = SegmentationHead(
label_category_dict=dict(organ=organ_num),
in_channel=Encoder_f_channel_list[-1], # encoder, not decoder
dim=dim)

# pool
self.pool = GlobalMaxPool()

# organ emb (optional)
if add_organ_embeddings:
self.organ_embed = PositionalEncoding_1D_learnable(
embedding_dim=Encoder_f_channel_list[-1],
token_num=organ_num)
else:
self.organ_embed = None

# relation_module
if attention_type == 'cross':
self.relation_module = TransformerDecoder(
token_channels=Encoder_f_channel_list[-1],
depth=relation_layer,
heads=relation_head,
dim_head=Encoder_f_channel_list[-1],
mlp_dim=Encoder_f_channel_list[-1],
)
elif attention_type == 'self':
self.norm = MakeNorm(dim, Encoder_f_channel_list[-1], norm='ln')
self.relation_module = TransformerEncoder(
token_channels=Encoder_f_channel_list[-1],
depth=relation_layer,
heads=relation_head,
dim_head=Encoder_f_channel_list[-1],
mlp_dim=Encoder_f_channel_list[-1],
)
else:
raise ValueError('must be cross or self')

def forward(self, x):
bz = x.shape[0]

# achieve encoder feaetures
f_encoder = self.encoder(x) # _ = [print(i.shape) for i in f_encoder]

# achieve organ seg map from the deepest encoder feature (called deep supervision)
organ_seg_map = self.seg_head_organ(f_encoder[-1]) # _ = [print(key, organ_seg_map[key].shape) for key in organ_seg_map.keys()]
organ_seg_map_tensor = organ_seg_map['organ']
organ_seg_map_tensor = organ_seg_map_tensor.contiguous().view(bz, organ_seg_map_tensor.shape[1],-1)
organ_seg_map_tensor = torch.softmax(organ_seg_map_tensor, dim=1) # print(organ_seg_map_tensor.sum(1).shape, organ_seg_map_tensor.sum(1))
organ_seg_map_tensor = organ_seg_map_tensor.contiguous().view(*organ_seg_map['organ'].shape)

# achieve organ tokens by multiplying the segmentation map by the feature graph
organ_tokens = [self.pool(torch.unsqueeze(organ_seg_map_tensor[:,i],1)*f_encoder[-1]) for i in range(self.organ_num)] # _ = [print(i.shape) for i in organ_tokens]

# achieve lesion token
lesion_token = f_encoder[-1].contiguous().view(bz, -1, f_encoder[-1].shape[1])
print(lesion_token.shape)

# relation module forward
if self.attention_type == 'cross':
lesion_token, self_attn_map_list, cross_attn_map_list = self.relation_module(
lesion_token,
torch.stack(organ_tokens, dim=1),
q_pos_embeddings=None,
v_pos_embeddings=self.organ_embed)
lesion_token = lesion_token.contiguous().view(*f_encoder[-1].shape)
elif self.attention_type == 'self':
organ_tokens = torch.stack(organ_tokens, dim=1)
if self.add_organ_embeddings:
print('add organ emb')
organ_tokens += self.organ_embed
all_tokens = torch.cat([lesion_token,organ_tokens], dim=1)
all_tokens, self_attn_map_list = self.relation_module(self.norm(all_tokens), pos_emb=None)
lesion_token = all_tokens[:,:lesion_token.shape[1],:]
lesion_token = lesion_token.contiguous().view(*f_encoder[-1].shape)

# get lesion seg map
f_decoder = self.decoder(f_encoder[:-1]+[lesion_token])
f_for_seg = resizeTensor(f_decoder[0], size=x.shape[2:])
lesion_seg_map = self.seg_head_lesion(f_for_seg)

return lesion_seg_map, organ_seg_map, organ_seg_map_tensor


if __name__ == '__main__':
batchsize = 2
channel = 1
shape = [64,64,64]
shape = [256,256,16]
organ_num = 16+1 # 1 refer to background
image = torch.ones([batchsize,channel]+shape)
organ_GT = resizeTensor(torch.ones([batchsize,organ_num]+shape), size=[8,8,8])
image[:,:,:64,:64,:16] = 0
organ_GT = resizeTensor(torch.ones([batchsize,organ_num]+shape), size=[32,32,2])
lesion_GT = torch.ones([batchsize,1]+shape)

RalationNet = RalationNet_v2 # use this to switch to RalationNet_v2

# mode1: self attention w/ organ emb ------------------------
print('# mode1: self attention w/ organ emb ------------------------')
model = RalationNet(
Expand All @@ -252,7 +384,7 @@ def forward(self, x):
optimized_parameters = list(model.parameters())
optimizer = optim.Adam(optimized_parameters, 1e-3, [0.5, 0.999], weight_decay= 5e-4)

lesion_seg_map, organ_seg_map = model(image)
lesion_seg_map, organ_seg_map, organ_seg_map_tensor= model(image)
loss = ((lesion_seg_map['lesion']-lesion_GT)**2).sum() + ((organ_seg_map['organ']-organ_GT)**2).sum()
model.zero_grad()
loss.backward()
Expand All @@ -271,7 +403,7 @@ def forward(self, x):
optimized_parameters = list(model.parameters())
optimizer = optim.Adam(optimized_parameters, 1e-3, [0.5, 0.999], weight_decay= 5e-4)

lesion_seg_map, organ_seg_map = model(image)
lesion_seg_map, organ_seg_map, organ_seg_map_tensor = model(image)
loss = ((lesion_seg_map['lesion']-lesion_GT)**2).sum() + ((organ_seg_map['organ']-organ_GT)**2).sum()
model.zero_grad()
loss.backward()
Expand All @@ -290,7 +422,7 @@ def forward(self, x):
optimized_parameters = list(model.parameters())
optimizer = optim.Adam(optimized_parameters, 1e-3, [0.5, 0.999], weight_decay= 5e-4)

lesion_seg_map, organ_seg_map = model(image)
lesion_seg_map, organ_seg_map, organ_seg_map_tensor = model(image)
loss = ((lesion_seg_map['lesion']-lesion_GT)**2).sum() + ((organ_seg_map['organ']-organ_GT)**2).sum()
model.zero_grad()
loss.backward()
Expand All @@ -309,11 +441,28 @@ def forward(self, x):
optimized_parameters = list(model.parameters())
optimizer = optim.Adam(optimized_parameters, 1e-3, [0.5, 0.999], weight_decay= 5e-4)

lesion_seg_map, organ_seg_map = model(image)
lesion_seg_map, organ_seg_map, organ_seg_map_tensor = model(image)
loss = ((lesion_seg_map['lesion']-lesion_GT)**2).sum() + ((organ_seg_map['organ']-organ_GT)**2).sum()
model.zero_grad()
loss.backward()
optimizer.step()



# vis to check whether the tensor.view is correct for keeping spacing information
lesion_seg_map = tensor2array(lesion_seg_map['lesion'])
show2D(np.squeeze(lesion_seg_map[0,0,0]))
show2D(np.squeeze(lesion_seg_map[0,0,:,0]))
show2D(np.squeeze(lesion_seg_map[0,0,:,:,0]))

organ_seg_map = tensor2array(organ_seg_map['organ'])
show2D(np.squeeze(organ_seg_map[0,0,0]))
show2D(np.squeeze(organ_seg_map[0,0,:,0]))
show2D(np.squeeze(organ_seg_map[0,0,:,:,0]))
show2D(np.squeeze(organ_seg_map[0,1,:,:,0]))
show2D(np.squeeze(organ_seg_map[0,2,:,:,0]))

organ_seg_map = tensor2array(organ_seg_map_tensor) # after sigmoid
show2D(np.squeeze(organ_seg_map[0,1,0]))
show2D(np.squeeze(organ_seg_map[0,1,:,0]))
show2D(np.squeeze(organ_seg_map[0,0,:,:,0]))
show2D(np.squeeze(organ_seg_map[0,1,:,:,0]))
show2D(np.squeeze(organ_seg_map[0,2,:,:,0]))
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit e90d479

Please sign in to comment.