Skip to content

Commit

Permalink
Add support for Pythia models
Browse files Browse the repository at this point in the history
Don't try to access logits from Pythia model outputs
  • Loading branch information
michaelbyun committed May 31, 2023
1 parent d48b34a commit c2d2229
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
from .generator import _GeneratorBuilder
from .prompt_loading import PromptConfig, load_prompts

import re


@dataclass
class Extract(Serializable):
Expand Down Expand Up @@ -122,7 +124,7 @@ def extract_hiddens(
model = assert_type(PreTrainedModel, model.get_encoder())
is_enc_dec = False

has_lm_preds = is_autoregressive(model.config, not cfg.use_encoder_states)
has_lm_preds = is_autoregressive(model.config, not cfg.use_encoder_states) and not re.match(r"EleutherAI/pythia", cfg.model)
if has_lm_preds and rank == 0:
print("Model has language model head, will store predictions.")

Expand Down Expand Up @@ -337,7 +339,7 @@ def get_splits() -> SplitDict:
}

# Only add model_logits if the model is an autoregressive model
if is_autoregressive(model_cfg, not cfg.use_encoder_states):
if is_autoregressive(model_cfg, not cfg.use_encoder_states) and not re.match(r"EleutherAI/pythia", cfg.model):
other_cols["model_logits"] = Array2D(
shape=(num_variants, num_classes),
dtype="float32",
Expand Down

0 comments on commit c2d2229

Please sign in to comment.