Skip to content

A PyTorch Computer Vision (CV) module library for building n-D networks flexibly ~


Notifications You must be signed in to change notification settings


Folders and files

Last commit message
Last commit date

Latest commit



26 Commits

Repository files navigation

ωαмα m⚙️dules

(🚧 still under building, but current module implementations are work...)

*A PyTorch module library for building 1D/2D/3D networks flexibly ~*

Highlights (Simple-to-use & Function-rich!)

  • Simple code that show all forward processes succinctly
  • Output as many features as possible for fast reuse
  • Support 1D / 2D / 3D networks
  • Easy to integrate with any other networks
  • 🚀 Pretrained weights (both 2D and 3D): 20+ 2D networks and 30+ 3D networks

1. Installation

Install wama_modules use ↓

pip install git+

Or you can directly copy the wama_modules folder to use

Introduction and installation command

segmentation_models_pytorch (called smp) is a 2D CNN lib includes many backbones and decoders, which is highly recommended to install for cooperating with this library.

Install with pip:

pip install segmentation-models-pytorch

Install the latest version:

pip install git+
Introduction and installation command

transformer is a lib includes abundant Transformer structures, which is highly recommended to install for cooperating with this library. Install transformer use ↓

pip install transformers
  • 💧1.4 timm (Optional)
Introduction and installation command

