Skip to content

Commit

Permalink
general update
Browse files Browse the repository at this point in the history
- Add docstrings and type hints.
- Remove unused functions.
  • Loading branch information
marph91 committed May 13, 2021
1 parent 3b54a4a commit c2577ab
Show file tree
Hide file tree
Showing 9 changed files with 31 additions and 50 deletions.
Empty file removed .gitmodules
Empty file.
1 change: 0 additions & 1 deletion .kateproject

This file was deleted.

4 changes: 1 addition & 3 deletions playground/04_custom_toplevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,9 +645,7 @@ def bnn_from_larq(path: str) -> Bnn:

# used at the next batch norm
fan_in = (
get_kernel_size(parameter["kernel_size"]) ** 2
* channel
* channel_bw
get_kernel_size(parameter["kernel_size"]) ** 2 * channel * channel_bw
)
# used at the next conv
channel = layer.output.shape[-1]
Expand Down
7 changes: 5 additions & 2 deletions sim/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import random

from test_utils.extra_libs import analyze_json, analyze_util, analyze_window_ctrl_lib
import numpy as np

from test_utils.extra_libs import analyze_util, analyze_window_ctrl_lib

# https://stackoverflow.com/questions/44624407/how-to-reduce-log-line-size-in-cocotb
os.environ["COCOTB_REDUCED_LOG_FMT"] = "1"
Expand All @@ -11,15 +13,16 @@
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # ERROR


# Use fixed seeds to get reproducible results.
random.seed(42)
np.random.seed(42)


def pytest_addoption(parser):
parser.addoption("--waves", action="store_true", help="Record the waveform.")


def pytest_configure(config):
analyze_json()
analyze_util()
analyze_window_ctrl_lib()

Expand Down
3 changes: 0 additions & 3 deletions sim/pytest.ini

This file was deleted.

4 changes: 3 additions & 1 deletion sim/test_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ async def run_test(dut):
width = dut.C_INPUT_WIDTH.value.integer
input_image = np.random.randint(0, 255, (1, height, width, 1), dtype=np.uint8)

# TODO: How to disable the custom gradient warning?
model = tf.keras.models.load_model("../../models/test")
lq.models.summary(model)

Expand All @@ -34,6 +35,7 @@ async def run_test(dut):
inputs=model.inputs, outputs=[layer.output for layer in model.layers]
)
features = extractor(input_image)

# class scores = (output before softmax + fan in) / 2
class_scores_pos = np.around((features[-2].numpy() + features[-4].shape[-1]) / 2)

Expand All @@ -51,6 +53,7 @@ async def run_test(dut):
clock_period = 10 # ns
tick = Tick(clock_period=clock_period)
cocotb.fork(Clock(dut.isl_clk, clock_period, units="ns").start())

dut.isl_valid <= 0
await tick.wait()

Expand All @@ -60,7 +63,6 @@ async def run_test(dut):
await tick.wait()
dut.isl_valid <= 0
await tick.wait()
await tick.wait()

await tick.wait_multiple(height * width)

Expand Down
22 changes: 14 additions & 8 deletions sim/test_utils/cocotb_helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
"""Collection of common cocotb helper functions."""

from cocotb.monitors import Monitor
from cocotb_bus.monitors import Monitor
from cocotb.triggers import RisingEdge, Timer


def get_safe_int(signal, default: int = 0):
try:
return signal.value.integer
except ValueError:
return default


# TODO: evaluate streambusmonitor
# https://github.com/cocotb/cocotb/blob/master/examples/mean/tests/test_mean.py#L16
class ImageMonitor(Monitor):
Expand All @@ -29,11 +36,8 @@ async def _monitor_recv(self):

while True:
await clock_edge
try:
valid = self.valid.value.integer
except ValueError:
valid = 0
if valid == 1:

