Lightening_Transformer/tile.vhd
2025-06-18 13:05:14 +02:00

149 lines
4.6 KiB
VHDL
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

library ieee;
use ieee.std_logic_1164.all;
use ieee.numeric_std.all;
use work.real_vector_pkg.all; -- defines my_real, my_real_vector, my_real_matrix, etc.
entity tile_unit is
generic (
N : integer := 2; -- Size of sub-vectors processed by each DPTC core
Nv : integer := 2; -- Rows in chunk from M1
Nh : integer := 2; -- Columns in chunk from M2
Nm : integer := 4; -- Common dimension (columns of M1 / rows of M2)
Nd : integer := 2 -- Number of DPTC cores
);
port (
clk : in std_logic;
reset_n : in std_logic;
enable : in std_logic;
m1_chunk : in my_real_matrix(0 to Nv-1, 0 to Nm-1); -- Nv × Nm sub-matrix of M1
m2_chunk : in my_real_matrix(0 to Nh-1, 0 to Nm-1); -- Nh × Nm sub-matrix of M2 (transposed view)
out_valid : out std_logic; -- Result ready signal
result_out : out real_matrix(0 to Nv-1, 0 to Nh-1) -- Output quadrant, Final Nv × Nh tile result
);
end entity;
architecture Behavioral of tile_unit is
component dptc
generic (
Nv : integer := 2;
Nh : integer := 2;
N : integer := 2
);
port (
clk : in std_logic;
reset_n : in std_logic;
enable : in std_logic;
x_matrix : in my_real_matrix(0 to Nv-1, 0 to N-1); --takes N columns from M1
y_matrix : in my_real_matrix(0 to Nh-1, 0 to N-1); -- takes N rows from M2.
out_valid : out std_logic;
result_matrix : out real_matrix(0 to Nv-1, 0 to Nh-1) -- Result(i,j)+=sum(M1chunk(i,k)M2chunk(j,k)) for k in tile
);
end component;
-- Internals
type matrix_array is array(0 to Nd-1) of real_matrix(0 to Nv-1, 0 to Nh-1);
signal dptc_outputs : matrix_array; -- stores Nd partial results from each DPTC
signal valid_signals : std_logic_vector(0 to Nd-1); -- stores Nd valid flags from each DPTC
-- Sliced inputs for each DPTC (Each core gets N columns of M1 and N columns of M2)
type chunk_matrix_array is array(0 to Nd-1) of my_real_matrix(0 to Nv-1, 0 to N-1);
signal m1_slices : chunk_matrix_array;
signal m2_slices : chunk_matrix_array;
-- Internal signals
signal sum_result_reg : real_matrix(0 to Nv-1, 0 to Nh-1); -- Final accumulated result
signal sum_valid_reg : std_logic := '0';
begin
-- Slice M1 and M2 into chunks for each DPTC
slicer_proc: process(m1_chunk, m2_chunk)
begin
-- M1 slicing
for d in 0 to Nd-1 loop
for i in 0 to Nv-1 loop
for k in 0 to N-1 loop
m1_slices(d)(i,k) <= m1_chunk(i, d*N + k); -- Take columns d*N ... d*N+N-1 and assign to m1_slices(d)
end loop;
end loop;
-- M2 slicing
for j in 0 to Nh-1 loop
for k in 0 to N-1 loop
--For each row of M2 (M2 is logically transposed here — columns treated as rows): Take columns d*N ... d*N+N-1
m2_slices(d)(j,k) <= m2_chunk(j, d*N + k);
end loop;
end loop;
end loop;
end process;
-- Instantiate DPTC cores
dptc_gen: for d in 0 to Nd-1 generate -- Instantiates Nd DPTC cores
dptc_inst : dptc
generic map (
Nv => Nv,
Nh => Nh,
N => N
)
port map (
clk => clk,
reset_n => reset_n,
enable => enable,
x_matrix => m1_slices(d),
y_matrix => m2_slices(d),
out_valid => valid_signals(d),
result_matrix => dptc_outputs(d)
);
end generate;
-- Sum all DPTC outputs element-wise
sum_proc: process(clk, reset_n)
variable acc : real_matrix(0 to Nv-1, 0 to Nh-1);
variable all_valid : std_logic;
begin
if reset_n = '0' then
sum_result_reg <= (others => (others => ZERO_REAL));
sum_valid_reg <= '0';
elsif rising_edge(clk) then
-- Wait until all DPTC outputs are valid
all_valid := '1';
for d in 0 to Nd-1 loop
if valid_signals(d) /= '1' then
all_valid := '0';
end if;
end loop;
if all_valid = '1' then
-- Compute the sum
for i in 0 to Nv-1 loop
for j in 0 to Nh-1 loop
acc(i,j) := ZERO_REAL;
end loop;
end loop;
for d in 0 to Nd-1 loop
for i in 0 to Nv-1 loop
for j in 0 to Nh-1 loop
acc(i,j) := acc(i,j) + dptc_outputs(d)(i,j);
end loop;
end loop;
end loop;
-- Store result to register (output happens next cycle)
sum_result_reg <= acc;
sum_valid_reg <= '1';
else
sum_valid_reg <= '0';
end if;
end if;
end process;
-- Outputs
result_out <= sum_result_reg;
out_valid <= sum_valid_reg;
end architecture;