Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add platt scaling for burns + fix for leace #288

Merged
merged 6 commits into from
Aug 28, 2023
Merged

Conversation

lauritowal
Copy link
Collaborator

  • Fixes platt scaling for Leace in CCS, where even during the training of the reporters the platt scaling parameters are multiplied and added to the raw_scores.

  • Adds platt_scaling when using burns normalization

@lauritowal lauritowal marked this pull request as ready for review August 23, 2023 15:20
elk/training/train.py Outdated Show resolved Hide resolved
rearrange(first_train_h, "n v k d -> (n v k) d"),
)
labels = to_one_hot(train_gt, k)
labels = repeat(train_gt, "n -> n v k", v=v, k=k)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we repeat the labels along the choices dimension? I think this means all dimensions of the one-hot prediction are expected to have the same value.

CCS only supports k=2, so we can just deal with that case after after asserting it true.

I think this line should be labels = to_one_hot(repeat(train_gt, "n -> (n v)", v=v), 2).flatten()

Copy link
Collaborator Author

@lauritowal lauritowal Aug 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

labels = to_one_hot(repeat(train_gt, "n -> (n v)", v=v), 2).flatten()
That wouldn't work, since we dimensions have to be the same for the labels and first_train_h, which is a 3D tensor in the case of CCS. But yeah let me see what I can do instead

@@ -88,6 +88,8 @@ def __init__(
num_variants: int = 1,
):
super().__init__()
self._is_training = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self._is_training is false during the call to platt_scale, then it won't work because the platt scaling parameter needs to be used for it to be updated in the backwards pass

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it not?

if self._is_training:
    return raw_scores
else:
    platt_scaled_scores = raw_scores.mul(self.scale).add(self.bias).squeeze(-1)
    return platt_scaled_scores

Anyway seems simple to test by comparing the results with and without the platt scaling params added

elk/training/train.py Outdated Show resolved Hide resolved
rearrange(first_train_h, "n v k d -> (n v k) d"),
)
labels = to_one_hot(train_gt, k)
labels = repeat(train_gt, "n -> n v k", v=v, k=k)
Copy link
Collaborator

@derpyplops derpyplops Aug 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 🔴 This implementation is incorrect.
Suggested change
labels = repeat(train_gt, "n -> n v k", v=v, k=k)
labels = to_one_hot(repeat(train_gt, "n -> n v", v=v), k)
  • Yours replicates the label values in both the second and third dimensions.
  • The corrected one first replicates the label values in the second dimension, then converts them into a one-hot representation in the third dimension.

@@ -88,6 +88,8 @@ def __init__(
num_variants: int = 1,
):
super().__init__()
self._is_training = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 I think this approach basically works but perhaps the wrapper approach that the EigenFitter/Reporter uses is cleaner

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact maybe we should just do that and make CcsReporter.fit() return a Reporter. Maybe need to have add a param that disables the eraser.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I removed the _is.training again.

In fact maybe we should just do that and make CcsReporter.fit() return a Reporter. Maybe need to have add a param that disables the eraser.

We can take a look at that. Could be maybe a new pull request or the same...

Copy link
Collaborator

@AlexTMallen AlexTMallen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

rearrange(first_train_h, "n v k d -> (n v k) d"),
)
labels = to_one_hot(train_gt, k)
labels = repeat(labels, "n k -> n v k", v=v)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟢 nitpick: bugs like the one made here are the reason I dislike reassigning things to the same variable. Either the name should be changed:
one_hotted_labels = to_one_hot(train_gt, k)
or it should just be inlined

@lauritowal lauritowal requested review from norabelrose and removed request for norabelrose August 27, 2023 22:19
@lauritowal lauritowal merged commit 4a6b654 into main Aug 28, 2023
6 checks passed
@lauritowal lauritowal deleted the fix-platt-scaling-ccs branch August 28, 2023 13:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants