265 lines
9.2 KiB
Verilog
265 lines
9.2 KiB
Verilog
`timescale 1ns / 1ps
|
|
|
|
module top_module_mask #(
|
|
parameter ROWS_READING = 16,
|
|
parameter COLS_READING = 12,
|
|
parameter ROWS_WRITING = 16,
|
|
parameter COLS_WRITING = 32,
|
|
parameter DATA_WIDTH = 16,
|
|
parameter THRESHOLD = 10,
|
|
parameter TOP_K = 4
|
|
)(
|
|
input clk,
|
|
input rst,
|
|
input enable_mask,
|
|
|
|
input read_valid,
|
|
|
|
input [$clog2(ROWS_READING)-1:0] rows_start_add_reading,
|
|
input [$clog2(COLS_READING)-1:0] cols_start_add_reading,
|
|
input [$clog2(ROWS_WRITING)-1:0] rows_start_add_writing,
|
|
input [$clog2(COLS_WRITING)-1:0] cols_start_add_writing,
|
|
|
|
input [$clog2(ROWS_READING)-1:0] rows_size_reading,
|
|
input [$clog2(COLS_READING)-1:0] cols_size_reading,
|
|
input [$clog2(ROWS_WRITING)-1:0] rows_size_writing,
|
|
input [$clog2(COLS_WRITING)-1:0] cols_size_writing,
|
|
|
|
input [DATA_WIDTH*((ROWS_READING>COLS_READING)?ROWS_READING:COLS_READING)-1:0] data_input_rows,
|
|
|
|
output reg [DATA_WIDTH - 1:0] data_output,
|
|
output reg write_enable,
|
|
|
|
output reg [$clog2(ROWS_READING)-1:0] row_addr_out_read,
|
|
output reg [$clog2(COLS_READING)-1:0] col_addr_out_read,
|
|
output reg [$clog2(ROWS_WRITING)-1:0] row_addr_out_write,
|
|
output reg [$clog2(COLS_WRITING)-1:0] col_addr_out_write,
|
|
|
|
output reg read_enable_full_row,
|
|
output reg valid_result
|
|
);
|
|
|
|
// FSM States
|
|
parameter IDLE = 0, READ_ROW = 1, WAIT_READ = 2, PRE_COMPUTE = 3,
|
|
COMPUTE_MASK = 4, WRITE_MASKA = 5,
|
|
WRITE_MASKQ = 6, WRITE_MASKKV = 7, DONE = 8;
|
|
|
|
reg [3:0] state, next_state;
|
|
|
|
reg [DATA_WIDTH-1:0] A[0:COLS_READING-1];
|
|
reg [COLS_READING-1:0]maskA;
|
|
reg maskQ;
|
|
reg [COLS_READING-1:0] maskKV;
|
|
|
|
reg [$clog2(ROWS_READING)-1:0] current_row;
|
|
reg [$clog2(ROWS_READING)-1:0] current_col;
|
|
|
|
reg [$clog2(COLS_READING)-1:0] maskA_count;
|
|
reg [DATA_WIDTH-1:0] masked_data;
|
|
reg [$clog2(COLS_READING)-1 :0] index; // For up to 16 bits
|
|
|
|
|
|
integer i;
|
|
integer remaining_bits;
|
|
|
|
///topk algo
|
|
reg [$clog2(COLS_READING)-1:0] topk_indices[0:TOP_K-1];
|
|
reg [DATA_WIDTH-1:0] topk_values[0:TOP_K-1];
|
|
reg found;
|
|
|
|
integer j,k;
|
|
|
|
// FSM Transition
|
|
always @(*) begin
|
|
next_state = state;
|
|
case (state)
|
|
IDLE: if (enable_mask) next_state = READ_ROW;
|
|
READ_ROW: next_state = WAIT_READ;
|
|
WAIT_READ: if (read_valid) next_state = PRE_COMPUTE;
|
|
PRE_COMPUTE: next_state = COMPUTE_MASK;
|
|
COMPUTE_MASK: next_state = WRITE_MASKA;
|
|
WRITE_MASKA: next_state = (maskA_count <= ((cols_size_reading+1)/DATA_WIDTH))? WRITE_MASKA:WRITE_MASKQ;
|
|
WRITE_MASKQ: begin
|
|
if (current_row == rows_size_reading)
|
|
next_state = WRITE_MASKKV;
|
|
else
|
|
next_state = READ_ROW;
|
|
end
|
|
// next_state = WRITE_MASKKV;
|
|
WRITE_MASKKV: begin
|
|
if (current_col == cols_size_reading)
|
|
next_state = DONE;
|
|
else
|
|
next_state = WRITE_MASKKV;
|
|
end
|
|
DONE: next_state = DONE;
|
|
endcase
|
|
end
|
|
|
|
// FSM Output Logic
|
|
always @(posedge clk or posedge rst) begin
|
|
if (rst) begin
|
|
state <= IDLE;
|
|
current_row <= 0;
|
|
maskA_count <= 0;
|
|
valid_result <= 0;
|
|
write_enable <= 0;
|
|
read_enable_full_row <= 0;
|
|
maskA <= 0;
|
|
maskKV <= 0;
|
|
data_output <= 0;
|
|
current_col <= 0;
|
|
|
|
end else begin
|
|
state <= next_state;
|
|
case (state)
|
|
IDLE: begin
|
|
current_row <= 0;
|
|
valid_result <= 0;
|
|
maskA_count <= 0;
|
|
maskA <= 0;
|
|
maskKV <= 0;
|
|
data_output <= 0;
|
|
current_col <= 0;
|
|
|
|
end
|
|
|
|
READ_ROW: begin
|
|
$display("[READ_ROW] current_row=%0d, row_addr_out_read=%0d, col_addr_out_read=%0d, time=%0t",
|
|
current_row, row_addr_out_read, col_addr_out_read, $time);
|
|
|
|
read_enable_full_row <= 1;
|
|
row_addr_out_read <= current_row + rows_start_add_reading;
|
|
col_addr_out_read <= cols_start_add_reading;
|
|
write_enable <= 0;
|
|
data_output <= 0;
|
|
|
|
end
|
|
|
|
WAIT_READ: begin
|
|
$display("[WAIT_READ] read_valid=%b, time=%0t", read_valid, $time);
|
|
if (read_valid) begin
|
|
for (i = 0; i <= cols_size_reading; i = i + 1)
|
|
A[i] <= data_input_rows[i*DATA_WIDTH +: DATA_WIDTH];
|
|
read_enable_full_row <= 0;
|
|
end
|
|
end
|
|
|
|
PRE_COMPUTE: begin
|
|
$display("[ PRE_COMPUTE-TOPK] Finding top %0d", TOP_K);
|
|
|
|
maskA_count = 0;
|
|
maskQ = 0;
|
|
|
|
// Initialize top-k values to 0
|
|
for (j = 0; j < TOP_K; j = j + 1) begin
|
|
topk_values[j] = 0;
|
|
topk_indices[j] = 0;
|
|
end
|
|
|
|
for (i = 0; i <= cols_size_reading; i = i + 1) begin
|
|
found = 0;
|
|
for (j = 0; j < TOP_K; j = j + 1) begin
|
|
if (!found && A[i] > topk_values[j]) begin
|
|
// Shift entries down
|
|
for (k = TOP_K-1; k > j; k = k - 1) begin
|
|
topk_values[k] = topk_values[k-1];
|
|
topk_indices[k] = topk_indices[k-1];
|
|
end
|
|
topk_values[j] = A[i];
|
|
topk_indices[j] = i;
|
|
found = 1; // simulates break
|
|
end
|
|
end
|
|
|
|
end
|
|
// Reset maskA
|
|
maskA = 0;
|
|
|
|
end
|
|
|
|
COMPUTE_MASK: begin
|
|
$display("[COMPUTE_MASK-TOPK] Finding top %0d", TOP_K);
|
|
// Initialize top-k values to 0
|
|
|
|
|
|
// Set top-k indices in maskA
|
|
for (j = 0; j < TOP_K; j = j + 1) begin
|
|
maskA[topk_indices[j]] = 1;
|
|
maskKV[topk_indices[j]] = 1;///////
|
|
$display("Top %0d: index=%0d, value=%0d", j, topk_indices[j], topk_values[j]);
|
|
end
|
|
|
|
for (i = 0; i <= cols_size_reading; i = i + 1) begin
|
|
// Compute maskQ and maskKV
|
|
maskQ = maskQ | maskA[i];
|
|
// if (maskA[i]) maskKV[i] = 1;
|
|
end
|
|
end
|
|
|
|
|
|
|
|
WRITE_MASKA: begin///reading is correct but writing in to the memory is wrong
|
|
|
|
$display("[WRITE_MASKA] row=%0d, col=%0d, maskA[%0d]=%b, data_output=%b, time=%0t",
|
|
row_addr_out_write, col_addr_out_write, maskA_count, maskA,
|
|
{{(DATA_WIDTH-1){1'b0}}, maskA[maskA_count]}, $time);
|
|
|
|
|
|
|
|
row_addr_out_write <= current_row + rows_start_add_writing;
|
|
col_addr_out_write <= maskA_count + cols_start_add_writing;
|
|
index = maskA_count * DATA_WIDTH;
|
|
|
|
if ((index + DATA_WIDTH) <= COLS_READING) begin
|
|
masked_data = maskA[index +: DATA_WIDTH];
|
|
end else begin
|
|
masked_data = {DATA_WIDTH{1'b0}};
|
|
for (i = 0; i < DATA_WIDTH; i = i + 1) begin
|
|
if ((index + i) < COLS_READING)
|
|
masked_data[i] = maskA[index + i];
|
|
end
|
|
end
|
|
// end
|
|
write_enable <= 1;
|
|
data_output <= masked_data;
|
|
maskA_count <= maskA_count + 1;
|
|
|
|
end
|
|
|
|
WRITE_MASKQ: begin
|
|
$display("[WRITE_MASKQ] maskQ=%b, data_output=%b, col_addr_out_write = %0d(%0d), time=%0t",
|
|
maskQ, {{(DATA_WIDTH-1){1'b0}}, maskQ}, col_addr_out_write, ((cols_size_reading+1)/DATA_WIDTH), $time);
|
|
|
|
row_addr_out_write <= current_row + rows_start_add_writing;
|
|
col_addr_out_write <= cols_start_add_writing + ((cols_size_reading+1)/DATA_WIDTH)+1; // offset after A
|
|
data_output <= {{(DATA_WIDTH-1){1'b0}}, maskQ};
|
|
write_enable <= 1;
|
|
current_row <= current_row + 1;
|
|
|
|
end
|
|
|
|
WRITE_MASKKV: begin
|
|
$display("[WRITE_MASKKV] maskKV(%0d)=%b, data_output=%b, time=%0t",
|
|
current_col, maskKV, {{(DATA_WIDTH-1){1'b0}}, maskKV[current_col]}, $time);
|
|
|
|
row_addr_out_write <= current_col + rows_start_add_writing;
|
|
col_addr_out_write <= cols_start_add_writing + ((cols_size_reading+1)/DATA_WIDTH)+2; // offset after A and Q
|
|
data_output <= {{(DATA_WIDTH-1){1'b0}}, maskKV[current_col]}; ///this logic and size has to be adjusted
|
|
write_enable <= 1;
|
|
current_col <= current_col + 1;
|
|
|
|
end
|
|
|
|
DONE: begin
|
|
$display("[DONE] current_row=%0d, result valid, time=%0t", current_row, $time);
|
|
|
|
valid_result <= 1;
|
|
write_enable <= 0;
|
|
read_enable_full_row <= 0;
|
|
end
|
|
endcase
|
|
end
|
|
end
|
|
|
|
endmodule
|