From 4173e3306ef6714b3d59d2c435da71dafcc10a09 Mon Sep 17 00:00:00 2001 From: David Rotermund <54365609+davrot@users.noreply.github.com> Date: Mon, 26 Feb 2024 13:01:09 +0100 Subject: [PATCH] Add files via upload --- new_pipeline/functions/get_torch_device.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 new_pipeline/functions/get_torch_device.py diff --git a/new_pipeline/functions/get_torch_device.py b/new_pipeline/functions/get_torch_device.py new file mode 100644 index 0000000..9eec5e9 --- /dev/null +++ b/new_pipeline/functions/get_torch_device.py @@ -0,0 +1,17 @@ +import torch +import logging + + +def get_torch_device(mylogger: logging.Logger, force_to_cpu: bool) -> torch.device: + + if torch.cuda.is_available(): + device_name: str = "cuda:0" + else: + device_name = "cpu" + + if force_to_cpu: + device_name = "cpu" + + mylogger.info(f"Using device: {device_name}") + device: torch.device = torch.device(device_name) + return device