Skip to content

Commit

Permalink
average pooling: fix incorrect results for multiple frames
Browse files Browse the repository at this point in the history
Also use FSM instead of pipelining, since it's better readable.
  • Loading branch information
marph91 committed May 14, 2021
1 parent 2a41e7e commit 309a6a6
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 58 deletions.
10 changes: 5 additions & 5 deletions sim/test_average_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ def output_data(self) -> int:
dut.isl_start <= 0
await tick.wait()

dut.isl_start <= 1
await tick.wait()
dut.isl_start <= 0
await tick.wait()

# run the specific testcases
for case in cases:
dut.isl_start <= 1
await tick.wait()
dut.isl_start <= 0
await tick.wait()

for datum in case.input_data:
dut.isl_valid <= 1
dut.islv_data <= datum
Expand Down
123 changes: 70 additions & 53 deletions src/average_pooling.vhd
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,45 @@ architecture behavioral of average_pooling is
constant C_INTW_SUM : integer := C_BITWIDTH + log2(C_IMG_HEIGHT * C_IMG_WIDTH + 1);
constant C_FRACW_REZI : integer range 1 to 16 := 16;

signal sl_calculate_average : std_logic := '0';
signal sl_calculate_average_d1 : std_logic := '0';
signal sl_calculate_average_d2 : std_logic := '0';

-- fixed point multiplication yields: A'left + B'left + 1 downto -(A'right + B'right)
signal ufix_average : ufixed(C_INTW_SUM downto - C_FRACW_REZI) := (others => '0');
attribute use_dsp : string;
attribute use_dsp of ufix_average : signal is "yes";
signal ufix_average_d1 : ufixed(C_INTW_SUM downto - C_FRACW_REZI) := (others => '0');

-- TODO: try real instead of ufixed
-- to_ufixed() yields always one fractional bit. Thus the reciprocal has at least 2 integer bits.
-- to_ufixed() yields always one fractional bit. Thus the reciprocal has at least one integer bit.
constant C_RECIPROCAL : ufixed(0 downto - C_FRACW_REZI) := reciprocal(to_ufixed(C_IMG_HEIGHT * C_IMG_WIDTH, C_FRACW_REZI - 1, 0));
signal slv_average : std_logic_vector(C_BITWIDTH - 1 downto 0) := (others => '0');

signal int_data_in_cnt : integer range 0 to C_IMG_WIDTH * C_IMG_HEIGHT - 1 := 0;
signal int_data_out_cnt : integer range 0 to C_CHANNEL := 0;
signal int_open_averages : integer range 0 to C_CHANNEL := 0;
signal sl_full_image : std_logic := '0';

type t_1d_array is array (natural range <>) of unsigned(C_INTW_SUM - 1 downto 0);

signal a_ch_buffer : t_1d_array(0 to C_CHANNEL - 1) := (others => (others => '0'));

signal sl_output_valid : std_logic := '0';

type t_state is (CLEAR_BUFFER, SUM, CALCULATE_AVERAGE, PIPELINE_AVERAGE, OUTPUT_AVERAGE);

signal state : t_state := CLEAR_BUFFER;

begin

i_pixel_count : entity util.basic_counter
generic map (
C_MAX => C_IMG_HEIGHT * C_IMG_WIDTH,
C_COUNT_DOWN => 0
)
port map (
isl_clk => isl_clk,
isl_reset => isl_start,
isl_valid => isl_valid,
oint_count => open,
osl_maximum => sl_full_image
);

-------------------------------------------------------
-- Process: Average Pooling (average of each channel)
-- Stage 1: sum up the values of every channel
Expand All @@ -72,61 +85,65 @@ begin
begin

if (rising_edge(isl_clk)) then
sl_calculate_average <= '0';
sl_output_valid <= '0';

if (isl_start = '1') then
a_ch_buffer <= (others => (others => '0'));
int_data_in_cnt <= 0;
else
if (isl_valid = '1') then
if (int_data_in_cnt = C_IMG_HEIGHT * C_IMG_WIDTH - 1) then
int_data_in_cnt <= 0;
int_data_out_cnt <= C_CHANNEL;
else
int_data_in_cnt <= int_data_in_cnt + 1;
state <= CLEAR_BUFFER;
end if;

