Skip to content

Commit

Permalink
add support for colored inputs (i. e. three input channel)
Browse files Browse the repository at this point in the history
Added cifar10 for testing colored inputs.
  • Loading branch information
marph91 committed Jun 9, 2021
1 parent d4cf856 commit 15f66fa
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 27 deletions.
42 changes: 29 additions & 13 deletions playground/04_custom_toplevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from typing import Dict, List, Optional

from bitstring import Bits
import larq as lq
import numpy as np
import tensorflow as tf

# TODO: copied from test_utils.general
def to_fixedint(number: int, bitwidth: int, is_unsigned: bool = True):
Expand Down Expand Up @@ -382,18 +384,18 @@ def __init__(
self.output_classes = output_classes
self.output_bitwidth = output_bitwidth
self.previous_layer_info = {
"name": "in",
"name": "in_deserialized",
"channel": input_channel,
"bitwidth": input_bitwidth,
"width": image_width,
"height": image_height,
}

self.input_data_signal = Parameter(
self.input_data_signal_deserialized = Parameter(
f"slv_data_{self.previous_layer_info['name']}",
f"std_logic_vector(8 - 1 downto 0)",
f"std_logic_vector(C_INPUT_CHANNEL * C_INPUT_CHANNEL_BITWIDTH - 1 downto 0)",
)
self.input_control_signal = Parameter(
self.input_control_signal_deserialized = Parameter(
f"sl_valid_{self.previous_layer_info['name']}", "std_logic"
)

Expand All @@ -419,7 +421,7 @@ def __init__(
isl_clk : in std_logic;
isl_start : in std_logic;
isl_valid : in std_logic;
islv_data : in std_logic_vector(C_INPUT_CHANNEL * C_INPUT_CHANNEL_BITWIDTH - 1 downto 0);
islv_data : in std_logic_vector(C_INPUT_CHANNEL_BITWIDTH - 1 downto 0);
oslv_data : out std_logic_vector(C_OUTPUT_CHANNEL_BITWIDTH - 1 downto 0);
osl_valid : out std_logic;
osl_finish : out std_logic
Expand All @@ -441,14 +443,32 @@ def to_vhdl(self):
declarations.append("-- input signals")
declarations.append(
parameter_to_vhdl(
"signal", [self.input_data_signal, self.input_control_signal]
"signal",
[
self.input_data_signal_deserialized,
self.input_control_signal_deserialized,
],
)
)
declarations.append("")

# connect input signals
implementation.append(f"{self.input_control_signal.name} <= isl_valid;")
implementation.append(f"{self.input_data_signal.name} <= islv_data;")
implementation.append(
f"""
i_deserializer : entity util.deserializer
generic map (
C_DATA_COUNT => C_INPUT_CHANNEL,
C_DATA_BITWIDTH => C_INPUT_CHANNEL_BITWIDTH
)
port map (
isl_clk => isl_clk,
isl_valid => isl_valid,
islv_data => islv_data,
oslv_data => {self.input_data_signal_deserialized.name},
osl_valid => {self.input_control_signal_deserialized.name}
);
"""
)

# parse the bnn
for layer in self.layers:
Expand Down Expand Up @@ -576,15 +596,11 @@ def get_stride(strides):


def bnn_from_larq(path: str) -> Bnn:
import larq as lq
import tensorflow as tf

model = tf.keras.models.load_model(path)
lq.models.summary(model)

input_channel = 1
input_channel = model.input.shape[-1]
input_channel_bitwidth = 8
output_channel = 8
output_channel_bitwidth = 8
bnn = Bnn(
*model.input.shape[1:], # h x w x ch
Expand Down
27 changes: 17 additions & 10 deletions playground/05_intro_modified.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,22 @@
import tensorflow as tf

# for resizing
import cv2
import numpy as np

(train_images, train_labels), (
test_images,
test_labels,
) = tf.keras.datasets.mnist.load_data()
dataset = "mnist"
if dataset == "mnist":
(
(train_images, train_labels),
(test_images, test_labels,),
) = tf.keras.datasets.mnist.load_data()

# reshape the inputs
# train_images = np.stack([cv2.resize(img, (22, 22)) for img in train_images])
train_images = train_images.reshape((60000, 28, 28, 1))
# test_images = np.stack([cv2.resize(img, (22, 22)) for img in test_images])
test_images = test_images.reshape((10000, 28, 28, 1))
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))
else:
(
(train_images, train_labels),
(test_images, test_labels,),
) = tf.keras.datasets.cifar10.load_data()

# All quantized layers except the first will use the same options
kwargs = dict(
Expand Down Expand Up @@ -50,6 +53,10 @@
model.add(lq.layers.QuantConv2D(64, (1, 1), use_bias=False, **kwargs))
model.add(tf.keras.layers.BatchNormalization(scale=False))
# model.add(tf.keras.layers.Dropout(0.2))

if dataset == "mnist":
model.add(lq.layers.QuantConv2D(128, (1, 1), use_bias=False, **kwargs))
model.add(tf.keras.layers.BatchNormalization(scale=False))
if True:
model.add(lq.layers.QuantConv2D(10, (1, 1), use_bias=False, **kwargs))
model.add(lq.layers.tf.keras.layers.GlobalAveragePooling2D())
Expand Down
5 changes: 3 additions & 2 deletions sim/test_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
async def run_test(dut):
height = dut.C_INPUT_HEIGHT.value.integer
width = dut.C_INPUT_WIDTH.value.integer
input_image = np.random.randint(0, 255, (1, height, width, 1), dtype=np.uint8)
# input_image = np.full((1, height, width, 1), 0, dtype=np.uint8)
channel = dut.C_INPUT_CHANNEL.value.integer
input_image = np.random.randint(0, 255, (1, height, width, channel), dtype=np.uint8)
# input_image = np.full((1, height, width, channel), 0, dtype=np.uint8)

# TODO: How to disable the custom gradient warning?
model = tf.keras.models.load_model("../../models/test")
Expand Down
5 changes: 3 additions & 2 deletions sim/test_bnn_uart.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
async def run_test(dut):
height = dut.i_bnn.C_INPUT_HEIGHT.value.integer
width = dut.i_bnn.C_INPUT_WIDTH.value.integer
channel = dut.i_bnn.C_INPUT_CHANNEL.value.integer
classes = dut.i_bnn.C_OUTPUT_CHANNEL.value.integer

# input_image = np.random.randint(0, 255, (1, height, width, 1), dtype=np.uint8)
input_image = np.full((1, height, width, 1), 0, dtype=np.uint8)
# input_image = np.random.randint(0, 255, (1, height, width, channel), dtype=np.uint8)
input_image = np.full((1, height, width, channel), 0, dtype=np.uint8)

# initialize the test
clock_period = 40 # ns
Expand Down
49 changes: 49 additions & 0 deletions src/util/deserializer.vhd
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
library ieee;
use ieee.std_logic_1164.all;

entity deserializer is
generic (
C_DATA_COUNT : integer := 4;
C_DATA_BITWIDTH : integer := 8
);
port (
isl_clk : in std_logic;
isl_valid : in std_logic;
islv_data : in std_logic_vector(C_DATA_BITWIDTH - 1 downto 0);
oslv_data : out std_logic_vector(C_DATA_COUNT * C_DATA_BITWIDTH - 1 downto 0);
osl_valid : out std_logic
);
end entity deserializer;

architecture rtl of deserializer is

signal int_input_count : integer range 0 to C_DATA_COUNT - 1 := 0;
signal slv_data_out : std_logic_vector(oslv_data'range) := (others => '0');
signal sl_valid_out : std_logic := '0';

begin

proc_deserializer : process (isl_clk) is
begin

if (rising_edge(isl_clk)) then
sl_valid_out <= '0';

if (isl_valid = '1') then
slv_data_out <= islv_data & slv_data_out(slv_data_out'high downto C_DATA_BITWIDTH);

if (int_input_count /= C_DATA_COUNT - 1) then
int_input_count <= int_input_count + 1;
else
int_input_count <= 0;
sl_valid_out <= '1';
end if;
end if;
end if;

end process proc_deserializer;

osl_valid <= sl_valid_out;
oslv_data <= slv_data_out;

end architecture rtl;

0 comments on commit 15f66fa

Please sign in to comment.