Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Python 3.11 to CI tests #91

Merged
merged 18 commits into from
Feb 19, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Switch to torch.multiprocessing.spawn
  • Loading branch information
norabelrose committed Feb 19, 2023
commit 5e9662e12f4dd73d7c7b1f87d43a40e9c2362bde
82 changes: 42 additions & 40 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@
from dataclasses import dataclass
from einops import rearrange
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import BatchEncoding, PreTrainedModel, PreTrainedTokenizerBase, AutoModel
from typing import cast, Literal, Iterator, Sequence
from transformers import (
BatchEncoding,
PreTrainedModel,
PreTrainedTokenizerBase,
AutoModel,
)
from typing import cast, Literal, Sequence
import logging
import numpy as np
import torch
import torch.multiprocessing as mp
Expand Down Expand Up @@ -36,71 +41,66 @@ def extract_hiddens(
prompt_suffix: str = "",
token_loc: Literal["first", "last", "mean"] = "last",
use_encoder_states: bool = False,
seed_start: int = 42,
):
"""Run inference on a model with a set of prompts, yielding the hidden states."""

ctx = mp.get_context("spawn")
queue = ctx.Queue()

# use different random seed for each process
curr_seed = seed_start
workers = []

# Start the workers.
num_gpus = torch.cuda.device_count()
shards = np.array_split(np.arange(len(collator)), num_gpus)
for rank, proc_indices in enumerate(shards):
params = ExtractionParameters(
model_str=model_str,
tokenizer=tokenizer,
collator=collator.split_and_copy(proc_indices, curr_seed),
batch_size=batch_size,
layers=layers,
prompt_suffix=prompt_suffix,
token_loc=token_loc,
use_encoder_states=use_encoder_states,
)

worker = ctx.Process(
target=_extract_hiddens_process,
args=(queue, params, rank),
)
worker.start()
workers.append(worker)
params = ExtractionParameters(
model_str=model_str,
tokenizer=tokenizer,
collator=collator,
batch_size=batch_size,
layers=layers,
prompt_suffix=prompt_suffix,
token_loc=token_loc,
use_encoder_states=use_encoder_states,
)

curr_seed += 1
# Spawn a process for each GPU
ctx = torch.multiprocessing.spawn(
_extract_hiddens_process,
args=(num_gpus, queue, params),
nprocs=num_gpus,
join=False,
)
assert ctx is not None

# Consume the results from the queue
# Yield results from the queue
for _ in range(len(collator)):
yield queue.get()

# Clean up
for worker in workers:
worker.join()
ctx.join()


@torch.no_grad()
def _extract_hiddens_process(
rank: int,
world_size: int,
queue: mp.Queue,
params: ExtractionParameters,
rank: int,
) -> Iterator[dict]:
):
"""
Do inference on a model with a set of prompts on a single process.
To be passed to Dataset.from_generator.
"""
print(f"Process with rank={rank}")
if rank != 0:
logging.getLogger("transformers").setLevel(logging.CRITICAL)

shards = np.array_split(np.arange(len(params.collator)), world_size)
params.collator.select_(shards[rank])

# AutoModel should do the right thing here in nearly all cases. We don't actually
# care what head the model has, since we are just extracting hidden states.
print(f"Rank={rank}: Loading model '{params.model_str}'...")
model = AutoModel.from_pretrained(params.model_str, torch_dtype="auto")
print(f"Rank={rank}: Done. Model class: '{model.__class__.__name__}'")

if params.use_encoder_states and not model.config.is_encoder_decoder:
raise ValueError(
"--use_encoder_states is only compatible with encoder-decoder models."
"use_encoder_states is only compatible with encoder-decoder models."
)

model = model.to(f"cuda:{rank}")
Expand Down Expand Up @@ -205,7 +205,7 @@ def reduce_seqs(
)

# Iterating over questions
for batch in tqdm(dl, position=rank):
for batch in dl:
# Condition 1: Encoder-decoder transformer, with answer in the decoder
if not should_concat:
questions, answers, labels = batch
Expand All @@ -219,7 +219,9 @@ def reduce_seqs(
# you get a ConnectionResetErrror
queue.put(
{
"hiddens": torch.stack(outputs.decoder_hidden_states, dim=2).cpu().numpy(),
"hiddens": torch.stack(outputs.decoder_hidden_states, dim=2)
.cpu()
.numpy(),
"labels": labels,
}
)
Expand All @@ -230,7 +232,7 @@ def reduce_seqs(

# Skip the input embeddings which are unlikely to be interesting
h = model(**choices, output_hidden_states=True).hidden_states[1:]

# need to convert hidden states to numpy array first or
# you get a ConnectionResetErrror
queue.put(
Expand Down
28 changes: 13 additions & 15 deletions elk/extraction/extraction_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from .extraction import extract_hiddens, PromptCollator
from ..files import args_to_uuid, elk_cache_dir
from ..training.preprocessing import silence_datasets_messages
from ..utils import maybe_all_gather
from transformers import AutoConfig, AutoTokenizer
import json
import torch
from datasets import Dataset


def run(args):
def extract(args, split: str):
"""Extract hidden states for a given split.
Expand Down Expand Up @@ -37,18 +36,18 @@ def extract(args, split: str):
print(f"Randomizing over {len(prompt_names)} prompts: {prompt_names}")
else:
raise ValueError(f"Unknown prompt strategy: {args.prompts}")

return Dataset.from_generator(
extract_hiddens,
gen_kwargs = {
'model_str': args.model,
'tokenizer': tokenizer,
'collator': collator,
'layers': args.layers,
'prompt_suffix': args.prompt_suffix,
'token_loc': args.token_loc,
'use_encoder_states': args.use_encoder_states,
}
gen_kwargs={
"model_str": args.model,
"tokenizer": tokenizer,
"collator": collator,
"layers": args.layers,
"prompt_suffix": args.prompt_suffix,
"token_loc": args.token_loc,
"use_encoder_states": args.use_encoder_states,
},
)

print("Loading tokenizer...")
Expand All @@ -70,9 +69,8 @@ def extract(args, split: str):
with open(save_dir / "args.json", "w") as f:
json.dump(vars(args), f)

config = AutoConfig.from_pretrained(args.model)

with open(save_dir / "model_config.json", "w") as f:
config = AutoConfig.from_pretrained(args.model)
json.dump(config.to_dict(), f)

return train_dset, valid_dset
21 changes: 2 additions & 19 deletions elk/extraction/prompt_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,22 +155,5 @@ def set_labels(self):
)
self.label_fracs = counts / counts.sum()

def split_and_copy(self, indices, new_seed):
"""
To avoid copying entire dataest num_proccesses times when multiprocessing,
this makes a shallow copy of self, but with self.dataset split
according to given indices.
"""
dataset_split = self.dataset.select(indices)

# only shallow copy is needed -- multiprocess will pickle (dill) objects
self_copy = copy.copy(self)
self_copy.dataset = dataset_split

# redo counts based on new split
self_copy.set_labels()

# give copy a new rng
self_copy.rng = Random(new_seed)

return self_copy
def select_(self, indices):
self.dataset = self.dataset.select(indices)