Skip to content

Commit

Permalink
Initial implementation of semi-supervised training (EleutherAI#58)
Browse files Browse the repository at this point in the history
* Initial implementation of semi-supervised training

* Full DDP support

* Bug fixes

* Add .item() where necessary

* Use maybe all gather instead of cat

* Shuffle before sharding

* Fix typing problems

* Revert Fabien's change to shuffle/shard ordering

* Send labels to cpu before Linear Regression

* Revert "Send labels to cpu before Linear Regression"

This reverts commit 47a6c76.

* Update README for multi-gpu

* Fix README typo

* First shuffle, then select max_examples examples

* Fix the fix to the typo

* Add --supervised-weight cli arg

* Silence HF warnings from other ranks

---------

Co-authored-by: Fabien Roger <[email protected]>
  • Loading branch information
norabelrose and FabienRoger committed Feb 14, 2023
1 parent da4e023 commit 32e633b
Show file tree
Hide file tree
Showing 16 changed files with 339 additions and 290 deletions.
68 changes: 0 additions & 68 deletions .vscode/launch.json

This file was deleted.

14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ First install the package with `pip install -e .` in the root directory, or `pip
To extract the hidden states for one model `model` and the dataset `dataset` *and* train a probe on these extracted hidden states, run:

```bash
elk elicit microsoft/deberta-v2-xxlarge-mnli imdb --max-examples 1000
elk elicit microsoft/deberta-v2-xxlarge-mnli imdb
```

To only extract the hidden states for one model `model` and the dataset `dataset`, run:

```bash
elk extract microsoft/deberta-v2-xxlarge-mnli imdb --max-examples 1000
elk extract microsoft/deberta-v2-xxlarge-mnli imdb
```

To only train a CCS model and a logistic regression model
Expand All @@ -35,6 +35,16 @@ and evaluate on different datasets: [WIP]

Once finished, results will be saved in `~/.cache/elk/{model}_{prefix}_{seed}.csv`

### Distributed hidden state extraction

You can run the hidden state extraction in parallel on multiple GPUs with `torchrun`. Specifically, you can run the hidden state extraction using all GPUs on a node with:

```bash
torchrun --nproc_per_node gpu -m elk extract microsoft/deberta-v2-xxlarge-mnli imdb
```

Currently, our code doesn't quite support distributed training of the probe. Running `elk train` under `torchrun` tends to hang. We're working on fixing this.

### Development

Use `pip install pre-commit && pre-commit install` in the root folder before your first commit.
Expand Down
95 changes: 51 additions & 44 deletions elk/__main__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from elk.files import args_to_uuid, elk_cache_dir
from .extraction.extraction_main import run as run_extraction
from .extraction.parser import (
add_saveable_args,
add_unsaveable_args,
get_extraction_parser,
)
from .training.parser import add_train_args, get_training_parser
from .extraction.parser import get_extraction_parser
from .training.parser import get_training_parser
from .training.train import train
from argparse import ArgumentParser
from pathlib import Path
from contextlib import nullcontext, redirect_stdout
from elk.files import args_to_uuid
from transformers import AutoConfig, PretrainedConfig
import json
import logging
import os
import torch.distributed as dist


def run():
Expand Down Expand Up @@ -45,21 +43,8 @@ def run():
)
args = parser.parse_args()

# Default to CUDA iff available
if args.device is None:
import torch

args.device = "cuda" if torch.cuda.is_available() else "cpu"

if model := getattr(args, "model", None):
config_path = Path(__file__).parent / "default_config.json"
with open(config_path, "r") as f:
default_config = json.load(f)
model_shortcuts = default_config["model_shortcuts"]

# Dereference shortcut
args.model = model_shortcuts.get(model, model)
config = AutoConfig.from_pretrained(args.model)
config = AutoConfig.from_pretrained(model)
assert isinstance(config, PretrainedConfig)

num_layers = getattr(config, "num_layers", config.num_hidden_layers)
Expand All @@ -72,29 +57,51 @@ def run():
elif args.layer_stride > 1:
args.layers = list(range(0, num_layers, args.layer_stride))

for key in list(vars(args).keys()):
print("{}: {}".format(key, vars(args)[key]))

# TODO: Implement the rest of the CLI
if args.command == "extract":
run_extraction(args)
elif args.command == "train":
train(args)
elif args.command == "elicit":
args.name = args_to_uuid(args)
cache_dir = elk_cache_dir() / args.name
if not cache_dir.exists():
# Support both distributed and non-distributed training
local_rank = os.environ.get("LOCAL_RANK")
if local_rank is not None:
dist.init_process_group("nccl")
local_rank = int(local_rank)

# Default to CUDA iff available
if args.device is None:
import torch

if not torch.cuda.is_available():
args.device = "cpu"
else:
args.device = f"cuda:{local_rank or 0}"

# Prevent printing from processes other than the first one
with redirect_stdout(None) if local_rank != 0 else nullcontext():
for key in list(vars(args).keys()):
print("{}: {}".format(key, vars(args)[key]))

if local_rank != 0:
logging.getLogger("transformers").setLevel(logging.ERROR)

# TODO: Implement the rest of the CLI
if args.command == "extract":
run_extraction(args)
elif args.command == "train":
train(args)
elif args.command == "elicit":
args.name = args_to_uuid(args)
try:
train(args)
except (EOFError, FileNotFoundError):
run_extraction(args)

# Ensure the extraction is finished before starting training
if dist.is_initialized():
dist.barrier()

train(args)

elif args.command == "eval":
raise NotImplementedError
else:
print(
f"Cache dir \033[1m{cache_dir}\033[0m exists, "
"skip extraction of hidden states"
) # bold
train(args)
elif args.command == "eval":
raise NotImplementedError
else:
raise ValueError(f"Unknown command {args.command}")
raise ValueError(f"Unknown command {args.command}")


if __name__ == "__main__":
Expand Down
42 changes: 0 additions & 42 deletions elk/default_config.json

This file was deleted.

2 changes: 2 additions & 0 deletions elk/extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .extraction import extract_hiddens
from .prompt_collator import PromptCollator
6 changes: 4 additions & 2 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from transformers import BatchEncoding, PreTrainedModel, PreTrainedTokenizerBase
from typing import cast, Iterable, Literal, Sequence
import torch
import torch.distributed as dist


@torch.autocast("cuda", enabled=torch.cuda.is_available())
@torch.autocast("cuda", enabled=torch.cuda.is_available()) # type: ignore
@torch.no_grad()
def extract_hiddens(
model: PreTrainedModel,
Expand Down Expand Up @@ -125,7 +126,8 @@ def reduce_seqs(
)

# Iterating over questions
for batch in tqdm(dl):
rank = dist.get_rank() if dist.is_initialized() else 0
for batch in tqdm(dl, position=rank):
# Condition 1: Encoder-decoder transformer, with answer in the decoder
if not should_concat:
questions, answers, labels = batch
Expand Down
26 changes: 18 additions & 8 deletions elk/extraction/extraction_main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from .extraction import extract_hiddens, PromptCollator
from ..files import args_to_uuid, elk_cache_dir
from ..training.preprocessing import silence_datasets_messages
from ..utils import maybe_all_gather
from transformers import AutoModel, AutoTokenizer
import json
import torch
import torch.distributed as dist


def run(args):
rank = dist.get_rank() if dist.is_initialized() else 0

def extract(args, split: str):
frac = 1 - args.val_frac if split == "train" else args.val_frac

Expand All @@ -29,7 +33,7 @@ def extract(args, split: str):
raise ValueError(f"Unknown prompt strategy: {args.prompts}")

items = [
(features.cpu(), labels)
(features, labels)
for features, labels in extract_hiddens(
model,
tokenizer,
Expand All @@ -44,9 +48,14 @@ def extract(args, split: str):

with open(save_dir / f"{split}_hiddens.pt", "wb") as f:
hidden_batches, label_batches = zip(*items)
hiddens = torch.cat(hidden_batches) # type: ignore
labels = sum(label_batches, [])
torch.save((hiddens, labels), f)
hiddens = maybe_all_gather(torch.cat(hidden_batches)) # type: ignore

# Moving labels to GPU just to be able to use maybe_all_gather
labels = torch.tensor(sum(label_batches, []), device=hiddens.device)
labels = maybe_all_gather(labels) # type: ignore

if rank == 0:
torch.save((hiddens.cpu(), labels.cpu()), f)

# AutoModel should do the right thing here in nearly all cases. We don't actually
# care what head the model has, since we are just extracting hidden states.
Expand Down Expand Up @@ -75,8 +84,9 @@ def extract(args, split: str):
extract(args, "train")
extract(args, "validation")

with open(save_dir / "args.json", "w") as f:
json.dump(vars(args), f)
if rank == 0:
with open(save_dir / "args.json", "w") as f:
json.dump(vars(args), f)

with open(save_dir / "model_config.json", "w") as f:
json.dump(model.config.to_dict(), f)
with open(save_dir / "model_config.json", "w") as f:
json.dump(model.config.to_dict(), f)
1 change: 0 additions & 1 deletion elk/extraction/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def add_saveable_args(parser):
)
parser.add_argument(
"--max-examples",
default=1000,
type=int,
help="Maximum number of examples to use from each dataset.",
)
Expand Down
Loading

0 comments on commit 32e633b

Please sign in to comment.