wk_sbs_hdl/hw/rtl/hu_dp.vhd

136 lines
4.8 KiB
VHDL
Raw Permalink Normal View History

-- hu_dp
-- Data path for Update H using stream of weights
-- Trivial fix point implementation
library ieee;
use ieee.std_logic_1164.all;
use work.pkg_sbs.all;
entity hu_dp is
generic (
K : natural := 3; -- additional bits for sum
B : natural := 10); -- bitwidth of input
port (
clk, rstn : in std_logic;
ctr_hu : in std_logic_vector(BW_HU_CTR-1 downto 0); -- Control for data path
loc_h : out std_logic_vector(ADDR_H_MAX-1 downto 0); -- Current location in H
eps : in std_logic_vector(B-1 downto 0);
wi : in std_logic_vector(B-1 downto 0); -- stream of weights
hi : in std_logic_vector(B-1 downto 0); -- stream of state
ho : out std_logic_vector(B-1 downto 0)); -- stream of states
end entity hu_dp;
library ieee;
use ieee.numeric_std.all;
architecture rtl of hu_dp is
-- Memory
subtype word is std_logic_vector(B-1 downto 0);
type array_as_h_w is array (N_H_MAX-1 downto 0) of word;
signal mem_hp : array_as_h_w; -- State (internal)
signal mem_hw : array_as_h_w; -- Copy of w*h
signal addr_wr, addr_nxt : std_logic_vector(ADDR_H_MAX-1 downto 0); -- Address
-- Data path for hp (i.t. h un-normalized) and hw (hp*w)
signal hp_new, hw_nxt : unsigned(2*B-1 downto 0);
signal hp_new_rg, hp_p, h_eff : std_logic_vector(B-1 downto 0);
signal hw_p : std_logic_vector(B-1 downto 0);
-- Accumulators for normalization
signal sum_hw, sum_hw_nxt : std_logic_vector(B-1 downto 0); -- Running sum hw
signal sum_hw_p, sum_hw_p_nxt : std_logic_vector(B-1 downto 0); -- Saved sum hw of previous
signal sum_hp, sum_hp_nxt : std_logic_vector(B-1 downto 0); -- Running sum hp
signal sum_hp_p, sum_hp_p_nxt : std_logic_vector(B-1 downto 0); -- Saved sum hw of previous (normalization)
-- Control signals
signal ctr_sel_ini, ctr_sum_ini, ctr_update_sum, ctr_update_sum2 : std_logic;
signal ctr_addr_rst, ctr_addr_inc, ctr_write_hw : std_logic;
signal ctr_wr_hw, ctr_wr_hp : std_logic;
begin -- architecture rtlf
-- Get control signals
ctr_sel_ini <= ctr_hu(0);
ctr_sum_ini <= ctr_hu(1);
ctr_update_sum <= ctr_hu(2);
ctr_addr_rst <= ctr_hu(3);
ctr_addr_inc <= ctr_hu(4);
ctr_write_hw <= ctr_hu(5);
ctr_wr_hw <= ctr_hu(6);
ctr_wr_hp <= ctr_hu(6);
ctr_update_sum2 <= ctr_hu(7);
-- Main calculation
hp_new <= unsigned(hp_p) * unsigned(sum_hw_p) + unsigned(sum_hp_p) * unsigned(hw_p) ;
-- Mux to select first h or saved one
h_eff <= hi when ctr_sel_ini='1' else hp_new_rg;
-- Calculate hw
hw_nxt <= unsigned(h_eff) * unsigned(wi) ;
-- Output h (note latency of a complete group)
ho <= h_eff;
-- Accumulate hw and hp
sum_hw_nxt <= std_logic_vector(hw_nxt(2*B-1 downto B)) when ctr_sum_ini='1' else std_logic_vector(unsigned(sum_hw) + hw_nxt(2*B-1 downto B));
sum_hw_p_nxt <= (others=>'0') when ctr_update_sum='1' else
sum_hw when ctr_update_sum2='1' else sum_hw_p;
sum_hp_nxt <= h_eff when ctr_sum_ini='1' else std_logic_vector(unsigned(sum_hp) + unsigned(h_eff)); -- Accumulate h
--sum_hp_p_nxt <= sum_hp_nxt * eps when ctr_update_sum='1' else sum_hp_p;
sum_hp_p_nxt <= eps when ctr_update_sum='1' else
std_logic_vector(hp_new(2*B-1 downto B)) when ctr_update_sum2='1' else sum_hp_p;
-- Read from memory
--hw_p <= mem_hw(to_integer(unsigned(addr_nxt)));
hw_p <= mem_hw(to_integer(unsigned(addr_nxt)))
when ctr_update_sum2='0' else sum_hp; -- Put sum_hp in mult
hp_p <= mem_hp(to_integer(unsigned(addr_nxt)));
-- Address calculation
addr_nxt <= (others => '0') when ctr_addr_rst='1' else
std_logic_vector(unsigned(addr_wr) + 1) when ctr_addr_inc='1' else addr_wr;
loc_h <= addr_wr; -- Output for ctrl path
-- Registers
rg: process (clk, rstn) is
begin -- process pipe1
if rstn = '0' then
hp_new_rg <= (others=>'0');
sum_hw <= (others=>'0');
sum_hw_p <= (others=>'0');
sum_hp <= (others=>'0');
sum_hp_p <= (others=>'0');
addr_wr <= (others => '0');
elsif clk'event and clk = '1' then
hp_new_rg <= std_logic_vector(hp_new(2*B-1 downto B));
sum_hw <= sum_hw_nxt;
sum_hw_p <= sum_hw_p_nxt;
sum_hp <= sum_hp_nxt;
sum_hp_p <= sum_hp_p_nxt;
addr_wr <= addr_nxt;
end if;
end process rg;
-- Memory
mem: process (clk) is
begin -- process mem
if clk'event and clk = '1' then -- rising clock edge
if ctr_wr_hw='1' then
mem_hw(to_integer(unsigned(addr_wr))) <= std_logic_vector(hw_nxt(2*B-1 downto B));
end if;
if ctr_wr_hp='1' then
mem_hp(to_integer(unsigned(addr_wr))) <= h_eff;
end if;
end if;
end process mem;
end architecture rtl;