-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Measurement tampering #33
Conversation
We're now just padding all inputs to the maximum sequence length as a workaround. Lots of room for improvement here, e.g. would be great if TransformerLens let us pass kwargs through to to_tokens (to save a bunch of copied code), and ideally we'd use the maximum sequence length over all training examples instead. Also looks like memory usage is still increasing slightly, but it's now on the order of 200MB for one epoch on the IMDB dataset with Pythia-14m, so at least some local testing is totally feasible now. For serious training, we'll use CUDA anyway (but this remaining memory leak might of course also affect that if it's a different source, should keep an eye out)
Mainly better padding, and allow freezing the model itself. Not sure yet whether either one is necessary, but at least I now get 80% accuracy pretty quickly on a CPU, will run bigger experiments soon.
…g, I think bugs are resolved in main?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few minor comments but mostly looks good for now! (of course there's stuff to finish up later) In addition to comments, I don't think demo.ipynb
has any changes that should be committed? measurement tampering.ipynb
also seems partially outdated and like we probably don't need it in the repository, but I do like the dataset exploration in the beginning, we can reuse that when we make a nicer demo notebook for measurement tampering at some point
@dataclass | ||
class TamperingDataConfig(DatasetConfig): | ||
n_sensors: ClassVar[int] = 3 # not configurable | ||
train: bool = True # TODO: how does cupbearer use this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re how we use this: there was some magic going on to sometimes automatically get the validation split where this made sense (e.g. if you train a backdoored classifier, and then evaluate a detector on that classifier, cupbearer would guess that you'd want to evaluate the detector on the train=False
split and check whether the training data config had a train
field). But #30 removes that along with most other behind-the-scenes magic like it, stuff like this is now handled explicitly by the user
from . import DatasetConfig | ||
|
||
|
||
class TamperingDataset(torch.utils.data.Dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To define a measurement tampering task, I see two approaches (with the task format from #30):
- Add options to
TamperingDataset
to only return a selection of the data (e.g. only easy training data, or only validation data that has tampering). Then useTask.from_separate_data()
to recombine them. Advantage is we get some control over mixing ratio and use the standard interface, downside is that it's kind of indirect and complicated. - Pass
trusted_data
,untrusted_train_data
, andtest_data
directly toTask
, without usingMixedData
at all. This is closer to the format the datasets are already in anyway, so probably simpler. We'll just have to make sure to return the right label format:trusted_data
anduntrusted_train_data
should just return(text, measurements)
, whereastest_data
should return((text, measurements), is_clean)
I think.
Leaning a bit towards 2. overall but open to either approach or some third one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be clear, no need to do this before merging, just noting down my thoughts
# get embeddings | ||
embeddings = self.get_embeddings(tokens) | ||
|
||
# TODO (odk) se store (doesn't this slow down training?) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(doesn't this slow down training?)
Calling self.store()
should be ~free outside a self.capture()
context manager
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(but also like we discussed, we should switch to pytorch hooks anyway, so it won't be each model's responsibility anymore to expose activations)
b, self.n_sensors, self.embed_dim | ||
) | ||
# last token embedding (for aggregate measurement) | ||
last_token_ind = tokens["attention_mask"].sum(dim=1) - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably add a check somewhere that padding is on the right just to avoid nasty surprises
last_embs = embeddings[torch.arange(b), last_token_ind] | ||
probe_embs = torch.concat([sensor_embs, last_embs.unsqueeze(dim=1)], axis=1) | ||
assert probe_embs.shape == (b, self.n_probes, self.embed_dim) | ||
logits = torch.concat( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably simpler to make self.probes
a single nn.Linear
with n_probes
output logits? Though if we ever have non-linear probes, then it actually makes a difference (do we share most of their body or not between measurements), not sure which option we'd want in that case
@oliveradk I merged it after making a few small changes (you might want to double check them). The remaining open comments seem less important, just take a look at them at some point |
(all tests passing)