Skip to content

Commit

Permalink
update for tf version
Browse files Browse the repository at this point in the history
  • Loading branch information
brianmanderson committed Apr 16, 2020
1 parent 26c0ab8 commit c44e9f8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
9 changes: 8 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@
# policy = mixed_precision.Policy('mixed_float16')
# mixed_precision.set_policy(policy)
import tensorflow.compat.v1 as tf
version_split = tf.__version__.split('.')
if version_split[0] == '2' and int(version_split[1]) > 1:
from tensorflow.keras.mixed_precision import experimental as mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)
from tensorflow.python.keras.models import Model
from tensorflow.python.keras import layers
from tensorflow.python.keras.layers import Input
Expand Down Expand Up @@ -446,7 +451,9 @@ def Deeplabv3(weights='pascal_voc', input_tensor=None, input_shape=(512, 512, 3)

if activation in {'softmax', 'sigmoid'}:
x = Activation(activation)(x)
# x = Activation('linear', dtype='float32')(x)
version_split = tf.__version__.split('.')
if version_split[0] == '2' and int(version_split[1]) > 1:
x = Activation('linear', dtype='float32')(x)
model = Model(inputs=inputs, outputs=x, name='deeplabv3plus')

# load weights
Expand Down
12 changes: 8 additions & 4 deletions model_tf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from __future__ import print_function

import tensorflow as tf
from tensorflow.keras.mixed_precision import experimental as mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)
version_split = tf.__version__.split('.')
if version_split[0] == '2' and int(version_split[1]) > 1:
from tensorflow.keras.mixed_precision import experimental as mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)
from tensorflow.keras.models import Model
from tensorflow.keras import layers
from tensorflow.keras.layers import Input
Expand Down Expand Up @@ -445,7 +447,9 @@ def Deeplabv3(weights='pascal_voc', input_tensor=None, input_shape=(512, 512, 3)

if activation in {'softmax', 'sigmoid'}:
x = Activation(activation)(x)
x = Activation('linear', dtype='float32')(x)
version_split = tf.__version__.split('.')
if version_split[0] == '2' and int(version_split[1]) > 1:
x = Activation('linear', dtype='float32')(x)
model = Model(inputs=inputs, outputs=x, name='deeplabv3plus')

# load weights
Expand Down

0 comments on commit c44e9f8

Please sign in to comment.