case state is

when CLEAR_BUFFER =>
a_ch_buffer <= (others => (others => '0'));
state <= SUM;

when SUM =>
if (isl_valid = '1') then
for ch in 0 to C_CHANNEL - 1 loop
a_ch_buffer(ch) <= resize(a_ch_buffer(ch) +
unsigned(get_slice(islv_data, ch, C_BITWIDTH)),
a_ch_buffer(0)'length);
end loop;
end if;

for ch in 0 to C_CHANNEL - 1 loop
a_ch_buffer(ch) <= resize(
a_ch_buffer(ch) +
unsigned(get_slice(islv_data, ch, C_BITWIDTH)),
a_ch_buffer(0)'length);
end loop;
end if;

------------------------DIVIDE OPTIONS---------------------------
-- 1. simple divide
-- ufix_average <= a_ch_buffer(0)/to_ufixed(C_IMG_HEIGHT*C_IMG_WIDTH, 8, 0);
--
-- 2. divide with round properties (round, guard bits)
-- ufix_average <= divide(a_ch_buffer(0), to_ufixed(C_IMG_HEIGHT*C_IMG_WIDTH, 8, 0), fixed_truncate, 0)
--
-- 3. multiply with reciprocal -> best for timing and ressource usage!
-- ufix_average <= a_ch_buffer(0) * C_RECIPROCAL;
-----------------------------------------------------------------

if (int_data_out_cnt /= 0) then
assert isl_valid = '0' severity failure;
int_data_out_cnt <= int_data_out_cnt - 1;
sl_calculate_average <= '1';
end if;
sl_calculate_average_d1 <= sl_calculate_average;
sl_calculate_average_d2 <= sl_calculate_average_d1;
sl_output_valid <= sl_calculate_average_d2;

if (sl_calculate_average = '1') then
if (sl_full_image = '1') then
state <= CALCULATE_AVERAGE;
int_open_averages <= C_CHANNEL;
end if;

when CALCULATE_AVERAGE =>
------------------------DIVIDE OPTIONS---------------------------
-- 1. simple divide
-- ufix_average <= a_ch_buffer(0)/to_ufixed(C_IMG_HEIGHT*C_IMG_WIDTH, 8, 0);
--
-- 2. divide with round properties (round, guard bits)
-- ufix_average <= divide(a_ch_buffer(0), to_ufixed(C_IMG_HEIGHT*C_IMG_WIDTH, 8, 0), fixed_truncate, 0)
--
-- 3. multiply with reciprocal -> best for timing and ressource usage!
-- ufix_average <= a_ch_buffer(0) * C_RECIPROCAL;
-----------------------------------------------------------------
a_ch_buffer <= a_ch_buffer(a_ch_buffer'high) & a_ch_buffer(0 to a_ch_buffer'high - 1);
ufix_average <= to_ufixed(a_ch_buffer(a_ch_buffer'high), C_INTW_SUM - 1, 0) * C_RECIPROCAL;
end if;

if (sl_calculate_average_d1 = '1') then
int_open_averages <= int_open_averages - 1;
state <= PIPELINE_AVERAGE;

when PIPELINE_AVERAGE =>
ufix_average_d1 <= ufix_average;
end if;
state <= OUTPUT_AVERAGE;

when OUTPUT_AVERAGE =>
slv_average <= to_slv(resize(ufix_average_d1, C_BITWIDTH - 1, 0, fixed_wrap, fixed_round));
sl_output_valid <= '1';

if (int_open_averages /= 0) then
state <= CALCULATE_AVERAGE;
else
state <= CLEAR_BUFFER;
end if;

end case;

if (sl_calculate_average_d2 = '1') then
slv_average <= to_slv(resize(ufix_average_d1, C_BITWIDTH - 1, 0, fixed_wrap, fixed_round));
end if;
end if;
end if;

end process proc_average_pooling;
Expand Down

0 comments on commit 309a6a6

Please sign in to comment.