module matrix_multiplication_unit #( parameter DATA_WIDTH = 16, parameter ROWS = 16, parameter COLS = 32, parameter COLS_USED = 4 input clk, input rst, input enable, input [1:0] mode, // 00: Output-Stationary, 01: Input-Stationary, 10: Weight-Stationary input [DATA_WIDTH-1:0] data_input_A, // matrix A input [DATA_WIDTH-1:0] data_input_B, // matrix B input [DATA_WIDTH*COLS_USED-1:0] full_row_A, //COLS input [DATA_WIDTH*COLS_USED-1:0] full_row_B, //COLS input valid_mem_input_A, // matrix A input valid_mem_input_B, // matrix B input [$clog2(ROWS)-1:0] rows_start_add_reading_A, // offset of A for reading from memory input [$clog2(COLS)-1:0] cols_start_add_reading_A, input [$clog2(ROWS)-1:0] rows_start_add_reading_B, // offset of B for reading from memory [+3]!!!!!!!!!!!!! input [$clog2(COLS)-1:0] cols_start_add_reading_B, input [$clog2(ROWS)-1:0] rows_start_add_writing, // offset for writing to memory input [$clog2(COLS)-1:0] cols_start_add_writing, input [$clog2(ROWS)-1:0] rows_size_reading_A, // size of A for reading from memory input [$clog2(COLS)-1:0] cols_size_reading_A, // should be equal to rows of B input [$clog2(ROWS)-1:0] rows_size_reading_B, // size of B for reading from memory input [$clog2(COLS)-1:0] cols_size_reading_B, output reg done, output reg read_full_row_A, output reg read_full_row_B, // Memory interface output reg [$clog2(ROWS)-1:0] row_addr_A, output reg [$clog2(COLS)-1:0] col_addr_A, output reg [$clog2(ROWS)-1:0] row_addr_B, output reg [$clog2(COLS)-1:0] col_addr_B, output reg [$clog2(ROWS)-1:0] row_addr_out, output reg [$clog2(COLS)-1:0] col_addr_out, output reg read_enable_A, output reg read_enable_B, output reg write_enable_out, output reg [DATA_WIDTH-1:0] data_out ); // FSM State Encoding parameter IDLE = 0, LOAD_ROW_A = 1, WAIT_VALID_ROW_A = 2, LOAD_COL_B = 3, WAIT_VALID_COL_B = 4, COMPUTE = 5, WAIT_OUT_COMPUTE = 6, WRITE_BACK = 7, LOAD_NEW_ROW_OR_COL = 8, DONE = 9, COMPUTE_ACC = 10; reg [3:0] current_state, next_state; // Registers for indices reg [$clog2(ROWS)-1:0] i, j; reg [1:0] k; reg [DATA_WIDTH-1:0] A_row [0:$clog2(COLS_USED)+1]; // A is 20x4 reg [DATA_WIDTH-1:0] B_col [0:$clog2(COLS_USED)+1]; // B is 4x10 reg [2*DATA_WIDTH-1:0] acc; // Internal control reg [$clog2(COLS_USED)+1:0] idx, idx2, idx3;//remove all idx!!!!!!!!!!!!!!!!!!! wire [2*DATA_WIDTH-1:0] products [0:$clog2(COLS)-1]; //------------------------------------ // FSM Sequential //------------------------------------ always @(posedge clk or posedge rst) begin if (rst) current_state <= IDLE; else current_state <= next_state; end //------------------------------------ // FSM Next-State Logic //------------------------------------ always @(*) begin if (mode == 2'b00) begin case (current_state) IDLE: next_state = enable ? LOAD_ROW_A : IDLE; LOAD_ROW_A: next_state = WAIT_VALID_ROW_A; WAIT_VALID_ROW_A: next_state = valid_mem_input_A ? LOAD_COL_B : WAIT_VALID_ROW_A; LOAD_COL_B: next_state = WAIT_VALID_COL_B; WAIT_VALID_COL_B: next_state = valid_mem_input_B ? COMPUTE : WAIT_VALID_COL_B; COMPUTE: next_state = WRITE_BACK;///WRITE_BACK : WAIT_OUT_COMPUTE; WRITE_BACK: next_state = (i == rows_size_reading_A && j == cols_size_reading_B) ? DONE : LOAD_NEW_ROW_OR_COL; LOAD_NEW_ROW_OR_COL: next_state = (j == cols_size_reading_B) ? LOAD_ROW_A : LOAD_COL_B; DONE: next_state = DONE; default: next_state = IDLE; endcase end else if (mode == 2'b01) begin case (current_state) IDLE: next_state = enable ? LOAD_ROW_A : IDLE; LOAD_ROW_A: next_state = WAIT_VALID_ROW_A; WAIT_VALID_ROW_A: next_state = valid_mem_input_A ? (idx == (cols_size_reading_A) ? LOAD_COL_B : LOAD_ROW_A) : WAIT_VALID_ROW_A; LOAD_COL_B: next_state = WAIT_VALID_COL_B; WAIT_VALID_COL_B: next_state = valid_mem_input_B ? (idx2 == (cols_size_reading_A) ? COMPUTE : LOAD_COL_B) : WAIT_VALID_COL_B; COMPUTE: next_state = WRITE_BACK; WRITE_BACK: next_state = (i == rows_size_reading_A && j == cols_size_reading_B) ? DONE : LOAD_NEW_ROW_OR_COL; LOAD_NEW_ROW_OR_COL: next_state = (j == cols_size_reading_B) ? LOAD_ROW_A : LOAD_COL_B; DONE: next_state = DONE; default: next_state = IDLE; endcase end else if (mode == 2'b10) begin case (current_state) IDLE: next_state = enable ? LOAD_COL_B : IDLE; //LOAD_ROW_A LOAD_ROW_A: next_state = WAIT_VALID_ROW_A; WAIT_VALID_ROW_A: next_state = valid_mem_input_A ? (idx == (cols_size_reading_A) ? COMPUTE : LOAD_ROW_A) : WAIT_VALID_ROW_A; //LOAD_COL_B LOAD_COL_B: next_state = WAIT_VALID_COL_B; WAIT_VALID_COL_B: next_state = valid_mem_input_A ? (idx2 == (cols_size_reading_A) ? LOAD_ROW_A : LOAD_COL_B) : WAIT_VALID_COL_B; //COMPUTE COMPUTE: next_state = WRITE_BACK; WRITE_BACK: next_state = (i == rows_size_reading_A && j == cols_size_reading_B) ? DONE : LOAD_NEW_ROW_OR_COL; LOAD_NEW_ROW_OR_COL: next_state = (i == rows_size_reading_A) ? LOAD_COL_B : LOAD_ROW_A;//j DONE: next_state = DONE; default: next_state = IDLE; endcase end else begin $finish; end end //------------------------------------ // Sequential Logic //------------------------------------ always @(posedge clk or posedge rst) begin if (rst) begin i <= 0; j <= 0; k <= 0; idx <= 0; idx2 <= 0; idx3 <= 0; acc <= 0; done <= 0; read_full_row_A <= 0; read_full_row_B <= 0; write_enable_out <= 0; //load_phase <= 0; end else begin case (current_state) IDLE: begin i <= 0; j <= 0; idx <= 0; idx2 <= 0; idx3 <= 0; acc <= 0; done <= 0; read_full_row_A <= 0; read_full_row_B <= 0; write_enable_out <= 0; $display("---- IDLE: Waiting for enable ----"); end LOAD_ROW_A: begin row_addr_A <= i + rows_start_add_reading_A; col_addr_A <= cols_start_add_reading_A;//idx + read_full_row_A <= 1; $display("LOAD_ROW_A: Reading A[%0d][x]", i); end WAIT_VALID_ROW_A: begin read_full_row_A <= 0; idx2 <= 0; ///new change acc <= 0; //// idx3 <= 0; if (valid_mem_input_A) begin //A_row[idx] <= data_input_A; $display("WAIT_VALID_ROW_A: A[%0d][x] = %0h", i, full_row_A); // if (idx < ( cols_size_reading_A)) // idx <= idx + 1; end end LOAD_COL_B: begin row_addr_B <= j + rows_start_add_reading_B; col_addr_B <= cols_start_add_reading_B; // Accessing B^T idx2 read_full_row_B <= 1; $display("LOAD_COL_B: Reading B[%0d][x] ", j); end WAIT_VALID_COL_B: begin read_full_row_B <= 0; acc <= 0; ///// idx <= 0; idx3 <= 0; ///new change if (valid_mem_input_B) begin $display("WAIT_VALID_COL_B: B[%0d][x] = %0h", j, full_row_B); end end COMPUTE: begin acc <= (full_row_A[(0*DATA_WIDTH) +: DATA_WIDTH] * full_row_B[(0*DATA_WIDTH) +: DATA_WIDTH]) + (full_row_A[(1*DATA_WIDTH) +: DATA_WIDTH] * full_row_B[(1*DATA_WIDTH) +: DATA_WIDTH]) + (full_row_A[(2*DATA_WIDTH) +: DATA_WIDTH] * full_row_B[(2*DATA_WIDTH) +: DATA_WIDTH]) + (full_row_A[(3*DATA_WIDTH) +: DATA_WIDTH] * full_row_B[(3*DATA_WIDTH) +: DATA_WIDTH]); $display("COMPUTE: Dot Product A[%0d][*] . B[%0d][*] {A(%0d, %0d, %0d, %0d) B(%0d, %0d, %0d, %0d)} = %0h", i, j, full_row_A[(0*DATA_WIDTH) +: DATA_WIDTH], full_row_A[(1*DATA_WIDTH) +: DATA_WIDTH], full_row_A[(2*DATA_WIDTH) +: DATA_WIDTH], full_row_A[(3*DATA_WIDTH) +: DATA_WIDTH], full_row_B[(0*DATA_WIDTH) +: DATA_WIDTH], full_row_B[(1*DATA_WIDTH) +: DATA_WIDTH], full_row_B[(2*DATA_WIDTH) +: DATA_WIDTH], full_row_B[(3*DATA_WIDTH) +: DATA_WIDTH], acc); end COMPUTE_ACC: begin acc <= (A_row[idx3] * B_col[idx3]) + acc; // $display("COMPUTE_ACC: partial sum A[%0d][%d] . B[%0d][%d] {A(%0d, %0d, %0d, %0d) B(%0d, %0d, %0d, %0d)} = %0h", i, idx3, j, idx3, A_row[0], A_row[1], A_row[2], A_row[3], B_col[0], B_col[1], B_col[2], B_col[3], acc); end WAIT_OUT_COMPUTE: begin idx3 <= idx3 + 1; $display("WAIT_OUT_COMPUTE: idx3 = %d ", idx3 ); end WRITE_BACK: begin row_addr_out <= i + rows_start_add_writing; col_addr_out <= j + cols_start_add_writing; // Adjust output memory offset data_out <= acc[DATA_WIDTH-1:0]; write_enable_out <= 1; $display("WRITE_BACK: Writing Result[%0d][%0d] = %0h ; write_enable => %d ; acc =%0h ", i, j, acc[DATA_WIDTH-1:0], write_enable_out, acc); end LOAD_NEW_ROW_OR_COL: begin acc <= 0; if (mode == 2'b00) begin idx2 <= 0; idx <= 0; write_enable_out <= 0; if(j == cols_size_reading_B) begin idx <= 0;// j <= 0; i <= i + 1; end else begin j <= j + 1; end end else if (mode == 2'b01) begin idx2 <= 0; write_enable_out <= 0; if(j == cols_size_reading_B) begin idx <= 0;// j <= 0; i <= i + 1; end else begin j <= j + 1; end end else if(mode == 2'b10) begin idx <= 0; write_enable_out <= 0; if(i == rows_size_reading_A) begin idx2 <= 0;// i <= 0; j <= j + 1; end else begin i <= i + 1; end end $display("LOAD_NEW_ROW_OR_COL: Writing [i = %0d; j = %0d] (idx = %0d; idx2 =%0d)", i, j, idx, idx2); end DONE: begin done <= 1; $display("DONE: All matrix multiplication complete."); end default: begin write_enable_out <= 0; read_full_row_A <= 0; read_full_row_B <= 0; end endcase end end endmodule