diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 39de74ac..2c961bcd 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -40,6 +40,8 @@ from .generator import _GeneratorBuilder from .prompt_loading import PromptConfig, load_prompts +import re + @dataclass class Extract(Serializable): @@ -122,7 +124,7 @@ def extract_hiddens( model = assert_type(PreTrainedModel, model.get_encoder()) is_enc_dec = False - has_lm_preds = is_autoregressive(model.config, not cfg.use_encoder_states) + has_lm_preds = is_autoregressive(model.config, not cfg.use_encoder_states) and not re.match(r"EleutherAI/pythia", cfg.model) if has_lm_preds and rank == 0: print("Model has language model head, will store predictions.") @@ -337,7 +339,7 @@ def get_splits() -> SplitDict: } # Only add model_logits if the model is an autoregressive model - if is_autoregressive(model_cfg, not cfg.use_encoder_states): + if is_autoregressive(model_cfg, not cfg.use_encoder_states) and not re.match(r"EleutherAI/pythia", cfg.model): other_cols["model_logits"] = Array2D( shape=(num_variants, num_classes), dtype="float32",