Skip to content
This repository has been archived by the owner on Mar 17, 2021. It is now read-only.

Commit

Permalink
fixes se_resnet spatial avg.
Browse files Browse the repository at this point in the history
  • Loading branch information
wyli committed Dec 18, 2019
1 parent 2cde540 commit 2952f2a
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 52 deletions.
131 changes: 79 additions & 52 deletions niftynet/network/se_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@

import tensorflow as tf

from niftynet.layer.bn import BNLayer
from niftynet.layer.fully_connected import FCLayer
from niftynet.layer import layer_util
from niftynet.layer.base_layer import TrainableLayer
from niftynet.layer.bn import BNLayer
from niftynet.layer.convolution import ConvolutionalLayer
from niftynet.layer.fully_connected import FCLayer
from niftynet.layer.squeeze_excitation import ChannelSELayer
from niftynet.network.base_net import BaseNet

SE_ResNetDesc = namedtuple('SE_ResNetDesc', ['bn', 'fc', 'conv1', 'blocks'])


class SE_ResNet(BaseNet):
"""
### Description
Expand All @@ -35,11 +38,10 @@ class SE_ResNet(BaseNet):
### Constraints
"""

def __init__(self,
num_classes,
n_features = [16, 64, 128],
n_blocks_per_resolution = 1,
n_features=[16, 64, 128],
n_blocks_per_resolution=1,
w_initializer=None,
w_regularizer=None,
b_initializer=None,
Expand All @@ -59,14 +61,13 @@ def __init__(self,
:param name: layer name
"""

super(SE_ResNet, self).__init__(
num_classes=num_classes,
w_initializer=w_initializer,
w_regularizer=w_regularizer,
b_initializer=b_initializer,
b_regularizer=b_regularizer,
acti_func=acti_func,
name=name)
super(SE_ResNet, self).__init__(num_classes=num_classes,
w_initializer=w_initializer,
w_regularizer=w_regularizer,
b_initializer=b_initializer,
b_regularizer=b_regularizer,
acti_func=acti_func,
name=name)

self.n_features = n_features
self.n_blocks_per_resolution = n_blocks_per_resolution
Expand All @@ -83,14 +84,21 @@ def create(self):
:return: tuple with batch norm layer, fully connected layer, first conv layer and all residual blocks
"""
bn=BNLayer()
fc=FCLayer(self.num_classes)
conv1=self.Conv(self.n_features[0], acti_func=None, feature_normalization=None)
blocks=[]
blocks+=[DownResBlock(self.n_features[1], self.n_blocks_per_resolution, 1, self.Conv)]
bn = BNLayer()
fc = FCLayer(self.num_classes)
conv1 = self.Conv(self.n_features[0],
acti_func=None,
feature_normalization=None)
blocks = []
blocks += [
DownResBlock(self.n_features[1], self.n_blocks_per_resolution, 1,
self.Conv)
]
for n in self.n_features[2:]:
blocks+=[DownResBlock(n, self.n_blocks_per_resolution, 2, self.Conv)]
return SE_ResNetDesc(bn=bn,fc=fc,conv1=conv1,blocks=blocks)
blocks += [
DownResBlock(n, self.n_blocks_per_resolution, 2, self.Conv)
]
return SE_ResNetDesc(bn=bn, fc=fc, conv1=conv1, blocks=blocks)

def layer_op(self, images, is_training=True, **unused_kwargs):
"""
Expand All @@ -104,11 +112,19 @@ def layer_op(self, images, is_training=True, **unused_kwargs):
out = layers.conv1(images, is_training)
for block in layers.blocks:
out = block(out, is_training)
out = tf.reduce_mean(tf.nn.relu(layers.bn(out, is_training)),axis=[1,2,3])

spatial_rank = layer_util.infer_spatial_rank(out)
axis_to_avg = [dim + 1 for dim in range(spatial_rank)]
out = tf.reduce_mean(tf.nn.relu(layers.bn(out, is_training)),
axis=axis_to_avg)
return layers.fc(out)



BottleneckBlockDesc1 = namedtuple('BottleneckBlockDesc1', ['conv'])
BottleneckBlockDesc2 = namedtuple('BottleneckBlockDesc2', ['common_bn', 'conv', 'conv_shortcut'])
BottleneckBlockDesc2 = namedtuple('BottleneckBlockDesc2',
['common_bn', 'conv', 'conv_shortcut'])


class BottleneckBlock(TrainableLayer):
def __init__(self, n_output_chns, stride, Conv, name='bottleneck'):
"""
Expand All @@ -119,11 +135,11 @@ def __init__(self, n_output_chns, stride, Conv, name='bottleneck'):
:param name: layer name
"""
self.n_output_chns = n_output_chns
self.stride=stride
self.stride = stride
self.bottle_neck_chns = n_output_chns // 4
self.Conv = Conv
super(BottleneckBlock, self).__init__(name=name)

def create(self, input_chns):
"""
Expand All @@ -132,21 +148,29 @@ def create(self, input_chns):
"""

if self.n_output_chns == input_chns:
b1 = self.Conv(self.bottle_neck_chns, kernel_size=1,
b1 = self.Conv(self.bottle_neck_chns,
kernel_size=1,
stride=self.stride)
b2 = self.Conv(self.bottle_neck_chns, kernel_size=3)
b3 = self.Conv(self.n_output_chns, 1)
return BottleneckBlockDesc1(conv=[b1, b2, b3])
else:
b1 = BNLayer()
b2 = self.Conv(self.bottle_neck_chns,kernel_size=1,
stride=self.stride, acti_func=None, feature_normalization=None)
b3 = self.Conv(self.bottle_neck_chns,kernel_size=3)
b4 = self.Conv(self.n_output_chns,kernel_size=1)
b5 = self.Conv(self.n_output_chns,kernel_size=1,
stride=self.stride, acti_func=None,feature_normalization=None)
return BottleneckBlockDesc2(common_bn=b1, conv=[b2, b3, b4],
conv_shortcut=b5)
b2 = self.Conv(self.bottle_neck_chns,
kernel_size=1,
stride=self.stride,
acti_func=None,
feature_normalization=None)
b3 = self.Conv(self.bottle_neck_chns, kernel_size=3)
b4 = self.Conv(self.n_output_chns, kernel_size=1)
b5 = self.Conv(self.n_output_chns,
kernel_size=1,
stride=self.stride,
acti_func=None,
feature_normalization=None)
return BottleneckBlockDesc2(common_bn=b1,
conv=[b2, b3, b4],
conv_shortcut=b5)

def layer_op(self, images, is_training=True):
"""
Expand All @@ -156,24 +180,27 @@ def layer_op(self, images, is_training=True):
:return: tensor, output of the BottleNeck block
"""
layers = self.create(images.shape[-1])
se=ChannelSELayer()
se = ChannelSELayer()
if self.n_output_chns == images.shape[-1]:
out=layers.conv[0](images, is_training)
out=layers.conv[1](out, is_training)
out=layers.conv[2](out, is_training)
out=se(out)
out = out+images
out = layers.conv[0](images, is_training)
out = layers.conv[1](out, is_training)
out = layers.conv[2](out, is_training)
out = se(out)
out = out + images
else:
tmp = tf.nn.relu(layers.common_bn(images, is_training))
out=layers.conv[0](tmp, is_training)
out=layers.conv[1](out, is_training)
out=layers.conv[2](out, is_training)
out=se(out)
out = layers.conv[0](tmp, is_training)
out = layers.conv[1](out, is_training)
out = layers.conv[2](out, is_training)
out = se(out)
out = layers.conv_shortcut(tmp, is_training) + out
print(out.shape)
return out


DownResBlockDesc = namedtuple('DownResBlockDesc', ['blocks'])


class DownResBlock(TrainableLayer):
def __init__(self, n_output_chns, count, stride, Conv, name='downres'):
"""
Expand All @@ -187,20 +214,20 @@ def __init__(self, n_output_chns, count, stride, Conv, name='downres'):
self.count = count
self.stride = stride
self.n_output_chns = n_output_chns
self.Conv=Conv
self.Conv = Conv
super(DownResBlock, self).__init__(name=name)

def create(self):
"""
:return: tuple, containing all the Bottleneck blocks composing the DownRes block
"""
blocks=[]
blocks+=[BottleneckBlock(self.n_output_chns, self.stride, self.Conv)]
for it in range(1,self.count):
blocks+=[BottleneckBlock(self.n_output_chns, 1, self.Conv)]
blocks = []
blocks += [BottleneckBlock(self.n_output_chns, self.stride, self.Conv)]
for it in range(1, self.count):
blocks += [BottleneckBlock(self.n_output_chns, 1, self.Conv)]
return DownResBlockDesc(blocks=blocks)

def layer_op(self, images, is_training):
"""
Expand All @@ -211,5 +238,5 @@ def layer_op(self, images, is_training):
layers = self.create()
out = images
for l in layers.blocks:
out=l(out,is_training)
out = l(out, is_training)
return out
68 changes: 68 additions & 0 deletions tests/se_resnet_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import absolute_import, print_function

import unittest

import tensorflow as tf
from tensorflow.contrib.layers.python.layers import regularizers

from niftynet.network.se_resnet import SE_ResNet
from tests.niftynet_testcase import NiftyNetTestCase

class SeResNet3DTest(NiftyNetTestCase):
def test_3d_shape(self):
input_shape = (2, 8, 16, 32, 1)
x = tf.ones(input_shape)

resnet_instance = SE_ResNet(num_classes=160)
out = resnet_instance(x, is_training=True)
print(resnet_instance.num_trainable_params())

with self.cached_session() as sess:
sess.run(tf.global_variables_initializer())
out = sess.run(out)
self.assertAllClose((2, 160), out.shape)

def test_2d_shape(self):
input_shape = (2, 8, 16, 1)
x = tf.ones(input_shape)

resnet_instance = SE_ResNet(num_classes=160)
out = resnet_instance(x, is_training=True)
print(resnet_instance.num_trainable_params())

with self.cached_session() as sess:
sess.run(tf.global_variables_initializer())
out = sess.run(out)
self.assertAllClose((2, 160), out.shape)

def test_3d_reg_shape(self):
input_shape = (2, 8, 16, 24, 1)
x = tf.ones(input_shape)

resnet_instance = SE_ResNet(num_classes=160,
w_regularizer=regularizers.l2_regularizer(0.4))
out = resnet_instance(x, is_training=True)
print(resnet_instance.num_trainable_params())

with self.cached_session() as sess:
sess.run(tf.global_variables_initializer())
out = sess.run(out)
self.assertAllClose((2, 160), out.shape)

def test_2d_reg_shape(self):
input_shape = (2, 8, 16, 1)
x = tf.ones(input_shape)

resnet_instance = SE_ResNet(num_classes=160,
w_regularizer=regularizers.l2_regularizer(0.4))
out = resnet_instance(x, is_training=True)
print(resnet_instance.num_trainable_params())

with self.cached_session() as sess:
sess.run(tf.global_variables_initializer())
out = sess.run(out)
self.assertAllClose((2, 160), out.shape)


if __name__ == "__main__":
tf.test.main()

0 comments on commit 2952f2a

Please sign in to comment.