Skip to content

Commit

Permalink
Revert "Add Python 3.11 to CI tests (#91)"
Browse files Browse the repository at this point in the history
This reverts commit a3fadd2.
  • Loading branch information
norabelrose committed Feb 19, 2023
1 parent a3fadd2 commit 5bd59e4
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 211 deletions.
22 changes: 0 additions & 22 deletions .github/workflows/cpu_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,3 @@ jobs:

- name: Run CPU Tests
run: pytest -m cpu

run-tests-python3_11:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

- name: Install Python
uses: actions/setup-python@v4
with:
python-version: "3.11"

- name: Upgrade Pip
run: python -m pip install --upgrade pip

- name: Install Dependencies
run: pip install -e .[dev]

- name: Type Checking
uses: jakebailey/pyright-action@v1

- name: Run CPU Tests
run: pytest -m cpu
64 changes: 41 additions & 23 deletions elk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from .files import args_to_uuid
from .list import list_runs
from argparse import ArgumentParser
from contextlib import nullcontext, redirect_stdout
import logging
import warnings


Expand Down Expand Up @@ -38,7 +40,6 @@ def run():
list_runs(args)
return

# Import here and not at the top to speed up `elk list`
from transformers import AutoConfig, PretrainedConfig

config = AutoConfig.from_pretrained(args.model)
Expand Down Expand Up @@ -66,30 +67,47 @@ def run():
# Import here and not at the top to speed up `elk list`
from .extraction.extraction_main import run as run_extraction
from .training.train import train
import os
import torch.distributed as dist

# Print CLI arguments to stdout
for key, value in vars(args).items():
print(f"{key}: {value}")

if args.command == "extract":
run_extraction(args)
elif args.command == "elicit":
# The user can specify a name for the run, but by default we use the
# MD5 hash of the arguments to ensure the name is unique
if not args.name:
args.name = args_to_uuid(args)

try:
train(args)
except (EOFError, FileNotFoundError):
run_extraction(args)
train(args)
# Check if we were called with torchrun or not
local_rank = os.environ.get("LOCAL_RANK")
if local_rank is not None:
dist.init_process_group("nccl")
local_rank = int(local_rank)

with redirect_stdout(None) if local_rank else nullcontext():
# Print CLI arguments to stdout
for key, value in vars(args).items():
print(f"{key}: {value}")

elif args.command == "eval":
# TODO: Implement evaluation script
raise NotImplementedError
else:
raise ValueError(f"Unknown command {args.command}")
if local_rank:
logging.getLogger("transformers").setLevel(logging.CRITICAL)

if args.command == "extract":
run_extraction(args)
elif args.command == "elicit":
# The user can specify a name for the run, but by default we use the
# MD5 hash of the arguments to ensure the name is unique
if not args.name:
args.name = args_to_uuid(args)

try:
train(args)
except (EOFError, FileNotFoundError):
run_extraction(args)

# Ensure the extraction is finished before starting training
if dist.is_initialized():
dist.barrier()

train(args)

elif args.command == "eval":
# TODO: Implement evaluation script
raise NotImplementedError
else:
raise ValueError(f"Unknown command {args.command}")


if __name__ == "__main__":
Expand Down
170 changes: 46 additions & 124 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,36 +2,19 @@

from ..utils import pytree_map
from .prompt_collator import Prompt, PromptCollator
from dataclasses import dataclass
from einops import rearrange
from torch.utils.data import DataLoader
from transformers import (
BatchEncoding,
PreTrainedModel,
PreTrainedTokenizerBase,
AutoModel,
)
from typing import cast, Literal, Sequence
import logging
import numpy as np
from tqdm.auto import tqdm
from transformers import BatchEncoding, PreTrainedModel, PreTrainedTokenizerBase
from typing import cast, Iterable, Literal, Sequence
import torch
import torch.multiprocessing as mp


