Skip to content

Commit

Permalink
add
Browse files Browse the repository at this point in the history
  • Loading branch information
WAMAWAMA committed Nov 8, 2022
1 parent ad02226 commit a4b1622
Show file tree
Hide file tree
Showing 66 changed files with 6,703 additions and 26 deletions.
125 changes: 116 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ Highlights (*Simple-to-use & Function-rich!*)
- Output as many features as possible for fast reuse
- Support 1D / 2D / 3D networks
- Easy to integrate with any other networks
- 🚀 Pretrained weights (both 2D and 3D): 20+ `2D networks` and 30+ `3D networks`
- 🚀 Abundant Pretrained weights: Including 80000+ `2D weights` and 80+ `3D weights`

## 1. Installation
- 🔥1.1 [`wama_modules`](https://github.com/WAMAWAMA/wama_modules) (*Basic*)

Install *wama_modules* use ↓
```
pip install git+https://github.com/WAMAWAMA/wama_modules.git
Expand Down Expand Up @@ -382,21 +382,127 @@ current pretrained support: (这里给一个表格,来自哪里,多少权重


### 5.1 smp encoders `2D`
???

smp (119 pretrained weights)

```python
import torch
from wama_modules.thirdparty_lib.SMP_qubvel.encoders import get_encoder
m = get_encoder('resnet18', in_channels=3, depth=5, weights='ssl')
m = get_encoder('name', in_channels=3, depth=5, weights='ssl')
m = get_encoder('resnet18', in_channels=3, depth=5, weights='ss')
f_list = m(torch.ones([2,3,128,128]))
_ = [print(i.shape) for i in f_list]
```

### 5.2 timm encoders `2D`
???
timm (400+ pretrained weights)
```python
import timm
m = timm.create_model(
'adv_inception_v3',
features_only=True,
pretrained=False,)
f_list = m(torch.ones([2,3,128,128]))
_ = [print(i.shape) for i in f_list]
```
### 5.3 Transformers (🤗 Huggingface ) `2D`

### 5.2 radimagenet `2D` `medical image`
???
transformers, supper powered by Huggingface ( with 80000+ pretrained weights)

```python
import torch
from transformers import ConvNextModel
from wama_modules.utils import load_weights
# Initializing a model (with random weights) from the convnext-tiny-224 style configuration
m = ConvNextModel.from_pretrained('facebook/convnext-base-224-22k')
f = m(torch.ones([2,3,224,224]), output_hidden_states=True)
f_list = f.hidden_states
_ = [print(i.shape) for i in f_list]

### 5.3 ??? `3D` `video`
???
weights = m.state_dict()
m1 = ConvNextModel(m.config)
m = load_weights(m, weights)

### 5.3 ??? `3D` `video`

import torch
from transformers import SwinModel
from wama_modules.utils import load_weights

m = SwinModel.from_pretrained('microsoft/swin-base-patch4-window12-384')
f = m(torch.ones([2,3,384,384]), output_hidden_states=True)
f_list = f.reshaped_hidden_states # For transformer, should use reshaped_hidden_states
_ = [print(i.shape) for i in f_list]

weights = m.state_dict()
m1 = SwinModel(m.config)
m = load_weights(m, weights)



```

### 5.2 radimagenet `2D` `medical image`
???


### 5.3 ResNets3D_kenshohara `3D` `video`
3D ResNets3D_kenshohara (21 weights)
```python
import torch
from wama_modules.thirdparty_lib.ResNets3D_kenshohara.resnet import generate_model
from wama_modules.utils import load_weights
m = generate_model(18)
pretrain_path = r"D:\pretrainedweights\ResNets3D_kenshohara\weights\resnet\r3d18_KM_200ep.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]


import torch
from wama_modules.thirdparty_lib.ResNets3D_kenshohara.resnet2p1d import generate_model
from wama_modules.utils import load_weights
m = generate_model(18)
pretrain_path = r"D:\pretrainedweights\ResNets3D_kenshohara\weights\resnet2p1d\r2p1d18_K_200ep.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]
```
### 5.3 VC3D_kenshohara `3D` `video`
3D VC3D_kenshohara (13 weights)
```python
import torch
from wama_modules.thirdparty_lib.VC3D_kenshohara.resnet import generate_model
from wama_modules.utils import load_weights
m = generate_model(18)
pretrain_path = r"D:\pretrainedweights\VC3D_kenshohara\VC3D_weights\resnet\resnet-18-kinetics.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]

import torch
from wama_modules.thirdparty_lib.VC3D_kenshohara.resnext import generate_model
from wama_modules.utils import load_weights
m = generate_model(101)
pretrain_path = r"D:\pretrainedweights\VC3D_kenshohara\VC3D_weights\resnext\resnext-101-64f-kinetics.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]

import torch
from wama_modules.thirdparty_lib.VC3D_kenshohara.wide_resnet import generate_model
from wama_modules.utils import load_weights
m = generate_model()
pretrain_path = r"D:\pretrainedweights\VC3D_kenshohara\VC3D_weights\wideresnet\wideresnet-50-kinetics.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]
```
### 5.3 ??? `3D` `video`
???

Expand Down Expand Up @@ -488,6 +594,7 @@ print(inputs3D.shape, GAMP(inputs3D).shape)
### 6.2 `wama_modules.utils`
- `resizeTensor` scale torch tensor, similar to scipy's zoom
- `tensor2array` transform tensor to ndarray
- `load_weights` load torch weights and print loading details(miss keys and match keys)

<details>
<summary> Click here to see demo code </summary>
Expand Down
128 changes: 111 additions & 17 deletions tmp.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,79 @@
import torchvision
from torchvision.models.resnet import ResNet
from torchvision.models.resnet import BasicBlock
from torchvision.models.resnet import Bottleneck

from wama_modules.BaseModule import GlobalAvgPool
BasicBlock = GlobalAvgPool()
print(1)
print(1)
print(1)
print(1)
print(1)
print(1)

import torch
from wama_modules.thirdparty_lib.ResNets3D_kenshohara.models.resnet import generate_model
m = generate_model(18, n_classes = 1039)
m.load_state_dict(torch.load(r"D:\pretrainedweights\kenshohara_ResNets3D\weights\r3d18_KM_200ep.pth", map_location='cpu')['state_dict'])
# todo 1 3D ResNets3D_kenshohara (21 weights)
if True:
import torch
from wama_modules.thirdparty_lib.ResNets3D_kenshohara.resnet import generate_model
from wama_modules.utils import load_weights
m = generate_model(18)
pretrain_path = r"D:\pretrainedweights\ResNets3D_kenshohara\weights\resnet\r3d18_KM_200ep.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]


import torch
from wama_modules.thirdparty_lib.ResNets3D_kenshohara.resnet2p1d import generate_model
from wama_modules.utils import load_weights
m = generate_model(18)
pretrain_path = r"D:\pretrainedweights\ResNets3D_kenshohara\weights\resnet2p1d\r2p1d18_K_200ep.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]


# todo 2 3D VC3D_kenshohara (13 weights)
if True:
import torch
from wama_modules.thirdparty_lib.VC3D_kenshohara.resnet import generate_model
from wama_modules.utils import load_weights
m = generate_model(18)
pretrain_path = r"D:\pretrainedweights\VC3D_kenshohara\VC3D_weights\resnet\resnet-18-kinetics.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]

import torch
from wama_modules.thirdparty_lib.VC3D_kenshohara.resnext import generate_model
from wama_modules.utils import load_weights
m = generate_model(101)
pretrain_path = r"D:\pretrainedweights\VC3D_kenshohara\VC3D_weights\resnext\resnext-101-64f-kinetics.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]

import torch
from wama_modules.thirdparty_lib.VC3D_kenshohara.wide_resnet import generate_model
from wama_modules.utils import load_weights
m = generate_model()
pretrain_path = r"D:\pretrainedweights\VC3D_kenshohara\VC3D_weights\wideresnet\wideresnet-50-kinetics.pth"
pretrain_weights = torch.load(pretrain_path, map_location='cpu')['state_dict']
m = load_weights(m, pretrain_weights, drop_modelDOT=True)
f_list = m(torch.ones([2,3,64,64,64]))
_ = [print(i.shape) for i in f_list]


# todo 3 3D Efficient3D_okankop (39 weights)



# todo 4 3D MedicalNet_tencent (11 weights)



# todo 5 3D C3D_jfzhang95 (1 weight)



# todo 6 3D C3D_yyuanad (1 weight)



# todo 7 2D smp (119 weight)
# smp
import torch
from wama_modules.thirdparty_lib.SMP_qubvel.encoders import get_encoder
m = get_encoder('resnet18', in_channels=3, depth=5, weights='ssl')
Expand All @@ -28,10 +82,50 @@
f_list = m(torch.ones([2,3,128,128]))
_ = [print(i.shape) for i in f_list]


# todo 8 timm (400+)
import timm
m = timm.create_model(
'adv_inception_v3',
features_only=True,
pretrained=False,)
f_list = m(torch.ones([2,3,128,128]))
_ = [print(i.shape) for i in f_list]



# todo 9 transformers (80000+ weights)
import torch
from transformers import ConvNextModel
from wama_modules.utils import load_weights
# Initializing a model (with random weights) from the convnext-tiny-224 style configuration
m = ConvNextModel.from_pretrained('facebook/convnext-base-224-22k')
f = m(torch.ones([2,3,224,224]), output_hidden_states=True)
f_list = f.hidden_states
_ = [print(i.shape) for i in f_list]

weights = m.state_dict()
m1 = ConvNextModel(m.config)
m = load_weights(m, weights)


import torch
from transformers import SwinModel
from wama_modules.utils import load_weights

m = SwinModel.from_pretrained('microsoft/swin-base-patch4-window12-384')
f = m(torch.ones([2,3,384,384]), output_hidden_states=True)
f_list = f.reshaped_hidden_states # For transformer, should use reshaped_hidden_states
_ = [print(i.shape) for i in f_list]

weights = m.state_dict()
m1 = SwinModel(m.config)
m = load_weights(m, weights)








Binary file modified wama_modules/__pycache__/utils.cpython-38.pyc
Binary file not shown.
Loading

0 comments on commit a4b1622

Please sign in to comment.