Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Apr 27, 2023
1 parent 72fc632 commit 56618ed
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 30 deletions.
20 changes: 6 additions & 14 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
import os
from copy import copy
from dataclasses import InitVar, dataclass
from functools import partial
from itertools import islice
from multiprocessing.pool import ThreadPool
from typing import Any, Callable, Iterable, Literal, Sequence, NewType
from typing import Any, Iterable, Literal, NewType
from warnings import filterwarnings

import torch
Expand All @@ -21,26 +19,21 @@
from torch import Tensor
from transformers import PreTrainedModel
from transformers.modeling_outputs import Seq2SeqLMOutput
from transformers.utils import ModelOutput

from .dataset_name import (
DatasetDictWithName,
extract_dataset_name_and_config,
)
from .prompt_loading import PromptConfig, load_prompts
from ..multiprocessing import evaluate_with_processes
from ..utils import (
assert_type,
colorize,
float32_to_int16,
instantiate_model,
instantiate_tokenizer,
is_autoregressive,
select_train_val_splits,
select_usable_devices,
)
from ..utils.data_utils import flatten_list
from ..utils.fsdp import InferenceServer
from .dataset_name import (
DatasetDictWithName,
extract_dataset_name_and_config,
)
from .prompt_loading import PromptConfig, load_prompts


@dataclass
Expand Down Expand Up @@ -538,7 +531,6 @@ def get_splits() -> SplitDict:
)
info = get_dataset_config_info(ds_name, config_name or None)


split_names = list(get_splits().keys())
server = InferenceServer(
model_str=cfg.model, fsdp=True, cpu_offload=False, num_workers=6
Expand Down
12 changes: 6 additions & 6 deletions elk/utils/fsdp.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
from dataclasses import dataclass
from functools import partial
from itertools import cycle
import logging
import multiprocessing as std_mp
import os
import socket
import warnings
from functools import partial
from itertools import cycle
from typing import Any, Callable, Iterable, Type, cast

import dill
import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from datasets import Dataset
from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from typing import Any, Callable, Iterable, Type, cast

from ..multiprocessing import A
from ..utils import instantiate_model, pytree_map, select_usable_devices
Expand Down
4 changes: 1 addition & 3 deletions elk/utils/hf_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import transformers
from accelerate import infer_auto_device_map
from transformers import (
AutoConfig,
AutoModel,
Expand All @@ -19,8 +18,7 @@


def instantiate_model(model_str: str, **kwargs) -> PreTrainedModel:
"""Instantiate a model string with the appropriate `Auto` class.
"""
"""Instantiate a model string with the appropriate `Auto` class."""
model_cfg = AutoConfig.from_pretrained(model_str)

archs = model_cfg.architectures
Expand Down
12 changes: 5 additions & 7 deletions tests/test_extract_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,18 @@
from elk import Extract
from elk.extraction import PromptConfig
from elk.extraction.extraction import extract_input_ids
from elk.training import CcsReporterConfig
from elk.training.train import Elicit


def test_extract(tmp_path: Path):
# we need about 5 mb of gpu memory to run this test
model_path, min_mem = "sshleifer/tiny-gpt2", 10 * 1024**2
dataset_name = "imdb"
extract = Extract(
model=model_path,
prompts=PromptConfig(datasets=[dataset_name], max_examples=[10]),
# run on all layers, tiny-gpt only has 2 layers
)
model=model_path,
prompts=PromptConfig(datasets=[dataset_name], max_examples=[10]),
# run on all layers, tiny-gpt only has 2 layers
)
model = AutoModelForCausalLM.from_pretrained(model_path)

result = extract_input_ids(cfg=extract, model=model, split_type="train")
extract_input_ids(cfg=extract, model=model, split_type="train")
print("ok")

0 comments on commit 56618ed

Please sign in to comment.