Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use one file per layer #59

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 0 additions & 68 deletions .vscode/launch.json

This file was deleted.

137 changes: 93 additions & 44 deletions elk/__main__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from elk.files import args_to_uuid, elk_cache_dir
from .extraction.extraction_main import run as run_extraction
from .extraction.parser import (
add_saveable_args,
add_unsaveable_args,
get_extraction_parser,
)
from .training.parser import add_train_args, get_training_parser
from .training.train import train
import json
import os
from argparse import ArgumentParser
from pathlib import Path
from contextlib import nullcontext, redirect_stdout
from typing import Optional

import torch.distributed as dist
from transformers import AutoConfig, PretrainedConfig
import json

from elk.files import args_to_uuid, elk_cache_dir, get_hiddens_path

from .extraction.extraction_main import run as run_extraction
from .extraction.parser import get_extraction_parser
from .training.parser import get_training_parser
from .training.train import train


def run():
Expand Down Expand Up @@ -45,21 +47,66 @@ def run():
)
args = parser.parse_args()

normalize_args_inplace(args)

# Support both distributed and non-distributed training
local_rank = os.environ.get("LOCAL_RANK")
if local_rank is not None:
dist.init_process_group("nccl")
local_rank = int(local_rank)

# Prevent printing from processes other than the first one
with redirect_stdout(None) if local_rank != 0 else nullcontext():
for key in list(vars(args).keys()):
print("{}: {}".format(key, vars(args)[key]))

# TODO: Implement the rest of the CLI
if args.command == "extract":
run_extraction(args)
elif args.command == "train":
train(args)
elif args.command == "elicit":
# Extract the hidden states if they're not already there
args.name = args_to_uuid(args)
cache_dir = elk_cache_dir() / args.name
missing_layers = find_missing_layers(args)
if missing_layers:
if cache_dir.exists():
print(
f"Found cache dir \033[1m{cache_dir}\033[0m"
f" but it's missing layers {', '.join(missing_layers)}"
)

old_layers = args.layers
args.layers = missing_layers
run_extraction(args)
args.layers = old_layers

# Ensure the extraction is finished before starting training
if dist.is_initialized():
dist.barrier()

train(args)

elif args.command == "eval":
raise NotImplementedError
else:
raise ValueError(f"Unknown command {args.command}")


def normalize_args_inplace(args):
# Default to CUDA iff available
if args.device is None:
import torch

args.device = "cuda" if torch.cuda.is_available() else "cpu"
if not torch.cuda.is_available():
args.device = "cpu"
else:
rank = dist.get_rank() if dist.is_initialized() else 0
args.device = f"cuda:{rank}"

if model := getattr(args, "model", None):
config_path = Path(__file__).parent / "default_config.json"
with open(config_path, "r") as f:
default_config = json.load(f)
model_shortcuts = default_config["model_shortcuts"]

# Dereference shortcut
args.model = model_shortcuts.get(model, model)
config = AutoConfig.from_pretrained(args.model)
config = AutoConfig.from_pretrained(model)
assert isinstance(config, PretrainedConfig)

num_layers = getattr(config, "num_layers", config.num_hidden_layers)
Expand All @@ -70,31 +117,33 @@ def run():
"Cannot use both --layers and --layer-stride. Please use only one."
)
elif args.layer_stride > 1:
args.layers = list(range(0, num_layers, args.layer_stride))

for key in list(vars(args).keys()):
print("{}: {}".format(key, vars(args)[key]))

# TODO: Implement the rest of the CLI
if args.command == "extract":
run_extraction(args)
elif args.command == "train":
train(args)
elif args.command == "elicit":
args.name = args_to_uuid(args)
cache_dir = elk_cache_dir() / args.name
if not cache_dir.exists():
run_extraction(args)
else:
print(
f"Cache dir \033[1m{cache_dir}\033[0m exists, "
"skip extraction of hidden states"
) # bold
train(args)
elif args.command == "eval":
raise NotImplementedError
# the last layer is often the most interesting
# layers = [..., num_layers - 1 - layer_stride, num_layers - 1]
args.layers = list(range(num_layers - 1, -1, -args.layer_stride)).reverse()
else:
raise ValueError(f"Unknown command {args.command}")
assert (
args.name is not None
) # If the model is not provided, it means we are using the name
config = json.load(open(elk_cache_dir() / args.name / "model_config.json", "r"))
num_layers = config.get("num_layers", config.get("num_hidden_layers"))

args.layers = normalized_layers(args.layers, num_layers)


