*A PyTorch module library for building 1D/2D/3D networks flexibly ~*
(Simple-to-use & Function-rich!)
Highlights
- Simple code that show whole forward processes succinctly
- Output rich features and attention map for fast reuse
- Support 1D / 2D / 3D networks (CNNs, GNNs, Transformers...)
- Easy and flexible to integrate with any other network
- 🚀 Abundant Pretrained weights: Including 80000+
2D weights
and 80+3D weights
*Download pretrained weights from
[Google Drive]
or [Baidu Netdisk psw: wama
]
*All modules are detailed in [Document] (🚧 still under building)
🔥 wama_modules
Basic
1D/2D/3D
Install wama_modules
with command ↓
pip install git+https://github.com/WAMAWAMA/wama_modules.git
Other ways to install (or use) wama_modules
- Way1: Download code and run
python setup.py install
- Way2: Directly copy the folder
wama_modules
into your project path
💧 segmentation_models_pytorch
Optional
2D
100+ pretrained weights
Introduction and installation command
segmentation_models_pytorch
(called smp
)
is a 2D CNN lib including many backbones and decoders, which is highly recommended to install for cooperating with this library.
*Our codes have already contained smp
, but you can still install the latest version with the code below.
Install with pip ↓
pip install segmentation-models-pytorch
Install the latest version ↓
pip install git+https://github.com/rwightman/pytorch-image-models.git
💧 transformers
Optional
2D
80000+ pretrained weights
Introduction and installation command
transformer
(powered by Huggingface) is a lib including super abundant CNN and Transformer structures, which is highly recommended to install for cooperating with this library.
Install transformer
with pip ↓
pip install transformers
💧 timm
Optional
2D
400+ pretrained weights
Introduction and installation command
timm
is a lib includes abundant CNN and Transformer structures, which is highly recommended to install for cooperating with this library.
Install with pip ↓
pip install timm
Install the latest version ↓
pip install git+https://github.com/rwightman/pytorch-image-models.git
- 2022/11/11: The birthday of this code, version
v0.0.1
- ...
An overview of this repo (let's call wama_modules
as wm
)
File | Description | Main class or function |
---|---|---|
wm.utils |
Some operations on tensors and pre-training weights | resizeTensor() tensor2array() load_weights() |
wm.thirdparty_lib |
2D/3D network structures (CNN/GNN/Transformer) from other repositories, and all are with pre-trained weights 🚀 | MedicalNet C3D 3D_DenseNet 3D_shufflenet transformers.ConvNextModel transformers.SwinModel Radimagenet |
wm.Attention |
Some attention-based plugins | SCSEModule NonLocal |
wm.BaseModule |
Basic modules(layers). For example, BottleNeck block (ResBlock) in ResNet, and DenseBlock in DenseNet, etc. | MakeNorm() MakeConv() MakeActive() VGGBlock ResBlock DenseBlock |
wm.Encoder |
Some encoders such like ResNet or DenseNet, but with more flexibility for building the network modularly, and 1D/2D/3D are all supported | VGGEncoder ResNetEncoder DenseNetEncoder |
wm.Decoder |
Some encoders with more flexibility for building the network modularly, and 1D/2D/3D are all supported | UNet_decoder |
wm.Neck |
Modules for making the multi-scale features (from encoder) interact with each other to generate stronger features | FPN |
wm.Transformer |
Some self-attention or cross-attention modules, which can be used to build ViT, DETR or TransUnet | TransformerEncoderLayer TransformerDecoderLayer |
- How to build your networks modularly and freely? 👉 See 'Guideline 1: Build networks modularly' below ~
- How to use pretrained model with
wm.thirdparty_lib
? 👉 See 'Guideline 2: Use pretrained weights' below ~
How to build a network modularly? Here's a paradigm ↓
'Design architecture according to tasks, pick modules according to architecture'
So, network architectures for different tasks can be viewed modularly such as:
- VGG for classification
=
VGG_encoder+
classification_head - ResNet for classification
=
ResNet_encoder+
classification_head - Unet for segmentation
=
encoder+
decoder+
segmentation_head - A multi-task net for classification and segmentation
=
encoder+
decoder+
cls_head+
seg_head
For example, build a 3D resnet50 encoder to output multi-scale feature maps ↓
from wama_modules.Encoder import ResNetEncoder
import torch
dim = 3 # input is 3D volume
in_channels = 3
input = torch.ones([2, in_channels, 64, 64, 64])
encoder = ResNetEncoder(
in_channels,
stage_output_channels=[64, 128, 256],
blocks=[6, 12, 24],
downsample_ration=[0.5, 0.5, 0.5], # set your downsampling speed
dim=dim
)
multi_scale_f = encoder(input)
_ = [print(i.shape) for i in multi_scale_f]
# --------------------------------
# output 👇
# torch.Size([2, 64, 15, 15, 15])
# torch.Size([2, 128, 7, 7, 7])
# torch.Size([2, 256, 3, 3, 3])
# --------------------------------
Here are more demos shown below ↓ (Click to view codes, or visit the demo
folder)
Demo0: Build a 3D VGG for Single Label Classification
import torch
import torch.nn as nn
from wama_modules.Encoder import VGGEncoder
from wama_modules.Head import ClassificationHead
from wama_modules.BaseModule import GlobalMaxPool
class Model(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
f_channel_list = [64, 128, 256, 512]
self.encoder = VGGEncoder(
in_channel,
stage_output_channels=f_channel_list,
blocks=[1, 2, 3, 4],
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# cls head
self.cls_head = ClassificationHead(label_category_dict, f_channel_list[-1])
self.pooling = GlobalMaxPool()
def forward(self, x):
f = self.encoder(x)
logits = self.cls_head(self.pooling(f[-1]))
return logits
if __name__ == '__main__':
x = torch.ones([2, 1, 64, 64, 64])
label_category_dict = dict(is_malignant=4)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('single-label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
# output 👇
# single-label predicted logits
# logits of is_malignant : torch.Size([2, 4])
Demo1: Build a 3D ResNet for Single Label Classification
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Head import ClassificationHead
from wama_modules.BaseModule import GlobalMaxPool
class Model(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=f_channel_list,
stage_middle_channels=f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# cls head
self.cls_head = ClassificationHead(label_category_dict, f_channel_list[-1])
self.pooling = GlobalMaxPool()
def forward(self, x):
f = self.encoder(x)
logits = self.cls_head(self.pooling(f[-1]))
return logits
if __name__ == '__main__':
x = torch.ones([2, 1, 64, 64, 64])
label_category_dict = dict(is_malignant=4)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('single-label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
# output 👇
# single-label predicted logits
# logits of is_malignant : torch.Size([2, 4])
Demo2: Build a ResNet for Multi-Label Classification
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Head import ClassificationHead
from wama_modules.BaseModule import GlobalMaxPool
class Model(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=f_channel_list,
stage_middle_channels=f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# cls head
self.cls_head = ClassificationHead(label_category_dict, f_channel_list[-1])
self.pooling = GlobalMaxPool()
def forward(self, x):
f = self.encoder(x)
logits = self.cls_head(self.pooling(f[-1]))
return logits
if __name__ == '__main__':
x = torch.ones([2, 1, 64, 64, 64])
label_category_dict = dict(shape=4, color=3, other=13)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi-label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
# out
# multi-label predicted logits
# logits of shape : torch.Size([2, 4])
# logits of color : torch.Size([2, 3])
# logits of other : torch.Size([2, 13])
Demo3: Build a ResNetUnet for Single Label Segmentation
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Decoder import UNet_decoder
from wama_modules.Head import SegmentationHead
from wama_modules.utils import resizeTensor
class Model(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
Encoder_f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=Encoder_f_channel_list,
stage_middle_channels=Encoder_f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# decoder
Decoder_f_channel_list = [32, 64, 128]
self.decoder = UNet_decoder(
in_channels_list=Encoder_f_channel_list,
skip_connection=[False, True, True],
out_channels_list=Decoder_f_channel_list,
dim=dim)
# seg head
self.seg_head = SegmentationHead(
label_category_dict,
Decoder_f_channel_list[0],
dim=dim)
def forward(self, x):
multi_scale_f1 = self.encoder(x)
multi_scale_f2 = self.decoder(multi_scale_f1)
f_for_seg = resizeTensor(multi_scale_f2[0], size=x.shape[2:])
logits = self.seg_head(f_for_seg)
return logits
if __name__ == '__main__':
x = torch.ones([2, 1, 128, 128, 128])
label_category_dict = dict(organ=3)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi-label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
# out
# multi-label predicted logits
# logits of organ : torch.Size([2, 3, 128, 128, 128])
Demo4: Build a ResNetUnet for Multi-Label Segmentation
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Decoder import UNet_decoder
from wama_modules.Head import SegmentationHead
from wama_modules.utils import resizeTensor
class Model(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
Encoder_f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=Encoder_f_channel_list,
stage_middle_channels=Encoder_f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# decoder
Decoder_f_channel_list = [32, 64, 128]
self.decoder = UNet_decoder(
in_channels_list=Encoder_f_channel_list,
skip_connection=[False, True, True],
out_channels_list=Decoder_f_channel_list,
dim=dim)
# seg head
self.seg_head = SegmentationHead(
label_category_dict,
Decoder_f_channel_list[0],
dim=dim)
def forward(self, x):
multi_scale_f1 = self.encoder(x)
multi_scale_f2 = self.decoder(multi_scale_f1)
f_for_seg = resizeTensor(multi_scale_f2[0], size=x.shape[2:])
logits = self.seg_head(f_for_seg)
return logits
if __name__ == '__main__':
x = torch.ones([2, 1, 128, 128, 128])
label_category_dict = dict(organ=3, tumor=4)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi-label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
# out
# multi-label predicted logits
# logits of organ : torch.Size([2, 3, 128, 128, 128])
# logits of tumor : torch.Size([2, 4, 128, 128, 128])
Demo5: Build a MultiTask net for Segmentation and Classfification
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Decoder import UNet_decoder
from wama_modules.Head import SegmentationHead, ClassificationHead
from wama_modules.utils import resizeTensor
from wama_modules.BaseModule import GlobalMaxPool
class Model(nn.Module):
def __init__(self,
in_channel,
seg_label_category_dict,
cls_label_category_dict,
dim=2):
super().__init__()
# encoder
Encoder_f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=Encoder_f_channel_list,
stage_middle_channels=Encoder_f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# decoder
Decoder_f_channel_list = [32, 64, 128]
self.decoder = UNet_decoder(
in_channels_list=Encoder_f_channel_list,
skip_connection=[False, True, True],
out_channels_list=Decoder_f_channel_list,
dim=dim)
# seg head
self.seg_head = SegmentationHead(
seg_label_category_dict,
Decoder_f_channel_list[0],
dim=dim)
# cls head
self.cls_head = ClassificationHead(cls_label_category_dict, Encoder_f_channel_list[-1])
# pooling
self.pooling = GlobalMaxPool()
def forward(self, x):
# get encoder features
multi_scale_encoder = self.encoder(x)
# get decoder features
multi_scale_decoder = self.decoder(multi_scale_encoder)
# perform segmentation
f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:])
seg_logits = self.seg_head(f_for_seg)
# perform classification
cls_logits = self.cls_head(self.pooling(multi_scale_encoder[-1]))
return seg_logits, cls_logits
if __name__ == '__main__':
x = torch.ones([2, 1, 128, 128, 128])
seg_label_category_dict = dict(organ=3, tumor=2)
cls_label_category_dict = dict(shape=4, color=3, other=13)
model = Model(
in_channel=1,
cls_label_category_dict=cls_label_category_dict,
seg_label_category_dict=seg_label_category_dict,
dim=3)
seg_logits, cls_logits = model(x)
print('multi-label predicted logits')
_ = [print('seg logits of ', key, ':', seg_logits[key].shape) for key in seg_logits.keys()]
print('-'*30)
_ = [print('cls logits of ', key, ':', cls_logits[key].shape) for key in cls_logits.keys()]
# out
# multi-label predicted logits
# seg logits of organ : torch.Size([2, 3, 128, 128, 128])
# seg logits of tumor : torch.Size([2, 2, 128, 128, 128])
# ------------------------------
# cls logits of shape : torch.Size([2, 4])
# cls logits of color : torch.Size([2, 3])
# cls logits of other : torch.Size([2, 13])
Demo6: Build a Unet with a resnet encoder and a FPN neck
import torch
import torch.nn as nn
from wama_modules.Encoder import ResNetEncoder
from wama_modules.Decoder import UNet_decoder
from wama_modules.Head import SegmentationHead
from wama_modules.utils import resizeTensor
from wama_modules.Neck import FPN
class Model(nn.Module):
def __init__(self, in_channel, label_category_dict, dim=2):
super().__init__()
# encoder
Encoder_f_channel_list = [64, 128, 256, 512]
self.encoder = ResNetEncoder(
in_channel,
stage_output_channels=Encoder_f_channel_list,
stage_middle_channels=Encoder_f_channel_list,
blocks=[1, 2, 3, 4],
type='131',
downsample_ration=[0.5, 0.5, 0.5, 0.5],
dim=dim)
# neck
FPN_output_channel = 256
FPN_channels = [FPN_output_channel]*len(Encoder_f_channel_list)
self.neck = FPN(in_channels_list=Encoder_f_channel_list,
c1=FPN_output_channel//2,
c2=FPN_output_channel,
mode='AddSmall2Big',
dim=dim,)
# decoder
Decoder_f_channel_list = [32, 64, 128]
self.decoder = UNet_decoder(
in_channels_list=FPN_channels,
skip_connection=[True, True, True],
out_channels_list=Decoder_f_channel_list,
dim=dim)
# seg head
self.seg_head = SegmentationHead(
label_category_dict,
Decoder_f_channel_list[0],
dim=dim)
def forward(self, x):
multi_scale_encoder = self.encoder(x)
multi_scale_neck = self.neck(multi_scale_encoder)
multi_scale_decoder = self.decoder(multi_scale_neck)
f_for_seg = resizeTensor(multi_scale_decoder[0], size=x.shape[2:])
logits = self.seg_head(f_for_seg)
return logits
if __name__ == '__main__':
x = torch.ones([2, 1, 128, 128, 128])
label_category_dict = dict(organ=3, tumor=4)
model = Model(in_channel=1, label_category_dict=label_category_dict, dim=3)
logits = model(x)
print('multi-label predicted logits')
_ = [print('logits of ', key, ':', logits[key].shape) for key in logits.keys()]
# out
# multi-label predicted logits
# logits of organ : torch.Size([2, 3, 128, 128, 128])
# logits of tumor : torch.Size([2, 4, 128, 128, 128])
*Todo-demo list (under preparation and coming soon...) ↓
Demo: Build a TransUnet
Demo: Build a C-tran model for multi-label classification
Demo: Build a Q2L model for multi-label classification
Demo: Build a ML-Decoder model for multi-label classification
Demo: Build a ML-GCN model for multi-label classification
Demo: Build a UCTransNet model for segmentation
Demo: Build a model for multiple inputs (1D signal and 2D image)
Demo: Build a 2D Unet with pretrained Resnet50 encoder (1D signal and 2D image)
Demo: Build a 3D DETR model for object detection
Demo: Build a 3D VGG with SE-attention module for multi-instanse classification
(*All pretrained weights are from third-party codes or repos)
Currently available pre-training models are shown below ↓
Module name | Number of pretrained weights | Pretrained data | Dimension | |
---|---|---|---|---|
1 | .ResNets3D_kenshohara |
21 | video | 3D |
2 | .VC3D_kenshohara |
13 | video | 3D |
3 | .Efficient3D_okankop |
39 | video | 3D |
4 | .MedicalNet_tencent |
11 | medical image | 3D |
5 | .C3D_jfzhang95 |
1 | video | 3D |
6 | .C3D_yyuanad |
1 | video | 3D |
7 | .SMP_qubvel |
119 | image | 2D |
8 | timm |
400+ | image | 2D |
9 | transformers |
80000+ | video/image | 2D/3D |
10 | radimagenet |
1 | medical image | 2D |
*Download all pretrained weights from
[Google Drive]
or [Baidu Netdisk psw: wama
]
ResNets3D_kenshohara (21 weights)
Demo code ---------------------------------
import torch
from wama_modules.thirdparty_lib.ResNets3D_kenshohara.resnet import generate_model
from wama_modules.utils import load_weights
m = generate_model(18)
pretrain_path = r"D:\pretrainedweights\ResNets3D_kenshohara\kenshohara_ResNets3D_weights\resnet\r3d18_KM_200ep.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]
import torch
from wama_modules.thirdparty_lib.ResNets3D_kenshohara.resnet2p1d import generate_model
from wama_modules.utils import load_weights
m = generate_model(18)
pretrain_path = r"D:\pretrainedweights\ResNets3D_kenshohara\kenshohara_ResNets3D_weights\resnet2p1d\r2p1d18_K_200ep.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]
VC3D_kenshohara (13 weights)
Demo code ---------------------------------
# resnet
import torch
from wama_modules.thirdparty_lib.VC3D_kenshohara.resnet import generate_model
from wama_modules.utils import load_weights
m = generate_model(18)
pretrain_path = r"D:\pretrainedweights\VC3D_kenshohara\VC3D_weights\resnet\resnet-18-kinetics.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]
# resnext
import torch
from wama_modules.thirdparty_lib.VC3D_kenshohara.resnext import generate_model
from wama_modules.utils import load_weights
m = generate_model(101)
pretrain_path = r"D:\pretrainedweights\VC3D_kenshohara\VC3D_weights\resnext\resnext-101-64f-kinetics.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]
# wide_resnet
import torch
from wama_modules.thirdparty_lib.VC3D_kenshohara.wide_resnet import generate_model
from wama_modules.utils import load_weights
m = generate_model()
pretrain_path = r"D:\pretrainedweights\VC3D_kenshohara\VC3D_weights\wideresnet\wideresnet-50-kinetics.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]
Efficient3D_okankop (39 weights)
Demo code ---------------------------------
# c3d
import torch
from wama_modules.thirdparty_lib.Efficient3D_okankop.models.c3d import get_model
m = get_model() # c3d has no pretrained weights
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
# mobilenet
import torch
from wama_modules.thirdparty_lib.Efficient3D_okankop.models.mobilenet import get_model
from wama_modules.utils import load_weights
m = get_model(width_mult = 1.) # e.g. width_mult = 1 when mobilenet_1.0x
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\mobilenet\jester_mobilenet_1.0x_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
m = get_model(width_mult = 2.) # e.g. width_mult = 2 when mobilenet_2.0x
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\mobilenet\jester_mobilenet_2.0x_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
# mobilenetv2
import torch
from wama_modules.thirdparty_lib.Efficient3D_okankop.models.mobilenetv2 import get_model
from wama_modules.utils import load_weights
m = get_model(width_mult = 1.) # e.g. width_mult = 1 when mobilenet_1.0x
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\mobilenetv2\jester_mobilenetv2_1.0x_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
m = get_model(width_mult = 0.45) # e.g. width_mult = 1 when mobilenet_1.0x
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\mobilenetv2\jester_mobilenetv2_0.45x_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
# resnet
import torch
from wama_modules.thirdparty_lib.Efficient3D_okankop.models.resnet import resnet18, resnet50, resnet101
from wama_modules.utils import load_weights
m = resnet18()
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\resnet\kinetics_resnet_18_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
m = resnet50()
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\resnet\kinetics_resnet_50_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
m = resnet101()
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\resnet\kinetics_resnet_101_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
# resnext
import torch
from wama_modules.thirdparty_lib.Efficient3D_okankop.models.resnext import resnext101
from wama_modules.utils import load_weights
m = resnext101()
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\resnext\jester_resnext_101_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
# shufflenet
import torch
from wama_modules.thirdparty_lib.Efficient3D_okankop.models.shufflenet import get_model
from wama_modules.utils import load_weights
m = get_model(groups=3, width_mult=1)
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\shufflenet\jester_shufflenet_1.0x_G3_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
m = get_model(groups=3, width_mult=1.5)
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\shufflenet\jester_shufflenet_1.5x_G3_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
# shufflenetv2
import torch
from wama_modules.thirdparty_lib.Efficient3D_okankop.models.shufflenetv2 import get_model
from wama_modules.utils import load_weights
m = get_model(width_mult=1)
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\shufflenetv2\jester_shufflenetv2_1.0x_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
m = get_model(width_mult=2)
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\shufflenetv2\jester_shufflenetv2_2.0x_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
# squeezenet
import torch
from wama_modules.thirdparty_lib.Efficient3D_okankop.models.squeezenet import get_model
from wama_modules.utils import load_weights
m = get_model()
pretrain_path = r"D:\pretrainedweights\Efficient3D_okankop\Efficient3D_okankop_weights\squeezenet\jester_squeezenet_RGB_16_best.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
MedicalNet_tencent (11 weights)
Demo code ---------------------------------
import torch
from wama_modules.utils import load_weights
from wama_modules.thirdparty_lib.MedicalNet_Tencent.model import generate_model
m = generate_model(18)
pretrain_path = r"D:\pretrainedweights\MedicalNet_Tencent\MedicalNet_weights\resnet_18_23dataset.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2, 1, 64, 64, 64])) # input channel is 1 (not 3 for video)
_ = [print(i.shape) for i in f_list]
C3D_jfzhang95 (1 weight)
Demo code ---------------------------------
import torch
from wama_modules.utils import load_weights
from wama_modules.thirdparty_lib.C3D_jfzhang95.c3d import C3D
m = C3D()
pretrain_path = r"D:\pretrainedweights\C3D_jfzhang95\C3D_jfzhang95_weights\C3D_jfzhang95_C3D.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')
m = load_weights(m, pretrain_weights)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
C3D_yyuanad (1 weight)
Demo code ---------------------------------
import torch
from wama_modules.utils import load_weights
from wama_modules.thirdparty_lib.C3D_yyuanad.c3d import C3D
m = C3D()
pretrain_path = r"D:\pretrainedweights\C3D_yyuanad\C3D_yyuanad_weights\C3D_yyuanad.pickle"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')
m = load_weights(m, pretrain_weights)
f_list = m(torch.ones([2, 3, 64, 64, 64]))
_ = [print(i.shape) for i in f_list]
SMP_qubvel (119 weights, Automatic online download)
Demo code ---------------------------------
import torch
from wama_modules.thirdparty_lib.SMP_qubvel.encoders import get_encoder
m = get_encoder('resnet18', in_channels=3, depth=5, weights='ssl')
f_list = m(torch.ones([2,3,128,128]))
_ = [print(i.shape) for i in f_list]
timm (400+ weights, Automatic online download)
Demo code ---------------------------------
import torch
import timm
m = timm.create_model(
'adv_inception_v3',
features_only=True,
pretrained=True,)
f_list = m(torch.ones([2,3,128,128]))
_ = [print(i.shape) for i in f_list]
transformers (80000+ weights, Automatic online download)
All models please go to Huggingface [ModelHub]
Demo code ---------------------------------
import torch
from transformers import ConvNextModel
from wama_modules.utils import load_weights
# Initializing a model (with random weights) from the convnext-tiny-224 style configuration
m = ConvNextModel.from_pretrained('facebook/convnext-base-224-22k')
f = m(torch.ones([2,3,224,224]), output_hidden_states=True)
f_list = f.hidden_states
_ = [print(i.shape) for i in f_list]
# reload weights
weights = m.state_dict()
m1 = ConvNextModel(m.config)
m = load_weights(m, weights)
import torch
from transformers import SwinModel
from wama_modules.utils import load_weights
m = SwinModel.from_pretrained('microsoft/swin-base-patch4-window12-384')
f = m(torch.ones([2,3,384,384]), output_hidden_states=True)
f_list = f.reshaped_hidden_states # For transformer, should use reshaped_hidden_states
_ = [print(i.shape) for i in f_list]
# reload weights
weights = m.state_dict()
m1 = SwinModel(m.config)
m = load_weights(m, weights)
radimagenet (1 weight)
Demo code ---------------------------------
import torch
from wama_modules.utils import load_weights
from wama_modules.thirdparty_lib.SMP_qubvel.encoders import get_encoder
m = get_encoder('resnet50', in_channels=3, depth=5, weights=None)
pretrain_path = r"D:\pretrainedweights\radimagnet\RadImageNet_models-20221104T172755Z-001\RadImageNet_models\RadImageNet-ResNet50_notop_torch.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')
m = load_weights(m, pretrain_weights)
f_list = m(torch.ones([2,3,128,128]))
_ = [print(i.shape) for i in f_list]
Thanks to these authors and their codes:
- https://github.com/ZhugeKongan/torch-template-for-deep-learning
- pytorch vit: https://github.com/lucidrains/vit-pytorch
- SMP: https://github.com/qubvel/segmentation_models.pytorch
- transformers: https://github.com/huggingface/transformers
- medicalnet: https://github.com/Tencent/MedicalNet
- timm: https://github.com/rwightman/pytorch-image-models
- ResNets3D_kenshohara: https://github.com/kenshohara/3D-ResNets-PyTorch
- VC3D_kenshohara: https://github.com/kenshohara/video-classification-3d-cnn-pytorch
- Efficient3D_okankop: https://github.com/okankop/Efficient-3DCNNs
- C3D_jfzhang95: https://github.com/jfzhang95/pytorch-video-recognition
- C3D_yyuanad: https://github.com/yyuanad/Pytorch_C3D_Feature_Extractor
- radimagenet: https://github.com/BMEII-AI/RadImageNet
- BMEII-AI/RadImageNet#3