Skip to content
/ TEGNAS Public
forked from VITA-Group/TEGNAS

"Understanding and Accelerating Neural Architecture Search with Training-Free and Theory-Grounded Metrics" by Wuyang Chen, Xinyu Gong, Yunchao Wei, Humphrey Shi, Zhicheng Yan, Yi Yang, and Zhangyang Wang

License

Notifications You must be signed in to change notification settings

KYE2138/TEGNAS

 
 

Repository files navigation

Understanding and Accelerating Neural Architecture Search with Training-Free and Theory-Grounded Metrics [PDF]

MIT licensed

Wuyang Chen*, Xinyu Gong*, Yunchao Wei, Humphrey Shi, Zhicheng Yan, Yi Yang, and Zhangyang Wang

Note

  1. This repo is still under development. Scripts are excutable but some CUDA errors may occur.
  2. Due to IP issue, we can only release the code for NAS via reinforcement learning and evolution, but not FP-NAS.

Overview

We present TEG-NAS, a generalized training-free neural architecture search method that can significantly reduce time cost of popular search methods (no gradient descent at all!) with high-quality performance.

Highlights:

  • Trainig-free NAS: for three popular NAS methods (Reinforcement Learning, Evolution, Differentiable), we adopt our TEG-NAS method into them and achieved extreme fast neural architecture search without a single gradient descent.
  • Bridging the theory-application gap: We identified three training-free indicators to rank the quality of deep networks: the condition number of their NTKs ("Trainability"), and the number of linear regions in their input space ("Expressivity"), and the error of NTK kernel regression ("Generalization").

Prerequisites

  • Ubuntu 16.04
  • Python 3.6.9
  • CUDA 11.0 (lower versions may work but were not tested)
  • NVIDIA GPU + CuDNN v7.6

This repository has been tested on GTX 1080Ti. Configurations may need to be changed on different platforms.

Installation

  • Clone this repo:
git clone https://github.com/chenwydj/TEGNAS.git
cd TEGNAS
  • Install dependencies:
pip install -r requirements.txt

Usage

0. Prepare the dataset

  • Please follow the guideline here to prepare the CIFAR-10/100 and ImageNet dataset, and also the NAS-Bench-201 database.
  • Remember to properly set the TORCH_HOME and data_paths in the prune_launch.py.

1. Search

Reinforcement Learning
python reinforce_launch.py --space nas-bench-201 --dataset cifar10 --gpu 0
python reinforce_launch.py --space nas-bench-201 --dataset cifar100 --gpu 0
python reinforce_launch.py --space nas-bench-201 --dataset ImageNet16-120 --gpu 0
Evolution
python evolution_launch.py --space nas-bench-201 --dataset cifar10 --gpu 0
python evolution_launch.py --space nas-bench-201 --dataset cifar100 --gpu 0
python evolution_launch.py --space nas-bench-201 --dataset ImageNet16-120 --gpu 0
Reinforcement Learning
python reinforce_launch.py --space darts --dataset cifar10 --gpu 0
python reinforce_launch.py --space darts --dataset imagenet-1k --gpu 0
Evolution
python evolution_launch.py --space darts --dataset cifar10 --gpu 0
python evolution_launch.py --space darts --dataset imagenet-1k --gpu 0

2. Evaluation

  • For architectures searched on nas-bench-201, the accuracies are immediately available at the end of search (from the console output).
  • For architectures searched on darts, please use DARTS_evaluation for training the searched architecture from scratch and evaluation. Genotypes of our searched architectures are listed in genotypes.py

Citation

@inproceedings{chen2021tegnas,
  title={Understanding and Accelerating Neural Architecture Search with Training-Free and Theory-Grounded Metrics},
  author={Chen, Wuyang and Gong, Xinyu and Wei, Yunchao and Shi, Humphrey and Yan, Zhicheng and Yang, Yi and Wang, Zhangyang},
  year={2021}
}

Acknowledgement

About

"Understanding and Accelerating Neural Architecture Search with Training-Free and Theory-Grounded Metrics" by Wuyang Chen, Xinyu Gong, Yunchao Wei, Humphrey Shi, Zhicheng Yan, Yi Yang, and Zhangyang Wang

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%