def normalized_layers(layers: Optional[list[int]], num_layers: int) -> list[int]:
layers = layers or list(range(num_layers))
return [layer if layer >= 0 else num_layers + layer for layer in layers]


def find_missing_layers(args):
missing_layers = []
for layer in args.layers:
cache_dir = elk_cache_dir() / args.name
train_layer_path = get_hiddens_path(cache_dir, "train", layer)
validation_layer_path = get_hiddens_path(cache_dir, "validation", layer)
if not train_layer_path.exists() or not validation_layer_path.exists():
missing_layers.append(layer)
return missing_layers


if __name__ == "__main__":
Expand Down
42 changes: 0 additions & 42 deletions elk/default_config.json

This file was deleted.

2 changes: 2 additions & 0 deletions elk/extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .extraction import extract_hiddens
from .prompt_collator import PromptCollator
6 changes: 4 additions & 2 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from transformers import BatchEncoding, PreTrainedModel, PreTrainedTokenizerBase
from typing import cast, Iterable, Literal, Sequence
import torch
import torch.distributed as dist


@torch.autocast("cuda", enabled=torch.cuda.is_available())
@torch.autocast("cuda", enabled=torch.cuda.is_available()) # type: ignore
@torch.no_grad()
def extract_hiddens(
model: PreTrainedModel,
Expand Down Expand Up @@ -125,7 +126,8 @@ def reduce_seqs(
)

# Iterating over questions
for batch in tqdm(dl):
rank = dist.get_rank() if dist.is_initialized() else 0
for batch in tqdm(dl, position=rank):
# Condition 1: Encoder-decoder transformer, with answer in the decoder
if not should_concat:
questions, answers, labels = batch
Expand Down
37 changes: 26 additions & 11 deletions elk/extraction/extraction_main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from pathlib import Path
from .extraction import extract_hiddens, PromptCollator
from ..files import args_to_uuid, elk_cache_dir
from ..files import args_to_uuid, elk_cache_dir, get_hiddens_path, get_labels_path
from ..training.preprocessing import silence_datasets_messages
from ..utils import maybe_all_gather
from transformers import AutoModel, AutoTokenizer
import json
import torch
import torch.distributed as dist


def run(args):
rank = dist.get_rank() if dist.is_initialized() else 0

def extract(args, split: str):
frac = 1 - args.val_frac if split == "train" else args.val_frac

Expand All @@ -29,7 +34,7 @@ def extract(args, split: str):
raise ValueError(f"Unknown prompt strategy: {args.prompts}")

items = [
(features.cpu(), labels)
(features, labels)
for features, labels in extract_hiddens(
model,
tokenizer,
Expand All @@ -42,11 +47,20 @@ def extract(args, split: str):
]
save_dir.mkdir(parents=True, exist_ok=True)

with open(save_dir / f"{split}_hiddens.pt", "wb") as f:
hidden_batches, label_batches = zip(*items)
hiddens = torch.cat(hidden_batches) # type: ignore
labels = sum(label_batches, [])
torch.save((hiddens, labels), f)
hidden_batches, label_batches = zip(*items)
hiddens = maybe_all_gather(torch.cat(hidden_batches)) # type: ignore

# Moving labels to GPU just to be able to use maybe_all_gather
labels = torch.tensor(sum(label_batches, []), device=hiddens.device)
labels = maybe_all_gather(labels) # type: ignore

if rank == 0:
for layer in args.layers:
hiddens_at_l = hiddens[:, layer, :, :]
with open(get_hiddens_path(save_dir, split, layer), "wb") as f:
torch.save(hiddens_at_l, f)
with open(get_labels_path(save_dir, split), "wb") as f:
torch.save(labels, f)

# AutoModel should do the right thing here in nearly all cases. We don't actually
# care what head the model has, since we are just extracting hidden states.
Expand Down Expand Up @@ -75,8 +89,9 @@ def extract(args, split: str):
extract(args, "train")
extract(args, "validation")

with open(save_dir / "args.json", "w") as f:
json.dump(vars(args), f)
if rank == 0:
with open(save_dir / "args.json", "w") as f:
json.dump(vars(args), f)

with open(save_dir / "model_config.json", "w") as f:
json.dump(model.config.to_dict(), f)
with open(save_dir / "model_config.json", "w") as f:
json.dump(model.config.to_dict(), f)
1 change: 0 additions & 1 deletion elk/extraction/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def add_saveable_args(parser):
)
parser.add_argument(
"--max-examples",
default=1000,
type=int,
help="Maximum number of examples to use from each dataset.",
)
Expand Down
Loading