Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binarized meta-templates; some extraction refactoring #218

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
bbee489
Initial support for FEVER
norabelrose Apr 22, 2023
5ba1ddd
Start saving and fitting a reporter to the input embeddings
norabelrose Apr 22, 2023
3b1f74d
Merge branch 'input-embeddings' into template-filtering
norabelrose Apr 22, 2023
51ba54f
Rename layer 0 to 'input' to make it more clear
norabelrose Apr 22, 2023
544b485
Actually rename layer 0 correctly
norabelrose Apr 22, 2023
43da44e
Handle layer_stride correctly
norabelrose Apr 22, 2023
9056e00
Merge branch 'input-embeddings' into template-filtering
norabelrose Apr 22, 2023
756fa53
label_choices
norabelrose Apr 22, 2023
93b7ae0
Clean up train and eval commands; do transfer in sweep
norabelrose Apr 22, 2023
57d0b8b
Support INLP and split eval output into multiple CSVs
norabelrose Apr 22, 2023
228a6a0
Merge branch 'inlp' into template-filtering
norabelrose Apr 22, 2023
b086f0b
Merge branch 'inlp' into template-filtering
norabelrose Apr 25, 2023
934cd54
Log ensembled metrics
norabelrose Apr 26, 2023
dff69bf
Fixing pyright version
norabelrose Apr 26, 2023
b181d3e
Merge remote-tracking branch 'origin/main' into ensembling
norabelrose Apr 26, 2023
15254bf
Merge main
norabelrose Apr 26, 2023
69c2d55
Tons of stuff, preparing for sciq_binary experiment
norabelrose Apr 27, 2023
960ff01
Support --binarize again
norabelrose Apr 27, 2023
c9e62ea
Partial support for truthful_qa
norabelrose Apr 27, 2023
eb71a6c
Merge branch 'main' into template-filtering
norabelrose Apr 29, 2023
88bb15e
Merge remote-tracking branch 'origin/main' into template-filtering
norabelrose Apr 29, 2023
c648ff0
Remove crap
norabelrose Apr 29, 2023
ef12130
EleutherAI/truthful_qa_mc
norabelrose Apr 29, 2023
5d60ebd
Update templates
norabelrose Apr 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Partial support for truthful_qa
  • Loading branch information
norabelrose committed Apr 27, 2023
commit c9e62ea22138b2e30f0cf31f249258fbcca830b5
13 changes: 4 additions & 9 deletions elk/evaluation/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,21 @@
from ..metrics import evaluate_preds
from ..run import Run
from ..training import Reporter
from ..utils import Color


@dataclass
@dataclass(kw_only=True)
class Eval(Run):
"""Full specification of a reporter evaluation run."""

# Using None as a default here is a hack; we actually raise an error if it's not
# specified in __post_init__. TODO: Maybe this is an indication we should be using
# composition and not inheritance here?
source: Path | None = field(default=None, positional=True)
source: Path = field(positional=True)
skip_supervised: bool = False

def __post_init__(self):
assert self.source, "Must specify a source experiment."

if not self.out_dir:
self.out_dir = self.source / "transfer" / "+".join(self.data.datasets)

def execute(self, highlight_color: str = "cyan"):
def execute(self, highlight_color: Color = "cyan"):
return super().execute(highlight_color, split_type="val")

@torch.inference_mode()
Expand All @@ -39,7 +35,6 @@ def apply_to_layer(
device = self.get_device(devices, world_size)
val_output = self.prepare_data(device, layer, "val")

assert self.source, "Must specify a source experiment."
experiment_dir = elk_reporter_dir() / self.source

reporter_path = experiment_dir / "reporters" / f"layer_{layer}.pt"
Expand Down
2 changes: 1 addition & 1 deletion elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def extract_hiddens(
input_ids = input_ids[..., -min(cur_len, max_len) :]

# Make sure we only pass the arguments that the model expects
inputs = dict(input_ids=input_ids)
inputs = dict(input_ids=input_ids.long())
if is_enc_dec:
inputs["labels"] = answer

Expand Down
10 changes: 7 additions & 3 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,11 @@ def load_prompts(
print("No label column found, not balancing")
ds = ds.to_iterable_dataset()

if rank == 0:
print(f"Label choices: {label_choices}")

for example in ds:
yield _convert_to_prompts(
example,
binarize=binarize,
choices_column=prompter.choices_column,
label_column=label_column,
label_choices=label_choices, # type: ignore[arg-type]
num_variants=num_variants,
Expand All @@ -124,6 +122,7 @@ def _convert_to_prompts(
example: dict[str, Any],
prompter: DatasetTemplates,
binarize: bool,
choices_column: str | None,
label_column: str,
label_choices: list[bool | int | str],
num_variants: int,
Expand All @@ -144,6 +143,11 @@ def qa_cat(q: str, a: str) -> str:
# For sanity checking that prompts are unique
prompt_counter = Counter()
label = example[label_column]
if choices_column:
label_choices = example[choices_column]
if isinstance(label, int):
label_choices = list(range(len(label_choices)))

if binarize:
# Replace the full list of possibilities with a randomly sampled false label
# and the correct label, as done in the DLK paper. Note that this does add some
Expand Down
18 changes: 6 additions & 12 deletions elk/promptsource/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,7 @@ class DatasetTemplates:
helper functions necessary to read/write to the yaml file
"""

TEMPLATES_KEY = "templates"
DATASET_KEY = "dataset"
SUBSET_KEY = "subset"
LABEL_COLUMN_KEY = "label_column"
LABEL_CHOICES_KEY = "label_choices"
TEMPLATE_FILENAME = "templates.yaml"

choices_column: str | None
label_column: str | None
label_choices: list[int | str]
templates: dict[str, Template]
Expand All @@ -289,11 +283,11 @@ def __init__(self, dataset_name: str, subset_name: str | None = None):
yaml_dict = yaml.load(f, Loader=yaml.FullLoader)

# Required field; contains all the templates keyed by ID
self.templates = yaml_dict[self.TEMPLATES_KEY]
self.templates = yaml_dict["templates"]

# Optional fields; may be None
self.label_column = yaml_dict.get(self.LABEL_COLUMN_KEY)
self.label_choices = yaml_dict.get(self.LABEL_CHOICES_KEY, [])
self.choices_column = yaml_dict.get("choices_column")
self.label_column = yaml_dict.get("label_column")
self.label_choices = yaml_dict.get("label_choices", [])

def drop_non_mc_templates(self) -> int:
"""Drop all templates that aren't multiple choice, return the number dropped"""
Expand Down Expand Up @@ -326,7 +320,7 @@ def folder_path(self) -> str:

@property
def yaml_path(self) -> str:
path = os.path.join(self.folder_path, self.TEMPLATE_FILENAME)
path = os.path.join(self.folder_path, "templates.yaml")
if not os.path.exists(path):
raise ValueError(f"Expected prompt templates to exist at {path}")

Expand Down
118 changes: 118 additions & 0 deletions elk/promptsource/templates/norabelrose/truthful_qa/templates.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Adapted from the ARC-Easy templates
dataset: norabelrose/truthful_qa_mc
choices_column: choices
templates:
a1: !Template
answer_choices: 'A|||B|||C|||D|||E|||F|||G|||H|||I|||J|||K|||L|||M|||N'
id: a1
jinja: 'Pick the most correct option to answer the following question.


{{question}}


Options:

{% for letter, t in zip(answer_choices, choices) %}

- {{letter}}: {{t}}

{% endfor %} |||{{answer_choices[label]}}'
metadata: !TemplateMetadata
choices_in_prompt: true
languages:
- en
metrics:
- Accuracy
original_task: true
name: pick_the_most_correct_option
reference: ''
a2: !Template
answer_choices: 'A|||B|||C|||D|||E|||F|||G|||H|||I|||J|||K|||L|||M|||N'
id: a2
jinja: '{{question}}


Options:

{% for letter, t in zip(answer_choices, choices) %}

- {{letter}}: {{t}}

{% endfor %}

A:|||{{answer_choices[label]}}'
metadata: !TemplateMetadata
choices_in_prompt: true
languages:
- en
metrics:
- Accuracy
original_task: true
name: qa_options
reference: ''
a3: !Template
answer_choices: 'A|||B|||C|||D|||E|||F|||G|||H|||I|||J|||K|||L|||M|||N'
id: a3
jinja: 'I am hesitating between 4 options to answer the following question, which
option should I choose?

Question: {{question}}

Possibilities:

{% for letter, t in zip(answer_choices, choices) %}

- {{letter}}: {{t}}

{% endfor %}

A:|||{{answer_choices[label]}}'
metadata: !TemplateMetadata
choices_in_prompt: true
languages:
- en
metrics:
- Accuracy
original_task: true
name: i_am_hesitating
reference: ''
a4: !Template
answer_choices: 'A|||B|||C|||D|||E|||F|||G|||H|||I|||J|||K|||L|||M|||N'
id: a4
jinja: 'I gave my students this multiple choice question: {{question}}


Only one answer is correct among these 4 choices:

{% for letter, t in zip(answer_choices, choices) %}

- {{letter}}: {{t}}

{% endfor %}

A:|||{{answer_choices[label]}}'
metadata: !TemplateMetadata
choices_in_prompt: true
languages:
- en
metrics:
- Accuracy
original_task: true
name: multiple_choice
reference: ''
a6: !Template
answer_choices: 'A|||B|||C|||D|||E|||F|||G|||H|||I|||J|||K|||L|||M|||N'
id: a6
jinja: "Here's a problem to solve: {{question}}\n\nAmong the 4 following options,\
\ which is the correct answer?\n{% for letter, t in zip(answer_choices, choices)\
\ %}\n- {{letter}}: {{t}}\n {% endfor %}A:|||{{answer_choices[label]}}"
metadata: !TemplateMetadata
choices_in_prompt: true
languages:
- en
metrics:
- Accuracy
original_task: true
name: heres_a_problem
reference: ''
4 changes: 3 additions & 1 deletion elk/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .extraction.dataset_name import DatasetDictWithName
from .files import elk_reporter_dir, memorably_named_dir
from .utils import (
Color,
assert_type,
get_layer_indices,
int16_to_float32,
Expand Down Expand Up @@ -48,7 +49,7 @@ class Run(ABC, Serializable):

def execute(
self,
highlight_color: str = "cyan",
highlight_color: Color = "cyan",
split_type: Literal["train", "val", None] = None,
):
self.datasets = [
Expand Down Expand Up @@ -127,6 +128,7 @@ def prepare_data(

split = ds[key].with_format("torch", device=device, dtype=torch.int16)
labels = assert_type(Tensor, split["label"])
breakpoint()
val_h = int16_to_float32(assert_type(Tensor, split[f"hidden_{layer}"]))

with split.formatted_as("torch", device=device):
Expand Down
5 changes: 1 addition & 4 deletions elk/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,7 @@ def infer_num_classes(label_feature: Any) -> int:
elif isinstance(label_feature, Value) and label_feature.dtype == "bool":
return 2
else:
raise ValueError(
f"Can't infer number of classes from label column of type {label_feature}. "
f"Please update the num_classes field in the prompt template yaml file."
)
return -1


def get_layer_indices(ds: DatasetDict) -> list[int]:
Expand Down
8 changes: 4 additions & 4 deletions elk/utils/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ def assert_type(typ: Type[T], obj: Any) -> T:


def float32_to_int16(x: torch.Tensor) -> torch.Tensor:
"""Converts float32 to float16, then reinterprets as int16."""
downcast = x.type(torch.float16)
"""Converts float32 to bfloat16, then reinterprets as int16."""
downcast = x.type(torch.bfloat16)
if not downcast.isfinite().all():
raise ValueError("Cannot convert to 16 bit: values are not finite")

return downcast.view(torch.int16)


def int16_to_float32(x: torch.Tensor) -> torch.Tensor:
"""Converts int16 to float16, then reinterprets as float32."""
return x.view(torch.float16).type(torch.float32)
"""Converts int16 to bfloat16, then reinterprets as float32."""
return x.view(torch.bfloat16).type(torch.float32)