Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't try loading models from the cwd ever #223

Merged
merged 4 commits into from
Apr 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion elk/extraction/dataset_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def extract_dataset_name_and_config(dataset_config_str: str) -> tuple[str, str]:
"""Extract the dataset name and config name from the dataset prompt."""
ds_name, _, config_name = dataset_config_str.partition(" ")
ds_name, _, config_name = dataset_config_str.partition(":")
return ds_name, config_name


Expand Down
4 changes: 3 additions & 1 deletion elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
instantiate_model,
instantiate_tokenizer,
is_autoregressive,
prevent_name_conflicts,
select_train_val_splits,
select_usable_devices,
)
Expand Down Expand Up @@ -306,7 +307,8 @@ def get_splits() -> SplitDict:
dataset_name=available_splits.dataset_name,
)

model_cfg = AutoConfig.from_pretrained(cfg.model)
with prevent_name_conflicts():
model_cfg = AutoConfig.from_pretrained(cfg.model)

ds_name, config_name = extract_dataset_name_and_config(
dataset_config_str=cfg.prompts.datasets[0]
Expand Down
6 changes: 3 additions & 3 deletions elk/extraction/prompt_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class PromptConfig(Serializable):
"""
Args:
dataset: List of space-delimited names of the HuggingFace dataset to use, e.g.
`"super_glue boolq"` or `"imdb"`.
`"super_glue:boolq"` or `"imdb"`.
data_dir: The directory to use for caching the dataset. Defaults to
`~/.cache/huggingface/datasets`.
label_column: The column containing the labels. By default, we infer this from
Expand Down Expand Up @@ -119,7 +119,7 @@ def load_prompts(

Args:
ds_string: Space-delimited name of the HuggingFace dataset to use,
e.g. `"super_glue boolq"` or `"imdb"`.
e.g. `"super_glue:boolq"` or `"imdb"`.
label_column: The column containing the labels. By default, we infer this from
the datatypes of the columns in the dataset.
num_classes: The number of classes in the dataset. If zero, we infer this from
Expand All @@ -135,7 +135,7 @@ def load_prompts(
Returns:
An iterable of prompt dictionaries.
"""
ds_name, _, config_name = ds_string.partition(" ")
ds_name, _, config_name = ds_string.partition(":")
prompter = DatasetTemplates(ds_name, config_name)

ds_dict = assert_type(
Expand Down
2 changes: 2 additions & 0 deletions elk/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
has_multiple_configs,
infer_label_column,
infer_num_classes,
prevent_name_conflicts,
select_train_val_splits,
)
from .gpu_utils import select_usable_devices
Expand All @@ -30,6 +31,7 @@
"instantiate_tokenizer",
"int16_to_float32",
"is_autoregressive",
"prevent_name_conflicts",
"pytree_map",
"select_train_val_splits",
"select_usable_devices",
Expand Down
15 changes: 15 additions & 0 deletions elk/utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import copy
import os
from contextlib import contextmanager
from functools import cache
from random import Random
from tempfile import TemporaryDirectory
from typing import Any, Iterable

from datasets import (
Expand Down Expand Up @@ -32,6 +35,18 @@ def has_multiple_configs(ds_name: str) -> bool:
return len(get_dataset_config_names(ds_name)) > 1


@contextmanager
def prevent_name_conflicts():
"""Temporarily change cwd to a temporary directory, to prevent name conflicts."""
with TemporaryDirectory() as tmp:
old_cwd = os.getcwd()
try:
os.chdir(tmp)
yield
finally:
os.chdir(old_cwd)


def select_train_val_splits(raw_splits: Iterable[str]) -> tuple[str, str]:
"""Return splits to use for train and validation, given an Iterable of splits."""

Expand Down
44 changes: 24 additions & 20 deletions elk/utils/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
PreTrainedTokenizerBase,
)

from .data_utils import prevent_name_conflicts

# Ordered by preference
_DECODER_ONLY_SUFFIXES = [
"CausalLM",
Expand All @@ -19,31 +21,33 @@

def instantiate_model(model_str: str, **kwargs) -> PreTrainedModel:
"""Instantiate a model string with the appropriate `Auto` class."""
model_cfg = AutoConfig.from_pretrained(model_str)
archs = model_cfg.architectures
if not isinstance(archs, list):
return AutoModel.from_pretrained(model_str, **kwargs)
with prevent_name_conflicts():
model_cfg = AutoConfig.from_pretrained(model_str)
archs = model_cfg.architectures
if not isinstance(archs, list):
return AutoModel.from_pretrained(model_str, **kwargs)

for suffix in _AUTOREGRESSIVE_SUFFIXES:
# Check if any of the architectures in the config end with the suffix.
# If so, return the corresponding model class.
for arch_str in archs:
if arch_str.endswith(suffix):
model_cls = getattr(transformers, arch_str)
return model_cls.from_pretrained(model_str, **kwargs)

for suffix in _AUTOREGRESSIVE_SUFFIXES:
# Check if any of the architectures in the config end with the suffix.
# If so, return the corresponding model class.
for arch_str in archs:
if arch_str.endswith(suffix):
model_cls = getattr(transformers, arch_str)
return model_cls.from_pretrained(model_str, **kwargs)

return AutoModel.from_pretrained(model_str, **kwargs)
return AutoModel.from_pretrained(model_str, **kwargs)


def instantiate_tokenizer(model_str: str, **kwargs) -> PreTrainedTokenizerBase:
"""Instantiate a tokenizer, using the fast one iff it exists."""
try:
return AutoTokenizer.from_pretrained(model_str, use_fast=True, **kwargs)
except Exception as e:
if kwargs.get("verbose", True):
print(f"Falling back to slow tokenizer; fast one failed to load: '{e}'")

return AutoTokenizer.from_pretrained(model_str, use_fast=False, **kwargs)
with prevent_name_conflicts():
try:
return AutoTokenizer.from_pretrained(model_str, use_fast=True, **kwargs)
except Exception as e:
if kwargs.get("verbose", True):
print(f"Falling back to slow tokenizer; fast one failed: '{e}'")

return AutoTokenizer.from_pretrained(model_str, use_fast=False, **kwargs)


def is_autoregressive(model_cfg: PretrainedConfig, include_enc_dec: bool) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions tests/super_glue_prompts.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
balance: true
datasets:
- "super_glue boolq"
- "super_glue copa"
- "super_glue:boolq"
- "super_glue:copa"
label_column: null
max_examples:
- 5
Expand Down
2 changes: 1 addition & 1 deletion tests/test_load_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_single_split(cfg: PromptConfig, split_type: Literal["train", "val"]):
ds_string = cfg.datasets[0]
prompt_ds = load_prompts(ds_string, split_type=split_type)

ds_name, _, config_name = ds_string.partition(" ")
ds_name, _, config_name = ds_string.partition(":")
prompter = DatasetTemplates(ds_name, config_name or None)

limit = cfg.max_examples[0 if split_type == "train" else 1]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_smoke_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,6 @@ def test_smoke_eval_run_tiny_gpt2_eigen(tmp_path: Path):

def test_smoke_multi_eval_run_tiny_gpt2_ccs(tmp_path: Path):
elicit = setup_elicit(tmp_path)
transfer_datasets = ["christykoh/imdb_pt", "super_glue boolq"]
transfer_datasets = ["christykoh/imdb_pt", "super_glue:boolq"]
eval_run(elicit, transfer_datasets=transfer_datasets)
eval_assert_files_created(elicit, transfer_datasets=transfer_datasets)