Skip to content

Commit

Permalink
improve the batch normalization and fix an error in the reference for…
Browse files Browse the repository at this point in the history
…mula
  • Loading branch information
marph91 committed Apr 25, 2021
1 parent 13377cf commit ae27c59
Showing 1 changed file with 33 additions and 20 deletions.
53 changes: 33 additions & 20 deletions sim/test_window_convolution_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,34 @@ async def run_test(dut):
name="test_conv",
)(input_)
if output_channel_bitwidth == 1:
x = tf.keras.layers.BatchNormalization(name="test_batchnorm")(x)
# Scale is not needed, since we clip afterwards anyway.
x = tf.keras.layers.BatchNormalization(name="test_batchnorm", scale=False)(x)
output_ = lq.quantizers.SteHeaviside()(x)
else:
# There is no batchnorm for output bitwidth > 1
output_ = x
model = tf.keras.Model(inputs=input_, outputs=output_)

# TODO: Try to set realistic batchnorm parameter.
# beta=offset, gamma=scale, mean, variance
# original_batch_weights = model.get_layer("test_batchnorm").get_weights()
## np.array([kernel_size[0] ** 2 * image_shape[2] / 2]
# model.get_layer("test_batchnorm").set_weights(
# [
# original_batch_weights[0],
# original_batch_weights[1],
# np.array([16] * output_channel),
# original_batch_weights[3],
# ]
# )
if output_channel_bitwidth == 1:
# Try to set realistic batchnorm parameter.
input_channel_bitwidth = dut.C_INPUT_CHANNEL_BITWIDTH.value.integer
model.get_layer("test_batchnorm").set_weights(
[
np.array([random.uniform(0, 0.5) for _ in range(output_channel)]),
np.array(
[
random.uniform(-1, 1) * input_channel_bitwidth
for _ in range(output_channel)
]
),
np.array(
[
random.uniform(0, 1) * fan_in * input_channel_bitwidth
for _ in range(output_channel)
]
),
]
)

# define the testcases
@dataclass
Expand All @@ -101,7 +110,12 @@ def output_data(self) -> int:
# inference
image_tensor = tf.convert_to_tensor(self.input_image)
reshaped_tensor = tf.reshape(image_tensor, batch_shape)
result = model(reshaped_tensor).numpy()

extractor = tf.keras.Model(
inputs=model.inputs, outputs=[layer.output for layer in model.layers]
)
features = extractor(reshaped_tensor)
result = features[-1].numpy()

if output_channel_bitwidth > 1:
# compensate (see also threshold for batchnorm)
Expand All @@ -126,14 +140,11 @@ def get_threshold(self):
batchnorm_params = [
a.tolist() for a in model.get_layer("test_batchnorm").get_weights()
]
for gamma, beta, mean, variance in zip(*batchnorm_params):
# TODO: Could be wrong order.

epsilon = 0.001 # prevent division by 0

for beta, mean, variance in zip(*batchnorm_params):
# use batch normalization as activation
# see also: https://arxiv.org/pdf/1612.07119.pdf, 4.2.2 Batchnorm-activation as Threshold
threshold_batchnorm = mean - (beta / (variance * epsilon - gamma))
# 0.001 is added to avoid division by 0.
threshold_batchnorm = mean - beta * math.sqrt(variance + 0.001)

# get the following formula by solving:
# x - y = fan_in; x + y = threshold
Expand Down Expand Up @@ -202,6 +213,8 @@ def get_threshold(self):
)
]
model.get_layer("test_conv").set_weights(reshaped_weights)
# This is the way how we convert the weights of the model back.
assert list(reshaped_weights[0].flat) == case.weights

dut.islv_weights <= case.get_weights()
dut.islv_threshold <= case.get_threshold()
Expand Down

0 comments on commit ae27c59

Please sign in to comment.