diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 39de74ac..3d2939bb 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -74,7 +74,9 @@ def __post_init__(self, layer_stride: int): config = assert_type( PretrainedConfig, AutoConfig.from_pretrained(self.model) ) - self.layers = tuple(range(0, config.num_hidden_layers, layer_stride)) + # Note that we always include 0 which is the embedding layer + layer_range = range(1, config.num_hidden_layers + 1, layer_stride) + self.layers = (0,) + tuple(layer_range) def explode(self) -> list["Extract"]: """Explode this config into a list of configs, one for each layer.""" @@ -136,8 +138,8 @@ def extract_hiddens( world_size=world_size, ) - # Iterating over questions - layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers)) + # Add one to the number of layers to account for the embedding layer + layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers + 1)) global_max_examples = p_cfg.max_examples[0 if split_type == "train" else 1] # break `max_examples` among the processes roughly equally @@ -229,9 +231,6 @@ def extract_hiddens( hiddens = ( outputs.get("decoder_hidden_states") or outputs["hidden_states"] ) - # First element of list is the input embeddings - hiddens = hiddens[1:] - # Throw out layers we don't care about hiddens = [hiddens[i] for i in layer_indices] @@ -320,7 +319,8 @@ def get_splits() -> SplitDict: dtype="int16", shape=(num_variants, num_classes, model_cfg.hidden_size), ) - for layer in cfg.layers or range(model_cfg.num_hidden_layers) + # Add 1 to include the embedding layer + for layer in cfg.layers or range(model_cfg.num_hidden_layers + 1) } other_cols = { "variant_ids": Sequence( diff --git a/elk/run.py b/elk/run.py index e246860a..34afc6f5 100644 --- a/elk/run.py +++ b/elk/run.py @@ -153,6 +153,9 @@ def apply_to_layers( # Make sure the CSV is written even if we crash or get interrupted if df_buf: df = pd.concat(df_buf).sort_values(by="layer") + + # Rename layer 0 to "input" to make it more clear + df["layer"].replace(0, "input", inplace=True) df.round(4).to_csv(f, index=False) if self.cfg.debug: save_debug_log(self.datasets, self.out_dir)