-
Notifications
You must be signed in to change notification settings - Fork 53
/
torch-ort_test_api.py
176 lines (139 loc) · 6.99 KB
/
torch-ort_test_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import copy
import os
import pytest
import tempfile
import torch
import _test_helpers
from torch_ort import ORTModule, DebugOptions, LogLevel, set_seed
from torch_ort.utils.data import LoadBalancingDistributedSampler, LoadBalancingDistributedBatchSampler
class NeuralNetSinglePositionalArgument(torch.nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNetSinglePositionalArgument, self).__init__()
self.fc1 = torch.nn.Linear(input_size, hidden_size)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(hidden_size, num_classes)
self.dropout = torch.nn.Dropout(p=0.5)
def forward(self, input1):
out = self.fc1(input1)
out = self.relu(out)
out = self.fc2(out)
return self.dropout(out)
class MyDataset(torch.utils.data.Dataset):
def __init__(self, samples):
self.samples = samples
def __getitem__(self, index):
return self.samples[index]
def __len__(self):
return len(self.samples)
def test_set_seed():
N, D_in, H, D_out = 64, 784, 500, 10
input = torch.randn(N, D_in)
orig_model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
predictions = []
for _ in range(10):
set_seed(1)
model = ORTModule(copy.deepcopy(orig_model))
prediction = model(input)
predictions.append(prediction)
# All predictions must match
for pred in predictions:
_test_helpers.assert_values_are_close(predictions[0], pred, rtol=1e-9, atol=0.0)
@pytest.mark.parametrize("mode", ['training', 'inference'])
def test_debug_options_save_onnx_models_os_environment(mode):
N, D_in, H, D_out = 64, 784, 500, 10
# Create a temporary directory for the onnx_models
with tempfile.TemporaryDirectory() as temporary_dir:
os.environ['ORTMODULE_SAVE_ONNX_PATH'] = temporary_dir
model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
ort_model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix='my_model'))
if mode == 'inference':
ort_model.eval()
x = torch.randn(N, D_in)
_ = ort_model(x)
# assert that the onnx models have been saved
assert os.path.exists(os.path.join(temporary_dir, f"my_model_torch_exported_{mode}.onnx"))
assert os.path.exists(os.path.join(temporary_dir, f"my_model_optimized_{mode}.onnx"))
del os.environ['ORTMODULE_SAVE_ONNX_PATH']
@pytest.mark.parametrize("mode", ['training', 'inference'])
def test_debug_options_save_onnx_models_cwd(mode):
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
ort_model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix='my_cwd_model'))
if mode == 'inference':
ort_model.eval()
x = torch.randn(N, D_in)
_ = ort_model(x)
# assert that the onnx models have been saved
if mode == 'training':
assert os.path.exists(os.path.join(os.getcwd(), f"my_cwd_model_optimized_pre_grad_{mode}.onnx"))
os.remove(os.path.join(os.getcwd(), f"my_cwd_model_optimized_pre_grad_{mode}.onnx"))
assert os.path.exists(os.path.join(os.getcwd(), f"my_cwd_model_torch_exported_{mode}.onnx"))
assert os.path.exists(os.path.join(os.getcwd(), f"my_cwd_model_optimized_{mode}.onnx"))
assert os.path.exists(os.path.join(os.getcwd(), f"my_cwd_model_execution_model_{mode}.onnx"))
os.remove(os.path.join(os.getcwd(), f"my_cwd_model_torch_exported_{mode}.onnx"))
os.remove(os.path.join(os.getcwd(), f"my_cwd_model_optimized_{mode}.onnx"))
os.remove(os.path.join(os.getcwd(), f"my_cwd_model_execution_model_{mode}.onnx"))
def test_debug_options_save_onnx_models_validate_fail_on_non_writable_dir():
non_existent_directory = None
with tempfile.TemporaryDirectory() as temporary_dir:
non_existent_directory = temporary_dir
os.environ['ORTMODULE_SAVE_ONNX_PATH'] = non_existent_directory
with pytest.raises(Exception) as ex_info:
_ = DebugOptions(save_onnx=True, onnx_prefix='my_model')
assert f"Directory {non_existent_directory} is not writable." in str(ex_info.value)
del os.environ['ORTMODULE_SAVE_ONNX_PATH']
def test_debug_options_save_onnx_models_validate_fail_on_non_str_prefix():
prefix = 23
with pytest.raises(Exception) as ex_info:
_ = DebugOptions(save_onnx=True, onnx_prefix=prefix)
assert f"Expected name prefix of type str, got {type(prefix)}." in str(ex_info.value)
def test_debug_options_save_onnx_models_validate_fail_on_no_prefix():
with pytest.raises(Exception) as ex_info:
_ = DebugOptions(save_onnx=True)
assert f"onnx_prefix must be provided when save_onnx is set." in str(ex_info.value)
def test_debug_options_log_level():
# NOTE: This test will output verbose logging
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
ort_model = ORTModule(model, DebugOptions(log_level=LogLevel.VERBOSE))
x = torch.randn(N, D_in)
_ = ort_model(x)
# assert that the logging is done in verbose mode
assert ort_model._torch_module._execution_manager(True)._debug_options.logging.log_level == LogLevel.VERBOSE
def test_debug_options_log_level_os_environment():
# NOTE: This test will output info logging
os.environ['ORTMODULE_LOG_LEVEL'] = 'INFO'
N, D_in, H, D_out = 64, 784, 500, 10
model = NeuralNetSinglePositionalArgument(D_in, H, D_out)
ort_model = ORTModule(model)
x = torch.randn(N, D_in)
_ = ort_model(x)
# assert that the logging is done in info mode
assert ort_model._torch_module._execution_manager(True)._debug_options.logging.log_level == LogLevel.INFO
del os.environ['ORTMODULE_LOG_LEVEL']
def test_debug_options_log_level_validation_fails_on_type_mismatch():
log_level = 'some_string'
with pytest.raises(Exception) as ex_info:
_ = DebugOptions(log_level=log_level)
assert f"Expected log_level of type LogLevel, got {type(log_level)}." in str(ex_info.value)
def test_load_balancing_data_sampler_balances_import():
samples_and_complexities = [(torch.FloatTensor([val]), torch.randint(0, 100, (1,)).item()) for val in range(100)]
dataset = MyDataset(samples_and_complexities)
def complexity_fn(sample):
return sample[1]
data_sampler = LoadBalancingDistributedSampler(
dataset, complexity_fn=complexity_fn, world_size=2, rank=0, shuffle=False
)
batch_size = 12
def batch_fn(indices):
nonlocal batch_size
batches = []
for batch_index_begin in range(0, len(indices), batch_size):
batch_index_end = min(batch_index_begin + batch_size, len(indices))
batches.append(indices[batch_index_begin:batch_index_end])
return batches
batch_sampler = LoadBalancingDistributedBatchSampler(data_sampler, batch_fn)