@dataclass
class ExtractionParameters:
model_str: str
tokenizer: PreTrainedTokenizerBase
collator: PromptCollator
batch_size: int = 1
layers: Sequence[int] = ()
prompt_suffix: str = ""
token_loc: Literal["first", "last", "mean"] = "last"
use_encoder_states: bool = False
import torch.distributed as dist


@torch.autocast("cuda", enabled=torch.cuda.is_available()) # type: ignore
@torch.no_grad()
def extract_hiddens(
model_str: str,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
collator: PromptCollator,
*,
Expand All @@ -41,85 +24,40 @@ def extract_hiddens(
prompt_suffix: str = "",
token_loc: Literal["first", "last", "mean"] = "last",
use_encoder_states: bool = False,
):
"""Run inference on a model with a set of prompts, yielding the hidden states."""

ctx = mp.get_context("spawn")
queue = ctx.Queue()

num_gpus = torch.cuda.device_count()
params = ExtractionParameters(
model_str=model_str,
tokenizer=tokenizer,
collator=collator,
batch_size=batch_size,
layers=layers,
prompt_suffix=prompt_suffix,
token_loc=token_loc,
use_encoder_states=use_encoder_states,
)

# Spawn a process for each GPU
ctx = torch.multiprocessing.spawn(
_extract_hiddens_process,
args=(num_gpus, queue, params),
nprocs=num_gpus,
join=False,
)
assert ctx is not None

# Yield results from the queue
for _ in range(len(collator)):
yield queue.get()

# Clean up
ctx.join()


@torch.no_grad()
def _extract_hiddens_process(
rank: int,
world_size: int,
queue: mp.Queue,
params: ExtractionParameters,
):
) -> Iterable[tuple[torch.Tensor, list[int]]]:
"""Run inference on a model with a set of prompts, yielding the hidden states.
Args:
model: The model to run inference on.
tokenizer: The tokenizer to use for tokenization.
collator: The PromptCollator to use for generating prompts.
batch_size: The batch size to use for inference.
layers (Sequence[int]): The layers to extract hidden states from.
prompt_suffix (str): A string to append to the end of each prompt.
token_loc: The location of the token to extract hidden states from.
can be either "first", "last", or "mean". Defaults to "last".
use_encoder_states: Whether to use the encoder states instead of the
decoder states. This allows simplification from an encoder-decoder
model to an encoder-only model. Defaults to False.
"""
Do inference on a model with a set of prompts on a single process.
To be passed to Dataset.from_generator.
"""
print(f"Process with rank={rank}")
if rank != 0:
logging.getLogger("transformers").setLevel(logging.CRITICAL)

num_choices = params.collator.num_classes
shards = np.array_split(np.arange(len(params.collator)), world_size)
params.collator.select_(shards[rank])

# AutoModel should do the right thing here in nearly all cases. We don't actually
# care what head the model has, since we are just extracting hidden states.
model = AutoModel.from_pretrained(params.model_str, torch_dtype="auto").to(
f"cuda:{rank}"
)

if params.use_encoder_states and not model.config.is_encoder_decoder:
raise ValueError(
"use_encoder_states is only compatible with encoder-decoder models."
)
device = model.device
num_choices = collator.num_classes

# TODO: Make this configurable or something
# Token used to separate the question from the answer
sep_token = params.tokenizer.sep_token or "\n"
sep_token = tokenizer.sep_token or "\n"

# TODO: Maybe also make this configurable?
# We want to make sure the answer is never truncated
params.tokenizer.truncation_side = "left"
if not params.tokenizer.pad_token:
params.tokenizer.pad_token = params.tokenizer.eos_token
tokenizer.truncation_side = "left"
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token

def tokenize(strings: list[str]):
return pytree_map(
lambda x: x.to(f"cuda:{rank}"),
params.tokenizer(
lambda x: x.to(device),
tokenizer(
strings,
padding=True,
return_tensors="pt",
Expand All @@ -131,7 +69,7 @@ def tokenize(strings: list[str]):
# each question-answer pair. After inference we need to reshape the results.
def collate(prompts: list[Prompt]) -> tuple[BatchEncoding, list[int]]:
choices = [
prompt.to_string(i, sep=sep_token) + params.prompt_suffix
prompt.to_string(i, sep=sep_token) + prompt_suffix
for prompt in prompts
for i in range(num_choices)
]
Expand All @@ -145,7 +83,7 @@ def collate_enc_dec(
)
tokenized_answers = tokenize(
[
prompt.answers[i] + params.prompt_suffix
prompt.answers[i] + prompt_suffix
for prompt in prompts
for i in range(num_choices)
]
Expand All @@ -161,22 +99,22 @@ def reduce_seqs(
# Unflatten the hiddens
hiddens = [rearrange(h, "(b c) l d -> b c l d", c=num_choices) for h in hiddens]

if params.token_loc == "first":
if token_loc == "first":
hiddens = [h[..., 0, :] for h in hiddens]
elif params.token_loc == "last":
elif token_loc == "last":
# Because of padding, the last token is going to be at a different index
# for each example, so we use gather.
B, C, _, D = hiddens[0].shape
lengths = attention_mask.sum(dim=-1).view(B, C, 1, 1)
indices = lengths.sub(1).expand(B, C, 1, D)
hiddens = [h.gather(index=indices, dim=-2).squeeze(-2) for h in hiddens]
elif params.token_loc == "mean":
elif token_loc == "mean":
hiddens = [h.mean(dim=-2) for h in hiddens]
else:
raise ValueError(f"Invalid token_loc: {params.token_loc}")
raise ValueError(f"Invalid token_loc: {token_loc}")

if params.layers:
hiddens = [hiddens[i] for i in params.layers]
if layers:
hiddens = [hiddens[i] for i in layers]

# [batch size, layers, num choices, hidden size]
return torch.stack(hiddens, dim=1)
Expand All @@ -185,7 +123,7 @@ def reduce_seqs(
# we don't need to run the decoder at all. Just strip it off, making the problem
# equivalent to a regular encoder-only model.
is_enc_dec = model.config.is_encoder_decoder
if is_enc_dec and params.use_encoder_states:
if is_enc_dec and use_encoder_states:
# This isn't actually *guaranteed* by HF, but it's true for all existing models
if not hasattr(model, "get_encoder") or not callable(model.get_encoder):
raise ValueError(
Expand All @@ -196,16 +134,17 @@ def reduce_seqs(

# Whether to concatenate the question and answer before passing to the model.
# If False pass them to the encoder and decoder separately.
should_concat = not is_enc_dec or params.use_encoder_states
should_concat = not is_enc_dec or use_encoder_states

dl = DataLoader(
params.collator,
batch_size=params.batch_size,
collator,
batch_size=batch_size,
collate_fn=collate if should_concat else collate_enc_dec,
)

# Iterating over questions
for batch in dl:
rank = dist.get_rank() if dist.is_initialized() else 0
for batch in tqdm(dl, position=rank):
# Condition 1: Encoder-decoder transformer, with answer in the decoder
if not should_concat:
questions, answers, labels = batch
Expand All @@ -215,29 +154,12 @@ def reduce_seqs(
output_hidden_states=True,
)
# [batch_size, num_layers, num_choices, hidden_size]
# need to convert hidden states to numpy array first or
# you get a ConnectionResetErrror
queue.put(
{
"hiddens": torch.stack(outputs.decoder_hidden_states, dim=2)
.cpu()
.numpy(),
"labels": labels,
}
)
yield torch.stack(outputs.decoder_hidden_states, dim=2), labels

# Condition 2: Either a decoder-only transformer or a transformer encoder
else:
choices, labels = batch

# Skip the input embeddings which are unlikely to be interesting
h = model(**choices, output_hidden_states=True).hidden_states[1:]

# need to convert hidden states to numpy array first or
# you get a ConnectionResetErrror
queue.put(
{
"hiddens": reduce_seqs(h, choices["attention_mask"]).cpu().numpy(),
"labels": labels,
}
)
yield reduce_seqs(h, choices["attention_mask"]), labels
Loading

0 comments on commit 5bd59e4

Please sign in to comment.