-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
324 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
from dataclasses import dataclass | ||
import math | ||
import pathlib | ||
from random import randint | ||
from typing import List | ||
|
||
import cocotb | ||
from cocotb.clock import Clock | ||
from cocotb.triggers import Timer | ||
from cocotb_test.simulator import run | ||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from test_utils.cocotb_helpers import ImageMonitor, Tick | ||
from test_utils.general import concatenate_channel, get_files | ||
|
||
|
||
@cocotb.test() | ||
async def run_test(dut): | ||
# layer parameter | ||
bitwidth = dut.C_BITWIDTH.value.integer | ||
image_shape = ( | ||
dut.C_IMG_WIDTH.value.integer, | ||
dut.C_IMG_HEIGHT.value.integer, | ||
dut.C_CHANNEL.value.integer, | ||
) | ||
|
||
# define the reference model | ||
batch_shape = (1,) + image_shape | ||
input_ = tf.keras.Input(batch_shape=batch_shape, name="img") | ||
output_ = tf.keras.layers.GlobalAveragePooling2D()(input_) | ||
model = tf.keras.Model(inputs=input_, outputs=output_) | ||
|
||
# define the testcases | ||
@dataclass | ||
class Testcase: | ||
input_image: List[int] | ||
|
||
@property | ||
def input_data(self) -> int: | ||
# send all channels (i. e. one pixel) at a time | ||
return concatenate_channel(self.input_image, image_shape[2], bitwidth) | ||
|
||
@property | ||
def output_data(self) -> int: | ||
# inference | ||
image_tensor = tf.convert_to_tensor(self.input_image) | ||
result = model(tf.reshape(image_tensor, batch_shape)) | ||
# TODO: Not bit accurate in corner cases. | ||
result_list = list(np.rint(result.numpy()).astype("uint8").flat) | ||
return result_list | ||
|
||
cases = ( | ||
# all zeros | ||
Testcase([0] * math.prod(image_shape)), | ||
# all ones | ||
Testcase([2 ** bitwidth - 1] * math.prod(image_shape)), | ||
# mixed | ||
Testcase( | ||
[randint(0, 2 ** bitwidth - 1) for _ in range(math.prod(image_shape))] | ||
), | ||
) | ||
|
||
# prepare coroutines | ||
clock_period = 10 # ns | ||
tick = Tick(clock_period=clock_period) | ||
cocotb.fork(Clock(dut.isl_clk, clock_period, units="ns").start()) | ||
output_mon = ImageMonitor( | ||
"output", | ||
dut.oslv_data, | ||
dut.osl_valid, | ||
dut.isl_clk, | ||
1, | ||
bitwidth * image_shape[2], | ||
) | ||
dut.isl_valid <= 0 | ||
dut.isl_start <= 0 | ||
await tick.wait() | ||
|
||
# run the specific testcases | ||
for case in cases: | ||
dut.isl_start <= 1 | ||
await tick.wait() | ||
dut.isl_start <= 0 | ||
await tick.wait() | ||
|
||
for datum in case.input_data: | ||
dut.isl_valid <= 1 | ||
dut.islv_data <= datum | ||
await tick.wait() | ||
dut.isl_valid <= 0 | ||
await tick.wait() | ||
|
||
dut.isl_valid <= 0 | ||
await tick.wait_multiple(40) | ||
|
||
print("Expected output:", case.output_data) | ||
print("Actual output:", output_mon.output) | ||
assert all( | ||
[ | ||
math.isclose(act, exp, abs_tol=1) | ||
for act, exp in zip(output_mon.output, case.output_data) | ||
] | ||
) | ||
output_mon.clear() | ||
|
||
|
||
def test_average_pooling(): | ||
generics = { | ||
"C_BITWIDTH": 8, | ||
"C_CHANNEL": 6, | ||
"C_IMG_WIDTH": 6, | ||
"C_IMG_HEIGHT": 6, | ||
} | ||
run( | ||
vhdl_sources=get_files( | ||
pathlib.Path(__file__).parent.absolute() / ".." / "src", "*.vhd" | ||
), | ||
toplevel="average_pooling", | ||
module="test_average_pooling", | ||
compile_args=["--work=cnn_lib", "--std=08"], | ||
parameters=generics, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
|
||
library ieee; | ||
use ieee.std_logic_1164.all; | ||
use ieee.numeric_std.all; | ||
use ieee.fixed_pkg.all; | ||
use ieee.fixed_float_types.all; | ||
|
||
library util; | ||
use util.array_pkg.all; | ||
use util.math_pkg.all; | ||
|
||
entity average_pooling is | ||
generic ( | ||
C_BITWIDTH : integer := 8; | ||
|
||
C_CHANNEL : integer range 1 to 512 := 4; | ||
C_IMG_WIDTH : integer range 1 to 512 := 6; | ||
C_IMG_HEIGHT : integer range 1 to 512 := 6 | ||
); | ||
port ( | ||
isl_clk : in std_logic; | ||
isl_start : in std_logic; | ||
isl_valid : in std_logic; | ||
islv_data : in std_logic_vector(C_CHANNEL * C_BITWIDTH - 1 downto 0); | ||
oslv_data : out std_logic_vector(C_BITWIDTH - 1 downto 0); | ||
osl_valid : out std_logic | ||
); | ||
end entity average_pooling; | ||
|
||
architecture behavioral of average_pooling is | ||
|
||
-- temporary higher int width to prevent overflow while summing up channel/pixel | ||
-- new bitwidth = log2(C_IMG_HEIGHT*C_IMG_WIDTH*(2^old bitwidth)) = log2(C_IMG_HEIGHT*C_IMG_WIDTH) + old bitwidth -> new bw = lb(16*(2^7)) = 12 | ||
constant C_INTW_SUM : integer := C_BITWIDTH + log2(C_IMG_HEIGHT * C_IMG_WIDTH + 1); | ||
constant C_FRACW_REZI : integer range 1 to 16 := 16; | ||
|
||
signal sl_calculate_average : std_logic := '0'; | ||
signal sl_calculate_average_d1 : std_logic := '0'; | ||
signal sl_calculate_average_d2 : std_logic := '0'; | ||
|
||
-- fixed point multiplication yields: A'left + B'left + 1 downto -(A'right + B'right) | ||
signal ufix_average : ufixed(C_INTW_SUM downto - C_FRACW_REZI) := (others => '0'); | ||
attribute use_dsp : string; | ||
attribute use_dsp of ufix_average : signal is "yes"; | ||
signal ufix_average_d1 : ufixed(C_INTW_SUM downto - C_FRACW_REZI) := (others => '0'); | ||
|
||
-- TODO: try real instead of ufixed | ||
-- to_ufixed() yields always one fractional bit. Thus the reciprocal has at least 2 integer bits. | ||
constant C_RECIPROCAL : ufixed(0 downto - C_FRACW_REZI) := reciprocal(to_ufixed(C_IMG_HEIGHT * C_IMG_WIDTH, C_FRACW_REZI - 1, 0)); | ||
signal slv_average : std_logic_vector(C_BITWIDTH - 1 downto 0) := (others => '0'); | ||
|
||
signal int_data_in_cnt : integer range 0 to C_IMG_WIDTH * C_IMG_HEIGHT - 1 := 0; | ||
signal int_data_out_cnt : integer range 0 to C_CHANNEL := 0; | ||
|
||
type t_1d_array is array (natural range <>) of unsigned(C_INTW_SUM - 1 downto 0); | ||
|
||
signal a_ch_buffer : t_1d_array(0 to C_CHANNEL - 1) := (others => (others => '0')); | ||
|
||
signal sl_output_valid : std_logic := '0'; | ||
|
||
begin | ||
|
||
------------------------------------------------------- | ||
-- Process: Average Pooling (average of each channel) | ||
-- Stage 1: sum up the values of every channel | ||
-- Stage 2*: multiply with reciprocal | ||
-- Stage 3: pipeline DSP output | ||
-- Stage 4: resize output | ||
-- *Stage 2 is entered when full image except of last pixel (C_IMG_HEIGHT*C_IMG_WIDTH) is loaded | ||
------------------------------------------------------- | ||
proc_average_pooling : process (isl_clk) is | ||
begin | ||
|
||
if (rising_edge(isl_clk)) then | ||
sl_calculate_average <= '0'; | ||
|
||
if (isl_start = '1') then | ||
a_ch_buffer <= (others => (others => '0')); | ||
int_data_in_cnt <= 0; | ||
else | ||
if (isl_valid = '1') then | ||
if (int_data_in_cnt = C_IMG_HEIGHT * C_IMG_WIDTH - 1) then | ||
int_data_in_cnt <= 0; | ||
int_data_out_cnt <= C_CHANNEL; | ||
else | ||
int_data_in_cnt <= int_data_in_cnt + 1; | ||
end if; | ||
|
||
for ch in 0 to C_CHANNEL - 1 loop | ||
a_ch_buffer(ch) <= resize( | ||
a_ch_buffer(ch) + | ||
unsigned(get_slice(islv_data, ch, C_BITWIDTH)), | ||
a_ch_buffer(0)'length); | ||
end loop; | ||
end if; | ||
|
||
------------------------DIVIDE OPTIONS--------------------------- | ||
-- 1. simple divide | ||
-- ufix_average <= a_ch_buffer(0)/to_ufixed(C_IMG_HEIGHT*C_IMG_WIDTH, 8, 0); | ||
-- | ||
-- 2. divide with round properties (round, guard bits) | ||
-- ufix_average <= divide(a_ch_buffer(0), to_ufixed(C_IMG_HEIGHT*C_IMG_WIDTH, 8, 0), fixed_truncate, 0) | ||
-- | ||
-- 3. multiply with reciprocal -> best for timing and ressource usage! | ||
-- ufix_average <= a_ch_buffer(0) * C_RECIPROCAL; | ||
----------------------------------------------------------------- | ||
|
||
if (int_data_out_cnt /= 0) then | ||
assert isl_valid = '0' severity failure; | ||
int_data_out_cnt <= int_data_out_cnt - 1; | ||
sl_calculate_average <= '1'; | ||
end if; | ||
sl_calculate_average_d1 <= sl_calculate_average; | ||
sl_calculate_average_d2 <= sl_calculate_average_d1; | ||
sl_output_valid <= sl_calculate_average_d2; | ||
|
||
if (sl_calculate_average = '1') then | ||
a_ch_buffer <= a_ch_buffer(a_ch_buffer'high) & a_ch_buffer(0 to a_ch_buffer'high - 1); | ||
ufix_average <= to_ufixed(a_ch_buffer(a_ch_buffer'high), C_INTW_SUM - 1, 0) * C_RECIPROCAL; | ||
end if; | ||
|
||
if (sl_calculate_average_d1 = '1') then | ||
ufix_average_d1 <= ufix_average; | ||
end if; | ||
|
||
if (sl_calculate_average_d2 = '1') then | ||
slv_average <= to_slv(resize(ufix_average_d1, C_BITWIDTH - 1, 0, fixed_wrap, fixed_round)); | ||
end if; | ||
end if; | ||
end if; | ||
|
||
end process proc_average_pooling; | ||
|
||
oslv_data <= slv_average; | ||
osl_valid <= sl_output_valid; | ||
|
||
end architecture behavioral; |