AYAKA_Transformer/rtl/single_PE_module.v

195 lines
8.1 KiB
Verilog

module processing_element #(
parameter DATA_WIDTH = 16,
parameter MEM_ROWS = 20,//20 ->5bits //16
parameter MEM_COLS = 80,//80 ->7bits //32SS
parameter COMMON_ROW_COL = 4,
parameter OUTPUT_COL = 3,
parameter OUTPUT_ROW = 5,
parameter PE_ROWS = 3,
parameter PE_COLS = 5
)(
input clk,
input rst,
input enable,
input [1:0] mode,// 00: Output-Stationary, 01: Input-Stationary, 10: Weight-Stationary
input initialization,
input output_enable,
input [$clog2(OUTPUT_ROW+1)-1:0]pe_row_postion,
input [$clog2(OUTPUT_COL+1)-1:0]pe_col_postion,
// Inputs from the north and west
input signed [DATA_WIDTH-1:0] data_in_north, // A matrix element
input signed [DATA_WIDTH-1:0] data_in_west, // B matrix element
input [$clog2(MEM_ROWS)-1:0] rows_size_PE,//A
input [$clog2(MEM_COLS)-1:0] cols_size_PE,//B
// Outputs to the south and east
output reg signed [DATA_WIDTH-1:0] data_out_south, // Forwarded A
output reg signed [DATA_WIDTH-1:0] data_out_east, // Forwarded B
output reg enable_south,
output reg enable_east,
output reg output_enable_south,
output reg output_enable_east,
// Accumulated output
output reg [2*DATA_WIDTH-1:0] acc_out,
output reg valid,
// Clear accumulator
input clear_acc
);
reg signed [DATA_WIDTH-1:0] acc;
integer count_acc;//-1
integer count_col;
integer count_row;
reg signed [DATA_WIDTH-1:0] data_in_west_reg, data_in_north_reg;
always @(posedge clk or posedge rst) begin
if (rst) begin
acc <= 0;
data_out_south <= 0;
data_out_east <= 0;
acc_out <= 0;
valid <= 0;
count_acc <= 0;
count_col <= 0;
count_row <= 0;
data_in_west_reg <= 0;
data_in_north_reg <= 0;
enable_south <= 0;
enable_east <= 0;
output_enable_south <= 0;
output_enable_east <= 0;
end else begin
if (clear_acc) begin
acc <= 0;
valid <= 0;
count_acc <= 0;
count_col <= 0;
count_row <= 0;
data_in_west_reg <= 0;
data_in_north_reg <= 0;
enable_south <= 0;
enable_east <= 0;
output_enable_south <= 0;
output_enable_east <= 0;
end else begin
case(mode)
2'b00: begin//output staionary
if ( (enable == 1) && (count_acc < COMMON_ROW_COL) && (valid == 0)) begin///<=
// MAC operation
acc <= acc + data_in_north * data_in_west;
count_acc <= count_acc + 1;
// $display("north =%0h , west = %0h, count_acc %d, acc =%0h ", data_in_north, data_in_west, count_acc, acc);///////////////////
// Forward the inputs
data_out_south <= data_in_north;
data_out_east <= data_in_west;
enable_south <= enable;
enable_east <= enable;
end else begin
enable_south <= enable;/////
enable_east <= enable;/////
end
// Display all the port values
if ( (enable == 1) && (count_acc >= COMMON_ROW_COL-1)) begin
if (cols_size_PE > rows_size_PE) begin
data_out_south <= (output_enable == 1)? acc: data_in_north;
if (count_acc == (COMMON_ROW_COL+pe_row_postion+1)) begin //-1 reason fo r the XXX
output_enable_south <= output_enable;
end
if((pe_row_postion == 0) && (count_acc == (COMMON_ROW_COL+pe_row_postion+1)))begin
output_enable_east <= output_enable;
end else if (count_acc == (COMMON_ROW_COL+pe_row_postion)) begin
output_enable_east <= output_enable;
end
end else begin
data_out_east <= (output_enable == 1)? acc: data_in_west;
if (count_acc == (COMMON_ROW_COL+pe_col_postion+1)) begin //-1
output_enable_east <= output_enable;
end
if( (pe_col_postion == 0) && (count_acc == (COMMON_ROW_COL+pe_col_postion+1))) begin
output_enable_south <= output_enable;
end else if (count_acc == (COMMON_ROW_COL+pe_col_postion)) begin
output_enable_south <= output_enable;
end
end
if ( count_acc == COMMON_ROW_COL)begin/////
valid <= 1;
end
count_acc <= count_acc + 1;
end
end
2'b01:begin //Input-Stationary
if ((initialization == 1)) begin
data_in_west_reg = data_in_west;
data_out_south <= 0;
data_out_east <= data_in_west;
end else if ( (enable == 1) && (count_col <= (cols_size_PE+1)) && (initialization == 0) ) begin///<= && (valid == 0)
// MAC operation
data_out_east <= data_in_west + data_in_north * data_in_west_reg;
count_col <= count_col + 1;
// $display("north =%0h , west = %0h, count_col %d, data_out_east =%0h ", data_in_north, data_in_west, count_col, data_out_east);///////////////////
// Display all the port values
// Forward the inputs
data_out_south <= data_in_north;
enable_south <= enable;
enable_east <= enable;
end
if(count_col == COMMON_ROW_COL ) begin ///COMMON_ROW_COL-1 ///logic needs to be updated
acc_out <= 0;//
valid <= 1;
count_col <= count_col + 1;
end
end
2'b10:begin //Weight-Stationary
if ((initialization == 1)) begin
data_in_north_reg = data_in_north;
data_out_south <= data_in_north;
data_out_east <= 0;
end else if ((enable == 1) && (count_row <= (rows_size_PE+1)) && (valid == 0) && (initialization == 0)) begin///<= use COMMON_ROW_COL
// MAC operation
data_out_south <= data_in_north + data_in_west * data_in_north_reg;
count_row <= count_row + 1;
// $display("north =%0h , west = %0h, count_row %d, data_out_south =%0h ", data_in_north, data_in_west, count_row, data_out_south);///////////////////
// end
// Display all the port values
// Forward the inputs
data_out_east <= data_in_west;
enable_south <= enable;
enable_east <= enable;
end
if(count_row == (rows_size_PE+1)) begin ///=== //> use COMMON_ROW_COL
acc_out <= 0;//
valid <= 1;
count_row <= count_row + 1;
end
end
endcase
end
end
end
endmodule