pytorch-sbs/network/CPP_Cuda/approximation_multiplication_function.cpp
2023-01-05 13:23:58 +01:00

138 lines
4.1 KiB
C++

#include <unistd.h>
#include <bitset>
#include <cassert>
#include <cctype>
#include "error_term.cpp"
// Best way to plot the bits
// std::cout << std::bitset<32>(ap_y_ptr[1]) << "\n";
// The result needs to be written back into h_pointer (which contains h)
// Don't write to w_pointer.
void approximation_multiplication_function(
float *h_pointer, float *w_pointer, int64_t pattern_length,
uint64_t number_of_trunc_bits, uint64_t number_of_frac_bits,
uint32_t *ap_x_ptr, uint32_t *ap_y_ptr, uint32_t *ap_x_exponent_ptr,
uint32_t *ap_y_exponent_ptr, uint32_t *ap_h_exponent_ptr, uint32_t ap_mask,
uint64_t *ap_res_ptr, uint32_t *sign_temp_ptr, bool approximation_enable) {
uint64_t counter;
uint32_t *w_pointer_mod = (uint32_t *)w_pointer;
uint32_t *h_pointer_mod = (uint32_t *)h_pointer;
// Calculate the new sign
#pragma omp simd
for (counter = 0; counter < pattern_length; counter++) {
sign_temp_ptr[counter] = (w_pointer_mod[counter] & 0x80000000) ^
(h_pointer_mod[counter] & 0x80000000);
}
// Extract the exponent
#pragma omp simd
for (counter = 0; counter < pattern_length; counter++) {
ap_x_exponent_ptr[counter] = (h_pointer_mod[counter] << 1) >> 24;
}
#pragma omp simd
for (counter = 0; counter < pattern_length; counter++) {
ap_y_exponent_ptr[counter] = (w_pointer_mod[counter] << 1) >> 24;
}
// Cast and "normalize"
uint64_t shift_value = 32 - number_of_frac_bits;
#pragma omp simd
for (counter = 0; counter < pattern_length; counter++) {
ap_x_ptr[counter] =
((h_pointer_mod[counter] << 8) | 0x80000000) >> shift_value;
}
#pragma omp simd
for (counter = 0; counter < pattern_length; counter++) {
ap_y_ptr[counter] =
((w_pointer_mod[counter] << 8) | 0x80000000) >> shift_value;
}
// Make the zero -g-r-e-a-t- correct again
#pragma omp simd
for (counter = 0; counter < pattern_length; counter++) {
if (h_pointer[counter] == 0) {
ap_x_ptr[counter] = 0;
}
}
#pragma omp simd
for (counter = 0; counter < pattern_length; counter++) {
if (w_pointer[counter] == 0) {
ap_y_ptr[counter] = 0;
}
}
// res = x*y
#pragma omp simd
for (counter = 0; counter < pattern_length; counter++) {
ap_res_ptr[counter] = static_cast<uint64_t>(ap_x_ptr[counter]) * static_cast<uint64_t>(ap_y_ptr[counter]);
}
uint32_t temp;
if (approximation_enable == true){
// Go through the vector values
for (counter = 0; counter < pattern_length; counter++) {
temp = error_term(ap_y_ptr[counter], ap_x_ptr[counter], ap_mask,
number_of_trunc_bits);
if (temp > ap_res_ptr[counter]) {
ap_res_ptr[counter] = 0;
} else {
ap_res_ptr[counter] -= temp;
}
}
}
// Cast from int to float
#pragma omp simd
for (counter = 0; counter < pattern_length; counter++) {
h_pointer[counter] = static_cast<float>(ap_res_ptr[counter]);
}
#pragma omp simd
for (counter = 0; counter < pattern_length; counter++) {
ap_h_exponent_ptr[counter] = (h_pointer_mod[counter] << 1) >> 24;
}
// devide by the 2^number_of_frac_bits
#pragma omp simd
for (counter = 0; counter < pattern_length; counter++) {
ap_h_exponent_ptr[counter] -= 2 * number_of_frac_bits;
}
#pragma omp simd
for (counter = 0; counter < pattern_length; counter++) {
temp = ap_x_exponent_ptr[counter] + ap_y_exponent_ptr[counter] +
ap_h_exponent_ptr[counter];
if (temp > 252) {
ap_h_exponent_ptr[counter] = temp - 252;
} else {
// Here I try to catch the case that the new exponent is too small
ap_h_exponent_ptr[counter] = 0;
}
}
// Remove the old exponent
#pragma omp simd
for (counter = 0; counter < pattern_length; counter++) {
h_pointer_mod[counter] = (h_pointer_mod[counter] << 9) >> 9;
}
// Install the new exponent
#pragma omp simd
for (counter = 0; counter < pattern_length; counter++) {
h_pointer_mod[counter] += ap_h_exponent_ptr[counter] << 23;
}
// Add the sign back
#pragma omp simd
for (counter = 0; counter < pattern_length; counter++) {
h_pointer_mod[counter] += sign_temp_ptr[counter];
}
return;
}