-
Notifications
You must be signed in to change notification settings - Fork 0
/
models.py
92 lines (72 loc) · 3.16 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate, Conv2DTranspose, BatchNormalization, Dropout, Lambda
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Activation, MaxPool2D, Concatenate
#Convolutional block to be used in autoencoder and U-Net
def conv_block(input, num_filters):
x = Conv2D(num_filters, 3, padding="same")(input)
x = BatchNormalization()(x) #Not in the original network.
x = Activation("relu")(x)
x = Conv2D(num_filters, 3, padding="same")(x)
x = BatchNormalization()(x) #Not in the original network
x = Activation("relu")(x)
return x
#Encoder block: Conv block followed by maxpooling
def encoder_block(input, num_filters):
x = conv_block(input, num_filters)
p = MaxPool2D((2, 2))(x)
return x, p
#Decoder block for autoencoder (no skip connections)
def decoder_block(input, num_filters):
x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
x = conv_block(x, num_filters)
return x
#Encoder will be the same for Autoencoder and U-net
#We are getting both conv output and maxpool output for convenience.
#we will ignore conv output for Autoencoder. It acts as skip connections for U-Net
def build_encoder(input_image):
#inputs = Input(input_shape)
s1, p1 = encoder_block(input_image, 64)
s2, p2 = encoder_block(p1, 128)
s3, p3 = encoder_block(p2, 256)
s4, p4 = encoder_block(p3, 512)
encoded = conv_block(p4, 1024) #Bridge
return encoded
#Decoder for Autoencoder ONLY.
def build_decoder(encoded):
d1 = decoder_block(encoded, 512)
d2 = decoder_block(d1, 256)
d3 = decoder_block(d2, 128)
d4 = decoder_block(d3, 64)
decoded = Conv2D(3, 3, padding="same", activation="sigmoid")(d4)
return decoded
#Use encoder and decoder blocks to build the autoencoder.
def build_autoencoder(input_shape):
input_img = Input(shape=input_shape)
autoencoder = Model(input_img, build_decoder(build_encoder(input_img)))
return(autoencoder)
# model=build_autoencoder((256, 256, 3))
# print(model.summary())
#Decoder block for unet
#skip features gets input from encoder for concatenation
def decoder_block_for_unet(input, skip_features, num_filters):
x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding="same")(input)
x = Concatenate()([x, skip_features])
x = conv_block(x, num_filters)
return x
#Build Unet using the blocks
def build_unet(input_shape):
inputs = Input(input_shape)
s1, p1 = encoder_block(inputs, 64)
s2, p2 = encoder_block(p1, 128)
s3, p3 = encoder_block(p2, 256)
s4, p4 = encoder_block(p3, 512)
b1 = conv_block(p4, 1024) #Bridge
d1 = decoder_block_for_unet(b1, s4, 512)
d2 = decoder_block_for_unet(d1, s3, 256)
d3 = decoder_block_for_unet(d2, s2, 128)
d4 = decoder_block_for_unet(d3, s1, 64)
outputs = Conv2D(1, 1, padding="same", activation="sigmoid")(d4) #Binary (can be multiclass)
model = Model(inputs, outputs, name="U-Net")
print(model.summary())
return model