AYAKA_Transformer/rtl/hdpe_unit.v

289 lines
12 KiB
Verilog

module matrix_multiplication_unit #(
parameter DATA_WIDTH = 16,
parameter ROWS = 16,
parameter COLS = 32,
parameter COLS_USED = 4
input clk,
input rst,
input enable,
input [1:0] mode, // 00: Output-Stationary, 01: Input-Stationary, 10: Weight-Stationary
input [DATA_WIDTH-1:0] data_input_A, // matrix A
input [DATA_WIDTH-1:0] data_input_B, // matrix B
input [DATA_WIDTH*COLS_USED-1:0] full_row_A, //COLS
input [DATA_WIDTH*COLS_USED-1:0] full_row_B, //COLS
input valid_mem_input_A, // matrix A
input valid_mem_input_B, // matrix B
input [$clog2(ROWS)-1:0] rows_start_add_reading_A, // offset of A for reading from memory
input [$clog2(COLS)-1:0] cols_start_add_reading_A,
input [$clog2(ROWS)-1:0] rows_start_add_reading_B, // offset of B for reading from memory [+3]!!!!!!!!!!!!!
input [$clog2(COLS)-1:0] cols_start_add_reading_B,
input [$clog2(ROWS)-1:0] rows_start_add_writing, // offset for writing to memory
input [$clog2(COLS)-1:0] cols_start_add_writing,
input [$clog2(ROWS)-1:0] rows_size_reading_A, // size of A for reading from memory
input [$clog2(COLS)-1:0] cols_size_reading_A, // should be equal to rows of B
input [$clog2(ROWS)-1:0] rows_size_reading_B, // size of B for reading from memory
input [$clog2(COLS)-1:0] cols_size_reading_B,
output reg done,
output reg read_full_row_A,
output reg read_full_row_B,
// Memory interface
output reg [$clog2(ROWS)-1:0] row_addr_A,
output reg [$clog2(COLS)-1:0] col_addr_A,
output reg [$clog2(ROWS)-1:0] row_addr_B,
output reg [$clog2(COLS)-1:0] col_addr_B,
output reg [$clog2(ROWS)-1:0] row_addr_out,
output reg [$clog2(COLS)-1:0] col_addr_out,
output reg read_enable_A,
output reg read_enable_B,
output reg write_enable_out,
output reg [DATA_WIDTH-1:0] data_out
);
// FSM State Encoding
parameter IDLE = 0,
LOAD_ROW_A = 1,
WAIT_VALID_ROW_A = 2,
LOAD_COL_B = 3,
WAIT_VALID_COL_B = 4,
COMPUTE = 5,
WAIT_OUT_COMPUTE = 6,
WRITE_BACK = 7,
LOAD_NEW_ROW_OR_COL = 8,
DONE = 9,
COMPUTE_ACC = 10;
reg [3:0] current_state, next_state;
// Registers for indices
reg [$clog2(ROWS)-1:0] i, j;
reg [1:0] k;
reg [DATA_WIDTH-1:0] A_row [0:$clog2(COLS_USED)+1]; // A is 20x4
reg [DATA_WIDTH-1:0] B_col [0:$clog2(COLS_USED)+1]; // B is 4x10
reg [2*DATA_WIDTH-1:0] acc;
// Internal control
reg [$clog2(COLS_USED)+1:0] idx, idx2, idx3;//remove all idx!!!!!!!!!!!!!!!!!!!
wire [2*DATA_WIDTH-1:0] products [0:$clog2(COLS)-1];
//------------------------------------
// FSM Sequential
//------------------------------------
always @(posedge clk or posedge rst) begin
if (rst)
current_state <= IDLE;
else
current_state <= next_state;
end
//------------------------------------
// FSM Next-State Logic
//------------------------------------
always @(*) begin
if (mode == 2'b00) begin
case (current_state)
IDLE: next_state = enable ? LOAD_ROW_A : IDLE;
LOAD_ROW_A: next_state = WAIT_VALID_ROW_A;
WAIT_VALID_ROW_A: next_state = valid_mem_input_A ? LOAD_COL_B : WAIT_VALID_ROW_A;
LOAD_COL_B: next_state = WAIT_VALID_COL_B;
WAIT_VALID_COL_B: next_state = valid_mem_input_B ? COMPUTE : WAIT_VALID_COL_B;
COMPUTE: next_state = WRITE_BACK;///WRITE_BACK : WAIT_OUT_COMPUTE;
WRITE_BACK: next_state = (i == rows_size_reading_A && j == cols_size_reading_B) ? DONE : LOAD_NEW_ROW_OR_COL;
LOAD_NEW_ROW_OR_COL: next_state = (j == cols_size_reading_B) ? LOAD_ROW_A : LOAD_COL_B;
DONE: next_state = DONE;
default: next_state = IDLE;
endcase
end else if (mode == 2'b01) begin
case (current_state)
IDLE: next_state = enable ? LOAD_ROW_A : IDLE;
LOAD_ROW_A: next_state = WAIT_VALID_ROW_A;
WAIT_VALID_ROW_A: next_state = valid_mem_input_A ? (idx == (cols_size_reading_A) ? LOAD_COL_B : LOAD_ROW_A) : WAIT_VALID_ROW_A;
LOAD_COL_B: next_state = WAIT_VALID_COL_B;
WAIT_VALID_COL_B: next_state = valid_mem_input_B ? (idx2 == (cols_size_reading_A) ? COMPUTE : LOAD_COL_B) : WAIT_VALID_COL_B;
COMPUTE: next_state = WRITE_BACK;
WRITE_BACK: next_state = (i == rows_size_reading_A && j == cols_size_reading_B) ? DONE : LOAD_NEW_ROW_OR_COL;
LOAD_NEW_ROW_OR_COL: next_state = (j == cols_size_reading_B) ? LOAD_ROW_A : LOAD_COL_B;
DONE: next_state = DONE;
default: next_state = IDLE;
endcase
end else if (mode == 2'b10) begin
case (current_state)
IDLE: next_state = enable ? LOAD_COL_B : IDLE; //LOAD_ROW_A
LOAD_ROW_A: next_state = WAIT_VALID_ROW_A;
WAIT_VALID_ROW_A: next_state = valid_mem_input_A ? (idx == (cols_size_reading_A) ? COMPUTE : LOAD_ROW_A) : WAIT_VALID_ROW_A; //LOAD_COL_B
LOAD_COL_B: next_state = WAIT_VALID_COL_B;
WAIT_VALID_COL_B: next_state = valid_mem_input_A ? (idx2 == (cols_size_reading_A) ? LOAD_ROW_A : LOAD_COL_B) : WAIT_VALID_COL_B; //COMPUTE
COMPUTE: next_state = WRITE_BACK;
WRITE_BACK: next_state = (i == rows_size_reading_A && j == cols_size_reading_B) ? DONE : LOAD_NEW_ROW_OR_COL;
LOAD_NEW_ROW_OR_COL: next_state = (i == rows_size_reading_A) ? LOAD_COL_B : LOAD_ROW_A;//j
DONE: next_state = DONE;
default: next_state = IDLE;
endcase
end else begin
$finish;
end
end
//------------------------------------
// Sequential Logic
//------------------------------------
always @(posedge clk or posedge rst) begin
if (rst) begin
i <= 0;
j <= 0;
k <= 0;
idx <= 0;
idx2 <= 0;
idx3 <= 0;
acc <= 0;
done <= 0;
read_full_row_A <= 0;
read_full_row_B <= 0;
write_enable_out <= 0;
//load_phase <= 0;
end else begin
case (current_state)
IDLE: begin
i <= 0;
j <= 0;
idx <= 0;
idx2 <= 0;
idx3 <= 0;
acc <= 0;
done <= 0;
read_full_row_A <= 0;
read_full_row_B <= 0;
write_enable_out <= 0;
$display("---- IDLE: Waiting for enable ----");
end
LOAD_ROW_A: begin
row_addr_A <= i + rows_start_add_reading_A;
col_addr_A <= cols_start_add_reading_A;//idx +
read_full_row_A <= 1;
$display("LOAD_ROW_A: Reading A[%0d][x]", i);
end
WAIT_VALID_ROW_A: begin
read_full_row_A <= 0;
idx2 <= 0; ///new change
acc <= 0; ////
idx3 <= 0;
if (valid_mem_input_A) begin
//A_row[idx] <= data_input_A;
$display("WAIT_VALID_ROW_A: A[%0d][x] = %0h", i, full_row_A);
// if (idx < ( cols_size_reading_A))
// idx <= idx + 1;
end
end
LOAD_COL_B: begin
row_addr_B <= j + rows_start_add_reading_B;
col_addr_B <= cols_start_add_reading_B; // Accessing B^T idx2
read_full_row_B <= 1;
$display("LOAD_COL_B: Reading B[%0d][x] ", j);
end
WAIT_VALID_COL_B: begin
read_full_row_B <= 0;
acc <= 0; /////
idx <= 0;
idx3 <= 0; ///new change
if (valid_mem_input_B) begin
$display("WAIT_VALID_COL_B: B[%0d][x] = %0h", j, full_row_B);
end
end
COMPUTE: begin
acc <= (full_row_A[(0*DATA_WIDTH) +: DATA_WIDTH] * full_row_B[(0*DATA_WIDTH) +: DATA_WIDTH]) +
(full_row_A[(1*DATA_WIDTH) +: DATA_WIDTH] * full_row_B[(1*DATA_WIDTH) +: DATA_WIDTH]) +
(full_row_A[(2*DATA_WIDTH) +: DATA_WIDTH] * full_row_B[(2*DATA_WIDTH) +: DATA_WIDTH]) +
(full_row_A[(3*DATA_WIDTH) +: DATA_WIDTH] * full_row_B[(3*DATA_WIDTH) +: DATA_WIDTH]);
$display("COMPUTE: Dot Product A[%0d][*] . B[%0d][*] {A(%0d, %0d, %0d, %0d) B(%0d, %0d, %0d, %0d)} = %0h", i, j, full_row_A[(0*DATA_WIDTH) +: DATA_WIDTH], full_row_A[(1*DATA_WIDTH) +: DATA_WIDTH], full_row_A[(2*DATA_WIDTH) +: DATA_WIDTH], full_row_A[(3*DATA_WIDTH) +: DATA_WIDTH], full_row_B[(0*DATA_WIDTH) +: DATA_WIDTH], full_row_B[(1*DATA_WIDTH) +: DATA_WIDTH], full_row_B[(2*DATA_WIDTH) +: DATA_WIDTH], full_row_B[(3*DATA_WIDTH) +: DATA_WIDTH], acc);
end
COMPUTE_ACC: begin
acc <= (A_row[idx3] * B_col[idx3]) + acc;
// $display("COMPUTE_ACC: partial sum A[%0d][%d] . B[%0d][%d] {A(%0d, %0d, %0d, %0d) B(%0d, %0d, %0d, %0d)} = %0h", i, idx3, j, idx3, A_row[0], A_row[1], A_row[2], A_row[3], B_col[0], B_col[1], B_col[2], B_col[3], acc);
end
WAIT_OUT_COMPUTE: begin
idx3 <= idx3 + 1;
$display("WAIT_OUT_COMPUTE: idx3 = %d ", idx3 );
end
WRITE_BACK: begin
row_addr_out <= i + rows_start_add_writing;
col_addr_out <= j + cols_start_add_writing; // Adjust output memory offset
data_out <= acc[DATA_WIDTH-1:0];
write_enable_out <= 1;
$display("WRITE_BACK: Writing Result[%0d][%0d] = %0h ; write_enable => %d ; acc =%0h ", i, j, acc[DATA_WIDTH-1:0], write_enable_out, acc);
end
LOAD_NEW_ROW_OR_COL: begin
acc <= 0;
if (mode == 2'b00) begin
idx2 <= 0;
idx <= 0;
write_enable_out <= 0;
if(j == cols_size_reading_B) begin
idx <= 0;//
j <= 0;
i <= i + 1;
end else begin
j <= j + 1;
end
end else if (mode == 2'b01) begin
idx2 <= 0;
write_enable_out <= 0;
if(j == cols_size_reading_B) begin
idx <= 0;//
j <= 0;
i <= i + 1;
end else begin
j <= j + 1;
end
end else if(mode == 2'b10) begin
idx <= 0;
write_enable_out <= 0;
if(i == rows_size_reading_A) begin
idx2 <= 0;//
i <= 0;
j <= j + 1;
end else begin
i <= i + 1;
end
end
$display("LOAD_NEW_ROW_OR_COL: Writing [i = %0d; j = %0d] (idx = %0d; idx2 =%0d)", i, j, idx, idx2);
end
DONE: begin
done <= 1;
$display("DONE: All matrix multiplication complete.");
end
default: begin
write_enable_out <= 0;
read_full_row_A <= 0;
read_full_row_B <= 0;
end
endcase
end
end
endmodule