From c2d22291c5a0fea79b9761b87e3ea8bce8e5749a Mon Sep 17 00:00:00 2001 From: michaelbyun Date: Wed, 31 May 2023 16:44:37 -0700 Subject: [PATCH] Add support for Pythia models Don't try to access logits from Pythia model outputs --- elk/extraction/extraction.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 39de74ac..2c961bcd 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -40,6 +40,8 @@ from .generator import _GeneratorBuilder from .prompt_loading import PromptConfig, load_prompts +import re + @dataclass class Extract(Serializable): @@ -122,7 +124,7 @@ def extract_hiddens( model = assert_type(PreTrainedModel, model.get_encoder()) is_enc_dec = False - has_lm_preds = is_autoregressive(model.config, not cfg.use_encoder_states) + has_lm_preds = is_autoregressive(model.config, not cfg.use_encoder_states) and not re.match(r"EleutherAI/pythia", cfg.model) if has_lm_preds and rank == 0: print("Model has language model head, will store predictions.") @@ -337,7 +339,7 @@ def get_splits() -> SplitDict: } # Only add model_logits if the model is an autoregressive model - if is_autoregressive(model_cfg, not cfg.use_encoder_states): + if is_autoregressive(model_cfg, not cfg.use_encoder_states) and not re.match(r"EleutherAI/pythia", cfg.model): other_cols["model_logits"] = Array2D( shape=(num_variants, num_classes), dtype="float32",