149 lines
4.6 KiB
VHDL
149 lines
4.6 KiB
VHDL
library ieee;
|
||
use ieee.std_logic_1164.all;
|
||
use ieee.numeric_std.all;
|
||
use work.real_vector_pkg.all; -- defines my_real, my_real_vector, my_real_matrix, etc.
|
||
|
||
entity tile_unit is
|
||
generic (
|
||
N : integer := 2; -- Size of sub-vectors processed by each DPTC core
|
||
Nv : integer := 2; -- Rows in chunk from M1
|
||
Nh : integer := 2; -- Columns in chunk from M2
|
||
Nm : integer := 4; -- Common dimension (columns of M1 / rows of M2)
|
||
Nd : integer := 2 -- Number of DPTC cores
|
||
);
|
||
port (
|
||
clk : in std_logic;
|
||
reset_n : in std_logic;
|
||
enable : in std_logic;
|
||
|
||
m1_chunk : in my_real_matrix(0 to Nv-1, 0 to Nm-1); -- Nv × Nm sub-matrix of M1
|
||
m2_chunk : in my_real_matrix(0 to Nh-1, 0 to Nm-1); -- Nh × Nm sub-matrix of M2 (transposed view)
|
||
|
||
out_valid : out std_logic; -- Result ready signal
|
||
result_out : out real_matrix(0 to Nv-1, 0 to Nh-1) -- Output quadrant, Final Nv × Nh tile result
|
||
);
|
||
end entity;
|
||
|
||
architecture Behavioral of tile_unit is
|
||
|
||
component dptc
|
||
generic (
|
||
Nv : integer := 2;
|
||
Nh : integer := 2;
|
||
N : integer := 2
|
||
);
|
||
port (
|
||
clk : in std_logic;
|
||
reset_n : in std_logic;
|
||
enable : in std_logic;
|
||
x_matrix : in my_real_matrix(0 to Nv-1, 0 to N-1); --takes N columns from M1
|
||
y_matrix : in my_real_matrix(0 to Nh-1, 0 to N-1); -- takes N rows from M2.
|
||
out_valid : out std_logic;
|
||
result_matrix : out real_matrix(0 to Nv-1, 0 to Nh-1) -- Result(i,j)+=sum(M1chunk(i,k)∗M2chunk(j,k)) for k in tile
|
||
);
|
||
end component;
|
||
|
||
-- Internals
|
||
type matrix_array is array(0 to Nd-1) of real_matrix(0 to Nv-1, 0 to Nh-1);
|
||
signal dptc_outputs : matrix_array; -- stores Nd partial results from each DPTC
|
||
signal valid_signals : std_logic_vector(0 to Nd-1); -- stores Nd valid flags from each DPTC
|
||
|
||
-- Sliced inputs for each DPTC (Each core gets N columns of M1 and N columns of M2)
|
||
type chunk_matrix_array is array(0 to Nd-1) of my_real_matrix(0 to Nv-1, 0 to N-1);
|
||
signal m1_slices : chunk_matrix_array;
|
||
signal m2_slices : chunk_matrix_array;
|
||
-- Internal signals
|
||
signal sum_result_reg : real_matrix(0 to Nv-1, 0 to Nh-1); -- Final accumulated result
|
||
signal sum_valid_reg : std_logic := '0';
|
||
|
||
begin
|
||
|
||
-- Slice M1 and M2 into chunks for each DPTC
|
||
slicer_proc: process(m1_chunk, m2_chunk)
|
||
begin
|
||
-- M1 slicing
|
||
for d in 0 to Nd-1 loop
|
||
for i in 0 to Nv-1 loop
|
||
for k in 0 to N-1 loop
|
||
m1_slices(d)(i,k) <= m1_chunk(i, d*N + k); -- Take columns d*N ... d*N+N-1 and assign to m1_slices(d)
|
||
end loop;
|
||
end loop;
|
||
-- M2 slicing
|
||
for j in 0 to Nh-1 loop
|
||
for k in 0 to N-1 loop
|
||
--For each row of M2 (M2 is logically transposed here — columns treated as rows): Take columns d*N ... d*N+N-1
|
||
m2_slices(d)(j,k) <= m2_chunk(j, d*N + k);
|
||
end loop;
|
||
end loop;
|
||
end loop;
|
||
end process;
|
||
|
||
-- Instantiate DPTC cores
|
||
dptc_gen: for d in 0 to Nd-1 generate -- Instantiates Nd DPTC cores
|
||
dptc_inst : dptc
|
||
generic map (
|
||
Nv => Nv,
|
||
Nh => Nh,
|
||
N => N
|
||
)
|
||
port map (
|
||
clk => clk,
|
||
reset_n => reset_n,
|
||
enable => enable,
|
||
x_matrix => m1_slices(d),
|
||
y_matrix => m2_slices(d),
|
||
out_valid => valid_signals(d),
|
||
result_matrix => dptc_outputs(d)
|
||
);
|
||
end generate;
|
||
|
||
-- Sum all DPTC outputs element-wise
|
||
sum_proc: process(clk, reset_n)
|
||
variable acc : real_matrix(0 to Nv-1, 0 to Nh-1);
|
||
variable all_valid : std_logic;
|
||
begin
|
||
if reset_n = '0' then
|
||
sum_result_reg <= (others => (others => ZERO_REAL));
|
||
sum_valid_reg <= '0';
|
||
|
||
elsif rising_edge(clk) then
|
||
-- Wait until all DPTC outputs are valid
|
||
all_valid := '1';
|
||
for d in 0 to Nd-1 loop
|
||
if valid_signals(d) /= '1' then
|
||
all_valid := '0';
|
||
end if;
|
||
end loop;
|
||
|
||
if all_valid = '1' then
|
||
-- Compute the sum
|
||
for i in 0 to Nv-1 loop
|
||
for j in 0 to Nh-1 loop
|
||
acc(i,j) := ZERO_REAL;
|
||
end loop;
|
||
end loop;
|
||
|
||
for d in 0 to Nd-1 loop
|
||
for i in 0 to Nv-1 loop
|
||
for j in 0 to Nh-1 loop
|
||
acc(i,j) := acc(i,j) + dptc_outputs(d)(i,j);
|
||
end loop;
|
||
end loop;
|
||
end loop;
|
||
|
||
-- Store result to register (output happens next cycle)
|
||
sum_result_reg <= acc;
|
||
sum_valid_reg <= '1';
|
||
|
||
else
|
||
sum_valid_reg <= '0';
|
||
end if;
|
||
end if;
|
||
end process;
|
||
|
||
-- Outputs
|
||
result_out <= sum_result_reg;
|
||
out_valid <= sum_valid_reg;
|
||
|
||
|
||
end architecture;
|