Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Python 3.11 to CI tests #91

Merged
merged 18 commits into from
Feb 19, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Replace custom method with np.array_split
  • Loading branch information
norabelrose committed Feb 18, 2023
commit 3862e926fe638d6733bd1db21cc807f31d84fda9
69 changes: 27 additions & 42 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from ..utils import pytree_map
from .prompt_collator import Prompt, PromptCollator
from dataclasses import dataclass
from datasets import Dataset
from einops import rearrange
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import BatchEncoding, PreTrainedModel, PreTrainedTokenizerBase
from typing import cast, Literal, Iterator, Sequence
import torch
from dataclasses import dataclass
from datasets import Dataset
import multiprocess as mp
import numpy as np
import torch


@dataclass
class ExtractionParameters:
Expand All @@ -27,26 +29,12 @@ def get_device_ids(local_rank: int, local_world_size: int) -> list[int]:
Splits devices among local_world_size processes and returns their ids.
"""
devices_per_proc = torch.cuda.device_count() // local_world_size
device_ids = list(range(local_rank * devices_per_proc, (local_rank + 1) * devices_per_proc))
device_ids = list(
range(local_rank * devices_per_proc, (local_rank + 1) * devices_per_proc)
)
return device_ids


def uniform_split(elements: list, num_splits: int) -> Iterator[list]:
"""
Splits input list as evenly as possible among num_splits splits. No elements are excluded.
"""
num_per_split = [len(elements) // num_splits] * num_splits

remaining = len(elements) % num_splits
for i in range(remaining):
num_per_split[i] += 1

start_idx = 0
for split_size in num_per_split:
yield elements[start_idx : start_idx + split_size]
start_idx += split_size


def extract_hiddens(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
Expand All @@ -59,16 +47,15 @@ def extract_hiddens(
token_loc: Literal["first", "last", "mean"] = "last",
use_encoder_states: bool = False,
num_procs: int = 1,
seed_start: int = 42
seed_start: int = 42,
):
"""Run inference on a model with a set of prompts, yielding the hidden states."""

# Dataset.from_generator expects a list >= num_procs
# This wraps given parameters into a list of ExtractionParameters with length num_procs
all_indices = list(range(len(collator)))

# This wraps given parameters into a list of ExtractionParameters with length
# num_procs
# Samples are split among processes here, instead of through Dataset.from_generator
all_proc_indices = uniform_split(all_indices, num_procs)
all_proc_indices = np.array_split(np.arange(len(collator)), num_procs)

all_params = []

Expand All @@ -84,38 +71,36 @@ def extract_hiddens(
layers=layers,
prompt_suffix=prompt_suffix,
token_loc=token_loc,
use_encoder_states=use_encoder_states
use_encoder_states=use_encoder_states,
)

all_params.append(params)

curr_seed += 1

# each list needs to have length num_proc
multiprocess_kwargs = {
'wrapped_params': all_params,
'wrapped_rank': list(range(num_procs)),
'wrapped_num_procs': [num_procs] * num_procs
"wrapped_params": all_params,
"wrapped_rank": list(range(num_procs)),
"wrapped_num_procs": [num_procs] * num_procs,
}

# CUDA needs os.spawn instead of the default os.fork (on Unix)
mp.set_start_method("spawn")

return Dataset.from_generator(
_extract_hiddens_process,
gen_kwargs=multiprocess_kwargs,
num_proc=num_procs
_extract_hiddens_process, gen_kwargs=multiprocess_kwargs, num_proc=num_procs
)


@torch.no_grad()
def _extract_hiddens_process(
wrapped_params: list[ExtractionParameters],
wrapped_rank: list[int],
wrapped_num_procs: list[int]
wrapped_num_procs: list[int],
) -> Iterator[dict]:
"""
Internal function for inference on a model with a set of prompts on a single process.
Do inference on a model with a set of prompts on a single process.
To be passed to Dataset.from_generator.
"""

Expand All @@ -126,11 +111,11 @@ def _extract_hiddens_process(

device_ids = get_device_ids(rank, num_procs)

print(f'Process with rank={rank} using GPUs with ids={device_ids}')
print(f"Process with rank={rank} using GPUs with ids={device_ids}")

# TODO: multi-GPU processes (i.e., sharded models) support
if len(device_ids) > 1:
raise ValueError('Only one GPU per process is supported.')
raise ValueError("Only one GPU per process is supported.")

model = params.model.to(device_ids[0])

Expand Down Expand Up @@ -246,8 +231,8 @@ def reduce_seqs(
)
# [batch_size, num_layers, num_choices, hidden_size]
yield {
'hiddens': torch.stack(outputs.decoder_hidden_states, dim=2),
'labels': labels
"hiddens": torch.stack(outputs.decoder_hidden_states, dim=2),
"labels": labels,
}

# Either a decoder-only transformer or a transformer encoder
Expand All @@ -257,6 +242,6 @@ def reduce_seqs(
# Skip the input embeddings which are unlikely to be interesting
h = model(**choices, output_hidden_states=True).hidden_states[1:]
yield {
'hiddens': reduce_seqs(h, choices["attention_mask"]),
'labels': labels
"hiddens": reduce_seqs(h, choices["attention_mask"]),
"labels": labels,
}