Skip to content

Commit

Permalink
remove C_POST_CONVOLUTION_BITWIDTH
Browse files Browse the repository at this point in the history
Instead calculate that bitwidth when needed.
  • Loading branch information
marph91 committed Feb 21, 2021
1 parent 59d65c6 commit 36f43d4
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 24 deletions.
4 changes: 1 addition & 3 deletions sim/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ def input_weights_int(self) -> int:
def output_data(self) -> int:
ones_count = 0
for act, weight in zip(self.input_activations, self.input_weights):
ones_count = ones_count + (
act and weight
) # TODO: should be xnor: (not (act ^ weight))
ones_count = ones_count + (not (act ^ weight))
return ones_count

input_channel = dut.C_INPUT_CHANNEL.value.integer
Expand Down
5 changes: 1 addition & 4 deletions sim/test_window_convolution_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ async def run_test(dut):
dut.C_INPUT_CHANNEL.value.integer,
)
output_channel = dut.C_OUTPUT_CHANNEL.value.integer
post_convolution_bitwidth = dut.C_POST_CONVOLUTION_BITWIDTH.value.integer

# define the reference model
batch_shape = (1,) + image_shape
Expand Down Expand Up @@ -120,7 +119,6 @@ def get_threshold(self):
# use batch normalization as activation
# see also: https://arxiv.org/pdf/1612.07119.pdf, 4.2.2 Batchnorm-activation as Threshold
threshold_batchnorm = mean - (beta / (variance * epsilon - gamma))
print(threshold_batchnorm)

# Conversion from LARQ format [-1, 1] to pocket-bnn format [0, 1] (positive only):
# make the threshold compatible to positive only values
Expand All @@ -131,7 +129,7 @@ def get_threshold(self):
threshold.append(int(threshold_pos))

return concatenate_integers(
self.replace_minus(threshold), bitwidth=post_convolution_bitwidth
self.replace_minus(threshold), bitwidth=math.ceil(math.log2(kernel_size[0] ** 2 * image_shape[2] + 1))
)

