AYAKA_Transformer/rtl/hdpe_unit_new.v

256 lines
9 KiB
Verilog

module matrix_multiplication_unit_new #(
parameter DATA_WIDTH = 16,
parameter MEM_ROWS = 20,//20 ->5bits //16
parameter MEM_COLS = 80,//80 ->7bits //32SS
parameter PE_ROWS = 16,
parameter PE_COLS = 32,
parameter COMMON_ROW_COL = 4,
parameter OUTPUT_COL = 5,
parameter OUTPUT_ROW = 4
)(
input clk,
input rst,
input enable,
input [1:0] mode, // 00: Output-Stationary, 01: Input-Stationary, 10: Weight-Stationary
//input initialization;
input [DATA_WIDTH-1:0] data_input_A,
input [DATA_WIDTH-1:0] data_input_B,
input [DATA_WIDTH*((MEM_ROWS>MEM_COLS)?MEM_ROWS-1:MEM_COLS-1):0] full_row_A,
input [DATA_WIDTH*((MEM_ROWS>MEM_COLS)?MEM_ROWS-1:MEM_COLS-1):0] full_row_B,
input valid_mem_input_A,
input valid_mem_input_B,
input [$clog2(MEM_ROWS)-1:0] rows_start_add_reading_A,
input [$clog2(MEM_COLS)-1:0] cols_start_add_reading_A,
input [$clog2(MEM_ROWS)-1:0] rows_start_add_reading_B,
input [$clog2(MEM_COLS)-1:0] cols_start_add_reading_B,
input [$clog2(MEM_ROWS)-1:0] rows_start_add_writing,
input [$clog2(MEM_COLS)-1:0] cols_start_add_writing,
input [$clog2(MEM_ROWS)-1:0] rows_size_reading_A,
input [$clog2(MEM_COLS)-1:0] cols_size_reading_A,
input [$clog2(MEM_ROWS)-1:0] rows_size_reading_B,
input [$clog2(MEM_COLS)-1:0] cols_size_reading_B,
output reg done,
output reg read_full_row_A,
output reg read_full_row_B,
output reg write_full_row_out,
output reg [$clog2(MEM_ROWS)-1:0] row_addr_A,
output reg [$clog2(MEM_COLS)-1:0] col_addr_A,
output reg [$clog2(MEM_ROWS)-1:0] row_addr_B,
output reg [$clog2(MEM_COLS)-1:0] col_addr_B,
output reg [$clog2(MEM_ROWS)-1:0] row_addr_out,
output reg [$clog2(MEM_COLS)-1:0] col_addr_out,
output reg read_enable_A,
output reg read_enable_B,
output reg write_enable_out,
output reg [DATA_WIDTH*((MEM_ROWS>MEM_COLS)?MEM_ROWS-1:MEM_COLS-1):0] Full_row_out,
output reg [DATA_WIDTH-1:0] data_out
);
// FSM State Encoding
parameter IDLE = 0,
LOAD_ROW_A = 1,
WAIT_VALID_ROW_A = 2,
LOAD_COL_B = 3,
WAIT_VALID_COL_B = 4,
COMPUTE = 5,
WRITE = 6,
DONE = 7,
BEFORE_COMPUTE = 8,
LOAD_ROW_A_COL_B = 9,
WAIT_VALID_ROW_A_COL_B = 10;
reg [3:0] current_state, next_state;
// PE Array interface
reg enable_pe_array;
wire [OUTPUT_ROW*OUTPUT_COL-1:0] valid_pe_array;
reg [DATA_WIDTH*OUTPUT_COL-1:0] north_inputs;
reg [DATA_WIDTH*PE_ROWS-1:0] west_inputs;
wire [OUTPUT_ROW*DATA_WIDTH-1:0] acc_outputs;
reg [9:0] compute_counter;
reg [9:0] write_counter, read_counter;
wire [2*DATA_WIDTH-1:0] selected_accum_value;
reg initialization;
reg output_enable;
wire acc_output_valid_pe;
// Instantiate PE Array
pe_array #(
.DATA_WIDTH(DATA_WIDTH),
.MEM_ROWS(MEM_ROWS),//20 ->5bits //16
.MEM_COLS(MEM_COLS),//80 ->7bits //32SS
.PE_ROWS(PE_ROWS),
.PE_COLS(PE_COLS),
.COMMON_ROW_COL(COMMON_ROW_COL),
.OUTPUT_COL(OUTPUT_COL),
.OUTPUT_ROW(OUTPUT_ROW)
) dut_pe_array (
.clk(clk),
.rst(rst),
.mode(mode),
.initialization(initialization),
.north_inputs(north_inputs),
.west_inputs(west_inputs),
.output_enable(output_enable),
.enable(enable_pe_array),
.valid(valid_pe_array),
.rows_size_PE(rows_size_reading_A),
.cols_size_PE(cols_size_reading_B),
.acc_outputs(acc_outputs),
.acc_output_valid(acc_output_valid_pe)
);
// FSM Sequential
always @(posedge clk or posedge rst) begin
if (rst)
current_state <= IDLE;
else
current_state <= next_state;
end
// FSM Next-State Logic
always @(*) begin
case (current_state)
IDLE: next_state = enable ? LOAD_ROW_A : IDLE;
LOAD_ROW_A: next_state = WAIT_VALID_ROW_A;
WAIT_VALID_ROW_A: next_state = valid_mem_input_A ? LOAD_COL_B : WAIT_VALID_ROW_A;
LOAD_COL_B: next_state = WAIT_VALID_COL_B;
WAIT_VALID_COL_B: next_state = valid_mem_input_B ? BEFORE_COMPUTE : WAIT_VALID_COL_B;//COMPUTE
BEFORE_COMPUTE : next_state = COMPUTE;
COMPUTE: next_state = (read_counter <= cols_size_reading_A) ? LOAD_ROW_A :(acc_output_valid_pe) ? WRITE : COMPUTE; // PE_ROWS+PE_COLS+4
WRITE: next_state = (write_counter >= ((cols_size_reading_B>rows_size_reading_A)?rows_size_reading_A:cols_size_reading_B)) ? DONE : WRITE;
DONE: next_state = DONE;
default: next_state = IDLE;
endcase
end
// FSM Outputs
always @(posedge clk or posedge rst) begin
if (rst) begin
read_full_row_A <= 0;
read_full_row_B <= 0;
enable_pe_array <= 0;
compute_counter <= 0;
write_counter <= 0;
done <= 0;
write_enable_out <= 0;
read_counter <= 0;
initialization <= 0;
output_enable <= 0;
write_full_row_out <= 0;
Full_row_out <= 0;
end else begin
case (current_state)
IDLE: begin
read_full_row_A <= 0;
read_full_row_B <= 0;
enable_pe_array <= 0;
compute_counter <= 0;
write_counter <= 0;
done <= 0;
write_enable_out <= 0;
read_counter <= 0;
write_full_row_out <= 0;
Full_row_out <= 0;
$display("[IDLE] Waiting for enable...");
end
LOAD_ROW_A: begin
row_addr_A <= rows_start_add_reading_A;
col_addr_A <= cols_start_add_reading_A + read_counter;//
read_full_row_A <= 1;
enable_pe_array <= 0;
$display("[LOAD_ROW_A] Reading full row A.");
end
WAIT_VALID_ROW_A: begin
if (valid_mem_input_A) begin
read_full_row_A <= 0;////
$display("[WAIT_VALID_ROW_A[x][%0d]] Row A received(read = %0h) = %0h", col_addr_A, read_full_row_A, full_row_A);
end
end
LOAD_COL_B: begin
row_addr_B <= rows_start_add_reading_B;
col_addr_B <= cols_start_add_reading_B + read_counter;//
read_full_row_B <= 1;
$display("[LOAD_COL_B] Reading full row B (transpose col).");
end
WAIT_VALID_COL_B: begin
if (valid_mem_input_B) begin
read_full_row_B <= 0;//////
$display("[WAIT_VALID_COL_B] Row B received(read = %0h) = %0h", read_full_row_B, full_row_B);
read_counter <= read_counter + 1;
end
end
BEFORE_COMPUTE: begin
north_inputs <= full_row_B[OUTPUT_COL*DATA_WIDTH-1:0];
west_inputs <= full_row_A;
$display("[BEFORE_COMPUTE]north=%0h, west=%0h ", full_row_B[OUTPUT_COL*DATA_WIDTH-1:0], full_row_A );
end
COMPUTE: begin
enable_pe_array <= 1;
compute_counter <= compute_counter + 1;
$display("[COMPUTE] Cycle %0d / %0d", compute_counter, rows_size_reading_A+cols_size_reading_B+cols_size_reading_A+4);
// $display("[COMPUTE] Cycle %0d / %0d", compute_counter, PE_ROWS+OUTPUT_COL-1);
if(compute_counter >= cols_size_reading_A+1) begin
output_enable <= 1;
$display("output_enable = %h", output_enable);
end
end
WRITE: begin
enable_pe_array <= 1;
if((rows_size_reading_A >= cols_size_reading_B)) begin
row_addr_out <= rows_start_add_writing;
col_addr_out <= (write_counter ) + cols_start_add_writing;/////
$display("[WRITE] Writing output[%0d][%0d] = %0h | Valid = %b",
(write_counter), 0, acc_outputs, valid_pe_array[write_counter]);
end else begin
row_addr_out <= (write_counter ) + rows_start_add_writing; //original
col_addr_out <= cols_start_add_writing;/////
$display("[WRITE] Writing output[%0d][%0d] = %0h | Valid = %b",
0, (write_counter), acc_outputs, valid_pe_array[write_counter]);
end
Full_row_out <= acc_outputs;////
write_full_row_out <= 1;
write_counter <= write_counter + 1;
end
DONE: begin
done <= 1;
enable_pe_array <= 0;
compute_counter <= 0;
write_full_row_out <= 0;
$display("[DONE] Matrix multiplication completed.");
end
default: begin
enable_pe_array <= 0;
write_enable_out <= 0;
end
endcase
end
end
endmodule