Skip to content

Commit

Permalink
Merge pull request bonlime#1 from brianmanderson/tf2_branch
Browse files Browse the repository at this point in the history
Tf2 branch
  • Loading branch information
brianmanderson authored Apr 16, 2020
2 parents 2860c1d + 5bd4f5c commit de8ff33
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions model.py → model_tf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,27 @@
from __future__ import division
from __future__ import print_function

import tensorflow.compat.v1 as tf

from tensorflow.python.keras.models import Model
from tensorflow.python.keras import layers
from tensorflow.python.keras.layers import Input
from tensorflow.python.keras.layers import Lambda
from tensorflow.python.keras.layers import Activation
from tensorflow.python.keras.layers import Concatenate
from tensorflow.python.keras.layers import Add
from tensorflow.python.keras.layers import Dropout
from tensorflow.python.keras.layers import BatchNormalization
from tensorflow.python.keras.layers import Conv2D
from tensorflow.python.keras.layers import DepthwiseConv2D
from tensorflow.python.keras.layers import ZeroPadding2D
from tensorflow.python.keras.layers import GlobalAveragePooling2D
from tensorflow.python.keras.layers import UpSampling2D
from tensorflow.python.keras.utils.layer_utils import get_source_inputs
from tensorflow.python.keras.utils.data_utils import get_file
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.activations import relu
from tensorflow.python.keras.applications.imagenet_utils import preprocess_input
import tensorflow as tf
from tensorflow.keras.mixed_precision import experimental as mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)
from tensorflow.keras.models import Model
from tensorflow.keras import layers
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Lambda
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Add
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import DepthwiseConv2D
from tensorflow.keras.layers import ZeroPadding2D
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.utils import get_file
from tensorflow.keras.utils import get_source_inputs
from tensorflow.keras.applications.imagenet_utils import preprocess_input
import tensorflow.keras.backend as K

WEIGHTS_PATH_X = "https://github.com/bonlime/keras-deeplab-v3-plus/releases/download/1.1/deeplabv3_xception_tf_dim_ordering_tf_kernels.h5"
WEIGHTS_PATH_MOBILE = "https://github.com/bonlime/keras-deeplab-v3-plus/releases/download/1.1/deeplabv3_mobilenetv2_tf_dim_ordering_tf_kernels.h5"
Expand Down Expand Up @@ -169,7 +169,7 @@ def _make_divisible(v, divisor, min_value=None):


def _inverted_res_block(inputs, expansion, stride, alpha, filters, block_id, skip_connection, rate=1):
in_channels = inputs.shape[-1].value # inputs._keras_shape[-1]
in_channels = inputs.shape[-1] # inputs._keras_shape[-1]
pointwise_conv_filters = int(filters * alpha)
pointwise_filters = _make_divisible(pointwise_conv_filters, 8)
x = inputs
Expand Down Expand Up @@ -364,8 +364,8 @@ def Deeplabv3(weights='pascal_voc', input_tensor=None, input_shape=(512, 512, 3)
# Image Feature branch
b4 = GlobalAveragePooling2D()(x)
# from (b_size, channels)->(b_size, 1, 1, channels)
b4 = Lambda(lambda x: K.expand_dims(x, 1))(b4)
b4 = Lambda(lambda x: K.expand_dims(x, 1))(b4)
b4 = Lambda(lambda x: tf.expand_dims(x, 1))(b4)
b4 = Lambda(lambda x: tf.expand_dims(x, 1))(b4)
b4 = Conv2D(256, (1, 1), padding='same',
use_bias=False, name='image_pooling')(b4)
b4 = BatchNormalization(name='image_pooling_BN', epsilon=1e-5)(b4)
Expand All @@ -374,7 +374,7 @@ def Deeplabv3(weights='pascal_voc', input_tensor=None, input_shape=(512, 512, 3)
size_before = K.int_shape(x)

b4 = Lambda(lambda x: tf.image.resize(x, size_before[1:3],
method='bilinear', align_corners=True))(b4)
method='bilinear'))(b4)
# b4 = UpSampling2D(size=(size_before[1],size_before[2]),interpolation='bilinear')(b4)
# simple 1x1
b0 = Conv2D(256, (1, 1), padding='same', use_bias=False, name='aspp0')(x)
Expand All @@ -401,8 +401,8 @@ def Deeplabv3(weights='pascal_voc', input_tensor=None, input_shape=(512, 512, 3)
x = Conv2D(256, (1, 1), padding='same',
use_bias=False, name='concat_projection')(x)
x = BatchNormalization(name='concat_projection_BN', epsilon=1e-5)(x)
x = Activation('elu')(x)
x = Dropout(0.1)(x)
x = Activation('elu')(x)
# DeepLab v.3+ decoder

if backbone == 'xception':
Expand Down Expand Up @@ -435,7 +435,7 @@ def Deeplabv3(weights='pascal_voc', input_tensor=None, input_shape=(512, 512, 3)
# size_out = K.int_shape(img_input)
# x = UpSampling2D(size=(size_out[1] // size_in[1], size_out[2] // size_in[2]), interpolation='bilinear')(x)
size_before3 = K.int_shape(img_input)
x = Lambda(lambda xx: tf.image.resize(xx, size_before3[1:3], method='bilinear', align_corners=True))(x)
x = Lambda(lambda xx: tf.image.resize(xx, size_before3[1:3], method='bilinear'))(x)
# Ensure that the model takes into account
# any potential predecessors of `input_tensor`.
if input_tensor is not None:
Expand All @@ -445,7 +445,7 @@ def Deeplabv3(weights='pascal_voc', input_tensor=None, input_shape=(512, 512, 3)

if activation in {'softmax', 'sigmoid'}:
x = Activation(activation)(x)

x = Activation('linear', dtype='float32')(x)
model = Model(inputs=inputs, outputs=x, name='deeplabv3plus')

# load weights
Expand Down

0 comments on commit de8ff33

Please sign in to comment.