(🚧 still under building, but feel free to try current module implementations~)
*A PyTorch module library for building 1D/2D/3D networks flexibly ~*
Highlights (Simple-to-use & Function-rich!)
- Simple code that show whole forward processes succinctly
- Output rich features and attention map for fast reuse
- Support 1D / 2D / 3D networks (CNNs, GNNs, Transformers...)
- Easy and flexible to integrate with any other network
- 🚀 Abundant Pretrained weights: Including 80000+
2D weights
and 80+3D weights
🔥 wama_modules Basic
1D
2D
3D
Install wama_modules
with command ↓
pip install git+https://github.com/WAMAWAMA/wama_modules.git
*Or you can directly copy the wama_modules
folder to use
💧 segmentation_models_pytorch Optional
2D
100+ pretrained weights
Introduction and installation command
segmentation_models_pytorch
(called smp
)
is a 2D CNN lib includes many backbones and decoders, which is highly recommended to install for cooperating with this library.
Our code have already contained smp
, but you can still install the latest version with the code below.
Install with pip:
pip install segmentation-models-pytorch
Install the latest version:
pip install git+https://github.com/rwightman/pytorch-image-models.git
💧 transformers Optional
2D
80000+ pretrained weights
Introduction and installation command
transformer
is a lib includes abundant CNN and Transformer structures, which is highly recommended to install for cooperating with this library.
Install transformer
use ↓
pip install transformers
💧 timm Optional
2D
400+ pretrained weights
Introduction and installation command
timm
is a lib includes abundant CNN and Transformer structures, which is highly recommended to install for cooperating with this library.
Install transformer use ↓
Install with pip:
pip install timm
Install the latest version:
pip install git+https://github.com/rwightman/pytorch-image-models.git
- 2022/11/5: Open the source code, version
v0.0.1-beta
- ...
Here I'll give an overview of this repo
How to build a network modularly?
The answer is a paradigm of building networks:
'Design architecture according to tasks, pick modules according to architecture'
So, network architectures for different tasks can be viewed modularly such as:
- vgg = vgg_encoder + cls_head
- Unet = encoder + decoder + seg_ead
- resnet = resnet_encoder + cls_head
- densenet = densenet_encoder + cls_head
- a multi-task net for classification and segmentation = encoder + decoder + cls_head + seg_head
For example, build a 3D resnet50
import wama_modules as ws
import torch
encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)
input = torch.ones([3,3,128,128])
Here are more demos shown below ↓ (Click to view codes, or visit the demo
folder)
Demo1: Build a 2D vgg16
import wama_modules as ws
import torch
encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)
input = torch.ones([3,3,128,128])
Demo2: Build a 3D resnet50
import wama_modules as ws
import torch
encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)
input = torch.ones([3,3,128,128])
Demo3: Build a 3D densenet121
import wama_modules as ws
import torch
encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)
input = torch.ones([3,3,128,128])
Demo4: Build a Unet
import wama_modules as ws
import torch
encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)
input = torch.ones([3,3,128,128])
Demo5: Build a Unet with a resnet50 encoder
import wama_modules as ws
import torch
encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)
input = torch.ones([3,3,128,128])
Demo6: Build a Unet with a resnet50 encoder and a FPN
import wama_modules as ws
import torch
encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)
input = torch.ones([3,3,128,128])
Demo7: Build a multi-task model for segmentation and classification
import wama_modules as ws
import torch
encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)
input = torch.ones([3,3,128,128])
Demo8: Build a C-tran model for multi-label classification
import wama_modules as ws
import torch
encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)
input = torch.ones([3,3,128,128])
Demo9: Build a Q2L model for multi-label classification
import wama_modules as ws
import torch
encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)
input = torch.ones([3,3,128,128])
Demo10: Build a ML-Decoder model for multi-label classification
import wama_modules as ws
import torch
encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)
input = torch.ones([3,3,128,128])
Demo11: Build a ML-GCN model for multi-label classification
import wama_modules as ws
import torch
encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)
input = torch.ones([3,3,128,128])
Demo12: Build a UCTransNet model for segmentation
import wama_modules as ws
import torch
encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)
input = torch.ones([3,3,128,128])
Demo13: Build a model for multiple inputs (1D signal and 2D image)
import wama_modules as ws
import torch
encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)
input = torch.ones([3,3,128,128])
Demo14: Build a 2D Unet with pretrained Resnet50 encoder (1D signal and 2D image)
import wama_modules as ws
import torch
encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)
input = torch.ones([3,3,128,128])
Demo15: Build a 3D DETR model for object detection
import wama_modules as ws
import torch
encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)
input = torch.ones([3,3,128,128])
Demo16: Build a 3D VGG with SE-attention module for multi-instanse classification
import wama_modules as ws
import torch
encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)
input = torch.ones([3,3,128,128])
All pretrained weights are from third-party codes or repos
current pretrained support: (这里给一个表格,来自哪里,多少权重,预训练数据类型,2D还是3D))
- 2D: smp, timm, radimagenet...
- 3D: medicalnet, 3D resnet, 3D densenet...
smp (119 pretrained weights)
import torch
from wama_modules.thirdparty_lib.SMP_qubvel.encoders import get_encoder
m = get_encoder('resnet18', in_channels=3, depth=5, weights='ssl')
m = get_encoder('name', in_channels=3, depth=5, weights='ssl')
m = get_encoder('resnet18', in_channels=3, depth=5, weights='ss')
f_list = m(torch.ones([2,3,128,128]))
_ = [print(i.shape) for i in f_list]
timm (400+ pretrained weights)
import timm
m = timm.create_model(
'adv_inception_v3',
features_only=True,
pretrained=False,)
f_list = m(torch.ones([2,3,128,128]))
_ = [print(i.shape) for i in f_list]
transformers, supper powered by Huggingface ( with 80000+ pretrained weights)
import torch
from transformers import ConvNextModel
from wama_modules.utils import load_weights
# Initializing a model (with random weights) from the convnext-tiny-224 style configuration
m = ConvNextModel.from_pretrained('facebook/convnext-base-224-22k')
f = m(torch.ones([2,3,224,224]), output_hidden_states=True)
f_list = f.hidden_states
_ = [print(i.shape) for i in f_list]
weights = m.state_dict()
m1 = ConvNextModel(m.config)
m = load_weights(m, weights)
import torch
from transformers import SwinModel
from wama_modules.utils import load_weights
m = SwinModel.from_pretrained('microsoft/swin-base-patch4-window12-384')
f = m(torch.ones([2,3,384,384]), output_hidden_states=True)
f_list = f.reshaped_hidden_states # For transformer, should use reshaped_hidden_states
_ = [print(i.shape) for i in f_list]
weights = m.state_dict()
m1 = SwinModel(m.config)
m = load_weights(m, weights)
???
3D ResNets3D_kenshohara (21 weights)
import torch
from wama_modules.thirdparty_lib.ResNets3D_kenshohara.resnet import generate_model
from wama_modules.utils import load_weights
m = generate_model(18)
pretrain_path = r"D:\pretrainedweights\ResNets3D_kenshohara\weights\resnet\r3d18_KM_200ep.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]
import torch
from wama_modules.thirdparty_lib.ResNets3D_kenshohara.resnet2p1d import generate_model
from wama_modules.utils import load_weights
m = generate_model(18)
pretrain_path = r"D:\pretrainedweights\ResNets3D_kenshohara\weights\resnet2p1d\r2p1d18_K_200ep.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]
3D VC3D_kenshohara (13 weights)
import torch
from wama_modules.thirdparty_lib.VC3D_kenshohara.resnet import generate_model
from wama_modules.utils import load_weights
m = generate_model(18)
pretrain_path = r"D:\pretrainedweights\VC3D_kenshohara\VC3D_weights\resnet\resnet-18-kinetics.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]
import torch
from wama_modules.thirdparty_lib.VC3D_kenshohara.resnext import generate_model
from wama_modules.utils import load_weights
m = generate_model(101)
pretrain_path = r"D:\pretrainedweights\VC3D_kenshohara\VC3D_weights\resnext\resnext-101-64f-kinetics.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]
import torch
from wama_modules.thirdparty_lib.VC3D_kenshohara.wide_resnet import generate_model
from wama_modules.utils import load_weights
m = generate_model()
pretrain_path = r"D:\pretrainedweights\VC3D_kenshohara\VC3D_weights\wideresnet\wideresnet-50-kinetics.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]
???
???
GlobalAvgPool
Global average poolingGlobalMaxPool
Global maximum poolingGlobalMaxAvgPool
GlobalMaxAvgPool = (GlobalAvgPool + GlobalMaxPool) / 2.
Click here to see demo code
""" demo """
# import libs
import torch
from wama_modules.BaseModule import GlobalAvgPool, GlobalMaxPool, GlobalMaxAvgPool
# make tensor
inputs1D = torch.ones([3,12,13]) # 1D
inputs2D = torch.ones([3,12,13,13]) # 2D
inputs3D = torch.ones([3,12,13,13,13]) # 3D
# build layer
GAP = GlobalAvgPool()
GMP = GlobalMaxPool()
GAMP = GlobalMaxAvgPool()
# test GAP & GMP & GAMP
print(inputs1D.shape, GAP(inputs1D).shape)
print(inputs2D.shape, GAP(inputs2D).shape)
print(inputs3D.shape, GAP(inputs3D).shape)
print(inputs1D.shape, GMP(inputs1D).shape)
print(inputs2D.shape, GMP(inputs2D).shape)
print(inputs3D.shape, GMP(inputs3D).shape)
print(inputs1D.shape, GAMP(inputs1D).shape)
print(inputs2D.shape, GAMP(inputs2D).shape)
print(inputs3D.shape, GAMP(inputs3D).shape)
customLayerNorm
a custom implementation of layer normalizationMakeNorm
make normalization layer, includes BN / GN / IN / LNMakeActive
make activation layer, includes Relu / LeakyReluMakeConv
make 1D / 2D / 3D convolutional layer
Click here to see demo code
""" demo """
ConvNormActive
'Convolution→Normalization→Activation', used in VGG or ResNetNormActiveConv
'Normalization→Activation→Convolution', used in DenseNetVGGBlock
the basic module in VGGVGGStage
a VGGStage = few VGGBlocksResBlock
the basic module in ResNetResStage
a ResStage = few ResBlocksDenseLayer
the basic module in DenseNetDenseBlock
a DenseBlock = few DenseLayers
Click here to see demo code
""" demo """
resizeTensor
scale torch tensor, similar to scipy's zoomtensor2array
transform tensor to ndarrayload_weights
load torch weights and print loading details(miss keys and match keys)
Click here to see demo code
""" demo """
SCSEModule
NonLocal
Click here to see demo code
""" demo """
VGGEncoder
ResNetEncoder
DenseNetEncoder
???
Click here to see demo code
""" demo """
UNet_decoder
Click here to see demo code
""" demo """
FPN
Click here to see demo code
""" demo """
import torch
from wama_modules.Neck import FPN
# make multi-scale feature maps
featuremaps = [
torch.ones([3,16,32,32,32]),
torch.ones([3,32,24,24,24]),
torch.ones([3,64,16,16,16]),
torch.ones([3,128,8,8,8]),
]
# build FPN
fpn_AddSmall2Big = FPN(in_channels_list=[16, 32, 64, 128],
c1=128,
c2=256,
active='relu',
norm='bn',
gn_c=8,
mode='AddSmall2Big',
dim=3,)
fpn_AddBig2Small = FPN(in_channels_list=[16, 32, 64, 128],
c1=128,
c2=256,
active='relu',
norm='bn',
gn_c=8,
mode='AddBig2Small', # Add big size feature to small size feature, for classification
dim=3,)
# forward
f_listA = fpn_AddSmall2Big(featuremaps)
f_listB = fpn_AddBig2Small(featuremaps)
_ = [print(i.shape) for i in featuremaps]
_ = [print(i.shape) for i in f_listA]
_ = [print(i.shape) for i in f_listB]
FeedForward
MultiHeadAttention
TransformerEncoderLayer
TransformerDecoderLayer
Click here to see demo code
""" demo """
Thanks to these authors and their codes: