module matrix_multiplication_unit_new #( parameter DATA_WIDTH = 16, parameter MEM_ROWS = 20,//20 ->5bits //16 parameter MEM_COLS = 80,//80 ->7bits //32SS parameter PE_ROWS = 16, parameter PE_COLS = 32, parameter COMMON_ROW_COL = 4, parameter OUTPUT_COL = 5, parameter OUTPUT_ROW = 4 )( input clk, input rst, input enable, input [1:0] mode, // 00: Output-Stationary, 01: Input-Stationary, 10: Weight-Stationary //input initialization; input [DATA_WIDTH-1:0] data_input_A, input [DATA_WIDTH-1:0] data_input_B, input [DATA_WIDTH*((MEM_ROWS>MEM_COLS)?MEM_ROWS-1:MEM_COLS-1):0] full_row_A, input [DATA_WIDTH*((MEM_ROWS>MEM_COLS)?MEM_ROWS-1:MEM_COLS-1):0] full_row_B, input valid_mem_input_A, input valid_mem_input_B, input [$clog2(MEM_ROWS)-1:0] rows_start_add_reading_A, input [$clog2(MEM_COLS)-1:0] cols_start_add_reading_A, input [$clog2(MEM_ROWS)-1:0] rows_start_add_reading_B, input [$clog2(MEM_COLS)-1:0] cols_start_add_reading_B, input [$clog2(MEM_ROWS)-1:0] rows_start_add_writing, input [$clog2(MEM_COLS)-1:0] cols_start_add_writing, input [$clog2(MEM_ROWS)-1:0] rows_size_reading_A, input [$clog2(MEM_COLS)-1:0] cols_size_reading_A, input [$clog2(MEM_ROWS)-1:0] rows_size_reading_B, input [$clog2(MEM_COLS)-1:0] cols_size_reading_B, output reg done, output reg read_full_row_A, output reg read_full_row_B, output reg write_full_row_out, output reg [$clog2(MEM_ROWS)-1:0] row_addr_A, output reg [$clog2(MEM_COLS)-1:0] col_addr_A, output reg [$clog2(MEM_ROWS)-1:0] row_addr_B, output reg [$clog2(MEM_COLS)-1:0] col_addr_B, output reg [$clog2(MEM_ROWS)-1:0] row_addr_out, output reg [$clog2(MEM_COLS)-1:0] col_addr_out, output reg read_enable_A, output reg read_enable_B, output reg write_enable_out, output reg [DATA_WIDTH*((MEM_ROWS>MEM_COLS)?MEM_ROWS-1:MEM_COLS-1):0] Full_row_out, output reg [DATA_WIDTH-1:0] data_out ); // FSM State Encoding parameter IDLE = 0, LOAD_ROW_A = 1, WAIT_VALID_ROW_A = 2, LOAD_COL_B = 3, WAIT_VALID_COL_B = 4, COMPUTE = 5, WRITE = 6, DONE = 7, BEFORE_COMPUTE = 8, LOAD_ROW_A_COL_B = 9, WAIT_VALID_ROW_A_COL_B = 10; reg [3:0] current_state, next_state; // PE Array interface reg enable_pe_array; wire [OUTPUT_ROW*OUTPUT_COL-1:0] valid_pe_array; reg [DATA_WIDTH*OUTPUT_COL-1:0] north_inputs; reg [DATA_WIDTH*PE_ROWS-1:0] west_inputs; wire [OUTPUT_ROW*DATA_WIDTH-1:0] acc_outputs; reg [9:0] compute_counter; reg [9:0] write_counter, read_counter; wire [2*DATA_WIDTH-1:0] selected_accum_value; reg initialization; reg output_enable; wire acc_output_valid_pe; // Instantiate PE Array pe_array #( .DATA_WIDTH(DATA_WIDTH), .MEM_ROWS(MEM_ROWS),//20 ->5bits //16 .MEM_COLS(MEM_COLS),//80 ->7bits //32SS .PE_ROWS(PE_ROWS), .PE_COLS(PE_COLS), .COMMON_ROW_COL(COMMON_ROW_COL), .OUTPUT_COL(OUTPUT_COL), .OUTPUT_ROW(OUTPUT_ROW) ) dut_pe_array ( .clk(clk), .rst(rst), .mode(mode), .initialization(initialization), .north_inputs(north_inputs), .west_inputs(west_inputs), .output_enable(output_enable), .enable(enable_pe_array), .valid(valid_pe_array), .rows_size_PE(rows_size_reading_A), .cols_size_PE(cols_size_reading_B), .acc_outputs(acc_outputs), .acc_output_valid(acc_output_valid_pe) ); // FSM Sequential always @(posedge clk or posedge rst) begin if (rst) current_state <= IDLE; else current_state <= next_state; end // FSM Next-State Logic always @(*) begin case (current_state) IDLE: next_state = enable ? LOAD_ROW_A : IDLE; LOAD_ROW_A: next_state = WAIT_VALID_ROW_A; WAIT_VALID_ROW_A: next_state = valid_mem_input_A ? LOAD_COL_B : WAIT_VALID_ROW_A; LOAD_COL_B: next_state = WAIT_VALID_COL_B; WAIT_VALID_COL_B: next_state = valid_mem_input_B ? BEFORE_COMPUTE : WAIT_VALID_COL_B;//COMPUTE BEFORE_COMPUTE : next_state = COMPUTE; COMPUTE: next_state = (read_counter <= cols_size_reading_A) ? LOAD_ROW_A :(acc_output_valid_pe) ? WRITE : COMPUTE; // PE_ROWS+PE_COLS+4 WRITE: next_state = (write_counter >= ((cols_size_reading_B>rows_size_reading_A)?rows_size_reading_A:cols_size_reading_B)) ? DONE : WRITE; DONE: next_state = DONE; default: next_state = IDLE; endcase end // FSM Outputs always @(posedge clk or posedge rst) begin if (rst) begin read_full_row_A <= 0; read_full_row_B <= 0; enable_pe_array <= 0; compute_counter <= 0; write_counter <= 0; done <= 0; write_enable_out <= 0; read_counter <= 0; initialization <= 0; output_enable <= 0; write_full_row_out <= 0; Full_row_out <= 0; end else begin case (current_state) IDLE: begin read_full_row_A <= 0; read_full_row_B <= 0; enable_pe_array <= 0; compute_counter <= 0; write_counter <= 0; done <= 0; write_enable_out <= 0; read_counter <= 0; write_full_row_out <= 0; Full_row_out <= 0; $display("[IDLE] Waiting for enable..."); end LOAD_ROW_A: begin row_addr_A <= rows_start_add_reading_A; col_addr_A <= cols_start_add_reading_A + read_counter;// read_full_row_A <= 1; enable_pe_array <= 0; $display("[LOAD_ROW_A] Reading full row A."); end WAIT_VALID_ROW_A: begin if (valid_mem_input_A) begin read_full_row_A <= 0;//// $display("[WAIT_VALID_ROW_A[x][%0d]] Row A received(read = %0h) = %0h", col_addr_A, read_full_row_A, full_row_A); end end LOAD_COL_B: begin row_addr_B <= rows_start_add_reading_B; col_addr_B <= cols_start_add_reading_B + read_counter;// read_full_row_B <= 1; $display("[LOAD_COL_B] Reading full row B (transpose col)."); end WAIT_VALID_COL_B: begin if (valid_mem_input_B) begin read_full_row_B <= 0;////// $display("[WAIT_VALID_COL_B] Row B received(read = %0h) = %0h", read_full_row_B, full_row_B); read_counter <= read_counter + 1; end end BEFORE_COMPUTE: begin north_inputs <= full_row_B[OUTPUT_COL*DATA_WIDTH-1:0]; west_inputs <= full_row_A; $display("[BEFORE_COMPUTE]north=%0h, west=%0h ", full_row_B[OUTPUT_COL*DATA_WIDTH-1:0], full_row_A ); end COMPUTE: begin enable_pe_array <= 1; compute_counter <= compute_counter + 1; $display("[COMPUTE] Cycle %0d / %0d", compute_counter, rows_size_reading_A+cols_size_reading_B+cols_size_reading_A+4); // $display("[COMPUTE] Cycle %0d / %0d", compute_counter, PE_ROWS+OUTPUT_COL-1); if(compute_counter >= cols_size_reading_A+1) begin output_enable <= 1; $display("output_enable = %h", output_enable); end end WRITE: begin enable_pe_array <= 1; if((rows_size_reading_A >= cols_size_reading_B)) begin row_addr_out <= rows_start_add_writing; col_addr_out <= (write_counter ) + cols_start_add_writing;///// $display("[WRITE] Writing output[%0d][%0d] = %0h | Valid = %b", (write_counter), 0, acc_outputs, valid_pe_array[write_counter]); end else begin row_addr_out <= (write_counter ) + rows_start_add_writing; //original col_addr_out <= cols_start_add_writing;///// $display("[WRITE] Writing output[%0d][%0d] = %0h | Valid = %b", 0, (write_counter), acc_outputs, valid_pe_array[write_counter]); end Full_row_out <= acc_outputs;//// write_full_row_out <= 1; write_counter <= write_counter + 1; end DONE: begin done <= 1; enable_pe_array <= 0; compute_counter <= 0; write_full_row_out <= 0; $display("[DONE] Matrix multiplication completed."); end default: begin enable_pe_array <= 0; write_enable_out <= 0; end endcase end end endmodule