Skip to content

Machine Learning package allowing faster cross-validation

License

Notifications You must be signed in to change notification settings

PiotrekGa/pruned-cv

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

96 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

pruned-cv

Introduction

The package implements Pruned Cross-Validation technique, which verifies whether all folds are worth calculating. It's components may be used as a standalone methods or as a part of hyperparameter optimization frameworks like Hyperopt or Optuna.

It proved to be more less three times faster than Scikit-Learn GridSearchCV and RandomizedSearchCV yielding the same results (see benchmarks section).

gs_vs_pgs

You can find a broader overview of the motivation an methodology under this directory or alternatively on Medium.

Motivation

The idea was to improve speed of hyperparameter optimization. All the methods which are based on Cross-Validation require big folds number (8 is an absolute minimum) to assure that the surrogate model (whether it's GridSearch, RandomSearch or a Bayesian model) does not overfit to the training set.

On the other hand Optuna proposes a mechanism of pruned learning for Artificial Neural Networks and Gradient Boosting Algorithms. It speeds the search process greatly but one issue with the method is that is prunes the trials based on a single validation sample. With relatively small datasets the model's quality variance may be high and lead to suboptimal hyperparameters choices. In addition it can only help to optimize an estimator and not the whople ML pipeline.

Pruned-cv is a compromise between brut-force methods like GridSearch and more elaborate, but vulnerable ones like Optuna's pruning.

How does it work?

You can see example of correlations between cumulative scores on folds with the final score:

correlations

You may find the whole study notebook here.

The package uses the fact that cumulative scores are highly correlated with the final score. In most cases after calculating 2 folds it's possible to predict the final score very accurately. If the partial score is very poor the cross-validation is stopped (pruned) and the final scores value is predicted based on best till now scores. If the partial score fits within some tolerance limit, next folds are evaluated.

Installation

The package works with Python 3. To install it clone the repository:

git clone [email protected]:PiotrekGa/pruned-cv.git

and run:

pip install -e pruned-cv

Examples

You can find example notebooks in the examples section of the repository.

Usage with Optuna

https://github.com/PiotrekGa/pruned-cv/blob/master/examples/Usage_with_Optuna.ipynb

Usage with Hyperopt

https://github.com/PiotrekGa/pruned-cv/blob/master/examples/Usage_with_Hyperopt.ipynb

Benchmarks

You can find benchmarks in examples section.

Grid Search CV

https://github.com/PiotrekGa/pruned-cv/blob/master/examples/GridSearchCV_Benchmark.ipynb

Randmized Search CV

https://github.com/PiotrekGa/pruned-cv/blob/master/examples/RandomizedSearchCV_Benchmark.ipynb

About

Machine Learning package allowing faster cross-validation

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages