Skip to content

Commit

Permalink
improve the full bnn testbench
Browse files Browse the repository at this point in the history
Use numpy functions, because they are easier to use.
  • Loading branch information
marph91 committed May 13, 2021
1 parent 1cd81c1 commit 3b54a4a
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions sim/test_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,18 @@
async def run_test(dut):
height = dut.C_INPUT_HEIGHT.value.integer
width = dut.C_INPUT_WIDTH.value.integer
input_image = [i % 255 for i in range(height * width)]
input_image = np.random.randint(0, 255, (1, height, width, 1), dtype=np.uint8)

model = tf.keras.models.load_model("../../models/test")
lq.models.summary(model)
data = np.reshape(np.array(input_image), (1, height, width, 1))

# https://keras.io/getting_started/faq/#how-can-i-obtain-the-output-of-an-intermediate-layer-feature-extraction
extractor = tf.keras.Model(
inputs=model.inputs, outputs=[layer.output for layer in model.layers]
)
features = extractor(data)
class_scores_pos = list(features[-2].numpy().flat)
class_scores = [
round((score + features[-4].shape[-1]) / 2) for score in class_scores_pos
]
features = extractor(input_image)
# class scores = (output before softmax + fan in) / 2
class_scores_pos = np.around((features[-2].numpy() + features[-4].shape[-1]) / 2)

output_bitwitdh = dut.C_OUTPUT_CHANNEL_BITWIDTH.value.integer
output_mon = ImageMonitor(
Expand All @@ -57,21 +54,21 @@ async def run_test(dut):
dut.isl_valid <= 0
await tick.wait()

for pixel in input_image:
for pixel in input_image.flat:
dut.isl_valid <= 1
dut.islv_data <= pixel
dut.islv_data <= int(pixel)
await tick.wait()
dut.isl_valid <= 0
await tick.wait()
await tick.wait()

await tick.wait_multiple(height * width)

print(class_scores_pos)
print(class_scores)
print(output_mon.output)
print(features)
assert output_mon.output == class_scores
np.testing.assert_almost_equal(
np.resize(np.array(output_mon.output), class_scores_pos.shape),
class_scores_pos,
decimal=0,
)


def test_bnn():
Expand Down

0 comments on commit 3b54a4a

Please sign in to comment.