mirror of
https://github.com/davrot/pytorch-sbs.git
synced 2025-07-04 23:00:03 +02:00
Add files via upload
This commit is contained in:
parent
3622b91591
commit
ec40834df1
31 changed files with 2707 additions and 0 deletions
40
network/multiplication_approximation_gpu_cpp/Makefile
Normal file
40
network/multiplication_approximation_gpu_cpp/Makefile
Normal file
|
@ -0,0 +1,40 @@
|
|||
include ../.env
|
||||
export
|
||||
|
||||
name = MultiplicationApproximation
|
||||
type = GPU
|
||||
|
||||
PYPOSTFIX := $(shell $(PYBIN)python3-config --extension-suffix)
|
||||
PYBIND11INCLUDE := $(shell $(PYBIN)python3 -m pybind11 --includes)
|
||||
PARAMETERS_O = $(PARAMETERS_O_GPU) $(PYBIND11INCLUDE)
|
||||
PARAMETERS_Linker = $(PARAMETERS_Linker_GPU)
|
||||
|
||||
so_file = Py$(name)$(type)$(PYPOSTFIX)
|
||||
pyi_file = Py$(name)$(type).pyi
|
||||
all: ../$(so_file)
|
||||
|
||||
$(O_DIRS)$(name)$(type).o: $(name)$(type).h \
|
||||
$(name)$(type).cu \
|
||||
gpu_error_term.cu \
|
||||
gpu_approximation_multiplication_function.cu
|
||||
mkdir -p $(O_DIRS)
|
||||
$(NVCC) $(PARAMETERS_O) -c $(name)$(type).cu -o $(O_DIRS)$(name)$(type).o
|
||||
|
||||
$(O_DIRS)Py$(name)$(type).o: $(name)$(type).h Py$(name)$(type).cpp
|
||||
mkdir -p $(O_DIRS)
|
||||
$(NVCC) $(PARAMETERS_O) -c Py$(name)$(type).cpp -o $(O_DIRS)Py$(name)$(type).o
|
||||
|
||||
../$(so_file): \
|
||||
$(O_DIRS)$(name)$(type).o \
|
||||
$(O_DIRS)Py$(name)$(type).o
|
||||
$(NVCC) $(PARAMETERS_Linker) -o ../$(so_file) \
|
||||
$(O_DIRS)$(name)$(type).o \
|
||||
$(O_DIRS)Py$(name)$(type).o
|
||||
|
||||
|
||||
#######################
|
||||
clean:
|
||||
rm -rf $(O_DIRS)
|
||||
rm -f ../$(so_file)
|
||||
rm -f ../$(pyi_file)
|
||||
|
|
@ -0,0 +1,182 @@
|
|||
#include "MultiplicationApproximationGPU.h"
|
||||
|
||||
#include <omp.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "gpu_approximation_multiplication_function.cu"
|
||||
|
||||
MultiplicationApproximationGPU::MultiplicationApproximationGPU()
|
||||
{
|
||||
|
||||
};
|
||||
|
||||
MultiplicationApproximationGPU::~MultiplicationApproximationGPU()
|
||||
{
|
||||
|
||||
};
|
||||
|
||||
void MultiplicationApproximationGPU::entrypoint(
|
||||
int64_t np_input_pointer_addr,
|
||||
int64_t np_weight_pointer_addr,
|
||||
int64_t np_output_pointer_addr,
|
||||
int64_t pattern_dim,
|
||||
int64_t feature_dim,
|
||||
int64_t x_dim,
|
||||
int64_t y_dim,
|
||||
int64_t input_channel_dim,
|
||||
int64_t number_of_processes,
|
||||
bool approximation_enable,
|
||||
int64_t number_of_trunc_bits,
|
||||
int64_t number_of_frac)
|
||||
{
|
||||
|
||||
// int64_t number_of_pattern = pattern_dim;
|
||||
|
||||
float* np_input_pointer = (float*)np_input_pointer_addr;
|
||||
float* np_weight_pointer = (float*)np_weight_pointer_addr;
|
||||
float* np_output_pointer = (float*)np_output_pointer_addr;
|
||||
|
||||
assert((np_input_pointer != nullptr));
|
||||
assert((np_output_pointer != nullptr));
|
||||
assert((np_weight_pointer != nullptr));
|
||||
|
||||
assert((pattern_dim > 0));
|
||||
assert((feature_dim > 0));
|
||||
assert((x_dim > 0));
|
||||
assert((y_dim > 0));
|
||||
assert((input_channel_dim > 0));
|
||||
|
||||
assert ((number_of_processes <= 0));
|
||||
|
||||
calculate_gpu(np_input_pointer, np_weight_pointer,
|
||||
np_output_pointer, pattern_dim, feature_dim, x_dim, y_dim,
|
||||
input_channel_dim, approximation_enable,
|
||||
number_of_trunc_bits, number_of_frac);
|
||||
|
||||
return;
|
||||
};
|
||||
|
||||
__global__ void kernel_approx_multiplication(
|
||||
float* __restrict__ input_pointer,
|
||||
float* __restrict__ weight_pointer,
|
||||
float* __restrict__ output_pointer,
|
||||
uint64_t pattern_dim,
|
||||
uint64_t feature_dim,
|
||||
uint64_t x_dim,
|
||||
uint64_t y_dim,
|
||||
uint64_t input_channel_dim,
|
||||
size_t max_threadable_tasks,
|
||||
uint64_t input_index_scale,
|
||||
uint64_t number_of_frac_bits,
|
||||
bool approximation_enable,
|
||||
uint64_t number_of_trunc_bits,
|
||||
uint32_t ap_mask)
|
||||
{
|
||||
int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
|
||||
if (idx < max_threadable_tasks)
|
||||
{
|
||||
int pattern_id = idx / feature_dim;
|
||||
int feature_id = idx - (pattern_id * feature_dim);
|
||||
int x_id = blockIdx.y;
|
||||
int y_id = blockIdx.z;
|
||||
|
||||
float* weight_pointer_sub = weight_pointer + feature_id * input_channel_dim;
|
||||
float* input_pointer_sub = input_pointer + pattern_id * input_channel_dim * x_dim * y_dim + x_id * y_dim + y_id;
|
||||
float* output_pointer_sub = output_pointer +
|
||||
pattern_id * feature_dim * x_dim * y_dim +
|
||||
feature_id * x_dim * y_dim + x_id * y_dim + y_id;
|
||||
*output_pointer_sub = 0.0;
|
||||
|
||||
for (size_t counter = 0; counter < input_channel_dim; counter++)
|
||||
{
|
||||
*output_pointer_sub += gpu_approximation_multiplication_function(
|
||||
weight_pointer_sub[counter],
|
||||
input_pointer_sub[counter * input_index_scale],
|
||||
number_of_frac_bits, approximation_enable,
|
||||
number_of_trunc_bits, ap_mask);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
void MultiplicationApproximationGPU::calculate_gpu(
|
||||
float* np_input_pointer,
|
||||
float* np_weight_pointer,
|
||||
float* np_output_pointer,
|
||||
size_t pattern_dim,
|
||||
size_t feature_dim,
|
||||
size_t x_dim,
|
||||
size_t y_dim,
|
||||
size_t input_channel_dim,
|
||||
bool approximation_enable,
|
||||
size_t number_of_trunc_bits,
|
||||
size_t number_of_frac_bits)
|
||||
{
|
||||
|
||||
uint32_t ap_mask = static_cast<uint64_t>(pow(2, number_of_trunc_bits)) - 1;
|
||||
// std::cout << approximation_enable << std::endl;
|
||||
// std::cout << number_of_trunc_bits << std::endl;
|
||||
// std::cout << number_of_frac_bits << std::endl;
|
||||
|
||||
cudaError_t status;
|
||||
assert((x_dim < 65535));
|
||||
assert((y_dim < 65535));
|
||||
|
||||
// //////////////////////////////////////
|
||||
// Calculate the distribution on the GPU
|
||||
// //////////////////////////////////////
|
||||
|
||||
int min_grid_size;
|
||||
int block_size;
|
||||
int grid_size;
|
||||
|
||||
size_t dynamic_s_mem_size = 0;
|
||||
size_t max_threadable_tasks = pattern_dim * feature_dim * x_dim * y_dim;
|
||||
|
||||
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html?highlight=blocksize#occupancy-calculator
|
||||
status = cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size,
|
||||
(void*)kernel_approx_multiplication,
|
||||
dynamic_s_mem_size, max_threadable_tasks);
|
||||
assert((status == cudaSuccess));
|
||||
|
||||
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications
|
||||
// Maximum dimensionality of grid of thread blocks: 3
|
||||
// Maximum x -dimension of a grid of thread blocks: (2^31)-1
|
||||
// Maximum y- or z-dimension of a grid of thread blocks: 65535
|
||||
|
||||
// Round up according to array size
|
||||
grid_size = ((pattern_dim * feature_dim) + block_size - 1) / block_size;
|
||||
|
||||
// std::cout << min_grid_size << std::endl;
|
||||
// std::cout << grid_size << std::endl;
|
||||
// std::cout << block_size << std::endl;
|
||||
// std::cout << max_threadable_tasks << std::endl;
|
||||
|
||||
dim3 grid(grid_size, x_dim, y_dim);
|
||||
|
||||
kernel_approx_multiplication<<<grid, block_size>>>(np_input_pointer,
|
||||
np_weight_pointer,
|
||||
np_output_pointer,
|
||||
pattern_dim,
|
||||
feature_dim,
|
||||
x_dim,
|
||||
y_dim,
|
||||
input_channel_dim,
|
||||
(pattern_dim * feature_dim),
|
||||
(x_dim * y_dim),
|
||||
number_of_frac_bits,
|
||||
approximation_enable,
|
||||
number_of_trunc_bits,
|
||||
ap_mask);
|
||||
|
||||
status = cudaDeviceSynchronize();
|
||||
assert((status == cudaSuccess));
|
||||
return;
|
||||
};
|
|
@ -0,0 +1,44 @@
|
|||
#ifndef MULTIPLICATIONAPPROXIMATIONGPU
|
||||
#define MULTIPLICATIONAPPROXIMATIONGPU
|
||||
|
||||
#include <unistd.h>
|
||||
#include <cctype>
|
||||
#include <iostream>
|
||||
|
||||
class MultiplicationApproximationGPU
|
||||
{
|
||||
public:
|
||||
MultiplicationApproximationGPU();
|
||||
~MultiplicationApproximationGPU();
|
||||
|
||||
void entrypoint(
|
||||
int64_t np_input_pointer_addr,
|
||||
int64_t np_weight_pointer_addr,
|
||||
int64_t np_output_pointer_addr,
|
||||
int64_t pattern_dim,
|
||||
int64_t feature_dim,
|
||||
int64_t x_dim,
|
||||
int64_t y_dim,
|
||||
int64_t input_channel_dim,
|
||||
int64_t number_of_processes,
|
||||
bool approximation_enable,
|
||||
int64_t number_of_trunc_bits,
|
||||
int64_t number_of_frac);
|
||||
|
||||
private:
|
||||
void calculate_gpu(
|
||||
float* input_pointer,
|
||||
float* weight_pointer,
|
||||
float* output_pointer,
|
||||
size_t pattern_dim,
|
||||
size_t feature_dim,
|
||||
size_t x_dim,
|
||||
size_t y_dim,
|
||||
size_t input_channel_dim,
|
||||
bool approximation_enable,
|
||||
size_t number_of_trunc_bits,
|
||||
size_t number_of_frac);
|
||||
|
||||
};
|
||||
|
||||
#endif /* MULTIPLICATIONAPPROXIMATIONGPU */
|
|
@ -0,0 +1,14 @@
|
|||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "MultiplicationApproximationGPU.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
PYBIND11_MODULE(PyMultiplicationApproximationGPU, m) {
|
||||
m.doc() = "MultiplicationApproximationGPU Module";
|
||||
py::class_<MultiplicationApproximationGPU>(m, "MultiplicationApproximationGPU")
|
||||
.def(py::init<>())
|
||||
.def("update_entrypoint",
|
||||
&MultiplicationApproximationGPU::entrypoint);
|
||||
}
|
|
@ -0,0 +1,103 @@
|
|||
|
||||
#include "gpu_error_term.cu"
|
||||
|
||||
__device__ float gpu_approximation_multiplication_function(
|
||||
float weight,
|
||||
float input,
|
||||
size_t number_of_frac_bits,
|
||||
bool approximation_enable,
|
||||
size_t number_of_trunc_bits,
|
||||
uint32_t ap_mask)
|
||||
{
|
||||
|
||||
float weight_copy = weight;
|
||||
float input_copy = input;
|
||||
|
||||
uint32_t *weight_pointer_mod = (uint32_t *)&weight_copy;
|
||||
uint32_t *input_pointer_mod = (uint32_t *)&input_copy;
|
||||
|
||||
// Calculate the new sign
|
||||
uint32_t sign_temp = (*weight_pointer_mod & 0x80000000) ^
|
||||
(*input_pointer_mod & 0x80000000);
|
||||
|
||||
// Extract the exponent
|
||||
uint32_t ap_input_exponent = (*input_pointer_mod << 1) >> 24;
|
||||
uint32_t ap_weight_exponent = (*weight_pointer_mod << 1) >> 24;
|
||||
|
||||
// Cast and "normalize"
|
||||
uint64_t shift_value = 32 - number_of_frac_bits;
|
||||
|
||||
uint32_t ap_input_mantissa =
|
||||
((*input_pointer_mod << 8) | 0x80000000) >> shift_value;
|
||||
|
||||
uint32_t ap_weight_mantissa =
|
||||
((*weight_pointer_mod << 8) | 0x80000000) >> shift_value;
|
||||
|
||||
// Make the zero -g-r-e-a-t- correct again
|
||||
if (input == 0)
|
||||
{
|
||||
ap_input_mantissa = 0;
|
||||
}
|
||||
|
||||
if (weight == 0)
|
||||
{
|
||||
ap_weight_mantissa = 0;
|
||||
}
|
||||
|
||||
// res = x*y
|
||||
uint64_t ap_result = static_cast<uint64_t>(ap_input_mantissa) * static_cast<uint64_t>(ap_weight_mantissa);
|
||||
|
||||
uint32_t temp;
|
||||
// --------------------------------------------
|
||||
// Approx
|
||||
// --------------------------------------------
|
||||
|
||||
if (approximation_enable == true)
|
||||
{
|
||||
// Go through the vector values
|
||||
temp = gpu_error_term(ap_weight_mantissa, ap_input_mantissa, ap_mask,
|
||||
number_of_trunc_bits);
|
||||
if (temp > ap_result)
|
||||
{
|
||||
ap_result = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
ap_result -= temp;
|
||||
}
|
||||
}
|
||||
|
||||
// Cast from int to float
|
||||
float output = static_cast<float>(ap_result);
|
||||
if (ap_result == 0)
|
||||
{
|
||||
output = 0.0;
|
||||
}
|
||||
else
|
||||
{
|
||||
uint32_t *output_pointer_mod = (uint32_t *)&output;
|
||||
|
||||
uint32_t ap_output_exponent = (*output_pointer_mod << 1) >> 24;
|
||||
ap_output_exponent -= 2 * number_of_frac_bits;
|
||||
temp = ap_input_exponent + ap_weight_exponent + ap_output_exponent;
|
||||
if (temp > 252)
|
||||
{
|
||||
ap_output_exponent = temp - 252;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Here I try to catch the case that the new exponent is too small
|
||||
ap_output_exponent = 0;
|
||||
}
|
||||
|
||||
// Remove the old exponent
|
||||
*output_pointer_mod = (*output_pointer_mod << 9) >> 9;
|
||||
|
||||
// Install the new exponent
|
||||
*output_pointer_mod += ap_output_exponent << 23;
|
||||
|
||||
// Add the sign back
|
||||
*output_pointer_mod += sign_temp;
|
||||
}
|
||||
return output;
|
||||
};
|
|
@ -0,0 +1,29 @@
|
|||
|
||||
__device__ uint32_t gpu_error_term(
|
||||
uint32_t ap_weight_mantissa,
|
||||
uint32_t ap_input_mantissa,
|
||||
uint32_t ap_mask,
|
||||
uint32_t number_of_trunc_bits)
|
||||
{
|
||||
uint32_t error_value = 0;
|
||||
|
||||
uint32_t temp_shift_a = ap_weight_mantissa;
|
||||
uint32_t temp_shift_b = ap_input_mantissa & ap_mask;
|
||||
|
||||
uint32_t counter_trunc;
|
||||
uint32_t temp;
|
||||
|
||||
// Go through the bits
|
||||
for (counter_trunc = 0; counter_trunc < number_of_trunc_bits; counter_trunc++)
|
||||
{
|
||||
temp = temp_shift_a & 1;
|
||||
if (temp == 1)
|
||||
{
|
||||
error_value += temp_shift_b & ap_mask;
|
||||
}
|
||||
temp_shift_a >>= 1;
|
||||
temp_shift_b <<= 1;
|
||||
}
|
||||
|
||||
return error_value;
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue