Delete network/unused_code directory
This commit is contained in:
parent
3b70ae789a
commit
b5f95bd0db
1 changed files with 0 additions and 186 deletions
|
@ -1,186 +0,0 @@
|
||||||
import torch
|
|
||||||
import math
|
|
||||||
|
|
||||||
from network.CPP.PyMultiApp import MultiApp
|
|
||||||
|
|
||||||
|
|
||||||
class LinearApproximation(torch.nn.Module):
|
|
||||||
|
|
||||||
in_features: int | None = None
|
|
||||||
out_features: int | None = None
|
|
||||||
use_bias: bool = False
|
|
||||||
|
|
||||||
approximation_enable: bool = False
|
|
||||||
number_of_trunc_bits: int = -1
|
|
||||||
number_of_frac: int = -1
|
|
||||||
|
|
||||||
number_of_processes: int = 1
|
|
||||||
|
|
||||||
weights: torch.nn.parameter.Parameter
|
|
||||||
bias: torch.nn.parameter.Parameter | None
|
|
||||||
|
|
||||||
device: torch.device
|
|
||||||
dtype: torch.dtype
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_features: int,
|
|
||||||
out_features: int,
|
|
||||||
bias: bool = True,
|
|
||||||
approximation_enable: bool = False,
|
|
||||||
number_of_trunc_bits: int = -1,
|
|
||||||
number_of_frac: int = -1,
|
|
||||||
number_of_processes: int = 1,
|
|
||||||
device: torch.device | None = None,
|
|
||||||
dtype: torch.dtype | None = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
assert device is not None
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
assert dtype is not None
|
|
||||||
self.dtype = dtype
|
|
||||||
|
|
||||||
self.in_features = in_features
|
|
||||||
self.out_channels = out_features
|
|
||||||
self.use_bias = bias
|
|
||||||
|
|
||||||
self.approximation_enable = approximation_enable
|
|
||||||
self.number_of_trunc_bits = number_of_trunc_bits
|
|
||||||
self.number_of_frac = number_of_frac
|
|
||||||
|
|
||||||
self.number_of_processes = number_of_processes
|
|
||||||
|
|
||||||
if self.use_bias is True:
|
|
||||||
self.bias: torch.nn.parameter.Parameter | None = (
|
|
||||||
torch.nn.parameter.Parameter(
|
|
||||||
torch.empty(
|
|
||||||
(out_features),
|
|
||||||
dtype=self.dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
self.weights: torch.nn.parameter.Parameter = torch.nn.parameter.Parameter(
|
|
||||||
torch.empty(
|
|
||||||
(out_features, in_features),
|
|
||||||
dtype=self.dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.functional_multi = FunctionalMultiLinear.apply
|
|
||||||
|
|
||||||
self.reset_parameters()
|
|
||||||
|
|
||||||
def reset_parameters(self) -> None:
|
|
||||||
# Stolen from original torch conv2 code
|
|
||||||
torch.nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5))
|
|
||||||
if self.bias is not None:
|
|
||||||
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weights)
|
|
||||||
if fan_in != 0:
|
|
||||||
bound = 1 / math.sqrt(fan_in)
|
|
||||||
torch.nn.init.uniform_(self.bias, -bound, bound)
|
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
||||||
|
|
||||||
assert input.dim() == 2
|
|
||||||
|
|
||||||
parameter_list = torch.tensor(
|
|
||||||
[
|
|
||||||
int(self.approximation_enable), # 0
|
|
||||||
int(self.number_of_trunc_bits), # 1
|
|
||||||
int(self.number_of_frac), # 2
|
|
||||||
int(self.number_of_processes), # 3
|
|
||||||
],
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
|
||||||
|
|
||||||
output = self.functional_multi(
|
|
||||||
input.unsqueeze(-1).unsqueeze(-1), self.weights, parameter_list
|
|
||||||
)
|
|
||||||
|
|
||||||
output = output.squeeze(-1).squeeze(-1)
|
|
||||||
|
|
||||||
if self.bias is not None:
|
|
||||||
output += self.bias.unsqueeze(0)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class FunctionalMultiLinear(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward( # type: ignore
|
|
||||||
ctx,
|
|
||||||
input: torch.Tensor,
|
|
||||||
weights: torch.Tensor,
|
|
||||||
parameter_list: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
assert input.ndim == 4
|
|
||||||
assert input.dtype is torch.float32
|
|
||||||
assert input.is_contiguous() is True
|
|
||||||
|
|
||||||
assert weights.ndim == 2
|
|
||||||
assert weights.dtype is torch.float32
|
|
||||||
assert weights.is_contiguous() is True
|
|
||||||
|
|
||||||
assert input.shape[1] == weights.shape[1]
|
|
||||||
|
|
||||||
approximation_enable = bool(parameter_list[0])
|
|
||||||
number_of_trunc_bits = int(parameter_list[1])
|
|
||||||
number_of_frac = int(parameter_list[2])
|
|
||||||
number_of_processes = int(parameter_list[3])
|
|
||||||
|
|
||||||
assert input.device == weights.device
|
|
||||||
|
|
||||||
output = torch.zeros(
|
|
||||||
(input.shape[0], weights.shape[0], input.shape[2], input.shape[3]),
|
|
||||||
dtype=weights.dtype,
|
|
||||||
device=weights.device,
|
|
||||||
requires_grad=True,
|
|
||||||
)
|
|
||||||
assert output.is_contiguous() is True
|
|
||||||
|
|
||||||
multiplier: MultiApp = MultiApp()
|
|
||||||
|
|
||||||
multiplier.update_with_init_vector_multi_pattern(
|
|
||||||
input.data_ptr(),
|
|
||||||
weights.data_ptr(),
|
|
||||||
output.data_ptr(),
|
|
||||||
int(output.shape[0]), # pattern
|
|
||||||
int(output.shape[1]), # feature channel
|
|
||||||
int(output.shape[2]), # x
|
|
||||||
int(output.shape[3]), # y
|
|
||||||
int(input.shape[1]), # input channel
|
|
||||||
int(number_of_processes),
|
|
||||||
bool(approximation_enable),
|
|
||||||
int(number_of_trunc_bits),
|
|
||||||
int(number_of_frac),
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx.save_for_backward(
|
|
||||||
input.detach(),
|
|
||||||
weights.detach(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
|
|
||||||
(input, weights) = ctx.saved_tensors
|
|
||||||
|
|
||||||
grad_input = (
|
|
||||||
grad_output.unsqueeze(2) * weights.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
|
||||||
).sum(1)
|
|
||||||
grad_weights = (
|
|
||||||
(grad_output.unsqueeze(2) * input.unsqueeze(1)).sum(0).sum(-1).sum(-1)
|
|
||||||
)
|
|
||||||
grad_parameter_list = None
|
|
||||||
|
|
||||||
return (grad_input, grad_weights, grad_parameter_list)
|
|
Loading…
Reference in a new issue