timm*` is a lib includes abundant CNN and Transformer structures, which is highly recommended to install for cooperating with this library. Install transformer use ↓

Install with pip:

pip install timm

Install the latest version:

pip install git+

2. Update list

  • 2022/11/5: Open the source code, version v0.0.1-beta
  • ...

3. Main modules and network architectures

Here I'll give an overview of this repo

4. Guideline 1: Build networks modularly

How to build a network modularly?

The answer is a paradigm of building networks:

'Design architecture according to tasks, pick modules according to architecture'

So, network architectures for different tasks can be viewed modularly such as:

  • vgg = vgg_encoder + cls_head
  • Unet = encoder + decoder + seg_ead
  • resnet = resnet_encoder + cls_head
  • densenet = densenet_encoder + cls_head
  • a multi-task net for classification and segmentation = encoder + decoder + cls_head + seg_head

For example, build a 3D resnet50

import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])

Here are more demos shown below ↓ (Click to view codes, or visit the demo folder)

Demo1: Build a 2D vgg16
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo2: Build a 3D resnet50
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo3: Build a 3D densenet121
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo4: Build a Unet
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo5: Build a Unet with a resnet50 encoder
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo6: Build a Unet with a resnet50 encoder and a FPN
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo7: Build a multi-task model for segmentation and classification
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo8: Build a C-tran model for multi-label classification
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo9: Build a Q2L model for multi-label classification
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo10: Build a ML-Decoder model for multi-label classification
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo11: Build a ML-GCN model for multi-label classification
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo12: Build a UCTransNet model for segmentation
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo13: Build a model for multiple inputs (1D signal and 2D image)
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo14: Build a 2D Unet with pretrained Resnet50 encoder (1D signal and 2D image)
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo15: Build a 3D DETR model for object detection
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo16: Build a 3D VGG with SE-attention module for multi-instanse classification
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])

5. Guideline 2: Use pretrained weights

All pretrained weights are from third-party codes or repos

current pretrained support: (这里给一个表格,来自哪里,多少权重,预训练数据类型,2D还是3D))

  • 2D: smp, timm, radimagenet...
  • 3D: medicalnet, 3D resnet, 3D densenet...

5.1 smp encoders 2D


5.2 timm encoders 2D


5.2 radimagenet 2D medical image


5.3 ??? 3D video


5.3 ??? 3D video


5.3 ??? 3D video


5.3 ??? 3D medical image


6. All modules and functions

6.1 wama_modules.BaseModule

6.1.1 Pooling

  • GlobalAvgPool Global average pooling
  • GlobalMaxPool Global maximum pooling
  • GlobalMaxAvgPool GlobalMaxAvgPool = (GlobalAvgPool + GlobalMaxPool) / 2.
Click here to see demo code
""" demo """
# import libs
import torch
from wama_modules.BaseModule import GlobalAvgPool, GlobalMaxPool, GlobalMaxAvgPool

# make tensor
inputs1D = torch.ones([3,12,13]) # 1D
inputs2D = torch.ones([3,12,13,13]) # 2D
inputs3D = torch.ones([3,12,13,13,13]) # 3D

# build layer
GAP = GlobalAvgPool()
GMP = GlobalMaxPool()
GAMP = GlobalMaxAvgPool()

# test GAP & GMP & GAMP
print(inputs1D.shape, GAP(inputs1D).shape)
print(inputs2D.shape, GAP(inputs2D).shape)
print(inputs3D.shape, GAP(inputs3D).shape)

print(inputs1D.shape, GMP(inputs1D).shape)
print(inputs2D.shape, GMP(inputs2D).shape)
print(inputs3D.shape, GMP(inputs3D).shape)

print(inputs1D.shape, GAMP(inputs1D).shape)
print(inputs2D.shape, GAMP(inputs2D).shape)
print(inputs3D.shape, GAMP(inputs3D).shape)

5.1.2 Norm&Activation

  • customLayerNorm a custom implementation of layer normalization
  • MakeNorm make normalization layer, includes BN / GN / IN / LN
  • MakeActive make activation layer, includes Relu / LeakyRelu
  • MakeConv make 1D / 2D / 3D convolutional layer
Click here to see demo code
""" demo """

5.1.3 Conv

  • ConvNormActive 'Convolution→Normalization→Activation', used in VGG or ResNet
  • NormActiveConv 'Normalization→Activation→Convolution', used in DenseNet
  • VGGBlock the basic module in VGG
  • VGGStage a VGGStage = few VGGBlocks
  • ResBlock the basic module in ResNet
  • ResStage a ResStage = few ResBlocks
  • DenseLayer the basic module in DenseNet
  • DenseBlock a DenseBlock = few DenseLayers
Click here to see demo code
""" demo """

6.2 wama_modules.utils

  • resizeTensor scale torch tensor, similar to scipy's zoom
  • tensor2array transform tensor to ndarray
Click here to see demo code
""" demo """

6.3 wama_modules.Attention

  • SCSEModule
  • NonLocal
Click here to see demo code
""" demo """

5.4 wama_modules.Encoder

  • VGGEncoder
  • ResNetEncoder
  • DenseNetEncoder
  • ???
Click here to see demo code
""" demo """

5.5 wama_modules.Decoder

  • UNet_decoder
Click here to see demo code
""" demo """

5.6 wama_modules.Neck

  • FPN
Click here to see demo code
""" demo """
import torch
from wama_modules.Neck import FPN

# make multi-scale feature maps
featuremaps = [

# build FPN
fpn_AddSmall2Big = FPN(in_channels_list=[16, 32, 64, 128],
fpn_AddBig2Small = FPN(in_channels_list=[16, 32, 64, 128],
         mode='AddBig2Small', # Add big size feature to small size feature, for classification

# forward
f_listA = fpn_AddSmall2Big(featuremaps)
f_listB = fpn_AddBig2Small(featuremaps)
_ = [print(i.shape) for i in featuremaps]
_ = [print(i.shape) for i in f_listA]
_ = [print(i.shape) for i in f_listB]

5.7 wama_modules.Transformer

  • FeedForward
  • MultiHeadAttention
  • TransformerEncoderLayer
  • TransformerDecoderLayer
Click here to see demo code
""" demo """

7. Acknowledgment 🥰

Thanks to these authors and their codes:

  2. pytorch vit
  3. SMP:
  4. transformers
  5. medicalnet
  6. timm: