Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cluster bootstrap for metrics; refactor metric computations into evaluate_preds #197

Merged
merged 13 commits into from
Apr 19, 2023
Merged
Prev Previous commit
Next Next commit
Don't normalize LM probs twice
  • Loading branch information
norabelrose committed Apr 18, 2023
commit 8e7dfff64da4cd43202fc3d1e3ff0baaeaf56404
12 changes: 6 additions & 6 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def extract_hiddens(
)
for layer_idx in layer_indices
}
lm_preds = torch.empty(
lm_logits = torch.empty(
num_variants,
num_choices,
device=device,
Expand Down Expand Up @@ -205,13 +205,13 @@ def extract_hiddens(
dim=-1
)
tokens = inputs.input_ids[..., start:end, None]
lm_preds[i, j] = log_p.gather(-1, tokens).sum()
lm_logits[i, j] = log_p.gather(-1, tokens).sum()

elif isinstance(outputs, Seq2SeqLMOutput):
# The cross entropy loss is averaged over tokens, so we need to
# multiply by the length to get the total log probability.
length = inputs.labels.shape[-1]
lm_preds[i, j] = -assert_type(Tensor, outputs.loss) * length
lm_logits[i, j] = -assert_type(Tensor, outputs.loss) * length

hiddens = (
outputs.get("decoder_hidden_states") or outputs["hidden_states"]
Expand Down Expand Up @@ -244,7 +244,7 @@ def extract_hiddens(
**hidden_dict,
)
if has_lm_preds:
out_record["model_preds"] = lm_preds.softmax(dim=-1)
out_record["model_logits"] = lm_logits

yield out_record

Expand Down Expand Up @@ -319,9 +319,9 @@ def get_splits() -> SplitDict:
),
}

# Only add model_preds if the model is an autoregressive model
# Only add model_logits if the model is an autoregressive model
if is_autoregressive(model_cfg):
other_cols["model_preds"] = Array2D(
other_cols["model_logits"] = Array2D(
shape=(num_variants, num_classes),
dtype="float32",
)
Expand Down
4 changes: 2 additions & 2 deletions elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def prepare_data(
val_h = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"]))

with split.formatted_as("torch", device=device):
has_preds = "model_preds" in split.features
lm_preds = split["model_preds"] if has_preds else None
has_preds = "model_logits" in split.features
lm_preds = split["model_logits"] if has_preds else None

ds_name = get_dataset_name(ds)
out[ds_name] = (val_h, labels.to(val_h.device), lm_preds)
Expand Down