Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a calibration error statistic #126

Merged
merged 21 commits into from
Mar 16, 2023
Merged

Add a calibration error statistic #126

merged 21 commits into from
Mar 16, 2023

Conversation

norabelrose
Copy link
Member

@norabelrose norabelrose commented Mar 14, 2023

Created CalibrationError class for computing the expected calibration error based on https://arxiv.org/abs/2012.08668. We use this to compute and log the ECE of the probe in train.py

Depends on #124

tests/test_math.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@AlexTMallen AlexTMallen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked over the paper and code and this looks good to me

@@ -190,7 +199,6 @@ def score(self, labels: Tensor, x_pos: Tensor, x_neg: Tensor) -> EvalResult:
# makes `num_variants` copies of each label, all within a single
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is no longer attached to its code

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@@ -182,6 +184,13 @@ def score(self, labels: Tensor, x_pos: Tensor, x_neg: Tensor) -> EvalResult:

pred_probs = self.predict(x_pos, x_neg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we are implicitly averaging over all the heads and variants? Also now just looking at this I'm not sure how this works when num_variants>1 and num_heads>1.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this doesn't average over heads and variants, we actually need to fully support the num_heads > 1 case, we don't right now

@norabelrose norabelrose merged commit bb8fadf into main Mar 16, 2023
@norabelrose norabelrose deleted the calibration branch March 16, 2023 06:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants