TQT's pytorch implementation.
Now availabel at https://pypi.org/project/tqt/0.1.2/!
Networks quantized via this package could be find at https://github.com/PannenetsF/QuantizationPool.
function
is a re-impletement of torch.nn.modules
. Besides all the args used in the original function, a quantized function get 2 kind of optional arguments: bit_width
and retrain
.
bit_width
has 2 type: weight/bias or activation.
If the retrain
is True
, the Module will be in Retrain Mode, with the log2_t
trainable. Else, in Static Mode, the log2_t
are determined by initialization and not trainable.
Provide 3 ways to initialize the threshold: init_max
, init_kl_j
, init_3sd
.
To initialize the weight and threshold correctly, please follow the method to build a network with TQT.
xxxxx.xxx
.
As we know, the input(k1-b-m1-p) multiplied by weights(k2-b-m2-p) will be like (k1+k2)-b-(m1+m2)-p. And then we need to accumlate all these inter-output. Assuming there are n inter-output, the bitwidth will need at least ceil(log2(n))
more bits to make sure the data should not overflow. For a typical network, a 3x3x64 kernel will call for 10 more bits, then the inter-output is (k1+k2+10)-b-(m1+m2)-p. Considering the 2-pow, it will be better to use 32 bit.
So in the code, we will have:
inter = qsigned(inter, self.weight_log2_t + input_log2_t, self.inter_bit_width)
But in need of keep the network unchanged, we cannot treat input_log2_t
as a argument. Then
input_log2_t = math.ceil(math.log2(math.ceil(input.max())))
To get output of each tqt module, the network should be flat, that is, no nn.Sequential
, no nested nn.ModuleList
.
You'd better use nn.ModuleList
and append every operation after it. If there're some operations that are nn.ModuleList
of some operation, you can use .extend
to keep the network flat.
Much often we need to re-train a network, and we can do a quick job with lambda
. As you can see in the file lenet.py
, with the change of the wrapper, a net could be simply converted into a quantized one.
Just 3 steps!
- Add hook for output storage.
- Adjust the threshold via
tqt.threshold
- Remove hook.
Supposed that you have a pretrained model, and it's hard to change all keys in its state dictionary. More often, it may contain lots of nn.Module
but not specially nn.ModuleList
. A dirty but useful way is simply change the import torch.nn as nn
to import tqt.function as nn
. You can get a quant-style network with all previous keys unchanged!
All you need to do is add a list self.proc
to the network module.
Through tqt.threshold.add_hook_general
, we can add hook for any network if you add a list containing all operations used in forward.
Let's get some example:
# noquant.py
import torch.nn as nn
class myNet(nn.Module):
def __init__(self, args):
self.op1 = ...
self.op2 = ...
if args:
self.op_args = ...
...
def forward(self, x):
...
and
# quant.py
import tqt.function as nn
class myNet(nn.Module):
def __init__(self):
self.op1 = ...
self.op2 = ...
if args:
self.op_args = ...
self.proc.append('op_args')
...
def forward(self, x):
...
We can load and retrain by:
# main.py
import tqt
from unquant import myNet as oNet
from quant import myNet as qNet
handler = tqt.threshold.hook_handler
train(oNet) ...
tqt.threshold.add_hook(oNet, 'oNet', handler)
qNet.load_state_dict(oNet.state_dict(), strict=False)
for (netproc, qnetproc) in zip(funct_list, qfunct_list):
tqt.threshold.init.init_network(netproc, qnetproc, show=True)
retrain(qNet)
With a network built by method metioned, we may need use a quant/or-not version. So we implement tqt.utils.make_net_quant_or_not
to change its mode easily.
Normally we wil disable the quantization of batchnorm modules, you can simply exclude the bn in tqt.utils.make_net_quant_or_not
like:
tqt.utils.make_net_quant_or_not(net,
'net',
quant=True,
exclude=[torch.nn.BatchNorm2d],
show=True)
Always, we need to do analysis over activations and weights to choose a proper way to quantize the network. We implement some function do these. It's recommend do this with tensorboard.
tqt.threshold.get_hook
will get all hook output got from the forward with their module name as a tuple.
net = QNet()
tqt.utils.make_net_quant_or_not(net, quant=True)
tqt.threshold.add_hook(net, 'net', tqt.threshold.hook_handler)
net.cuda()
for i, (images, labels) in enumerate(data_test_loader):
net(images.cuda())
break
out = get_hook(net, 'net', show=True)
for i in out:
print(i[0], i[1].shape)
writer.add_histogram(i[0], i[1].cpu().data.flatten().detach().numpy())
Similarly, the weights could be get from net.named_parameters()
.
It will be great of you to make this project better! There is some ways to contribute!
- To start with, issues and feature request could let maintainers know what's wrong or anything essential to be added.
- If you use the package in you work/repo, just cite the repo and add a dependency note!
- You can add some function in
torch.nn
likeHardTanh
and feel free to open a pull request! The code style is simple as here.
The initial version of tqt-torch is developed by Jinyu Bai.
The beta version was tested by Jinghan Xu, based on whose feedback a lot of bugs were fixed.
The original papar could be find at Arxiv, Trained Quantization Thresholds for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks.