Skip to content

Latest commit

 

History

History
39 lines (27 loc) · 2.3 KB

README.md

File metadata and controls

39 lines (27 loc) · 2.3 KB

Neural Network Pruning by Gradient Descent

This repository will contain the PyTorch implementation of:

Neural Network Pruning by Gradient Descent
Zhang Zhang, Ruyi Tao, Jiang Zhangyejiu*

(*: Corresponding author)
Download PDF

Abstract:

The rapid increase in the parameters of deep learning models has led to significant costs, challenging computational efficiency and model interpretability. In this paper, we introduce a novel and straightforward neural network pruning framework that incorporates the Gumbel-Softmax technique. This framework enables the simultaneous optimization of a network's weights and topology in an end-to-end process using stochastic gradient descent. Empirical results demonstrate its exceptional compression capability, maintaining high accuracy on the MNIST dataset with only 0.15% of the original network parameters. Moreover, our framework enhances neural network interpretability, not only by allowing easy extraction of feature importance directly from the pruned network but also by enabling visualization of feature symmetry and the pathways of information propagation from features to outcomes. Although the pruning strategy is learned through deep learning, it is surprisingly intuitive and understandable, focusing on selecting key representative features and exploiting data patterns to achieve extreme sparse pruning. We believe our method opens a promising new avenue for deep learning pruning and the creation of interpretable machine learning systems.

Requirements

  • Python 3.7.0
  • Pytorch 2.0.1

To Understand how our model works, please see this tutorial:

Tutorial Here

From this turorial, you will see how to prun the network to show the relationship between features and labels directly.

Like the figure shows below:

Cite

If you use this code in your own work, please cite our paper:

Zhang, Zhang, Ruyi Tao, and Jiang Zhang. "Neural Network Pruning by Gradient Descent." arXiv preprint arXiv:2311.12526 (2023).