Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Are GridSearch using the update! method? #82

Closed
leonardtschora opened this issue Sep 23, 2020 · 3 comments
Closed

Are GridSearch using the update! method? #82

leonardtschora opened this issue Sep 23, 2020 · 3 comments
Projects

Comments

@leonardtschora
Copy link

Hi everyone,

While benchmarking some toy grid searches, I obtained odd results, and it seemed to me that performing a grid search using a TunedModel is slower than it should be.

The idea is to run a grid search over a model that implements the update method, and avoid re-fitting models from scratch for each sampled hyper-parameter set. More precisly, by arranging the grid search so it only changes 1 hyper-parameter per iteration.

Here is a sample code on a toy problem, using EnsembleModel and DecisionTree. The idea is to play with the number of estimators of the Ensemble, and find the optimal one. While a naïve approach would be to restart training form scratch for each new number of estimators, a smarter approach would be to start at the lowest number, and add 1 estimator at each iteration. The updating cost of the ensemble model is then very low (only 1 new estimator to fit) and we expect the Grid search to be much faster.

using MLJ, BenchmarkTools, MLJModels

X = MLJ.table(rand(100, 10));
y = 2X.x1 - X.x2 + 0.05*rand(100);
tree_model = @load  DecisionTreeRegressor
RNG = 90125

Solving this problem using the MLJ interface:

# Tuned Model
forest_model = EnsembleModel(atom=tree_model, rng=RNG)
r = range(forest_model, :n; values=[i for i in 4:103]);
all_rows = collect(1:100)

self_tuning_forest_model = TunedModel(model=forest_model,
                                      tuning=Grid(shuffle=false),
                                      resampling=[(all_rows, all_rows)],
                                      range=r,
                                      measure=rms);

self_tuning_forest = machine(self_tuning_forest_model, X, y);
fit!(self_tuning_forest, verbosity=1)
m1 = self_tuning_forest.report.best_history_entry.measurement[1]
n1 = self_tuning_forest.report.best_history_entry.model.n

@btime begin
    self_tuning_forest = machine(self_tuning_forest_model, X, y);
    fit!(self_tuning_forest, verbosity=0)
end

Then, I have implemented 2 manual grid searches. The first is not intelligent and will restart from scratch, the second will only mutate the n_estimator field of the EnsembleModel and update the associated machine.

# Get the ranges values for n_estimator
values = self_tuning_forest.report.plotting.parameter_values

# Dumb Grid
results = Vector{Float64}(undef, 100)
forest_model = EnsembleModel(atom=tree_model, rng=RNG)
for i in values
    forest_model.n = i
    mach = machine(forest_model, X, y)        
    fit!(mach, verbosity=0)
    rms(predict(mach, X), y)
    results[i-3] = rms(predict(mach, X), y)
end

m3, ind = findmin(results)
n3 = values[ind]

@btime begin
    for i in values        
        forest_model.n = i
        mach = machine(forest_model, X, y)        
        fit!(mach, verbosity=0)
        rms(predict(mach, X), y)
    end
end

# Smart retraining
results = Vector{Float64}(undef, 100)
forest_model = EnsembleModel(atom=tree_model, rng=RNG)
mach = machine(forest_model, X, y)
for i in values
    forest_model.n = i
    fit!(mach, verbosity=0)
    results[i-3] = rms(predict(mach, X), y)
end

m2, ind = findmin(results)
n2 = values[ind]

@btime begin
    mach = machine(forest_model, X, y)
    for i in values
        forest_model.n = i
        fit!(mach, verbosity=0)
        rms(predict(mach, X), y)
    end
end

The obtained results are the following:

Measure Tuned Model Dumb Grid Smart Grid
Fitting Time 737ms 737ms 79ms
Metric (rms) 0.097 0.097 0.099
Optimal n_estimator 11 11 4

Given those results, it seems to me that the Grid Search using a TunedModel is just performing a naïve search by retraining every new model from scratch, instead of re-fitting them. We can also see that we can improve the speed of the grid search by a factor of 10 on this toy example.

I started delving into the implementation details, and found that the problem was not coming form the Grid implementation. The Grid creates a list of models to train by cloning and mutating them, but if we mutate the model field of a machine and set it to a new one, the machine should still update itself as in this example:

### Cloning model, keeping the machine
forest_model = EnsembleModel(atom=tree_model, rng=RNG)
mach = machine(forest_model, X, y)        
fit!(mach)

forest_model_2 = deepcopy(forest_model)
forest_model_2.n +=1
mach.model = forest_model_2
fit!(mach)

Then I started looking at the TunedModel code, but things are becoming much more complicated and I'm afraid I would not be able to understand it alone.

As always, thanks for the time and support you provide me.

@ablaom
Copy link
Member

ablaom commented Oct 13, 2020

Sorry, I guess this one fell under the radar.

Just skimmed your comment but here's a quick reply, which hopefully addressed your point:

In general, because one is resampling to get performance estimates for each model (set of hyperparameters) you can't make this "intelligent" except in the special case that resampling isa Holdout (and no randomisation), eg resampling=Holdout() or, in the case you have, that resampling consists of a single test/train pair. However, update for Resampler model wrapper (a private object) is only overloaded for Holdout and not your special case, and hence is slow. I guess one could overload for your case also, but this is probably not a big use-case. PR welcome.

Try your benchmarks with resampling=Holdout() and see if you get an improvement.

Does that make sense?

@leonardtschora
Copy link
Author

Hi, thanks for your reply.

I have tried to use the Holdout resampling and it yields the expected results: 80ms, showing that it performs intelligent refiting.
Here is the code:

self_tuning_forest_model_holdout = TunedModel(model=forest_model,
                                      tuning=Grid(shuffle=false),
                                      resampling=Holdout(; fraction_train=0.99, shuffle=false),
                                      range=r,
                                      measure=rms);
self_tuning_forest_holdout = machine(self_tuning_forest_model_holdout, X, y);
fit!(self_tuning_forest_holdout, verbosity=1)

@btime begin
    self_tuning_forest_holdout = machine(self_tuning_forest_model_holdout, X, y);
    fit!(self_tuning_forest_holdout, verbosity=0)
end

I think the strategy here would be to iterate first on the different (train, test) datasets and then on hyper-parameters (this is an example from one of my grid search):

    train_test_pairs = train_test_pairs(my_sampler, X, y)
    Threads.@threads for k in collect(1:n_cv)        
        (train_indices, val_indices) = train_test_pairs[k]
        train_set = selectrows(X, train_indices)
        train_labels = selectrows(y, train_indices)

        val_set = selectrows(X, val_indices)
        val_labels = selectrows(y, val_indices)

        model = MyModel()
        mach = machine(model, train_set, train_labels)
       
        for every hyper-parameter configuration to try
              mutate the model's attributes 
              update the machine mach
              compute the error on the validation set
              store the error
         end
    end

Then, all you have to do is compute the average error across all datasets. It worked well fro my use case.
Let me know if you have any updates on this subject, I will try to spare time and make a proper implementation of this.

@ablaom
Copy link
Member

ablaom commented May 26, 2021

As explained above, the best we can expect here is for user-specified holdout train/test pairs to work in addition to Holdout resampling strategy. This PR resolves this (also in the case that Holdout includes shuffling), so closing.

@ablaom ablaom added this to tracking/discussion/metaissues/misc in General Sep 22, 2021
@ablaom ablaom closed this as completed Sep 25, 2023
General automation moved this from tracking/discussion/metaissues/misc to Done Sep 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Status: Done
Development

No branches or pull requests

2 participants