Skip to content

A PyTorch Computer Vision (CV) module library for building n-D networks flexibly ~

License

Notifications You must be signed in to change notification settings

Ainimal/Aini_Modules

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

27 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ωαмα m⚙️dules

(🚧 still under building, but current module implementations are work...)

*A PyTorch module library for building 1D/2D/3D networks flexibly ~*

Highlights (Simple-to-use & Function-rich!)

  • Simple code that show all forward processes succinctly
  • Output as many features as possible for fast reuse
  • Support 1D / 2D / 3D networks
  • Easy to integrate with any other networks
  • 🚀 Abundant Pretrained weights: Including 80000+ 2D weights and 80+ 3D weights

1. Installation

Install wama_modules use ↓

pip install git+https://github.com/WAMAWAMA/wama_modules.git

Or you can directly copy the wama_modules folder to use

Introduction and installation command

segmentation_models_pytorch (called smp) is a 2D CNN lib includes many backbones and decoders, which is highly recommended to install for cooperating with this library.

Install with pip:

pip install segmentation-models-pytorch

Install the latest version:

pip install git+https://github.com/rwightman/pytorch-image-models.git
Introduction and installation command

transformer is a lib includes abundant Transformer structures, which is highly recommended to install for cooperating with this library. Install transformer use ↓

pip install transformers
  • 💧1.4 timm (Optional)
Introduction and installation command

