Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use concept-erasure implementation of LEACE and SAL #252

Merged
merged 8 commits into from
Jul 10, 2023
Merged

Conversation

norabelrose
Copy link
Member

Now that concept-erasure is on PyPI, we can outsource our ConceptEraser implementation to that repo.

This PR makes LEACE, rather than SAL, the default method for pseudolabel and prompt template normalization. I should probably add a config option to change it though.

@@ -265,12 +265,12 @@ def fit(self, hiddens: Tensor) -> float:
self.norm.update(
x=x_neg,
# Independent indicator for each (template, pseudo-label) pair
y=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1),
z=torch.cat([torch.zeros_like(prompt_ids), prompt_ids], dim=-1),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed; replaced y with z for ccs

Copy link
Collaborator

@AlexTMallen AlexTMallen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice :) no need for pseudolabels at inference time

@@ -40,8 +39,7 @@ def apply_to_layer(
experiment_dir = elk_reporter_dir() / self.source

reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
reporter = Reporter.load(reporter_path, map_location=device)
reporter.eval()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the eval() here still neded?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No because CcsReporter doesn't actually have any submodules like nn.BatchNorm or nn.Dropout whose behavior changes due to eval()

@norabelrose norabelrose merged commit a88c01a into main Jul 10, 2023
6 checks passed
@norabelrose norabelrose deleted the leace branch July 10, 2023 05:18
@artkpv
Copy link
Contributor

artkpv commented Jul 26, 2023

JFI, my probes / reporters now won't load with this PR because I used Reporter.load. https://github.com/EleutherAI/elk/pull/252/files#diff-d08b84a509f043deeb98c9c642f692fffbd1967486738d2ff242b7897eb0b1ae

@norabelrose
Copy link
Member Author

JFI, my probes / reporters now won't load with this PR because I used Reporter.load. https://github.com/EleutherAI/elk/pull/252/files#diff-d08b84a509f043deeb98c9c642f692fffbd1967486738d2ff242b7897eb0b1ae

Sorry about that, we can't really guarantee backward compatibility at this point. You should be able to load the reporters with an older commit and extract the raw weights if necessary.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants