289 lines
12 KiB
Verilog
289 lines
12 KiB
Verilog
module matrix_multiplication_unit #(
|
|
parameter DATA_WIDTH = 16,
|
|
parameter ROWS = 16,
|
|
parameter COLS = 32,
|
|
parameter COLS_USED = 4
|
|
input clk,
|
|
input rst,
|
|
input enable,
|
|
input [1:0] mode, // 00: Output-Stationary, 01: Input-Stationary, 10: Weight-Stationary
|
|
input [DATA_WIDTH-1:0] data_input_A, // matrix A
|
|
input [DATA_WIDTH-1:0] data_input_B, // matrix B
|
|
input [DATA_WIDTH*COLS_USED-1:0] full_row_A, //COLS
|
|
input [DATA_WIDTH*COLS_USED-1:0] full_row_B, //COLS
|
|
|
|
input valid_mem_input_A, // matrix A
|
|
input valid_mem_input_B, // matrix B
|
|
|
|
|
|
input [$clog2(ROWS)-1:0] rows_start_add_reading_A, // offset of A for reading from memory
|
|
input [$clog2(COLS)-1:0] cols_start_add_reading_A,
|
|
input [$clog2(ROWS)-1:0] rows_start_add_reading_B, // offset of B for reading from memory [+3]!!!!!!!!!!!!!
|
|
input [$clog2(COLS)-1:0] cols_start_add_reading_B,
|
|
input [$clog2(ROWS)-1:0] rows_start_add_writing, // offset for writing to memory
|
|
input [$clog2(COLS)-1:0] cols_start_add_writing,
|
|
|
|
input [$clog2(ROWS)-1:0] rows_size_reading_A, // size of A for reading from memory
|
|
input [$clog2(COLS)-1:0] cols_size_reading_A, // should be equal to rows of B
|
|
input [$clog2(ROWS)-1:0] rows_size_reading_B, // size of B for reading from memory
|
|
input [$clog2(COLS)-1:0] cols_size_reading_B,
|
|
|
|
|
|
output reg done,
|
|
|
|
output reg read_full_row_A,
|
|
output reg read_full_row_B,
|
|
|
|
// Memory interface
|
|
output reg [$clog2(ROWS)-1:0] row_addr_A,
|
|
output reg [$clog2(COLS)-1:0] col_addr_A,
|
|
output reg [$clog2(ROWS)-1:0] row_addr_B,
|
|
output reg [$clog2(COLS)-1:0] col_addr_B,
|
|
output reg [$clog2(ROWS)-1:0] row_addr_out,
|
|
output reg [$clog2(COLS)-1:0] col_addr_out,
|
|
output reg read_enable_A,
|
|
output reg read_enable_B,
|
|
output reg write_enable_out,
|
|
output reg [DATA_WIDTH-1:0] data_out
|
|
);
|
|
|
|
// FSM State Encoding
|
|
parameter IDLE = 0,
|
|
LOAD_ROW_A = 1,
|
|
WAIT_VALID_ROW_A = 2,
|
|
LOAD_COL_B = 3,
|
|
WAIT_VALID_COL_B = 4,
|
|
COMPUTE = 5,
|
|
WAIT_OUT_COMPUTE = 6,
|
|
WRITE_BACK = 7,
|
|
LOAD_NEW_ROW_OR_COL = 8,
|
|
DONE = 9,
|
|
COMPUTE_ACC = 10;
|
|
|
|
reg [3:0] current_state, next_state;
|
|
|
|
// Registers for indices
|
|
reg [$clog2(ROWS)-1:0] i, j;
|
|
reg [1:0] k;
|
|
reg [DATA_WIDTH-1:0] A_row [0:$clog2(COLS_USED)+1]; // A is 20x4
|
|
reg [DATA_WIDTH-1:0] B_col [0:$clog2(COLS_USED)+1]; // B is 4x10
|
|
reg [2*DATA_WIDTH-1:0] acc;
|
|
|
|
// Internal control
|
|
reg [$clog2(COLS_USED)+1:0] idx, idx2, idx3;//remove all idx!!!!!!!!!!!!!!!!!!!
|
|
wire [2*DATA_WIDTH-1:0] products [0:$clog2(COLS)-1];
|
|
|
|
|
|
//------------------------------------
|
|
// FSM Sequential
|
|
//------------------------------------
|
|
always @(posedge clk or posedge rst) begin
|
|
if (rst)
|
|
current_state <= IDLE;
|
|
else
|
|
current_state <= next_state;
|
|
end
|
|
|
|
//------------------------------------
|
|
// FSM Next-State Logic
|
|
//------------------------------------
|
|
always @(*) begin
|
|
|
|
if (mode == 2'b00) begin
|
|
case (current_state)
|
|
IDLE: next_state = enable ? LOAD_ROW_A : IDLE;
|
|
LOAD_ROW_A: next_state = WAIT_VALID_ROW_A;
|
|
WAIT_VALID_ROW_A: next_state = valid_mem_input_A ? LOAD_COL_B : WAIT_VALID_ROW_A;
|
|
LOAD_COL_B: next_state = WAIT_VALID_COL_B;
|
|
WAIT_VALID_COL_B: next_state = valid_mem_input_B ? COMPUTE : WAIT_VALID_COL_B;
|
|
COMPUTE: next_state = WRITE_BACK;///WRITE_BACK : WAIT_OUT_COMPUTE;
|
|
WRITE_BACK: next_state = (i == rows_size_reading_A && j == cols_size_reading_B) ? DONE : LOAD_NEW_ROW_OR_COL;
|
|
LOAD_NEW_ROW_OR_COL: next_state = (j == cols_size_reading_B) ? LOAD_ROW_A : LOAD_COL_B;
|
|
DONE: next_state = DONE;
|
|
default: next_state = IDLE;
|
|
endcase
|
|
end else if (mode == 2'b01) begin
|
|
case (current_state)
|
|
IDLE: next_state = enable ? LOAD_ROW_A : IDLE;
|
|
LOAD_ROW_A: next_state = WAIT_VALID_ROW_A;
|
|
WAIT_VALID_ROW_A: next_state = valid_mem_input_A ? (idx == (cols_size_reading_A) ? LOAD_COL_B : LOAD_ROW_A) : WAIT_VALID_ROW_A;
|
|
LOAD_COL_B: next_state = WAIT_VALID_COL_B;
|
|
WAIT_VALID_COL_B: next_state = valid_mem_input_B ? (idx2 == (cols_size_reading_A) ? COMPUTE : LOAD_COL_B) : WAIT_VALID_COL_B;
|
|
COMPUTE: next_state = WRITE_BACK;
|
|
WRITE_BACK: next_state = (i == rows_size_reading_A && j == cols_size_reading_B) ? DONE : LOAD_NEW_ROW_OR_COL;
|
|
LOAD_NEW_ROW_OR_COL: next_state = (j == cols_size_reading_B) ? LOAD_ROW_A : LOAD_COL_B;
|
|
DONE: next_state = DONE;
|
|
default: next_state = IDLE;
|
|
endcase
|
|
end else if (mode == 2'b10) begin
|
|
case (current_state)
|
|
IDLE: next_state = enable ? LOAD_COL_B : IDLE; //LOAD_ROW_A
|
|
LOAD_ROW_A: next_state = WAIT_VALID_ROW_A;
|
|
WAIT_VALID_ROW_A: next_state = valid_mem_input_A ? (idx == (cols_size_reading_A) ? COMPUTE : LOAD_ROW_A) : WAIT_VALID_ROW_A; //LOAD_COL_B
|
|
LOAD_COL_B: next_state = WAIT_VALID_COL_B;
|
|
WAIT_VALID_COL_B: next_state = valid_mem_input_A ? (idx2 == (cols_size_reading_A) ? LOAD_ROW_A : LOAD_COL_B) : WAIT_VALID_COL_B; //COMPUTE
|
|
COMPUTE: next_state = WRITE_BACK;
|
|
WRITE_BACK: next_state = (i == rows_size_reading_A && j == cols_size_reading_B) ? DONE : LOAD_NEW_ROW_OR_COL;
|
|
LOAD_NEW_ROW_OR_COL: next_state = (i == rows_size_reading_A) ? LOAD_COL_B : LOAD_ROW_A;//j
|
|
DONE: next_state = DONE;
|
|
default: next_state = IDLE;
|
|
endcase
|
|
end else begin
|
|
$finish;
|
|
end
|
|
end
|
|
|
|
//------------------------------------
|
|
// Sequential Logic
|
|
//------------------------------------
|
|
always @(posedge clk or posedge rst) begin
|
|
if (rst) begin
|
|
i <= 0;
|
|
j <= 0;
|
|
k <= 0;
|
|
idx <= 0;
|
|
idx2 <= 0;
|
|
idx3 <= 0;
|
|
acc <= 0;
|
|
done <= 0;
|
|
read_full_row_A <= 0;
|
|
read_full_row_B <= 0;
|
|
write_enable_out <= 0;
|
|
//load_phase <= 0;
|
|
end else begin
|
|
case (current_state)
|
|
|
|
IDLE: begin
|
|
i <= 0;
|
|
j <= 0;
|
|
idx <= 0;
|
|
idx2 <= 0;
|
|
idx3 <= 0;
|
|
acc <= 0;
|
|
done <= 0;
|
|
read_full_row_A <= 0;
|
|
read_full_row_B <= 0;
|
|
write_enable_out <= 0;
|
|
$display("---- IDLE: Waiting for enable ----");
|
|
end
|
|
|
|
LOAD_ROW_A: begin
|
|
row_addr_A <= i + rows_start_add_reading_A;
|
|
col_addr_A <= cols_start_add_reading_A;//idx +
|
|
read_full_row_A <= 1;
|
|
$display("LOAD_ROW_A: Reading A[%0d][x]", i);
|
|
end
|
|
|
|
WAIT_VALID_ROW_A: begin
|
|
read_full_row_A <= 0;
|
|
idx2 <= 0; ///new change
|
|
acc <= 0; ////
|
|
idx3 <= 0;
|
|
if (valid_mem_input_A) begin
|
|
//A_row[idx] <= data_input_A;
|
|
$display("WAIT_VALID_ROW_A: A[%0d][x] = %0h", i, full_row_A);
|
|
// if (idx < ( cols_size_reading_A))
|
|
// idx <= idx + 1;
|
|
end
|
|
end
|
|
|
|
LOAD_COL_B: begin
|
|
row_addr_B <= j + rows_start_add_reading_B;
|
|
col_addr_B <= cols_start_add_reading_B; // Accessing B^T idx2
|
|
read_full_row_B <= 1;
|
|
$display("LOAD_COL_B: Reading B[%0d][x] ", j);
|
|
end
|
|
|
|
WAIT_VALID_COL_B: begin
|
|
read_full_row_B <= 0;
|
|
acc <= 0; /////
|
|
idx <= 0;
|
|
idx3 <= 0; ///new change
|
|
if (valid_mem_input_B) begin
|
|
$display("WAIT_VALID_COL_B: B[%0d][x] = %0h", j, full_row_B);
|
|
|
|
end
|
|
end
|
|
|
|
COMPUTE: begin
|
|
acc <= (full_row_A[(0*DATA_WIDTH) +: DATA_WIDTH] * full_row_B[(0*DATA_WIDTH) +: DATA_WIDTH]) +
|
|
(full_row_A[(1*DATA_WIDTH) +: DATA_WIDTH] * full_row_B[(1*DATA_WIDTH) +: DATA_WIDTH]) +
|
|
(full_row_A[(2*DATA_WIDTH) +: DATA_WIDTH] * full_row_B[(2*DATA_WIDTH) +: DATA_WIDTH]) +
|
|
(full_row_A[(3*DATA_WIDTH) +: DATA_WIDTH] * full_row_B[(3*DATA_WIDTH) +: DATA_WIDTH]);
|
|
|
|
$display("COMPUTE: Dot Product A[%0d][*] . B[%0d][*] {A(%0d, %0d, %0d, %0d) B(%0d, %0d, %0d, %0d)} = %0h", i, j, full_row_A[(0*DATA_WIDTH) +: DATA_WIDTH], full_row_A[(1*DATA_WIDTH) +: DATA_WIDTH], full_row_A[(2*DATA_WIDTH) +: DATA_WIDTH], full_row_A[(3*DATA_WIDTH) +: DATA_WIDTH], full_row_B[(0*DATA_WIDTH) +: DATA_WIDTH], full_row_B[(1*DATA_WIDTH) +: DATA_WIDTH], full_row_B[(2*DATA_WIDTH) +: DATA_WIDTH], full_row_B[(3*DATA_WIDTH) +: DATA_WIDTH], acc);
|
|
end
|
|
|
|
COMPUTE_ACC: begin
|
|
acc <= (A_row[idx3] * B_col[idx3]) + acc;
|
|
// $display("COMPUTE_ACC: partial sum A[%0d][%d] . B[%0d][%d] {A(%0d, %0d, %0d, %0d) B(%0d, %0d, %0d, %0d)} = %0h", i, idx3, j, idx3, A_row[0], A_row[1], A_row[2], A_row[3], B_col[0], B_col[1], B_col[2], B_col[3], acc);
|
|
end
|
|
|
|
WAIT_OUT_COMPUTE: begin
|
|
idx3 <= idx3 + 1;
|
|
$display("WAIT_OUT_COMPUTE: idx3 = %d ", idx3 );
|
|
end
|
|
|
|
WRITE_BACK: begin
|
|
row_addr_out <= i + rows_start_add_writing;
|
|
col_addr_out <= j + cols_start_add_writing; // Adjust output memory offset
|
|
data_out <= acc[DATA_WIDTH-1:0];
|
|
write_enable_out <= 1;
|
|
$display("WRITE_BACK: Writing Result[%0d][%0d] = %0h ; write_enable => %d ; acc =%0h ", i, j, acc[DATA_WIDTH-1:0], write_enable_out, acc);
|
|
end
|
|
|
|
LOAD_NEW_ROW_OR_COL: begin
|
|
acc <= 0;
|
|
if (mode == 2'b00) begin
|
|
idx2 <= 0;
|
|
idx <= 0;
|
|
write_enable_out <= 0;
|
|
if(j == cols_size_reading_B) begin
|
|
idx <= 0;//
|
|
j <= 0;
|
|
i <= i + 1;
|
|
end else begin
|
|
j <= j + 1;
|
|
end
|
|
end else if (mode == 2'b01) begin
|
|
idx2 <= 0;
|
|
write_enable_out <= 0;
|
|
if(j == cols_size_reading_B) begin
|
|
idx <= 0;//
|
|
j <= 0;
|
|
i <= i + 1;
|
|
end else begin
|
|
j <= j + 1;
|
|
end
|
|
end else if(mode == 2'b10) begin
|
|
idx <= 0;
|
|
write_enable_out <= 0;
|
|
if(i == rows_size_reading_A) begin
|
|
idx2 <= 0;//
|
|
i <= 0;
|
|
j <= j + 1;
|
|
end else begin
|
|
i <= i + 1;
|
|
end
|
|
end
|
|
|
|
$display("LOAD_NEW_ROW_OR_COL: Writing [i = %0d; j = %0d] (idx = %0d; idx2 =%0d)", i, j, idx, idx2);
|
|
end
|
|
|
|
DONE: begin
|
|
done <= 1;
|
|
$display("DONE: All matrix multiplication complete.");
|
|
end
|
|
|
|
default: begin
|
|
write_enable_out <= 0;
|
|
read_full_row_A <= 0;
|
|
read_full_row_B <= 0;
|
|
end
|
|
endcase
|
|
|
|
end
|
|
end
|
|
|
|
endmodule
|
|
|