-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add platt scaling for burns + fix for leace #288
Conversation
elk/training/train.py
Outdated
rearrange(first_train_h, "n v k d -> (n v k) d"), | ||
) | ||
labels = to_one_hot(train_gt, k) | ||
labels = repeat(train_gt, "n -> n v k", v=v, k=k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we repeat the labels along the choices dimension? I think this means all dimensions of the one-hot prediction are expected to have the same value.
CCS only supports k=2, so we can just deal with that case after after asserting it true.
I think this line should be labels = to_one_hot(repeat(train_gt, "n -> (n v)", v=v), 2).flatten()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
labels = to_one_hot(repeat(train_gt, "n -> (n v)", v=v), 2).flatten()
That wouldn't work, since we dimensions have to be the same for the labels and first_train_h, which is a 3D tensor in the case of CCS. But yeah let me see what I can do instead
elk/training/ccs_reporter.py
Outdated
@@ -88,6 +88,8 @@ def __init__( | |||
num_variants: int = 1, | |||
): | |||
super().__init__() | |||
self._is_training = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self._is_training is false during the call to platt_scale, then it won't work because the platt scaling parameter needs to be used for it to be updated in the backwards pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it not?
if self._is_training:
return raw_scores
else:
platt_scaled_scores = raw_scores.mul(self.scale).add(self.bias).squeeze(-1)
return platt_scaled_scores
Anyway seems simple to test by comparing the results with and without the platt scaling params added
elk/training/train.py
Outdated
rearrange(first_train_h, "n v k d -> (n v k) d"), | ||
) | ||
labels = to_one_hot(train_gt, k) | ||
labels = repeat(train_gt, "n -> n v k", v=v, k=k) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 🔴 This implementation is incorrect.
labels = repeat(train_gt, "n -> n v k", v=v, k=k) | |
labels = to_one_hot(repeat(train_gt, "n -> n v", v=v), k) |
- Yours replicates the label values in both the second and third dimensions.
- The corrected one first replicates the label values in the second dimension, then converts them into a one-hot representation in the third dimension.
elk/training/ccs_reporter.py
Outdated
@@ -88,6 +88,8 @@ def __init__( | |||
num_variants: int = 1, | |||
): | |||
super().__init__() | |||
self._is_training = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 I think this approach basically works but perhaps the wrapper approach that the EigenFitter/Reporter uses is cleaner
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In fact maybe we should just do that and make CcsReporter.fit()
return a Reporter
. Maybe need to have add a param that disables the eraser.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I removed the _is.training again.
In fact maybe we should just do that and make CcsReporter.fit() return a Reporter. Maybe need to have add a param that disables the eraser.
We can take a look at that. Could be maybe a new pull request or the same...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
elk/training/train.py
Outdated
rearrange(first_train_h, "n v k d -> (n v k) d"), | ||
) | ||
labels = to_one_hot(train_gt, k) | ||
labels = repeat(labels, "n k -> n v k", v=v) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟢 nitpick: bugs like the one made here are the reason I dislike reassigning things to the same variable. Either the name should be changed:
one_hotted_labels = to_one_hot(train_gt, k)
or it should just be inlined
Fixes platt scaling for Leace in CCS, where even during the training of the reporters the platt scaling parameters are multiplied and added to the raw_scores.
Adds platt_scaling when using burns normalization