cases = (
Expand Down Expand Up @@ -220,7 +218,6 @@ def test_window_convolution_activation(record_waveform, kernel_size):
"C_OUTPUT_CHANNEL": 8,
"C_IMG_WIDTH": 4,
"C_IMG_HEIGHT": 4,
"C_POST_CONVOLUTION_BITWIDTH": 8,
}
run(
vhdl_sources=get_files(
Expand Down
10 changes: 5 additions & 5 deletions src/convolution.vhd
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,23 @@ library ieee;
use ieee.std_logic_1164.all;
use ieee.numeric_std.all;

library util;
use util.math_pkg.all;

entity convolution is
generic (
-- TODO: input bitwidth, for now = 1

C_KERNEL_SIZE : integer range 2 to 7 := 2;
C_INPUT_CHANNEL : integer := 1;

C_POST_CONVOLUTION_BITWIDTH : integer := 8
C_INPUT_CHANNEL : integer := 1
);
port (
isl_clk : in std_logic;
isl_valid : in std_logic;
islv_data : in std_logic_vector(C_KERNEL_SIZE * C_KERNEL_SIZE * C_INPUT_CHANNEL - 1 downto 0);
-- maybe weights + 1 for bias
islv_weights : in std_logic_vector(C_KERNEL_SIZE * C_KERNEL_SIZE * C_INPUT_CHANNEL - 1 downto 0);
oslv_data : out std_logic_vector(C_POST_CONVOLUTION_BITWIDTH - 1 downto 0);
oslv_data : out std_logic_vector(log2(C_KERNEL_SIZE * C_KERNEL_SIZE * C_INPUT_CHANNEL + 1) - 1 downto 0);
osl_valid : out std_logic
);
end entity convolution;
Expand All @@ -35,7 +36,6 @@ begin

proc_convolution : process (isl_clk) is

-- TODO: log2(islv_data'range)
variable usig_ones_count : unsigned(oslv_data'range);

begin
Expand Down
4 changes: 1 addition & 3 deletions src/processing_element.vhd
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ entity processing_element is
-- TODO: input bitwidth, for now = 1

C_KERNEL_SIZE : integer range 2 to 3 := 2;
C_INPUT_CHANNEL : integer;
C_INPUT_CHANNEL : integer
-- C_OUTPUT_CHANNEL : integer; --> 1 output channel

C_POST_CONVOLUTION_BITWIDTH : integer := 8
);
port (
isl_clk : in std_logic;
Expand Down
16 changes: 7 additions & 9 deletions src/window_convolution_activation.vhd
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ library ieee;
use ieee.numeric_std.all;

library cnn_lib;
library util;
use util.math_pkg.all;

library window_ctrl_lib;

Expand All @@ -18,9 +20,7 @@ entity window_convolution_activation is
C_OUTPUT_CHANNEL : integer;

C_IMG_WIDTH : integer;
C_IMG_HEIGHT : integer;

C_POST_CONVOLUTION_BITWIDTH : integer := 8
C_IMG_HEIGHT : integer
);
port (
isl_clk : in std_logic;
Expand All @@ -29,7 +29,7 @@ entity window_convolution_activation is
islv_data : in std_logic_vector(C_INPUT_CHANNEL - 1 downto 0);
-- islv_weights and islv_threshold are constants
islv_weights : in std_logic_vector(C_KERNEL_SIZE * C_KERNEL_SIZE * C_INPUT_CHANNEL * C_OUTPUT_CHANNEL - 1 downto 0);
islv_threshold : in std_logic_vector(C_POST_CONVOLUTION_BITWIDTH * C_OUTPUT_CHANNEL - 1 downto 0);
islv_threshold : in std_logic_vector(log2(C_KERNEL_SIZE * C_KERNEL_SIZE * C_INPUT_CHANNEL + 1) * C_OUTPUT_CHANNEL - 1 downto 0);
oslv_data : out std_logic_vector(C_OUTPUT_CHANNEL - 1 downto 0);
osl_valid : out std_logic
);
Expand All @@ -44,6 +44,7 @@ architecture behavioral of window_convolution_activation is

type t_slv_array_1d is array(natural range <>) of std_logic_vector;

constant C_POST_CONVOLUTION_BITWIDTH : integer := log2(C_KERNEL_SIZE * C_KERNEL_SIZE * C_INPUT_CHANNEL + 1);
signal a_data_convolution : t_slv_array_1d(0 to C_OUTPUT_CHANNEL - 1)(C_POST_CONVOLUTION_BITWIDTH - 1 downto 0);

signal sl_valid_batch_normalization : std_logic := '0';
Expand Down Expand Up @@ -98,8 +99,7 @@ begin
i_convolution : entity cnn_lib.convolution
generic map (
C_KERNEL_SIZE => C_KERNEL_SIZE,
C_INPUT_CHANNEL => C_INPUT_CHANNEL,
C_POST_CONVOLUTION_BITWIDTH => C_POST_CONVOLUTION_BITWIDTH
C_INPUT_CHANNEL => C_INPUT_CHANNEL
)
port map (
isl_clk => isl_clk,
Expand All @@ -109,11 +109,9 @@ begin
oslv_data => a_data_convolution(output_channel),
osl_valid => sl_valid_convolution
);

-- TODO: output channel increments fastest, not slowest
-- output channel increments fastest
a_weights(output_channel) <= get_fastest_increment(islv_weights, output_channel, C_OUTPUT_CHANNEL);

-- one batch normalization for each output channel
i_batch_normalization : entity cnn_lib.batch_normalization
generic map (
C_POST_CONVOLUTION_BITWIDTH => C_POST_CONVOLUTION_BITWIDTH
Expand Down

0 comments on commit 36f43d4

Please sign in to comment.