AYAKA_Transformer/rtl/RPAS_unit.v

233 lines
7.5 KiB
Verilog

module top_module_rpas #(
parameter ROWS_READING = 16,//16
parameter COLS_READING = 12,//12
parameter ROWS_WRITING = 16,//16
parameter COLS_WRITING = 32,//32
parameter DATA_WIDTH = 16,
parameter MIN_THRESHOLD = 3,
parameter integer CTRL_WIDTH = 2 // {-1, 0, 1} as 2's complement
)(
input clk,
input rst,
input enable_rpas,
input read_valid,
input [$clog2(ROWS_READING)-1:0] rows_start_add_reading, // offset for reading from memory
input [$clog2(COLS_READING)-1:0] cols_start_add_reading,
input [$clog2(ROWS_WRITING)-1:0] rows_start_add_writing, // offset for writing to memory
input [$clog2(COLS_WRITING)-1:0] cols_start_add_writing,
input [$clog2(ROWS_READING)-1:0] rows_size_reading, // Size for reading from memory
input [$clog2(COLS_READING)-1:0] cols_size_reading,
input [$clog2(ROWS_WRITING)-1:0] rows_size_writing, // Size for writing to memory
input [$clog2(COLS_WRITING)-1:0] cols_size_writing, // we don't really care about it because it's derived from the input
input [DATA_WIDTH*((ROWS_READING>COLS_READING)?ROWS_READING-1:COLS_READING-1):0] data_input_rows,
output reg [DATA_WIDTH - 1:0] data_output,
output reg write_enable,
output reg [$clog2(ROWS_READING)-1:0] row_addr_out_read, //manipulate between reading and writing
output reg [$clog2(COLS_READING)-1:0] col_addr_out_read,
output reg [$clog2(ROWS_WRITING)-1:0] row_addr_out_write, //manipulate between reading and writing
output reg [$clog2(COLS_WRITING)-1:0] col_addr_out_write,
output reg read_enable_full_row,
output reg valid_result
);
// === FSM States ===
parameter IDLE = 0, GEN_VECTOR = 1, WAIT_VECTOR_VALID = 2,
READ_ROW_START = 3, READ_ROW_WAIT = 4, ACCUMULATE = 5,
NEXT_VECTOR = 6, NEXT_ROW = 7, DONE = 8;
reg [3:0] state, next_state;
reg [$clog2(ROWS_READING)-1:0] current_row;
reg [$clog2(COLS_WRITING)-1:0] current_vec;
reg [$clog2(COLS_READING)-1:0] current_col;
wire signed [DATA_WIDTH-1:0] sum_out;
reg valid_in_acc, valid_in_vector_gen;
wire valid_out_vector_gen, valid_out_acc;
wire [CTRL_WIDTH*((ROWS_READING>COLS_READING)?ROWS_READING-1:COLS_READING-1):0] rand_vec_flat;
integer i = 0;
// Instantiate random vector generator
random_vector_generator #(
.S(2),
.NO_ROWS(ROWS_READING),
.NO_COLS(COLS_READING),
.DATA_WIDTH(2)
) rand_gen_inst (
.clk(clk),
.rst(rst),
.enable(valid_in_vector_gen),
.cols_size_reading(cols_size_reading),
.rand_vector_flat(rand_vec_flat),
.valid(valid_out_vector_gen)
);
// Instantiate accumulator
sparse_vector_accumulator #(
.NO_ROWS(ROWS_READING),
.NO_COLS(COLS_READING),
.DATA_WIDTH(16),
.CTRL_WIDTH(2)
) adder_inst (
.clk(clk),
.rst(rst),
.valid_in(valid_in_acc),
.data_row_flat(data_input_rows),
.cols_size_reading(cols_size_reading),
.rand_vector_flat(rand_vec_flat),
.sum_out(sum_out),
.valid_out(valid_out_acc)
);
// === FSM State Transition ===
always @(*) begin
next_state = state;
case (state)
IDLE: begin
if (enable_rpas) next_state = GEN_VECTOR;
end
GEN_VECTOR: next_state = WAIT_VECTOR_VALID;
WAIT_VECTOR_VALID: begin
if (valid_out_vector_gen) next_state = READ_ROW_START;
end
READ_ROW_START: next_state = READ_ROW_WAIT;
READ_ROW_WAIT: begin
if (read_valid) begin
next_state = ACCUMULATE;
end
end
ACCUMULATE: begin
if (valid_out_acc) begin
if (current_row == rows_size_reading) //no_rows
next_state = NEXT_VECTOR;
else
next_state = NEXT_ROW;
end
end
NEXT_ROW: next_state = READ_ROW_START;
NEXT_VECTOR: begin
if (current_vec < ((cols_size_reading >> 2) > MIN_THRESHOLD ? (cols_size_reading >> 2) : MIN_THRESHOLD))
next_state = GEN_VECTOR;
else
next_state = DONE;
end
DONE: next_state = DONE;
endcase
end
// === FSM Outputs and Registers ===
always @(posedge clk or posedge rst) begin
if (rst) begin
state <= IDLE;
current_row <= 0;
current_col <= 0;
current_vec <= 0;
valid_result <= 0;
read_enable_full_row <= 0;
valid_in_vector_gen <= 0;
valid_in_acc <= 0;
row_addr_out_read <= 0;
col_addr_out_read <= 0;
end else begin
state <= next_state;
case (state)
IDLE: begin
current_row <= 0;
current_vec <= 0;
current_col <= 0;
valid_result <= 0;
end
GEN_VECTOR: begin
valid_in_vector_gen <= 1;
end
WAIT_VECTOR_VALID: begin
if (valid_out_vector_gen)
valid_in_vector_gen <= 0;
end
READ_ROW_START: begin
read_enable_full_row <= 1;
row_addr_out_read <= current_row + rows_start_add_reading;
col_addr_out_read <= current_col + cols_start_add_reading;
end
READ_ROW_WAIT: begin
if (read_valid) begin
////$display("Writing data_input = %0d (0x%0h) to data_input_rows at column = %0d, bit location = [%0d : %0d]",
end
end
ACCUMULATE: begin
valid_in_acc <= 1;
if (valid_out_acc) begin
valid_in_acc <= 0;
$display("==================START====================");
for (i = 0; i <= cols_size_reading; i = i + 1) begin
$display("col[%0d]: %0d | Vector[%0d]: %0d", i,
data_input_rows[i*DATA_WIDTH +: DATA_WIDTH],
i,
rand_vec_flat[i*2 +: 2]);
end
$display("---- Result (Row %0d, column %0d) ----", current_row, current_vec);
$display("Sum out: %0X", sum_out);
$display("===================END=====================");
row_addr_out_write <= current_row + rows_start_add_writing;
col_addr_out_write <= current_vec + cols_start_add_writing;
write_enable <= 1;
data_output <= sum_out[DATA_WIDTH - 1:0];
end
read_enable_full_row <= 0;/////
end
NEXT_ROW: begin
current_row <= current_row + 1;
valid_result <= 0;
write_enable <= 0;
end
NEXT_VECTOR: begin
if (current_vec < ((cols_size_reading >> 2) > MIN_THRESHOLD ? (cols_size_reading >> 2) : MIN_THRESHOLD)) begin
current_vec <= current_vec + 1;
current_row <= 0;
current_col <= 0;
valid_result <= 0;
write_enable <= 0;
end else begin
write_enable <= 0;
valid_result <= 1;
end
end
DONE: begin
write_enable <= 0;
valid_result <= 1;
end
endcase
end
end
endmodule