diff --git a/network/Conv2dApproximation.py b/network/Conv2dApproximation.py index 9276d4f..c37bdb7 100644 --- a/network/Conv2dApproximation.py +++ b/network/Conv2dApproximation.py @@ -1,11 +1,14 @@ import torch import math -from network.CPP.PyMultiApp import MultiApp +from network.PyMultiplicationApproximationCPU import MultiplicationApproximationCPU +from network.PyMultiplicationApproximationGPU import MultiplicationApproximationGPU global_multiapp_gpu_setting: list[torch.Tensor] = [] global_multiapp_size: list[torch.Tensor] = [] -global_multiapp_cpp: list[MultiApp] = [] +global_multiapp_cpp: list[ + MultiplicationApproximationCPU | MultiplicationApproximationGPU +] = [] class Conv2dApproximation(torch.nn.Module): @@ -77,7 +80,11 @@ class Conv2dApproximation(torch.nn.Module): global_multiapp_gpu_setting.append(torch.tensor([0])) global_multiapp_size.append(torch.tensor([0, 0, 0, 0])) - global_multiapp_cpp.append(MultiApp()) + if device == torch.device("cpu"): + global_multiapp_cpp.append(MultiplicationApproximationCPU()) + else: + global_multiapp_cpp.append(MultiplicationApproximationGPU()) + self.multiapp_gpu_setting_position = len(global_multiapp_gpu_setting) - 1 self.multiapp_cpp_position = len(global_multiapp_cpp) - 1