Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support L2 regularization & cross validation for Classifier #135

Merged
merged 5 commits into from
Mar 19, 2023

Conversation

norabelrose
Copy link
Member

This PR adds an l2_penalty parameter to Classifier.fit with a default value of 0.1— this is a change from the previous behavior, where there was no penalty by default.

I initially tried to exactly imitate the behavior of scikit-learn's C inverse regularization parameter, but I couldn't quite figure out how they're computing the final loss. Based on their code it seems like they're doing some weird thing where they're summing the BCE loss over the samples rather than taking the average and this changes the scale of everything, making it dependent on the number of samples. But that didn't seem to give the exact same results either, so idk. I gave up on exactly imitating it— the tests only check that when l2_penalty is set to 0.0, the results are ~the same.

This PR also adds a relatively well optimized fit_cv method that uses warm-starting to get at least a 2x speed up over a naive approach where you start from a zero initialization every time. I initially wanted to parallelize this code over the folds but this seemed like it would be a real pain in the ass that would complicate the code substantially, and I'm not sure there would be a significant speed boost at the end of the day (at least not without rewriting PyTorch's LBFGS optimizer which I don't want to do right now).

There is a question of whether we want to use fit_cv by default in train.py. I think we probably should, but it does get us back into the territory where Classifier is taking up more compute than the actual VINC algorithm. At the moment the code does use fit_cv and doesn't actually give you an option to turn this off (which should probably be changed).

Copy link
Collaborator

@AlexTMallen AlexTMallen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants