-
Notifications
You must be signed in to change notification settings - Fork 4
/
utils.py
105 lines (86 loc) 路 2.34 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse
import sys
from typing import Optional
import molgrid
import torch
def print_args(
args: argparse.Namespace, header: Optional[str] = None, stream=sys.stdout
):
"""
Print command line arguments to stream.py
Parameters
----------
args: argparse.Namespace
Command line arguments
header: str
stream:
Output stream
"""
if header is not None:
print(header, file=stream)
for name, value in vars(args).items():
if type(value) is float:
print(f"{name}: {value:.5E}", file=stream)
else:
print(f"{name} = {value!r}", file=stream)
# Flush stream
print("", end="", file=stream, flush=True)
def log_print(
metrics,
title: Optional[str] = None,
epoch: Optional[int] = None,
stream=sys.stdout,
):
"""
Print metrics to the console.
Parameters
----------
metrics:
Dictionary of metrics
title: str
Title to print
epoch: int
Epoch number
stream:
Outoput stream
"""
if title is not None and epoch is not None:
print(f">>> {title} - Epoch[{epoch}] <<<", file=stream)
indent = " "
else:
indent = ""
# TODO: Order metrics?
loss: float = 0.0
for name, value in metrics.items():
print(f"{indent}{name}: {value:.5f}", file=stream)
if "loss" in name.lower():
loss += value
if loss > 0:
print(f" Loss: {loss:.5f}", file=stream)
print("", end="", file=stream, flush=True)
def set_device(device_name: str) -> torch.device:
"""
Set the device to use.
Parameters
----------
device_name: str
Name of the device to use (:code:`"cpu"`, :code:`"cuda"`, :code:`"cuda:0"`, ...)
Returns
-------
torch.device
PyTorch device
Notes
-----
This function also set the global device for :code:`molgrid` so that the
:code:`molgrid.ExampoleProvider` works on the correct device.
https://github.com/gnina/libmolgrid/issues/43
"""
# TODO: Set global PyTorch device?
device = torch.device(device_name)
if "cuda" in device_name:
try: # cuda:IDX
molgrid.set_gpu_device(int(device_name[-1]))
except ValueError: # cuda
# Set device 0 by default
molgrid.set_gpu_device(0)
return device