if get_safe_int(self.valid) == 1:
vec = self.signal.value.binstr
output = [
int(vec[ch * self.bitwidth : (ch + 1) * self.bitwidth], 2)
Expand All @@ -45,13 +49,15 @@ async def _monitor_recv(self):
class Tick:
"""Convenience class to avoid specifying the unit always."""

def __init__(self, clock_period=10, units="ns"):
def __init__(self, clock_period: int = 10, units: str = "ns"):
self.clock_period = clock_period
self.units = units
self.tick = Timer(clock_period, units=units)

async def wait(self):
"""Wait a single clock tick."""
await self.tick

async def wait_multiple(self, tick_count=1):
async def wait_multiple(self, tick_count: int = 1):
"""Wait multiple clock ticks."""
await Timer(self.clock_period * tick_count, units=self.units)
18 changes: 4 additions & 14 deletions sim/test_utils/extra_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,8 @@
ABSOLUTE_PATH = pathlib.Path(__file__).parent.absolute()


def analyze_json():
work = "json"
source_path = ABSOLUTE_PATH / ".." / ".." / "submodules" / "JSON-for-VHDL" / "src"
source_files = get_files(source_path, "*.vhdl")

if outdated(f"{WORK_DIR}/{work}-obj08.cf", source_files):
os.makedirs(f"{WORK_DIR}", exist_ok=True)

analyze_command = ["ghdl", "-i", STD, f"--work={work}", f"--workdir={WORK_DIR}"]
analyze_command.extend(source_files)
subprocess.run(analyze_command, check=True)


def analyze_util():
"""Analyze the utility library."""
work = "util"
source_path = ABSOLUTE_PATH / ".." / ".." / "src" / "util"
source_files = get_files(source_path, "*.vhd")
Expand All @@ -39,6 +27,7 @@ def analyze_util():


def analyze_window_ctrl_lib():
"""Analyze the window control library."""
analyze_util()

work = "window_ctrl_lib"
Expand All @@ -53,7 +42,8 @@ def analyze_window_ctrl_lib():
subprocess.run(analyze_command, check=True)


def outdated(output, dependencies):
def outdated(output: str, dependencies: list) -> bool:
"""Check whether files are outdated with regards to a reference output."""
if not os.path.isfile(output):
return True

Expand Down
22 changes: 4 additions & 18 deletions sim/test_utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,15 @@
from typing import List, Optional, Sequence

from bitstring import Bits
import pytest


def position_to_index(col: int, row: int, width: int, height: int) -> int:
"""Convert a position into an index of a one dimensional stream."""
index = row * width + col
assert index < width * height
return index


def generate_random_image(
channel: int, width: int, height: int, bitwidth: int = 8
) -> List[int]:
image = []
for _ in range(height * width * channel):
image.append(randint(0, 2 ** bitwidth - 1))
return image


def index_to_position(index: int, width: int, height: int) -> tuple:
"""Convert a one dimensional stream index to a two dimensional
image position."""
Expand All @@ -36,19 +27,13 @@ def index_to_position(index: int, width: int, height: int) -> tuple:
return xpos, ypos


def print_row_by_row(
name: str, list_: List[int], width: int, height: int, channel: int
):
print(f"{name}:")
for row in range(height):
print(list_[row * width * channel : (row + 1) * width * channel])


def get_files(path: pathlib.Path, pattern: str) -> List[str]:
"""Obtain all files matching a pattern in a specific path."""
return [p.resolve() for p in list(path.glob(pattern))]


def concatenate_integers(integer_list: List[int], bitwidth=1) -> int:
"""Concatenate multiple integers into a single integer."""
concatenated_integer = 0
for value in integer_list:
if value > 2 ** bitwidth:
Expand All @@ -58,6 +43,7 @@ def concatenate_integers(integer_list: List[int], bitwidth=1) -> int:


def concatenate_channel(image, channel, bitwidth=1):
"""Concatenate the channels of an image."""
return [
concatenate_integers(
image[pixel_index : pixel_index + channel], bitwidth=bitwidth
Expand Down

0 comments on commit c2577ab

Please sign in to comment.