diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 96183bb7a172..9e5f8c1e2311 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -258,7 +258,7 @@ def _smooth_distribution(p, eps=0.0001): # pylint: disable=line-too-long -def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255): +def _get_optimal_threshold(arr, quantized_dtype, num_bins=8001, num_quantized_bins=255): """Given a dataset, find the optimal threshold for quantizing it. The reference distribution is `q`, and the candidate distribution is `p`. `q` is a truncated version of the original distribution. @@ -285,6 +285,10 @@ def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255): max_val = np.max(arr) th = max(abs(min_val), abs(max_val)) + if min_val >= 0 and quantized_dtype in ['auto', 'uint8']: + # We need to move negative bins to positive bins to fit uint8 range. + num_quantized_bins = num_quantized_bins * 2 + 1 + hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-th, th)) zero_bin_idx = num_bins // 2 num_half_quantized_bins = num_quantized_bins // 2 @@ -348,7 +352,7 @@ def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255): # pylint: enable=line-too-long -def _get_optimal_thresholds(nd_dict, num_bins=8001, num_quantized_bins=255, logger=None): +def _get_optimal_thresholds(nd_dict, quantized_dtype, num_bins=8001, num_quantized_bins=255, logger=None): """Given a ndarray dict, find the optimal threshold for quantizing each value of the key.""" if stats is None: raise ImportError('scipy.stats is required for running entropy mode of calculating' @@ -364,7 +368,7 @@ def _get_optimal_thresholds(nd_dict, num_bins=8001, num_quantized_bins=255, logg for name in layer_names: assert name in nd_dict min_val, max_val, min_divergence, opt_th = \ - _get_optimal_threshold(nd_dict[name], num_bins=num_bins, + _get_optimal_threshold(nd_dict[name], quantized_dtype, num_bins=num_bins, num_quantized_bins=num_quantized_bins) del nd_dict[name] # release the memory of ndarray if min_val < 0: @@ -521,7 +525,7 @@ def quantize_model(sym, arg_params, aux_params, logger=logger) logger.info('Collected layer outputs from FP32 model using %d examples' % num_examples) logger.info('Calculating optimal thresholds for quantization') - th_dict = _get_optimal_thresholds(nd_dict, logger=logger) + th_dict = _get_optimal_thresholds(nd_dict, quantized_dtype, logger=logger) elif calib_mode == 'naive': th_dict, num_examples = _collect_layer_output_min_max( mod, calib_data, include_layer=calib_layer, max_num_examples=num_calib_examples, diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index e4cc277d0ae3..e2457c7a4d50 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -713,10 +713,11 @@ def test_optimal_threshold_adversarial_case(): # The worst case for the optimal_threshold function is when the values are concentrated # at one edge: [0, 0, ..., 1000]. (histogram) # We want to make sure that the optimal threshold in this case is the max. - arr = np.array([2]*1000) - res = mx.contrib.quant._get_optimal_threshold(arr, num_quantized_bins=5) - # The threshold should be 2. - assert res[3] - 2 < 1e-5 + arr = np.array([2] * 1000) + for dtype in ['uint8', 'int8', 'auto']: + res = mx.contrib.quant._get_optimal_threshold(arr, dtype, num_quantized_bins=5) + # The threshold should be 2. + assert res[3] - 2 < 1e-5 @with_seed() @@ -728,11 +729,12 @@ def get_threshold(nd): max_nd = mx.nd.max(nd) return mx.nd.maximum(mx.nd.abs(min_nd), mx.nd.abs(max_nd)).asnumpy() - nd_dict = {'layer1': mx.nd.uniform(low=-10.532, high=11.3432, shape=(8, 3, 23, 23), dtype=np.float64)} - expected_threshold = get_threshold(nd_dict['layer1']) - th_dict = mx.contrib.quant._get_optimal_thresholds(nd_dict) - assert 'layer1' in th_dict - assert_almost_equal(np.array([th_dict['layer1'][1]]), expected_threshold, rtol=1e-2, atol=1e-4) + for dtype in ['uint8', 'int8', 'auto']: + nd_dict = {'layer1': mx.nd.uniform(low=-10.532, high=11.3432, shape=(8, 3, 23, 23), dtype=np.float64)} + expected_threshold = get_threshold(nd_dict['layer1']) + th_dict = mx.contrib.quant._get_optimal_thresholds(nd_dict, dtype) + assert 'layer1' in th_dict + assert_almost_equal(np.array([th_dict['layer1'][1]]), expected_threshold, rtol=1e-2, atol=1e-4) if __name__ == "__main__":