256 lines
9 KiB
Verilog
256 lines
9 KiB
Verilog
module matrix_multiplication_unit_new #(
|
|
parameter DATA_WIDTH = 16,
|
|
parameter MEM_ROWS = 20,//20 ->5bits //16
|
|
parameter MEM_COLS = 80,//80 ->7bits //32SS
|
|
parameter PE_ROWS = 16,
|
|
parameter PE_COLS = 32,
|
|
parameter COMMON_ROW_COL = 4,
|
|
parameter OUTPUT_COL = 5,
|
|
parameter OUTPUT_ROW = 4
|
|
)(
|
|
input clk,
|
|
input rst,
|
|
input enable,
|
|
input [1:0] mode, // 00: Output-Stationary, 01: Input-Stationary, 10: Weight-Stationary
|
|
//input initialization;
|
|
input [DATA_WIDTH-1:0] data_input_A,
|
|
input [DATA_WIDTH-1:0] data_input_B,
|
|
input [DATA_WIDTH*((MEM_ROWS>MEM_COLS)?MEM_ROWS-1:MEM_COLS-1):0] full_row_A,
|
|
input [DATA_WIDTH*((MEM_ROWS>MEM_COLS)?MEM_ROWS-1:MEM_COLS-1):0] full_row_B,
|
|
|
|
input valid_mem_input_A,
|
|
input valid_mem_input_B,
|
|
|
|
input [$clog2(MEM_ROWS)-1:0] rows_start_add_reading_A,
|
|
input [$clog2(MEM_COLS)-1:0] cols_start_add_reading_A,
|
|
input [$clog2(MEM_ROWS)-1:0] rows_start_add_reading_B,
|
|
input [$clog2(MEM_COLS)-1:0] cols_start_add_reading_B,
|
|
input [$clog2(MEM_ROWS)-1:0] rows_start_add_writing,
|
|
input [$clog2(MEM_COLS)-1:0] cols_start_add_writing,
|
|
|
|
input [$clog2(MEM_ROWS)-1:0] rows_size_reading_A,
|
|
input [$clog2(MEM_COLS)-1:0] cols_size_reading_A,
|
|
input [$clog2(MEM_ROWS)-1:0] rows_size_reading_B,
|
|
input [$clog2(MEM_COLS)-1:0] cols_size_reading_B,
|
|
|
|
|
|
output reg done,
|
|
|
|
output reg read_full_row_A,
|
|
output reg read_full_row_B,
|
|
output reg write_full_row_out,
|
|
|
|
output reg [$clog2(MEM_ROWS)-1:0] row_addr_A,
|
|
output reg [$clog2(MEM_COLS)-1:0] col_addr_A,
|
|
output reg [$clog2(MEM_ROWS)-1:0] row_addr_B,
|
|
output reg [$clog2(MEM_COLS)-1:0] col_addr_B,
|
|
output reg [$clog2(MEM_ROWS)-1:0] row_addr_out,
|
|
output reg [$clog2(MEM_COLS)-1:0] col_addr_out,
|
|
output reg read_enable_A,
|
|
output reg read_enable_B,
|
|
output reg write_enable_out,
|
|
output reg [DATA_WIDTH*((MEM_ROWS>MEM_COLS)?MEM_ROWS-1:MEM_COLS-1):0] Full_row_out,
|
|
output reg [DATA_WIDTH-1:0] data_out
|
|
);
|
|
|
|
// FSM State Encoding
|
|
parameter IDLE = 0,
|
|
LOAD_ROW_A = 1,
|
|
WAIT_VALID_ROW_A = 2,
|
|
LOAD_COL_B = 3,
|
|
WAIT_VALID_COL_B = 4,
|
|
COMPUTE = 5,
|
|
WRITE = 6,
|
|
DONE = 7,
|
|
BEFORE_COMPUTE = 8,
|
|
LOAD_ROW_A_COL_B = 9,
|
|
WAIT_VALID_ROW_A_COL_B = 10;
|
|
|
|
reg [3:0] current_state, next_state;
|
|
|
|
// PE Array interface
|
|
reg enable_pe_array;
|
|
wire [OUTPUT_ROW*OUTPUT_COL-1:0] valid_pe_array;
|
|
reg [DATA_WIDTH*OUTPUT_COL-1:0] north_inputs;
|
|
reg [DATA_WIDTH*PE_ROWS-1:0] west_inputs;
|
|
wire [OUTPUT_ROW*DATA_WIDTH-1:0] acc_outputs;
|
|
|
|
reg [9:0] compute_counter;
|
|
reg [9:0] write_counter, read_counter;
|
|
|
|
wire [2*DATA_WIDTH-1:0] selected_accum_value;
|
|
|
|
reg initialization;
|
|
reg output_enable;
|
|
wire acc_output_valid_pe;
|
|
|
|
|
|
// Instantiate PE Array
|
|
pe_array #(
|
|
.DATA_WIDTH(DATA_WIDTH),
|
|
.MEM_ROWS(MEM_ROWS),//20 ->5bits //16
|
|
.MEM_COLS(MEM_COLS),//80 ->7bits //32SS
|
|
.PE_ROWS(PE_ROWS),
|
|
.PE_COLS(PE_COLS),
|
|
.COMMON_ROW_COL(COMMON_ROW_COL),
|
|
.OUTPUT_COL(OUTPUT_COL),
|
|
.OUTPUT_ROW(OUTPUT_ROW)
|
|
) dut_pe_array (
|
|
.clk(clk),
|
|
.rst(rst),
|
|
.mode(mode),
|
|
.initialization(initialization),
|
|
.north_inputs(north_inputs),
|
|
.west_inputs(west_inputs),
|
|
.output_enable(output_enable),
|
|
.enable(enable_pe_array),
|
|
.valid(valid_pe_array),
|
|
.rows_size_PE(rows_size_reading_A),
|
|
.cols_size_PE(cols_size_reading_B),
|
|
.acc_outputs(acc_outputs),
|
|
.acc_output_valid(acc_output_valid_pe)
|
|
);
|
|
|
|
// FSM Sequential
|
|
always @(posedge clk or posedge rst) begin
|
|
if (rst)
|
|
current_state <= IDLE;
|
|
else
|
|
current_state <= next_state;
|
|
end
|
|
|
|
// FSM Next-State Logic
|
|
always @(*) begin
|
|
|
|
case (current_state)
|
|
IDLE: next_state = enable ? LOAD_ROW_A : IDLE;
|
|
LOAD_ROW_A: next_state = WAIT_VALID_ROW_A;
|
|
WAIT_VALID_ROW_A: next_state = valid_mem_input_A ? LOAD_COL_B : WAIT_VALID_ROW_A;
|
|
LOAD_COL_B: next_state = WAIT_VALID_COL_B;
|
|
WAIT_VALID_COL_B: next_state = valid_mem_input_B ? BEFORE_COMPUTE : WAIT_VALID_COL_B;//COMPUTE
|
|
BEFORE_COMPUTE : next_state = COMPUTE;
|
|
COMPUTE: next_state = (read_counter <= cols_size_reading_A) ? LOAD_ROW_A :(acc_output_valid_pe) ? WRITE : COMPUTE; // PE_ROWS+PE_COLS+4
|
|
WRITE: next_state = (write_counter >= ((cols_size_reading_B>rows_size_reading_A)?rows_size_reading_A:cols_size_reading_B)) ? DONE : WRITE;
|
|
DONE: next_state = DONE;
|
|
default: next_state = IDLE;
|
|
endcase
|
|
end
|
|
|
|
// FSM Outputs
|
|
always @(posedge clk or posedge rst) begin
|
|
if (rst) begin
|
|
read_full_row_A <= 0;
|
|
read_full_row_B <= 0;
|
|
enable_pe_array <= 0;
|
|
compute_counter <= 0;
|
|
write_counter <= 0;
|
|
done <= 0;
|
|
write_enable_out <= 0;
|
|
read_counter <= 0;
|
|
initialization <= 0;
|
|
output_enable <= 0;
|
|
write_full_row_out <= 0;
|
|
Full_row_out <= 0;
|
|
end else begin
|
|
case (current_state)
|
|
IDLE: begin
|
|
read_full_row_A <= 0;
|
|
read_full_row_B <= 0;
|
|
enable_pe_array <= 0;
|
|
compute_counter <= 0;
|
|
write_counter <= 0;
|
|
done <= 0;
|
|
write_enable_out <= 0;
|
|
read_counter <= 0;
|
|
write_full_row_out <= 0;
|
|
Full_row_out <= 0;
|
|
$display("[IDLE] Waiting for enable...");
|
|
end
|
|
|
|
LOAD_ROW_A: begin
|
|
row_addr_A <= rows_start_add_reading_A;
|
|
col_addr_A <= cols_start_add_reading_A + read_counter;//
|
|
read_full_row_A <= 1;
|
|
enable_pe_array <= 0;
|
|
$display("[LOAD_ROW_A] Reading full row A.");
|
|
end
|
|
|
|
WAIT_VALID_ROW_A: begin
|
|
if (valid_mem_input_A) begin
|
|
read_full_row_A <= 0;////
|
|
$display("[WAIT_VALID_ROW_A[x][%0d]] Row A received(read = %0h) = %0h", col_addr_A, read_full_row_A, full_row_A);
|
|
end
|
|
end
|
|
|
|
LOAD_COL_B: begin
|
|
row_addr_B <= rows_start_add_reading_B;
|
|
col_addr_B <= cols_start_add_reading_B + read_counter;//
|
|
read_full_row_B <= 1;
|
|
$display("[LOAD_COL_B] Reading full row B (transpose col).");
|
|
end
|
|
|
|
WAIT_VALID_COL_B: begin
|
|
if (valid_mem_input_B) begin
|
|
read_full_row_B <= 0;//////
|
|
$display("[WAIT_VALID_COL_B] Row B received(read = %0h) = %0h", read_full_row_B, full_row_B);
|
|
read_counter <= read_counter + 1;
|
|
end
|
|
end
|
|
|
|
BEFORE_COMPUTE: begin
|
|
north_inputs <= full_row_B[OUTPUT_COL*DATA_WIDTH-1:0];
|
|
west_inputs <= full_row_A;
|
|
$display("[BEFORE_COMPUTE]north=%0h, west=%0h ", full_row_B[OUTPUT_COL*DATA_WIDTH-1:0], full_row_A );
|
|
|
|
end
|
|
|
|
COMPUTE: begin
|
|
|
|
enable_pe_array <= 1;
|
|
compute_counter <= compute_counter + 1;
|
|
$display("[COMPUTE] Cycle %0d / %0d", compute_counter, rows_size_reading_A+cols_size_reading_B+cols_size_reading_A+4);
|
|
// $display("[COMPUTE] Cycle %0d / %0d", compute_counter, PE_ROWS+OUTPUT_COL-1);
|
|
if(compute_counter >= cols_size_reading_A+1) begin
|
|
output_enable <= 1;
|
|
$display("output_enable = %h", output_enable);
|
|
end
|
|
end
|
|
|
|
WRITE: begin
|
|
enable_pe_array <= 1;
|
|
|
|
if((rows_size_reading_A >= cols_size_reading_B)) begin
|
|
row_addr_out <= rows_start_add_writing;
|
|
col_addr_out <= (write_counter ) + cols_start_add_writing;/////
|
|
$display("[WRITE] Writing output[%0d][%0d] = %0h | Valid = %b",
|
|
(write_counter), 0, acc_outputs, valid_pe_array[write_counter]);
|
|
end else begin
|
|
row_addr_out <= (write_counter ) + rows_start_add_writing; //original
|
|
col_addr_out <= cols_start_add_writing;/////
|
|
$display("[WRITE] Writing output[%0d][%0d] = %0h | Valid = %b",
|
|
0, (write_counter), acc_outputs, valid_pe_array[write_counter]);
|
|
end
|
|
Full_row_out <= acc_outputs;////
|
|
write_full_row_out <= 1;
|
|
write_counter <= write_counter + 1;
|
|
end
|
|
|
|
DONE: begin
|
|
done <= 1;
|
|
enable_pe_array <= 0;
|
|
compute_counter <= 0;
|
|
write_full_row_out <= 0;
|
|
$display("[DONE] Matrix multiplication completed.");
|
|
end
|
|
|
|
default: begin
|
|
enable_pe_array <= 0;
|
|
write_enable_out <= 0;
|
|
end
|
|
endcase
|
|
end
|
|
end
|
|
|
|
endmodule
|
|
|
|
|