Skip to content

Commit

Permalink
add support of variable output bitwidth
Browse files Browse the repository at this point in the history
I. e. make the batch normalization optional. This is done to get viable class
scores at the output of the bnn.

Also fixed a small synthesis error, caused by multiple assignments at the
batch normalization signals.
  • Loading branch information
marph91 committed Mar 18, 2021
1 parent 62bc74a commit 1e7ead9
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 28 deletions.
2 changes: 1 addition & 1 deletion sim/test_window_convolution_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def get_threshold(self):
),
# mixed
Testcase(
# choice([-1, 1])
[choice([-1, 1]) for _ in range(math.prod(image_shape))],
[
choice([-1, 1])
Expand Down Expand Up @@ -234,6 +233,7 @@ def get_threshold(self):
def test_window_convolution_activation(
record_waveform, kernel_size, stride, input_channel, output_channel
):
# TODO: Add test for output bitwidth /= 1.
generics = {
"C_KERNEL_SIZE": kernel_size,
"C_STRIDE": stride,
Expand Down
60 changes: 34 additions & 26 deletions src/window_convolution_activation.vhd
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ entity window_convolution_activation is
C_KERNEL_SIZE : integer range 1 to 7 := 3;
C_STRIDE : integer := 1;

C_INPUT_CHANNEL : integer := 4;
C_OUTPUT_CHANNEL : integer := 8;
C_INPUT_CHANNEL : integer := 4;
C_OUTPUT_CHANNEL : integer := 8;
C_OUTPUT_CHANNEL_BITWIDTH : integer range 1 to 32 := 1;

C_IMG_WIDTH : integer := 4;
C_IMG_HEIGHT : integer := 4
Expand All @@ -30,8 +31,8 @@ entity window_convolution_activation is
islv_data : in std_logic_vector(C_INPUT_CHANNEL - 1 downto 0);
-- islv_weights and islv_threshold are constants
islv_weights : in std_logic_vector(C_KERNEL_SIZE * C_KERNEL_SIZE * C_INPUT_CHANNEL * C_OUTPUT_CHANNEL - 1 downto 0);
islv_threshold : in std_logic_vector(log2(C_KERNEL_SIZE * C_KERNEL_SIZE * C_INPUT_CHANNEL + 1) * C_OUTPUT_CHANNEL - 1 downto 0);
oslv_data : out std_logic_vector(C_OUTPUT_CHANNEL - 1 downto 0);
islv_threshold : in std_logic_vector(C_OUTPUT_CHANNEL * log2(C_KERNEL_SIZE ** 2 * C_INPUT_CHANNEL + 1) - 1 downto 0);
oslv_data : out std_logic_vector(C_OUTPUT_CHANNEL * C_OUTPUT_CHANNEL_BITWIDTH - 1 downto 0);
osl_valid : out std_logic
);
end entity window_convolution_activation;
Expand All @@ -41,16 +42,16 @@ architecture behavioral of window_convolution_activation is
signal sl_valid_window_ctrl : std_logic := '0';
signal slv_data_window_ctrl : std_logic_vector(C_KERNEL_SIZE * C_KERNEL_SIZE * C_INPUT_CHANNEL - 1 downto 0);

signal sl_valid_convolution : std_logic := '0';
signal slv_valid_convolution : std_logic_vector(C_OUTPUT_CHANNEL - 1 downto 0);

type t_slv_array_1d is array(natural range <>) of std_logic_vector;

constant C_POST_CONVOLUTION_BITWIDTH : integer := log2(C_KERNEL_SIZE * C_KERNEL_SIZE * C_INPUT_CHANNEL + 1);
constant C_POST_CONVOLUTION_BITWIDTH : integer := log2(C_KERNEL_SIZE ** 2 * C_INPUT_CHANNEL + 1);
signal a_data_convolution : t_slv_array_1d(0 to C_OUTPUT_CHANNEL - 1)(C_POST_CONVOLUTION_BITWIDTH - 1 downto 0);

signal sl_valid_batch_normalization : std_logic := '0';
signal slv_data_batch_normalization : std_logic_vector(C_OUTPUT_CHANNEL * 1 - 1 downto 0);
signal a_data_batch_normalization : t_slv_array_1d(0 to C_OUTPUT_CHANNEL - 1)(1 - 1 downto 0);
signal slv_valid_batch_normalization : std_logic_vector(C_OUTPUT_CHANNEL - 1 downto 0);
signal slv_data_batch_normalization : std_logic_vector(C_OUTPUT_CHANNEL * C_OUTPUT_CHANNEL_BITWIDTH - 1 downto 0);
signal a_data_batch_normalization : t_slv_array_1d(0 to C_OUTPUT_CHANNEL - 1)(C_OUTPUT_CHANNEL_BITWIDTH - 1 downto 0);

signal sl_valid_out : std_logic := '0';
signal slv_data_out : std_logic_vector(oslv_data'range);
Expand Down Expand Up @@ -108,31 +109,38 @@ begin
islv_data => slv_data_window_ctrl,
islv_weights => a_weights(output_channel),
oslv_data => a_data_convolution(output_channel),
osl_valid => sl_valid_convolution
osl_valid => slv_valid_convolution(output_channel)
);

-- output channel increments fastest
a_weights(output_channel) <= get_fastest_increment(islv_weights, output_channel, C_OUTPUT_CHANNEL);

i_batch_normalization : entity cnn_lib.batch_normalization
generic map (
C_POST_CONVOLUTION_BITWIDTH => C_POST_CONVOLUTION_BITWIDTH
)
port map (
isl_clk => isl_clk,
isl_valid => sl_valid_convolution,
islv_data => a_data_convolution(output_channel),
islv_threshold => a_threshold(output_channel),
oslv_data => a_data_batch_normalization(output_channel),
osl_valid => sl_valid_batch_normalization
);
gen_batch_normalization : if C_OUTPUT_CHANNEL_BITWIDTH = 1 generate

i_batch_normalization : entity cnn_lib.batch_normalization
generic map (
C_POST_CONVOLUTION_BITWIDTH => C_POST_CONVOLUTION_BITWIDTH
)
port map (
isl_clk => isl_clk,
isl_valid => slv_valid_convolution(output_channel),
islv_data => a_data_convolution(output_channel),
islv_threshold => a_threshold(output_channel),
oslv_data => a_data_batch_normalization(output_channel),
osl_valid => slv_valid_batch_normalization(output_channel)
);

-- TODO: output channel increments fastest, not slowest
a_threshold(output_channel) <= get_slice(islv_threshold, output_channel, C_POST_CONVOLUTION_BITWIDTH);
slv_data_batch_normalization((output_channel + 1) * C_OUTPUT_CHANNEL_BITWIDTH - 1 downto output_channel * C_OUTPUT_CHANNEL_BITWIDTH) <= a_data_batch_normalization(output_channel);
else generate
slv_data_batch_normalization((output_channel + 1) * C_OUTPUT_CHANNEL_BITWIDTH - 1 downto output_channel * C_OUTPUT_CHANNEL_BITWIDTH) <= std_logic_vector(resize(unsigned(a_data_convolution(output_channel)), C_OUTPUT_CHANNEL_BITWIDTH));
slv_valid_batch_normalization(output_channel) <= slv_valid_convolution(output_channel);
end generate gen_batch_normalization;

-- TODO: output channel increments fastest, not slowest
a_threshold(output_channel) <= get_slice(islv_threshold, output_channel, C_POST_CONVOLUTION_BITWIDTH);
slv_data_batch_normalization((output_channel + 1) * 1 - 1 downto output_channel * 1) <= a_data_batch_normalization(output_channel);
end generate gen_convolution;

oslv_data <= slv_data_batch_normalization;
osl_valid <= sl_valid_batch_normalization;
osl_valid <= slv_valid_batch_normalization(0);

end architecture behavioral;
2 changes: 1 addition & 1 deletion src/window_ctrl/window_ctrl.vhd
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ library window_ctrl_lib;
entity window_ctrl is
generic (
-- global data bitwidth
C_BITWIDTH : integer range 1 to 16 := 8;
C_BITWIDTH : integer := 8;

-- image properties
C_CH_IN : integer range 1 to 512 := 1; -- input channel
Expand Down

0 comments on commit 1e7ead9

Please sign in to comment.