125 lines
3.8 KiB
VHDL
125 lines
3.8 KiB
VHDL
|
-- hu_dp
|
||
|
-- Data path for Update H using stream of weights
|
||
|
|
||
|
use work.pkg_sbs.all;
|
||
|
|
||
|
entity hu_dp is
|
||
|
port (
|
||
|
clk, rstn : in bit;
|
||
|
ctr_hu : in bit_vector(BW_HU_CTR-1 downto 0); -- Control for data path
|
||
|
loc_h : out bit_vector(ADDR_H_MAX-1 downto 0); -- Current location in H
|
||
|
eps : in real;
|
||
|
wi : in real; -- stream of weights
|
||
|
hi : in real; -- stream of state
|
||
|
ho : out real); -- stream of states
|
||
|
|
||
|
end entity hu_dp;
|
||
|
|
||
|
library ieee;
|
||
|
use ieee.numeric_bit.all;
|
||
|
|
||
|
architecture rtlf of hu_dp is
|
||
|
-- Memory
|
||
|
signal mem_hp : array_as_h; -- State (internal)
|
||
|
signal mem_hw : array_as_h; -- Copy of w*h
|
||
|
signal addr_wr, addr_nxt : bit_vector(ADDR_H_MAX-1 downto 0); -- Address
|
||
|
|
||
|
-- Data path for hp (i.t. h un-normalized) and hw (hp*w)
|
||
|
signal hp_new, hp_new_rg, hp_p, h_eff : real := 0.0;
|
||
|
signal hw_p, hw_nxt : real := 0.0;
|
||
|
|
||
|
-- Accumulators for normalization
|
||
|
signal sum_hw, sum_hw_nxt : real := 0.0; -- Running sum hw
|
||
|
signal sum_hw_p, sum_hw_p_nxt : real := 0.0; -- Saved sum hw of previous
|
||
|
signal sum_hp, sum_hp_nxt : real := 0.0; -- Running sum hp
|
||
|
signal sum_hp_p, sum_hp_p_nxt : real := 0.0; -- Saved sum hw of previous (normalization)
|
||
|
|
||
|
-- Control signals
|
||
|
signal ctr_sel_ini, ctr_sum_ini, ctr_update_sum, ctr_update_sum2 : bit;
|
||
|
signal ctr_addr_rst, ctr_addr_inc, ctr_write_hw : bit;
|
||
|
signal ctr_wr_hw, ctr_wr_hp : bit;
|
||
|
|
||
|
begin -- architecture rtlf
|
||
|
|
||
|
-- Get control signals
|
||
|
ctr_sel_ini <= ctr_hu(0);
|
||
|
ctr_sum_ini <= ctr_hu(1);
|
||
|
ctr_update_sum <= ctr_hu(2);
|
||
|
ctr_addr_rst <= ctr_hu(3);
|
||
|
ctr_addr_inc <= ctr_hu(4);
|
||
|
ctr_write_hw <= ctr_hu(5);
|
||
|
ctr_wr_hw <= ctr_hu(6);
|
||
|
ctr_wr_hp <= ctr_hu(6);
|
||
|
ctr_update_sum2 <= ctr_hu(7);
|
||
|
|
||
|
|
||
|
-- Main calculation
|
||
|
hp_new <= hp_p * sum_hw_p + sum_hp_p * hw_p ;
|
||
|
|
||
|
-- Mux to select first h or saved one
|
||
|
h_eff <= hi when ctr_sel_ini='1' else hp_new_rg;
|
||
|
|
||
|
-- Calculate hw
|
||
|
hw_nxt <= h_eff * wi ;
|
||
|
|
||
|
-- Output h (note latency of a complete group)
|
||
|
ho <= h_eff;
|
||
|
|
||
|
-- Accumulate hw and hp
|
||
|
sum_hw_nxt <= hw_nxt when ctr_sum_ini='1' else sum_hw + hw_nxt;
|
||
|
sum_hw_p_nxt <= 0.0 when ctr_update_sum='1' else
|
||
|
sum_hw when ctr_update_sum2='1' else sum_hw_p;
|
||
|
|
||
|
sum_hp_nxt <= h_eff when ctr_sum_ini='1' else sum_hp + h_eff; -- Accumulate h
|
||
|
--sum_hp_p_nxt <= sum_hp_nxt * eps when ctr_update_sum='1' else sum_hp_p;
|
||
|
sum_hp_p_nxt <= eps when ctr_update_sum='1' else
|
||
|
hp_new when ctr_update_sum2='1' else sum_hp_p;
|
||
|
|
||
|
|
||
|
-- Read from memory
|
||
|
--hw_p <= mem_hw(to_integer(unsigned(addr_nxt)));
|
||
|
hw_p <= mem_hw(to_integer(unsigned(addr_nxt)))
|
||
|
when ctr_update_sum2='0' else sum_hp; -- Put sum_hp in mult
|
||
|
hp_p <= mem_hp(to_integer(unsigned(addr_nxt)));
|
||
|
|
||
|
-- Address calculation
|
||
|
addr_nxt <= (others => '0') when ctr_addr_rst='1' else
|
||
|
bit_vector(unsigned(addr_wr) + 1) when ctr_addr_inc='1' else addr_wr;
|
||
|
loc_h <= addr_wr; -- Output for ctrl path
|
||
|
|
||
|
-- Registers
|
||
|
rg: process (clk, rstn) is
|
||
|
begin -- process pipe1
|
||
|
if rstn = '0' then
|
||
|
hp_new_rg <= 0.0;
|
||
|
sum_hw <= 0.0;
|
||
|
sum_hw_p <= 0.0;
|
||
|
sum_hp <= 0.0;
|
||
|
sum_hp_p <= 0.0;
|
||
|
addr_wr <= (others => '0');
|
||
|
elsif clk'event and clk = '1' then
|
||
|
hp_new_rg <= hp_new;
|
||
|
sum_hw <= sum_hw_nxt;
|
||
|
sum_hw_p <= sum_hw_p_nxt;
|
||
|
sum_hp <= sum_hp_nxt;
|
||
|
sum_hp_p <= sum_hp_p_nxt;
|
||
|
addr_wr <= addr_nxt;
|
||
|
end if;
|
||
|
end process rg;
|
||
|
|
||
|
|
||
|
-- Memory
|
||
|
mem: process (clk) is
|
||
|
begin -- process mem
|
||
|
if clk'event and clk = '1' then -- rising clock edge
|
||
|
if ctr_wr_hw='1' then
|
||
|
mem_hw(to_integer(unsigned(addr_wr))) <= hw_nxt;
|
||
|
end if;
|
||
|
if ctr_wr_hp='1' then
|
||
|
mem_hp(to_integer(unsigned(addr_wr))) <= h_eff;
|
||
|
end if;
|
||
|
end if;
|
||
|
end process mem;
|
||
|
|
||
|
end architecture rtlf;
|