timm*` is a lib includes abundant CNN and Transformer structures, which is highly recommended to install for cooperating with this library. Install transformer use ↓

Install with pip:

pip install timm

Install the latest version:

pip install git+https://github.com/rwightman/pytorch-image-models.git

2. Update list

  • 2022/11/5: Open the source code, version v0.0.1-beta
  • ...

3. Main modules and network architectures

Here I'll give an overview of this repo

4. Guideline 1: Build networks modularly

How to build a network modularly?

The answer is a paradigm of building networks:

'Design architecture according to tasks, pick modules according to architecture'

So, network architectures for different tasks can be viewed modularly such as:

  • vgg = vgg_encoder + cls_head
  • Unet = encoder + decoder + seg_ead
  • resnet = resnet_encoder + cls_head
  • densenet = densenet_encoder + cls_head
  • a multi-task net for classification and segmentation = encoder + decoder + cls_head + seg_head

For example, build a 3D resnet50

import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])

Here are more demos shown below ↓ (Click to view codes, or visit the demo folder)

Demo1: Build a 2D vgg16
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo2: Build a 3D resnet50
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo3: Build a 3D densenet121
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo4: Build a Unet
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo5: Build a Unet with a resnet50 encoder
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo6: Build a Unet with a resnet50 encoder and a FPN
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo7: Build a multi-task model for segmentation and classification
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo8: Build a C-tran model for multi-label classification
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo9: Build a Q2L model for multi-label classification
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo10: Build a ML-Decoder model for multi-label classification
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo11: Build a ML-GCN model for multi-label classification
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo12: Build a UCTransNet model for segmentation
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo13: Build a model for multiple inputs (1D signal and 2D image)
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo14: Build a 2D Unet with pretrained Resnet50 encoder (1D signal and 2D image)
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo15: Build a 3D DETR model for object detection
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])
Demo16: Build a 3D VGG with SE-attention module for multi-instanse classification
import wama_modules as ws
import torch

encoder = ws.resnet(input_channel = 3, per_stage_channel = [8,16,32,64], dim=3)
decoder = ws.unet(encoder = encoder, output_channel = 3, dim=3)

input = torch.ones([3,3,128,128])

5. Guideline 2: Use pretrained weights

All pretrained weights are from third-party codes or repos

current pretrained support: (这里给一个表格,来自哪里,多少权重,预训练数据类型,2D还是3D))

  • 2D: smp, timm, radimagenet...
  • 3D: medicalnet, 3D resnet, 3D densenet...

5.1 smp encoders 2D

smp (119 pretrained weights)

import torch
from wama_modules.thirdparty_lib.SMP_qubvel.encoders import get_encoder
m = get_encoder('resnet18', in_channels=3, depth=5, weights='ssl')
m = get_encoder('name', in_channels=3, depth=5, weights='ssl')
m = get_encoder('resnet18', in_channels=3, depth=5, weights='ss')
f_list = m(torch.ones([2,3,128,128]))
_ = [print(i.shape) for i in f_list]

5.2 timm encoders 2D

timm (400+ pretrained weights)

import timm
m = timm.create_model(
    'adv_inception_v3',
    features_only=True,
    pretrained=False,)
f_list = m(torch.ones([2,3,128,128]))
_ = [print(i.shape) for i in f_list]

5.3 Transformers (🤗 Huggingface ) 2D

transformers, supper powered by Huggingface ( with 80000+ pretrained weights)

import torch
from transformers import ConvNextModel
from wama_modules.utils import load_weights
# Initializing a model (with random weights) from the convnext-tiny-224 style configuration
m = ConvNextModel.from_pretrained('facebook/convnext-base-224-22k')
f = m(torch.ones([2,3,224,224]), output_hidden_states=True)
f_list = f.hidden_states
_ = [print(i.shape) for i in f_list]

weights = m.state_dict()
m1 = ConvNextModel(m.config)
m = load_weights(m, weights)


import torch
from transformers import SwinModel
from wama_modules.utils import load_weights

m = SwinModel.from_pretrained('microsoft/swin-base-patch4-window12-384')
f = m(torch.ones([2,3,384,384]), output_hidden_states=True)
f_list = f.reshaped_hidden_states # For transformer, should use reshaped_hidden_states
_ = [print(i.shape) for i in f_list]

weights = m.state_dict()
m1 = SwinModel(m.config)
m = load_weights(m, weights)

5.2 radimagenet 2D medical image

???

5.3 ResNets3D_kenshohara 3D video

3D ResNets3D_kenshohara (21 weights)

 import torch
    from wama_modules.thirdparty_lib.ResNets3D_kenshohara.resnet import generate_model
    from wama_modules.utils import load_weights
    m = generate_model(18)
    pretrain_path = r"D:\pretrainedweights\ResNets3D_kenshohara\weights\resnet\r3d18_KM_200ep.pth"
    pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
    m = load_weights(m, pretrain_weights)
    f_list = m(torch.ones([2,3,64,64,64]))
    _ = [print(i.shape) for i in f_list]


    import torch
    from wama_modules.thirdparty_lib.ResNets3D_kenshohara.resnet2p1d import generate_model
    from wama_modules.utils import load_weights
    m = generate_model(18)
    pretrain_path = r"D:\pretrainedweights\ResNets3D_kenshohara\weights\resnet2p1d\r2p1d18_K_200ep.pth"
    pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
    m = load_weights(m, pretrain_weights)
    f_list = m(torch.ones([2,3,64,64,64]))
    _ = [print(i.shape) for i in f_list]

5.3 VC3D_kenshohara 3D video

3D VC3D_kenshohara (13 weights)

 import torch
    from wama_modules.thirdparty_lib.VC3D_kenshohara.resnet import generate_model
    from wama_modules.utils import load_weights
    m = generate_model(18)
    pretrain_path = r"D:\pretrainedweights\VC3D_kenshohara\VC3D_weights\resnet\resnet-18-kinetics.pth"
    pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
    m = load_weights(m, pretrain_weights, drop_modelDOT=True)
    f_list = m(torch.ones([2,3,64,64,64]))
    _ = [print(i.shape) for i in f_list]

    import torch
    from wama_modules.thirdparty_lib.VC3D_kenshohara.resnext import generate_model
    from wama_modules.utils import load_weights
    m = generate_model(101)
    pretrain_path = r"D:\pretrainedweights\VC3D_kenshohara\VC3D_weights\resnext\resnext-101-64f-kinetics.pth"
    pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
    m = load_weights(m, pretrain_weights, drop_modelDOT=True)
    f_list = m(torch.ones([2,3,64,64,64]))
    _ = [print(i.shape) for i in f_list]

    import torch
    from wama_modules.thirdparty_lib.VC3D_kenshohara.wide_resnet import generate_model
    from wama_modules.utils import load_weights
    m = generate_model()
    pretrain_path = r"D:\pretrainedweights\VC3D_kenshohara\VC3D_weights\wideresnet\wideresnet-50-kinetics.pth"
    pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
    m = load_weights(m, pretrain_weights, drop_modelDOT=True)
    f_list = m(torch.ones([2,3,64,64,64]))
    _ = [print(i.shape) for i in f_list]

5.3 ??? 3D video

???

5.3 ??? 3D medical image

???

6. All modules and functions

6.1 wama_modules.BaseModule

6.1.1 Pooling

  • GlobalAvgPool Global average pooling
  • GlobalMaxPool Global maximum pooling
  • GlobalMaxAvgPool GlobalMaxAvgPool = (GlobalAvgPool + GlobalMaxPool) / 2.
Click here to see demo code
""" demo """
# import libs
import torch
from wama_modules.BaseModule import GlobalAvgPool, GlobalMaxPool, GlobalMaxAvgPool

# make tensor
inputs1D = torch.ones([3,12,13]) # 1D
inputs2D = torch.ones([3,12,13,13]) # 2D
inputs3D = torch.ones([3,12,13,13,13]) # 3D

# build layer
GAP = GlobalAvgPool()
GMP = GlobalMaxPool()
GAMP = GlobalMaxAvgPool()

# test GAP & GMP & GAMP
print(inputs1D.shape, GAP(inputs1D).shape)
print(inputs2D.shape, GAP(inputs2D).shape)
print(inputs3D.shape, GAP(inputs3D).shape)

print(inputs1D.shape, GMP(inputs1D).shape)
print(inputs2D.shape, GMP(inputs2D).shape)
print(inputs3D.shape, GMP(inputs3D).shape)

print(inputs1D.shape, GAMP(inputs1D).shape)
print(inputs2D.shape, GAMP(inputs2D).shape)
print(inputs3D.shape, GAMP(inputs3D).shape)

5.1.2 Norm&Activation

  • customLayerNorm a custom implementation of layer normalization
  • MakeNorm make normalization layer, includes BN / GN / IN / LN
  • MakeActive make activation layer, includes Relu / LeakyRelu
  • MakeConv make 1D / 2D / 3D convolutional layer
Click here to see demo code
""" demo """

5.1.3 Conv

  • ConvNormActive 'Convolution→Normalization→Activation', used in VGG or ResNet
  • NormActiveConv 'Normalization→Activation→Convolution', used in DenseNet
  • VGGBlock the basic module in VGG
  • VGGStage a VGGStage = few VGGBlocks
  • ResBlock the basic module in ResNet
  • ResStage a ResStage = few ResBlocks
  • DenseLayer the basic module in DenseNet
  • DenseBlock a DenseBlock = few DenseLayers
Click here to see demo code
""" demo """

6.2 wama_modules.utils

  • resizeTensor scale torch tensor, similar to scipy's zoom
  • tensor2array transform tensor to ndarray
  • load_weights load torch weights and print loading details(miss keys and match keys)
Click here to see demo code
""" demo """

6.3 wama_modules.Attention

  • SCSEModule
  • NonLocal
Click here to see demo code
""" demo """

5.4 wama_modules.Encoder

  • VGGEncoder
  • ResNetEncoder
  • DenseNetEncoder
  • ???
Click here to see demo code
""" demo """

5.5 wama_modules.Decoder

  • UNet_decoder
Click here to see demo code
""" demo """

5.6 wama_modules.Neck

  • FPN
Click here to see demo code
""" demo """
import torch
from wama_modules.Neck import FPN

# make multi-scale feature maps
featuremaps = [
    torch.ones([3,16,32,32,32]),
    torch.ones([3,32,24,24,24]),
    torch.ones([3,64,16,16,16]),
    torch.ones([3,128,8,8,8]),
]

# build FPN
fpn_AddSmall2Big = FPN(in_channels_list=[16, 32, 64, 128],
         c1=128,
         c2=256,
         active='relu',
         norm='bn',
         gn_c=8,
         mode='AddSmall2Big',
         dim=3,)
fpn_AddBig2Small = FPN(in_channels_list=[16, 32, 64, 128],
         c1=128,
         c2=256,
         active='relu',
         norm='bn',
         gn_c=8,
         mode='AddBig2Small', # Add big size feature to small size feature, for classification
         dim=3,)

# forward
f_listA = fpn_AddSmall2Big(featuremaps)
f_listB = fpn_AddBig2Small(featuremaps)
_ = [print(i.shape) for i in featuremaps]
_ = [print(i.shape) for i in f_listA]
_ = [print(i.shape) for i in f_listB]

5.7 wama_modules.Transformer

  • FeedForward
  • MultiHeadAttention
  • TransformerEncoderLayer
  • TransformerDecoderLayer
Click here to see demo code
""" demo """

7. Acknowledgment 🥰

Thanks to these authors and their codes:

  1. https://github.com/ZhugeKongan/torch-template-for-deep-learning
  2. pytorch vit
  3. SMP: https://github.com/qubvel/segmentation_models.pytorch
  4. transformers
  5. medicalnet
  6. timm: https://github.com/rwightman/pytorch-image-models