From d51738832cc586a642a36dab0e0d590540c8beb4 Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Thu, 2 Feb 2023 19:23:09 +0100 Subject: [PATCH] Add files via upload --- network/Conv2dApproximation.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/network/Conv2dApproximation.py b/network/Conv2dApproximation.py index 9276d4f..c37bdb7 100644 --- a/network/Conv2dApproximation.py +++ b/network/Conv2dApproximation.py @@ -1,11 +1,14 @@ import torch import math -from network.CPP.PyMultiApp import MultiApp +from network.PyMultiplicationApproximationCPU import MultiplicationApproximationCPU +from network.PyMultiplicationApproximationGPU import MultiplicationApproximationGPU global_multiapp_gpu_setting: list[torch.Tensor] = [] global_multiapp_size: list[torch.Tensor] = [] -global_multiapp_cpp: list[MultiApp] = [] +global_multiapp_cpp: list[ + MultiplicationApproximationCPU | MultiplicationApproximationGPU +] = [] class Conv2dApproximation(torch.nn.Module): @@ -77,7 +80,11 @@ class Conv2dApproximation(torch.nn.Module): global_multiapp_gpu_setting.append(torch.tensor([0])) global_multiapp_size.append(torch.tensor([0, 0, 0, 0])) - global_multiapp_cpp.append(MultiApp()) + if device == torch.device("cpu"): + global_multiapp_cpp.append(MultiplicationApproximationCPU()) + else: + global_multiapp_cpp.append(MultiplicationApproximationGPU()) + self.multiapp_gpu_setting_position = len(global_multiapp_gpu_setting) - 1 self.multiapp_cpp_position = len(global_multiapp_cpp) - 1