Add files via upload
This commit is contained in:
parent
9fa12c258e
commit
4173e3306e
1 changed files with 17 additions and 0 deletions
17
new_pipeline/functions/get_torch_device.py
Normal file
17
new_pipeline/functions/get_torch_device.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
import torch
|
||||
import logging
|
||||
|
||||
|
||||
def get_torch_device(mylogger: logging.Logger, force_to_cpu: bool) -> torch.device:
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device_name: str = "cuda:0"
|
||||
else:
|
||||
device_name = "cpu"
|
||||
|
||||
if force_to_cpu:
|
||||
device_name = "cpu"
|
||||
|
||||
mylogger.info(f"Using device: {device_name}")
|
||||
device: torch.device = torch.device(device_name)
|
||||
return device
|
Loading…
Reference in a new issue