`timescale 1ns / 1ps module top_module_mask #( parameter ROWS_READING = 16, parameter COLS_READING = 12, parameter ROWS_WRITING = 16, parameter COLS_WRITING = 32, parameter DATA_WIDTH = 16, parameter THRESHOLD = 10, parameter TOP_K = 4 )( input clk, input rst, input enable_mask, input read_valid, input [$clog2(ROWS_READING)-1:0] rows_start_add_reading, input [$clog2(COLS_READING)-1:0] cols_start_add_reading, input [$clog2(ROWS_WRITING)-1:0] rows_start_add_writing, input [$clog2(COLS_WRITING)-1:0] cols_start_add_writing, input [$clog2(ROWS_READING)-1:0] rows_size_reading, input [$clog2(COLS_READING)-1:0] cols_size_reading, input [$clog2(ROWS_WRITING)-1:0] rows_size_writing, input [$clog2(COLS_WRITING)-1:0] cols_size_writing, input [DATA_WIDTH*((ROWS_READING>COLS_READING)?ROWS_READING:COLS_READING)-1:0] data_input_rows, output reg [DATA_WIDTH - 1:0] data_output, output reg write_enable, output reg [$clog2(ROWS_READING)-1:0] row_addr_out_read, output reg [$clog2(COLS_READING)-1:0] col_addr_out_read, output reg [$clog2(ROWS_WRITING)-1:0] row_addr_out_write, output reg [$clog2(COLS_WRITING)-1:0] col_addr_out_write, output reg read_enable_full_row, output reg valid_result ); // FSM States parameter IDLE = 0, READ_ROW = 1, WAIT_READ = 2, PRE_COMPUTE = 3, COMPUTE_MASK = 4, WRITE_MASKA = 5, WRITE_MASKQ = 6, WRITE_MASKKV = 7, DONE = 8; reg [3:0] state, next_state; reg [DATA_WIDTH-1:0] A[0:COLS_READING-1]; reg [COLS_READING-1:0]maskA; reg maskQ; reg [COLS_READING-1:0] maskKV; reg [$clog2(ROWS_READING)-1:0] current_row; reg [$clog2(ROWS_READING)-1:0] current_col; reg [$clog2(COLS_READING)-1:0] maskA_count; reg [DATA_WIDTH-1:0] masked_data; reg [$clog2(COLS_READING)-1 :0] index; // For up to 16 bits integer i; integer remaining_bits; ///topk algo reg [$clog2(COLS_READING)-1:0] topk_indices[0:TOP_K-1]; reg [DATA_WIDTH-1:0] topk_values[0:TOP_K-1]; reg found; integer j,k; // FSM Transition always @(*) begin next_state = state; case (state) IDLE: if (enable_mask) next_state = READ_ROW; READ_ROW: next_state = WAIT_READ; WAIT_READ: if (read_valid) next_state = PRE_COMPUTE; PRE_COMPUTE: next_state = COMPUTE_MASK; COMPUTE_MASK: next_state = WRITE_MASKA; WRITE_MASKA: next_state = (maskA_count <= ((cols_size_reading+1)/DATA_WIDTH))? WRITE_MASKA:WRITE_MASKQ; WRITE_MASKQ: begin if (current_row == rows_size_reading) next_state = WRITE_MASKKV; else next_state = READ_ROW; end // next_state = WRITE_MASKKV; WRITE_MASKKV: begin if (current_col == cols_size_reading) next_state = DONE; else next_state = WRITE_MASKKV; end DONE: next_state = DONE; endcase end // FSM Output Logic always @(posedge clk or posedge rst) begin if (rst) begin state <= IDLE; current_row <= 0; maskA_count <= 0; valid_result <= 0; write_enable <= 0; read_enable_full_row <= 0; maskA <= 0; maskKV <= 0; data_output <= 0; current_col <= 0; end else begin state <= next_state; case (state) IDLE: begin current_row <= 0; valid_result <= 0; maskA_count <= 0; maskA <= 0; maskKV <= 0; data_output <= 0; current_col <= 0; end READ_ROW: begin $display("[READ_ROW] current_row=%0d, row_addr_out_read=%0d, col_addr_out_read=%0d, time=%0t", current_row, row_addr_out_read, col_addr_out_read, $time); read_enable_full_row <= 1; row_addr_out_read <= current_row + rows_start_add_reading; col_addr_out_read <= cols_start_add_reading; write_enable <= 0; data_output <= 0; end WAIT_READ: begin $display("[WAIT_READ] read_valid=%b, time=%0t", read_valid, $time); if (read_valid) begin for (i = 0; i <= cols_size_reading; i = i + 1) A[i] <= data_input_rows[i*DATA_WIDTH +: DATA_WIDTH]; read_enable_full_row <= 0; end end PRE_COMPUTE: begin $display("[ PRE_COMPUTE-TOPK] Finding top %0d", TOP_K); maskA_count = 0; maskQ = 0; // Initialize top-k values to 0 for (j = 0; j < TOP_K; j = j + 1) begin topk_values[j] = 0; topk_indices[j] = 0; end for (i = 0; i <= cols_size_reading; i = i + 1) begin found = 0; for (j = 0; j < TOP_K; j = j + 1) begin if (!found && A[i] > topk_values[j]) begin // Shift entries down for (k = TOP_K-1; k > j; k = k - 1) begin topk_values[k] = topk_values[k-1]; topk_indices[k] = topk_indices[k-1]; end topk_values[j] = A[i]; topk_indices[j] = i; found = 1; // simulates break end end end // Reset maskA maskA = 0; end COMPUTE_MASK: begin $display("[COMPUTE_MASK-TOPK] Finding top %0d", TOP_K); // Initialize top-k values to 0 // Set top-k indices in maskA for (j = 0; j < TOP_K; j = j + 1) begin maskA[topk_indices[j]] = 1; maskKV[topk_indices[j]] = 1;/////// $display("Top %0d: index=%0d, value=%0d", j, topk_indices[j], topk_values[j]); end for (i = 0; i <= cols_size_reading; i = i + 1) begin // Compute maskQ and maskKV maskQ = maskQ | maskA[i]; // if (maskA[i]) maskKV[i] = 1; end end WRITE_MASKA: begin///reading is correct but writing in to the memory is wrong $display("[WRITE_MASKA] row=%0d, col=%0d, maskA[%0d]=%b, data_output=%b, time=%0t", row_addr_out_write, col_addr_out_write, maskA_count, maskA, {{(DATA_WIDTH-1){1'b0}}, maskA[maskA_count]}, $time); row_addr_out_write <= current_row + rows_start_add_writing; col_addr_out_write <= maskA_count + cols_start_add_writing; index = maskA_count * DATA_WIDTH; if ((index + DATA_WIDTH) <= COLS_READING) begin masked_data = maskA[index +: DATA_WIDTH]; end else begin masked_data = {DATA_WIDTH{1'b0}}; for (i = 0; i < DATA_WIDTH; i = i + 1) begin if ((index + i) < COLS_READING) masked_data[i] = maskA[index + i]; end end // end write_enable <= 1; data_output <= masked_data; maskA_count <= maskA_count + 1; end WRITE_MASKQ: begin $display("[WRITE_MASKQ] maskQ=%b, data_output=%b, col_addr_out_write = %0d(%0d), time=%0t", maskQ, {{(DATA_WIDTH-1){1'b0}}, maskQ}, col_addr_out_write, ((cols_size_reading+1)/DATA_WIDTH), $time); row_addr_out_write <= current_row + rows_start_add_writing; col_addr_out_write <= cols_start_add_writing + ((cols_size_reading+1)/DATA_WIDTH)+1; // offset after A data_output <= {{(DATA_WIDTH-1){1'b0}}, maskQ}; write_enable <= 1; current_row <= current_row + 1; end WRITE_MASKKV: begin $display("[WRITE_MASKKV] maskKV(%0d)=%b, data_output=%b, time=%0t", current_col, maskKV, {{(DATA_WIDTH-1){1'b0}}, maskKV[current_col]}, $time); row_addr_out_write <= current_col + rows_start_add_writing; col_addr_out_write <= cols_start_add_writing + ((cols_size_reading+1)/DATA_WIDTH)+2; // offset after A and Q data_output <= {{(DATA_WIDTH-1){1'b0}}, maskKV[current_col]}; ///this logic and size has to be adjusted write_enable <= 1; current_col <= current_col + 1; end DONE: begin $display("[DONE] current_row=%0d, result valid, time=%0t", current_row, $time); valid_result <= 1; write_enable <= 0; read_enable_full_row <= 0; end endcase end end endmodule