AYAKA_Transformer/rtl/mask_A_Q_KV.v

265 lines
9.2 KiB
Verilog

`timescale 1ns / 1ps
module top_module_mask #(
parameter ROWS_READING = 16,
parameter COLS_READING = 12,
parameter ROWS_WRITING = 16,
parameter COLS_WRITING = 32,
parameter DATA_WIDTH = 16,
parameter THRESHOLD = 10,
parameter TOP_K = 4
)(
input clk,
input rst,
input enable_mask,
input read_valid,
input [$clog2(ROWS_READING)-1:0] rows_start_add_reading,
input [$clog2(COLS_READING)-1:0] cols_start_add_reading,
input [$clog2(ROWS_WRITING)-1:0] rows_start_add_writing,
input [$clog2(COLS_WRITING)-1:0] cols_start_add_writing,
input [$clog2(ROWS_READING)-1:0] rows_size_reading,
input [$clog2(COLS_READING)-1:0] cols_size_reading,
input [$clog2(ROWS_WRITING)-1:0] rows_size_writing,
input [$clog2(COLS_WRITING)-1:0] cols_size_writing,
input [DATA_WIDTH*((ROWS_READING>COLS_READING)?ROWS_READING:COLS_READING)-1:0] data_input_rows,
output reg [DATA_WIDTH - 1:0] data_output,
output reg write_enable,
output reg [$clog2(ROWS_READING)-1:0] row_addr_out_read,
output reg [$clog2(COLS_READING)-1:0] col_addr_out_read,
output reg [$clog2(ROWS_WRITING)-1:0] row_addr_out_write,
output reg [$clog2(COLS_WRITING)-1:0] col_addr_out_write,
output reg read_enable_full_row,
output reg valid_result
);
// FSM States
parameter IDLE = 0, READ_ROW = 1, WAIT_READ = 2, PRE_COMPUTE = 3,
COMPUTE_MASK = 4, WRITE_MASKA = 5,
WRITE_MASKQ = 6, WRITE_MASKKV = 7, DONE = 8;
reg [3:0] state, next_state;
reg [DATA_WIDTH-1:0] A[0:COLS_READING-1];
reg [COLS_READING-1:0]maskA;
reg maskQ;
reg [COLS_READING-1:0] maskKV;
reg [$clog2(ROWS_READING)-1:0] current_row;
reg [$clog2(ROWS_READING)-1:0] current_col;
reg [$clog2(COLS_READING)-1:0] maskA_count;
reg [DATA_WIDTH-1:0] masked_data;
reg [$clog2(COLS_READING)-1 :0] index; // For up to 16 bits
integer i;
integer remaining_bits;
///topk algo
reg [$clog2(COLS_READING)-1:0] topk_indices[0:TOP_K-1];
reg [DATA_WIDTH-1:0] topk_values[0:TOP_K-1];
reg found;
integer j,k;
// FSM Transition
always @(*) begin
next_state = state;
case (state)
IDLE: if (enable_mask) next_state = READ_ROW;
READ_ROW: next_state = WAIT_READ;
WAIT_READ: if (read_valid) next_state = PRE_COMPUTE;
PRE_COMPUTE: next_state = COMPUTE_MASK;
COMPUTE_MASK: next_state = WRITE_MASKA;
WRITE_MASKA: next_state = (maskA_count <= ((cols_size_reading+1)/DATA_WIDTH))? WRITE_MASKA:WRITE_MASKQ;
WRITE_MASKQ: begin
if (current_row == rows_size_reading)
next_state = WRITE_MASKKV;
else
next_state = READ_ROW;
end
// next_state = WRITE_MASKKV;
WRITE_MASKKV: begin
if (current_col == cols_size_reading)
next_state = DONE;
else
next_state = WRITE_MASKKV;
end
DONE: next_state = DONE;
endcase
end
// FSM Output Logic
always @(posedge clk or posedge rst) begin
if (rst) begin
state <= IDLE;
current_row <= 0;
maskA_count <= 0;
valid_result <= 0;
write_enable <= 0;
read_enable_full_row <= 0;
maskA <= 0;
maskKV <= 0;
data_output <= 0;
current_col <= 0;
end else begin
state <= next_state;
case (state)
IDLE: begin
current_row <= 0;
valid_result <= 0;
maskA_count <= 0;
maskA <= 0;
maskKV <= 0;
data_output <= 0;
current_col <= 0;
end
READ_ROW: begin
$display("[READ_ROW] current_row=%0d, row_addr_out_read=%0d, col_addr_out_read=%0d, time=%0t",
current_row, row_addr_out_read, col_addr_out_read, $time);
read_enable_full_row <= 1;
row_addr_out_read <= current_row + rows_start_add_reading;
col_addr_out_read <= cols_start_add_reading;
write_enable <= 0;
data_output <= 0;
end
WAIT_READ: begin
$display("[WAIT_READ] read_valid=%b, time=%0t", read_valid, $time);
if (read_valid) begin
for (i = 0; i <= cols_size_reading; i = i + 1)
A[i] <= data_input_rows[i*DATA_WIDTH +: DATA_WIDTH];
read_enable_full_row <= 0;
end
end
PRE_COMPUTE: begin
$display("[ PRE_COMPUTE-TOPK] Finding top %0d", TOP_K);
maskA_count = 0;
maskQ = 0;
// Initialize top-k values to 0
for (j = 0; j < TOP_K; j = j + 1) begin
topk_values[j] = 0;
topk_indices[j] = 0;
end
for (i = 0; i <= cols_size_reading; i = i + 1) begin
found = 0;
for (j = 0; j < TOP_K; j = j + 1) begin
if (!found && A[i] > topk_values[j]) begin
// Shift entries down
for (k = TOP_K-1; k > j; k = k - 1) begin
topk_values[k] = topk_values[k-1];
topk_indices[k] = topk_indices[k-1];
end
topk_values[j] = A[i];
topk_indices[j] = i;
found = 1; // simulates break
end
end
end
// Reset maskA
maskA = 0;
end
COMPUTE_MASK: begin
$display("[COMPUTE_MASK-TOPK] Finding top %0d", TOP_K);
// Initialize top-k values to 0
// Set top-k indices in maskA
for (j = 0; j < TOP_K; j = j + 1) begin
maskA[topk_indices[j]] = 1;
maskKV[topk_indices[j]] = 1;///////
$display("Top %0d: index=%0d, value=%0d", j, topk_indices[j], topk_values[j]);
end
for (i = 0; i <= cols_size_reading; i = i + 1) begin
// Compute maskQ and maskKV
maskQ = maskQ | maskA[i];
// if (maskA[i]) maskKV[i] = 1;
end
end
WRITE_MASKA: begin///reading is correct but writing in to the memory is wrong
$display("[WRITE_MASKA] row=%0d, col=%0d, maskA[%0d]=%b, data_output=%b, time=%0t",
row_addr_out_write, col_addr_out_write, maskA_count, maskA,
{{(DATA_WIDTH-1){1'b0}}, maskA[maskA_count]}, $time);
row_addr_out_write <= current_row + rows_start_add_writing;
col_addr_out_write <= maskA_count + cols_start_add_writing;
index = maskA_count * DATA_WIDTH;
if ((index + DATA_WIDTH) <= COLS_READING) begin
masked_data = maskA[index +: DATA_WIDTH];
end else begin
masked_data = {DATA_WIDTH{1'b0}};
for (i = 0; i < DATA_WIDTH; i = i + 1) begin
if ((index + i) < COLS_READING)
masked_data[i] = maskA[index + i];
end
end
// end
write_enable <= 1;
data_output <= masked_data;
maskA_count <= maskA_count + 1;
end
WRITE_MASKQ: begin
$display("[WRITE_MASKQ] maskQ=%b, data_output=%b, col_addr_out_write = %0d(%0d), time=%0t",
maskQ, {{(DATA_WIDTH-1){1'b0}}, maskQ}, col_addr_out_write, ((cols_size_reading+1)/DATA_WIDTH), $time);
row_addr_out_write <= current_row + rows_start_add_writing;
col_addr_out_write <= cols_start_add_writing + ((cols_size_reading+1)/DATA_WIDTH)+1; // offset after A
data_output <= {{(DATA_WIDTH-1){1'b0}}, maskQ};
write_enable <= 1;
current_row <= current_row + 1;
end
WRITE_MASKKV: begin
$display("[WRITE_MASKKV] maskKV(%0d)=%b, data_output=%b, time=%0t",
current_col, maskKV, {{(DATA_WIDTH-1){1'b0}}, maskKV[current_col]}, $time);
row_addr_out_write <= current_col + rows_start_add_writing;
col_addr_out_write <= cols_start_add_writing + ((cols_size_reading+1)/DATA_WIDTH)+2; // offset after A and Q
data_output <= {{(DATA_WIDTH-1){1'b0}}, maskKV[current_col]}; ///this logic and size has to be adjusted
write_enable <= 1;
current_col <= current_col + 1;
end
DONE: begin
$display("[DONE] current_row=%0d, result valid, time=%0t", current_row, $time);
valid_result <= 1;
write_enable <= 0;
read_enable_full_row <= 0;
end
endcase
end
end
endmodule