AYAKA_Transformer/rtl/PE_array_tb.v

301 lines
12 KiB
Verilog

module pe_array_tb;
parameter DATA_WIDTH = 8;
parameter PE_ROWS = 4; //5-ip sationary
parameter PE_COLS = 5;//4;//5; //4-ip sat // NOTE: 5 columns because 4x5 output expected
parameter COMMON_ROW_COL = 4;
parameter OUTPUT_ROW = 4;
parameter OUTPUT_COL = 5;
reg clk;
reg rst;
reg [PE_COLS*DATA_WIDTH-1:0] north_inputs; //PE_COLS
reg [PE_ROWS*DATA_WIDTH-1:0] west_inputs; //PE_ROWS
reg enable; // Enable signal for PE 0,0
wire [OUTPUT_COL*OUTPUT_ROW*2*DATA_WIDTH-1:0] acc_outputs; // Accumulated outputs from all PEs
wire [OUTPUT_COL*OUTPUT_ROW-1:0] valid; // Valid signal for each PE
reg initialization, output_enable;
integer i, j;
integer cycle_count;
integer p, q;
reg [1:0]mode;
// Define your input matrices
reg [DATA_WIDTH-1:0] matrix_A [0:OUTPUT_ROW-1][0:COMMON_ROW_COL-1]; // 4x4 matrix
reg [DATA_WIDTH-1:0] matrix_B [0:COMMON_ROW_COL-1][0:OUTPUT_COL-1]; // 4x5 matrix
reg [DATA_WIDTH-1:0] expected_C [0:OUTPUT_ROW-1][0:OUTPUT_COL-1]; // 4x5 output matrix (for checking)
pe_array #(
.DATA_WIDTH(DATA_WIDTH),
.PE_ROWS(PE_ROWS),
.PE_COLS(PE_COLS),
.COMMON_ROW_COL(COMMON_ROW_COL),
.OUTPUT_COL(OUTPUT_COL),
.OUTPUT_ROW(OUTPUT_ROW)
) dut (
.clk(clk),
.rst(rst),
.north_inputs(north_inputs),
.west_inputs(west_inputs),
.mode(mode),
.initialization(initialization),
.enable(enable),
.output_enable(output_enable),
.valid(valid),
.acc_outputs(acc_outputs)
);
// Clock generation
initial begin
clk = 0;
forever #5 clk = ~clk; // 100MHz
end
initial begin
rst = 1;
north_inputs = 0;
west_inputs = 0;
enable = 0; // Initially, disable all PEs
initialization = 0;
output_enable = 0;
cycle_count = 0;
#20;
rst = 0;
mode = 2'b00;//output_staionary
// mode = 2'b01;//input_staionary
// mode = 2'b10;//weight_staionary
// Initialize Matrix A (4x4) and Matrix B (4x5)
// matrix_A[0][0] = 8'd1; matrix_A[0][1] = 8'd2; matrix_A[0][2] = 8'd3; matrix_A[0][3] = 8'd4;
// matrix_A[1][0] = 8'd5; matrix_A[1][1] = 8'd6; matrix_A[1][2] = 8'd7; matrix_A[1][3] = 8'd8;
// matrix_A[2][0] = 8'd9; matrix_A[2][1] = 8'd10; matrix_A[2][2] = 8'd11; matrix_A[2][3] = 8'd12;
// matrix_A[3][0] = 8'd13; matrix_A[3][1] = 8'd14; matrix_A[3][2] = 8'd15; matrix_A[3][3] = 8'd16;
// matrix_B[0][0] = 8'd1; matrix_B[0][1] = 8'd2; matrix_B[0][2] = 8'd3; matrix_B[0][3] = 8'd4; matrix_B[0][4] = 8'd5;
// matrix_B[1][0] = 8'd6; matrix_B[1][1] = 8'd7; matrix_B[1][2] = 8'd8; matrix_B[1][3] = 8'd9; matrix_B[1][4] = 8'd10;
// matrix_B[2][0] = 8'd11; matrix_B[2][1] = 8'd12; matrix_B[2][2] = 8'd13; matrix_B[2][3] = 8'd14; matrix_B[2][4] = 8'd15;
// matrix_B[3][0] = 8'd16; matrix_B[3][1] = 8'd17; matrix_B[3][2] = 8'd18; matrix_B[3][3] = 8'd19; matrix_B[3][4] = 8'd20;
matrix_A[0][0] = 8'd1; matrix_A[0][1] = 8'd2; matrix_A[0][2] = 8'd3; matrix_A[0][3] = 8'd4;
matrix_A[1][0] = 8'd4; matrix_A[1][1] = 8'd3; matrix_A[1][2] = 8'd2; matrix_A[1][3] = 8'd1;
matrix_A[2][0] = 8'd1; matrix_A[2][1] = 8'd2; matrix_A[2][2] = 8'd3; matrix_A[2][3] = 8'd4;
matrix_A[3][0] = 8'd4; matrix_A[3][1] = 8'd3; matrix_A[3][2] = 8'd2; matrix_A[3][3] = 8'd1;
matrix_B[0][0] = 8'd1; matrix_B[0][1] = 8'd2; matrix_B[0][2] = 8'd3; matrix_B[0][3] = 8'd4; matrix_B[0][4] = 8'd5;
matrix_B[1][0] = 8'd5; matrix_B[1][1] = 8'd4; matrix_B[1][2] = 8'd3; matrix_B[1][3] = 8'd2; matrix_B[1][4] = 8'd1;
matrix_B[2][0] = 8'd1; matrix_B[2][1] = 8'd2; matrix_B[2][2] = 8'd3; matrix_B[2][3] = 8'd4; matrix_B[2][4] = 8'd5;
matrix_B[3][0] = 8'd5; matrix_B[3][1] = 8'd4; matrix_B[3][2] = 8'd3; matrix_B[3][3] = 8'd2; matrix_B[3][4] = 8'd1;
// Clear accumulators (after reset)
#10;
// Compute expected output C = A * B
for (i = 0; i < OUTPUT_ROW; i = i + 1) begin
for (j = 0; j < OUTPUT_COL; j = j + 1) begin
expected_C[i][j] = 0;
for (p = 0; p < OUTPUT_ROW; p = p + 1) begin // This should be based on PE_ROWS
expected_C[i][j] = expected_C[i][j] + matrix_A[i][p] * matrix_B[p][j];
#20; // or #40 based on PE array timing
end
end
end
// Clear accumulators (after reset)
#100;
/*
// enable = 1;/////////////////////////
///weight_stationary working!!!!
for (cycle_count = 0; cycle_count < (PE_ROWS + PE_COLS +2); cycle_count = cycle_count + 1) begin//-1
north_inputs = 0;
west_inputs = 0;
$display("\n================== Cycle %0d feeding start ===================", cycle_count);
for (j = 0; j < PE_COLS; j = j + 1) begin// PE_COLS
if (cycle_count < PE_ROWS ) begin
// Valid region to fetch B matrix for north inputs
north_inputs[(j+1)*DATA_WIDTH-1 -: DATA_WIDTH] = matrix_B[(PE_ROWS-cycle_count-1)][j];
// north_inputs[(j+1)*DATA_WIDTH-1 -: DATA_WIDTH] = 8'b1;
initialization = 1;
$display(" NORTH: PE(0,%0d): B[%0d][%0d] = %0h", j, (PE_ROWS-cycle_count-1), j, matrix_B[(PE_ROWS-cycle_count-1)][j]);
end else begin
initialization = 0;
end
end
for (i = 0; i < PE_ROWS; i = i + 1) begin
if ((cycle_count >= PE_COLS) && (cycle_count <= ( PE_COLS + PE_ROWS - 1)) ) begin
// Valid region to fetch A matrix for west inputs
west_inputs[(i+1)*DATA_WIDTH-1 -: DATA_WIDTH] = matrix_A[cycle_count - PE_COLS][i];
// west_inputs[(i+1)*DATA_WIDTH-1 -: DATA_WIDTH] = 8'b1;
enable = 1;
initialization = 0;
$display(" WEST: PE(%0d,0): A[%0d][%0d] = %0h", i, (cycle_count - PE_COLS), i, matrix_A[cycle_count - PE_COLS][i]);
end
end
$display("west_inputs = %h", west_inputs);
$display("north_inputs = %h", north_inputs);
$display("================== Cycle %0d feeding end ===================\n", cycle_count);
#10; // adjust based on PE array latency
end
*/
/*
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// enable = 1;/////////////////////////
///input_stationary working!!!!
for (cycle_count = 0; cycle_count < (PE_ROWS + PE_COLS +2); cycle_count = cycle_count + 1) begin//-1
north_inputs = 0;
west_inputs = 0;
$display("\n================== Cycle %0d feeding start ===================", cycle_count);
for (i = 0; i < PE_ROWS; i = i + 1) begin
if (cycle_count < COMMON_ROW_COL ) begin
// Valid region to fetch A matrix for west inputs
west_inputs[(i+1)*DATA_WIDTH-1 -: DATA_WIDTH] = matrix_A[i][(COMMON_ROW_COL-cycle_count-1)];
initialization = 1;
// west_inputs[(i+1)*DATA_WIDTH-1 -: DATA_WIDTH] = 8'b1;
$display(" WEST: PE(%0d,0): A[%0d][%0d] = %0h", i, i, (COMMON_ROW_COL-cycle_count-1), matrix_A[i][(COMMON_ROW_COL-cycle_count-1)]);
end else begin
initialization = 0;
end
end
for (j = 0; j < PE_COLS; j = j + 1) begin// PE_COLS
if ((cycle_count >= COMMON_ROW_COL) && (cycle_count <= ( COMMON_ROW_COL + OUTPUT_COL - 1)) ) begin
enable = 1;
initialization = 0;
// Valid region to fetch B matrix for north inputs
north_inputs[(j+1)*DATA_WIDTH-1 -: DATA_WIDTH] = matrix_B[j][cycle_count- COMMON_ROW_COL];
// north_inputs[(j+1)*DATA_WIDTH-1 -: DATA_WIDTH] = 8'b1;
$display(" NORTH: PE(0,%0d): B[%0d][%0d] = %0h", j, j, cycle_count -COMMON_ROW_COL, matrix_B[j][cycle_count-COMMON_ROW_COL]);
end
end
$display("west_inputs = %h", west_inputs);
$display("north_inputs = %h", north_inputs);
$display("================== Cycle %0d feeding end ===================\n", cycle_count);
#10; // adjust based on PE array latency
end
*/
///*
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//Output stationary
for (cycle_count = 0; cycle_count < (PE_ROWS + PE_COLS +4); cycle_count = cycle_count + 1) begin//-1
enable = 1;
north_inputs = 0;
west_inputs = 0;
$display("\n================== Cycle %0d feeding start ===================", cycle_count);
for (i = 0; i < PE_ROWS; i = i + 1) begin
if (cycle_count < COMMON_ROW_COL ) begin
// Valid region to fetch A matrix for west inputs
west_inputs[(i+1)*DATA_WIDTH-1 -: DATA_WIDTH] = matrix_A[i][cycle_count];
// west_inputs[(i+1)*DATA_WIDTH-1 -: DATA_WIDTH] = 8'b1;
// enable = 1;
$display(" WEST: PE(%0d,0): A[%0d][%0d] = %0h", i, i, cycle_count, matrix_A[i][cycle_count]);
end
end
for (j = 0; j < PE_COLS; j = j + 1) begin
if (cycle_count < COMMON_ROW_COL ) begin
// Valid region to fetch B matrix for north inputs
north_inputs[(j+1)*DATA_WIDTH-1 -: DATA_WIDTH] = matrix_B[cycle_count][j];
// north_inputs[(j+1)*DATA_WIDTH-1 -: DATA_WIDTH] = 8'b1;
$display(" NORTH: PE(0,%0d): B[%0d][%0d] = %0h", j, cycle_count, j, matrix_B[cycle_count][j]);
end
end
if(cycle_count >= COMMON_ROW_COL) begin
output_enable <= 1;
$display("output_enable = %h", output_enable);
end
$display("west_inputs = %h", west_inputs);
$display("north_inputs = %h", north_inputs);
$display("================== Cycle %0d feeding end ===================\n", cycle_count);
#10; // adjust based on PE array latency
end
//*/
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Wait for operations to complete
#100;
// Optional: Extract and display result matrix
$display("Result matrix:");
for (i = 0; i < OUTPUT_ROW; i = i + 1) begin
for (j = 0; j < OUTPUT_COL; j = j + 1) begin
$write("%0d ", acc_outputs[((i*(OUTPUT_COL)+j)+1)*2*DATA_WIDTH-1 -: 2*DATA_WIDTH]);//PE_COLS-1
end
$display();
end
// Display expected result matrix
$display("\nExpected matrix:");
for (i = 0; i < OUTPUT_ROW; i = i + 1) begin
for (j = 0; j < OUTPUT_COL; j = j + 1) begin
$write("%0d ", expected_C[i][j]);
end
$display();
end
$stop;
end
endmodule
// Result Matrix C = A x B (4x5):
// [110, 120, 130, 140, 150]
// [246, 272, 298, 324, 350]
// [382, 424, 466, 508, 550]
// [518, 576, 634, 692, 750]
// [34, 32, 30, 28, 26]
// [26, 28, 30, 32, 34]
// [34, 32, 30, 28, 26]
// [26, 28, 30, 32, 34]