Skip to content

Commit

Permalink
Quant-aware training: Quantize bias to 32 bits (Hard-coded for now)
Browse files Browse the repository at this point in the history
  • Loading branch information
guyjacob committed Jan 23, 2019
1 parent 0dac3e0 commit c98df54
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
16 changes: 11 additions & 5 deletions distiller/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,17 +176,23 @@ def _prepare_model_impl(self):

curr_parameters = dict(module.named_parameters())
for param_name, param in curr_parameters.items():
if param_name.endswith('bias') and not self.quantize_bias:
continue
# Bias is usually quantized according to the accumulator's number of bits
# Temporary hack: Assume that number is 32 bits and hard-code it here
# TODO: Handle # of bits for bias quantization as "first-class" citizen, similarly to weights
n_bits = qbits.wts
if param_name.endswith('bias'):
if not self.quantize_bias:
continue
n_bits = 32
fp_attr_name = param_name
if self.train_with_fp_copy:
hack_float_backup_parameter(module, param_name, qbits.wts)
hack_float_backup_parameter(module, param_name, n_bits)
fp_attr_name = FP_BKP_PREFIX + param_name
self.params_to_quantize.append(_ParamToQuant(module, module_name, fp_attr_name, param_name, qbits.wts))
self.params_to_quantize.append(_ParamToQuant(module, module_name, fp_attr_name, param_name, n_bits))

param_full_name = '.'.join([module_name, param_name])
msglogger.info(
"Parameter '{0}' will be quantized to {1} bits".format(param_full_name, qbits.wts))
"Parameter '{0}' will be quantized to {1} bits".format(param_full_name, n_bits))

# If an optimizer was passed, assume we need to update it
if self.optimizer:
Expand Down
7 changes: 6 additions & 1 deletion tests/test_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,9 @@ def test_model_prep(model, optimizer, qbits, bits_overrides, explicit_expected_o
assert ptq.fp_attr_name in named_params

# Check number of bits registered correctly
assert ptq.num_bits == expected_qbits[ptq.module_name].wts
# Bias number of bits is hard-coded to 32 for now...
expected_n_bits = 32 if ptq.q_attr_name == 'bias' else expected_qbits[ptq.module_name].wts
assert ptq.num_bits == expected_n_bits

q_named_modules = dict(model.named_modules())
orig_named_modules = dict(m_orig.named_modules())
Expand Down Expand Up @@ -321,6 +323,9 @@ def test_param_quantization(model, optimizer, qbits, bits_overrides, explicit_ex
quantizable = num_bits is not None
if param_name.endswith('bias'):
quantizable = quantizable and quantize_bias
# Bias number of bits is hard-coded to 32 for now...
if quantizable:
num_bits = 32

if quantizable and train_with_fp_copy:
# "param_name" and "pre_quant_param" refer to the float copy
Expand Down

0 comments on commit c98df54

Please sign in to comment.