AYAKA_Transformer/rtl/PE_array_module.v

262 lines
14 KiB
Verilog

module pe_array #(
parameter DATA_WIDTH = 8,
parameter MEM_ROWS = 20,//20 ->5bits //16
parameter MEM_COLS = 80,//80 ->7bits //32SS
parameter PE_ROWS = 4,
parameter PE_COLS = 5,
parameter COMMON_ROW_COL = 4,
parameter OUTPUT_COL = 5,
parameter OUTPUT_ROW = 4
)(
input wire clk,
input wire rst,
input wire [1:0] mode,// 00: Output-Stationary, 01: Input-Stationary, 10: Weight-Stationary
input wire [PE_COLS*DATA_WIDTH-1:0] north_inputs, // Flattened inputs PE_ROWS
input wire [PE_ROWS*DATA_WIDTH-1:0] west_inputs, // Flattened inputs PE_COLS
input wire enable,
input wire output_enable,
input wire initialization,
input [$clog2(MEM_ROWS)-1:0] rows_size_PE,//A
input [$clog2(MEM_COLS)-1:0] cols_size_PE,//B
output wire [OUTPUT_COL*OUTPUT_ROW-1:0] valid,
output reg [OUTPUT_ROW*DATA_WIDTH-1:0] acc_outputs, // 2*DATA_WIDTH because of accumulation
output reg acc_output_valid
);
// Internal wires for each PE
wire [DATA_WIDTH-1:0] south_array [0:PE_ROWS-1][0:PE_COLS-1];
wire [DATA_WIDTH-1:0] east_array [0:PE_ROWS-1][0:PE_COLS-1];
wire [DATA_WIDTH-1:0] south_array_end [0:PE_COLS-1];/////
wire [DATA_WIDTH-1:0] east_array_end [0:PE_ROWS-1];//////
wire [2*DATA_WIDTH-1:0] output_array [0:PE_ROWS-1][0:PE_COLS-1];
// Delayed north and west inputs
reg [DATA_WIDTH-1:0] north_pipe [0:PE_COLS-1][0:PE_COLS-1]; // [which column][delay stages]
reg [DATA_WIDTH-1:0] west_pipe [0:PE_ROWS-1][0:PE_ROWS-1]; // [which row][delay stages]
wire enable_var [0:PE_ROWS-1][0:PE_COLS-1];
wire output_enable_var [0:PE_ROWS-1][0:PE_COLS-1];
reg [DATA_WIDTH-1:0] delayed_south[0:OUTPUT_COL-1][0:OUTPUT_COL-1]; // [row index][delay stages]
reg [DATA_WIDTH-1:0] delayed_east[0:OUTPUT_ROW-1][0:OUTPUT_ROW-1]; // [row index][delay stages]
reg [OUTPUT_ROW*DATA_WIDTH-1:0] acc_outputs_delayed, acc_outputs_delayed1;
integer r, d, delay_count, read_count;
integer m, n;
integer x, y, cycle ;
always @(posedge clk or posedge rst) begin
if (rst) begin
for (m = 0; m <= cols_size_PE; m = m + 1)
for (n = 0; n <= m; n = n + 1)
north_pipe[m][n] <= 0;
for (m = 0; m <= rows_size_PE; m = m + 1)
for (n = 0; n <= m; n = n + 1)
west_pipe[m][n] <= 0;
for (r = 0; r <= cols_size_PE; r = r + 1)
for (d = 0; d <= cols_size_PE; d = d + 1)
delayed_south[r][d] <= 0;
for (r = 0; r <= rows_size_PE; r = r + 1)
for (d = 0; d <= rows_size_PE; d = d + 1)
delayed_east[r][d] <= 0;
valid[r] = 0;
acc_outputs = 0;
cycle = 0;
acc_output_valid = 0;
acc_outputs_delayed = 0;
acc_outputs_delayed1 = 0;
delay_count = 0;
read_count = 0;
end else begin/// use the clearing logic
// Update north pipeline
for (m = 0; m <= cols_size_PE; m = m + 1) begin
north_pipe[m][0] <= north_inputs[(m+1)*DATA_WIDTH-1 -: DATA_WIDTH];
for (n = 1; n <= m; n = n + 1)
north_pipe[m][n] <= north_pipe[m][n-1];
end
// Update west pipeline
for (m = 0; m <= rows_size_PE; m = m + 1) begin
west_pipe[m][0] <= west_inputs[(m+1)*DATA_WIDTH-1 -: DATA_WIDTH];
for (n = 1; n <= m; n = n + 1)
west_pipe[m][n] <= west_pipe[m][n-1];
end
if((valid[0] == 1) && ((mode == 2'b01)||(mode == 2'b10))) begin //if condition
for (r = 0; r <= cols_size_PE; r = r + 1) begin
delayed_south[r][0] <= south_array[rows_size_PE][r];//[r][0]; // south_array from column 0
for (d = 1; d <= (cols_size_PE+1) - r; d = d + 1)// r
delayed_south[r][d] <= delayed_south[r][d-1];
end
for (r = 0; r < rows_size_PE; r = r + 1) begin
delayed_east[r][0] <= east_array[r][cols_size_PE];//east_array[PE_ROWS-1][r];//[r][0]; // south_array from column 0
for (d = 1; d <= (rows_size_PE+1) - r; d = d + 1)// r
delayed_east[r][d] <= delayed_east[r][d-1];
end
end else if((valid[((cols_size_PE>rows_size_PE)?(rows_size_PE):(cols_size_PE))] == 1) && (mode == 2'b00)) begin //if condition
delay_count = delay_count + 1;/////
for (r = 0; r <= cols_size_PE; r = r + 1) begin
delayed_south[r][0] <= south_array[rows_size_PE][r];//[r][0]; // south_array from column 0
for (d = 1; d <= (cols_size_PE+1); d = d + 1)// d <= PE_COLS - r
delayed_south[r][d] <= delayed_south[r][d-1];
end
for (r = 0; r <= rows_size_PE; r = r + 1) begin
delayed_east[r][0] <= east_array[r][cols_size_PE];//east_array[PE_ROWS-1][r];//[r][0]; // south_array from column 0
for (d = 1; d <= (rows_size_PE+1); d = d + 1)// d <=PE_ROWS - r
delayed_east[r][d] <= delayed_east[r][d-1];
end
end
//weight-stationary
if (mode == 2'b10 && valid[(OUTPUT_COL*2)-1] == 1 && (cycle < OUTPUT_ROW)) begin // no need for all valid just last element in PE array should be high //(OUTPUT_COL*OUTPUT_ROW)
for (x = 0; x < PE_COLS; x = x + 1) begin
// Now filling for a particular cycle and column
acc_outputs[((x + cycle*OUTPUT_COL) + 1)*DATA_WIDTH -1 -: DATA_WIDTH] =
{ {(DATA_WIDTH){1'b0}}, delayed_south[x][PE_COLS-x-1] };//cycle, PE_COLS-x-1
// Display what's being assigned
$display("Cycle: %0d, Col: %0d, Index: %0d to %0d, delayed_south(%0d,%0d):= %0d",
cycle,
x,
((x + cycle * OUTPUT_COL) + 1) * DATA_WIDTH - 1,
((x + cycle * OUTPUT_COL) + 1) * DATA_WIDTH - DATA_WIDTH,
x, // First argument for the first %0d
(PE_COLS - x - 1), // Second argument for the second %0d
//x,
delayed_south[x][PE_COLS - x - 1] // Third argument for the third %0d
);
// end
end
cycle = cycle + 1;
end else if (mode == 2'b01 && valid[(OUTPUT_COL*2)] == 1 && (cycle < OUTPUT_COL)) begin // input-stationary (OUTPUT_COL*2)
for (x = 0; x < PE_ROWS; x = x + 1) begin
// Now filling for a particular cycle and column
acc_outputs[((x*OUTPUT_COL + cycle) + 1)*DATA_WIDTH -1 -: DATA_WIDTH] =
{ {(DATA_WIDTH){1'b0}}, delayed_east[x][PE_ROWS-x-1] };//cycle, PE_COLS-x-1
// Display what's being assigned
/* $display("Cycle: %0d, Col: %0d, Index: %0d to %0d, delayed_east(%0d,%0d):= %0d",
cycle,
x,
((x*OUTPUT_COL + cycle) + 1) * 2 * DATA_WIDTH - 1,
((x*OUTPUT_COL + cycle) + 1) * 2 * DATA_WIDTH - 2 * DATA_WIDTH,
x, // First argument for the first %0d
(PE_ROWS - x - 1), // Second argument for the second %0d
//x,
delayed_east[x][PE_ROWS - x - 1] // Third argument for the third %0d
);
*/
end
cycle = cycle + 1;
end else if (mode == 2'b00 && (delay_count >= ((cols_size_PE > rows_size_PE)?(cols_size_PE)+2:(rows_size_PE)+2)) && (cycle <= ((cols_size_PE > rows_size_PE)?(rows_size_PE+1):(cols_size_PE+1)))) begin // output-stationary for HDPE array module
if (cols_size_PE > rows_size_PE) begin
for (x = 0; x <= cols_size_PE; x = x + 1) begin
// Now filling for a particular cycle and column
acc_outputs_delayed[((x ) + 1)*DATA_WIDTH -1 -: DATA_WIDTH] =
{ {(DATA_WIDTH){1'b0}}, delayed_south[x][cols_size_PE-x] };//cycle, PE_COLS-x-1
// Display what's being assigned
/*
$display("o/p_Cycle: %0d, Col: %0d, Index: %0d to %0d, delayed_south(%0d,%0d):= %0h; acc_output_valid =%0h; delay_count = %0h",
cycle,
x,
((x ) + 1) * DATA_WIDTH - 1,
((x ) + 1) * DATA_WIDTH - DATA_WIDTH,
x, // First argument for the first %0d
(cols_size_PE - x), // Second argument for the second %0d
//x,
delayed_south[x][cols_size_PE - x], // Third argument for the third %0d
acc_output_valid,
delay_count
);
*/
end
acc_outputs_delayed1 <= acc_outputs_delayed;
acc_outputs <= acc_outputs_delayed1;
acc_output_valid = 1;
cycle = cycle + 1;
end else begin
for (x = 0; x <= rows_size_PE; x = x + 1) begin
// Now filling for a particular cycle and column
acc_outputs_delayed[((x ) + 1)*DATA_WIDTH -1 -: DATA_WIDTH] =
{ {(DATA_WIDTH){1'b0}}, delayed_east[x][rows_size_PE-x] };//cycle, PE_COLS-x-1
// Display what's being assigned
/*
$display("o/p_Cycle: %0d, Row: %0d, Index: %0d to %0d, delayed_east(%0d,%0d):= %0h ; acc_output_valid =%0h; delay_count = %0h",
cycle,
x,
((x) + 1) * DATA_WIDTH - 1,
((x ) + 1) * DATA_WIDTH - DATA_WIDTH,
x, // First argument for the first %0d
(rows_size_PE - x), // Second argument for the second %0d
//x,
delayed_east[x][rows_size_PE - x], // Third argument for the third %0d
acc_output_valid,
delay_count
);
*/
// end
end
acc_outputs_delayed1 <=acc_outputs_delayed;
acc_outputs = acc_outputs_delayed1;
cycle = cycle + 1;
acc_output_valid = 1;
end
end
end
end
genvar i, j;
generate
for (i = 0; i < PE_ROWS; i = i + 1) begin : row_gen //ROW
for (j = 0; j < PE_COLS; j = j + 1) begin : col_gen //PE_COLS
processing_element #(
.DATA_WIDTH(DATA_WIDTH),
.MEM_ROWS(MEM_ROWS),//20 ->5bits //16
.MEM_COLS(MEM_COLS),//80 ->7bits //32SS
.COMMON_ROW_COL(COMMON_ROW_COL),
.OUTPUT_COL(OUTPUT_COL),
.OUTPUT_ROW(OUTPUT_ROW),
.PE_ROWS(PE_ROWS),
.PE_COLS(PE_COLS)
) pe_inst (
.clk(clk),
.rst(rst),
.initialization(initialization),
.enable((j==0) && (i==0) ? enable : enable_var[i][j]),
.output_enable((j==0) && (i==0) ? output_enable : output_enable_var[i][j]),
.clear_acc(1'b0), // No accumulator clearing
.data_in_north((i == 0) ? (initialization ? north_inputs[(j+1)*DATA_WIDTH-1 -: DATA_WIDTH] : north_pipe[j][j])
: south_array[i-1][j]),
.data_in_west((j == 0) ? (initialization ? west_inputs[(i+1)*DATA_WIDTH-1 -: DATA_WIDTH] : west_pipe[i][i])
: east_array[i][j-1]),
.pe_row_postion(i[$clog2(OUTPUT_ROW+1)-1:0]),
.pe_col_postion(j[$clog2(OUTPUT_COL+1)-1:0]),
.output_enable_south(output_enable_var[i+1][j]),
.output_enable_east(output_enable_var[i][j+1]),
.enable_south(enable_var[i+1][j]),
.enable_east(enable_var[i][j+1]),
.data_out_south(south_array[i][j]),
.data_out_east(east_array[i][j]),
.cols_size_PE(cols_size_PE),
.rows_size_PE(rows_size_PE),
.mode(mode),
.acc_out(output_array[i][j]),
.valid(valid[(i*PE_COLS) + j])
);
end
end
endgenerate
///*
//redundent logic for debuging simplicity
genvar k;
generate
for (k = 0; k < PE_COLS; k = k + 1) begin
assign south_array_end[k] = south_array[PE_ROWS-1][k]; // Bottom row
end
for (k = 0; k < PE_ROWS; k = k + 1) begin
assign east_array_end[k] = east_array[k][PE_COLS-1]; // Rightmost column
end
endgenerate
//*/
endmodule