From 3a97312c8e5be1c016e00419a088032a4b944164 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Mon, 13 Feb 2023 17:35:01 +0000 Subject: [PATCH 01/15] Initial implementation of semi-supervised training --- elk/extraction/extraction.py | 4 +- elk/extraction/extraction_main.py | 21 +++++--- elk/extraction/parser.py | 1 - elk/extraction/prompt_collator.py | 8 ++-- elk/training/ccs.py | 80 ++++++++++++++++++++++++------- elk/training/losses.py | 46 +++++++----------- elk/training/parser.py | 8 +++- elk/training/train.py | 60 +++++++++++++---------- elk/utils.py | 20 ++++++++ 9 files changed, 165 insertions(+), 83 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index 853a8fcf..f3a3eaab 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -6,6 +6,7 @@ from transformers import BatchEncoding, PreTrainedModel, PreTrainedTokenizerBase from typing import cast, Iterable, Literal, Sequence import torch +import torch.distributed as dist @torch.autocast("cuda", enabled=torch.cuda.is_available()) @@ -125,7 +126,8 @@ def reduce_seqs( ) # Iterating over questions - for batch in tqdm(dl): + rank = dist.get_rank() if dist.is_initialized() else 0 + for batch in tqdm(dl, position=rank): # Condition 1: Encoder-decoder transformer, with answer in the decoder if not should_concat: questions, answers, labels = batch diff --git a/elk/extraction/extraction_main.py b/elk/extraction/extraction_main.py index 8ab2ec5b..38b8fe0b 100644 --- a/elk/extraction/extraction_main.py +++ b/elk/extraction/extraction_main.py @@ -1,12 +1,16 @@ from .extraction import extract_hiddens, PromptCollator from ..files import args_to_uuid, elk_cache_dir from ..training.preprocessing import silence_datasets_messages +from ..utils import maybe_all_cat, maybe_all_gather_lists from transformers import AutoModel, AutoTokenizer import json import torch +import torch.distributed as dist def run(args): + rank = dist.get_rank() if dist.is_initialized() else 0 + def extract(args, split: str): frac = 1 - args.val_frac if split == "train" else args.val_frac @@ -43,9 +47,11 @@ def extract(args, split: str): with open(save_dir / f"{split}_hiddens.pt", "wb") as f: hidden_batches, label_batches = zip(*items) - hiddens = torch.cat(hidden_batches) # type: ignore - labels = sum(label_batches, []) - torch.save((hiddens, labels), f) + hiddens = maybe_all_cat(torch.cat(hidden_batches)) # type: ignore + labels = maybe_all_gather_lists(sum(label_batches, [])) # type: ignore + + if rank == 0: + torch.save((hiddens, labels), f) # AutoModel should do the right thing here in nearly all cases. We don't actually # care what head the model has, since we are just extracting hidden states. @@ -74,8 +80,9 @@ def extract(args, split: str): extract(args, "train") extract(args, "validation") - with open(save_dir / "args.json", "w") as f: - json.dump(vars(args), f) + if rank == 0: + with open(save_dir / "args.json", "w") as f: + json.dump(vars(args), f) - with open(save_dir / "model_config.json", "w") as f: - json.dump(model.config.to_dict(), f) + with open(save_dir / "model_config.json", "w") as f: + json.dump(model.config.to_dict(), f) diff --git a/elk/extraction/parser.py b/elk/extraction/parser.py index 30ed95a0..6ad16d09 100644 --- a/elk/extraction/parser.py +++ b/elk/extraction/parser.py @@ -27,7 +27,6 @@ def add_saveable_args(parser): ) parser.add_argument( "--max-examples", - default=1000, type=int, help="Maximum number of examples to use from each dataset.", ) diff --git a/elk/extraction/prompt_collator.py b/elk/extraction/prompt_collator.py index 42c7c17a..0439dc9d 100644 --- a/elk/extraction/prompt_collator.py +++ b/elk/extraction/prompt_collator.py @@ -4,6 +4,7 @@ from random import Random from typing import Literal, Optional import numpy as np +import torch.distributed as dist @dataclass @@ -38,10 +39,8 @@ def __init__( if not others: print("Creating a train-test split...") data = data[train_name].train_test_split( - seed=seed, stratify_by_column=label_column + seed=seed, shuffle=False, stratify_by_column=label_column ) - else: - data = data.shuffle(seed) if split not in data and split == "validation": print("No validation split found, using test split instead") @@ -54,7 +53,10 @@ def __init__( raise ValueError(f"Dataset {path}/{name} has only one label") if max_examples: self.dataset = self.dataset.select(range(max_examples)) + if dist.is_initialized(): + self.dataset = self.dataset.shard(dist.get_world_size(), dist.get_rank()) + self.dataset = self.dataset.shuffle(seed=seed) self.label_column = label_column self.prompter = DatasetTemplates(path, subset_name=name) # type: ignore self.rng = Random(seed) diff --git a/elk/training/ccs.py b/elk/training/ccs.py index be9b209b..d9f96c9e 100644 --- a/elk/training/ccs.py +++ b/elk/training/ccs.py @@ -2,7 +2,9 @@ from copy import deepcopy from pathlib import Path from sklearn.metrics import roc_auc_score +from torch.nn.functional import binary_cross_entropy as bce from typing import cast, Literal, NamedTuple, Optional, Type, Union +import math import torch import torch.nn as nn @@ -25,7 +27,7 @@ def __init__( bias: bool = True, device: str = "cuda", hidden_size: Optional[int] = None, - init: Literal["default", "spherical"] = "default", + init: Literal["default", "spherical", "zero"] = "zero", loss: Literal["js", "squared"] = "squared", num_layers: int = 1, pre_ln: bool = False, @@ -63,7 +65,7 @@ def __init__( self.init = init self.device = device - self.loss = js_loss if loss == "js" else ccs_squared_loss + self.unsupervised_loss = js_loss if loss == "js" else ccs_squared_loss def reset_parameters(self): # Mathematically equivalent to the unusual initialization scheme used in the @@ -85,6 +87,9 @@ def reset_parameters(self): for layer in self.probe: if isinstance(layer, nn.Linear): layer.reset_parameters() + elif self.init == "zero": + for param in self.parameters(): + param.data.zero_() else: raise ValueError(f"Unknown init: {self.init}") @@ -101,12 +106,35 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """Return the raw score output of the probe on `x`.""" return self.probe(x) + def loss( + self, + logit0: torch.Tensor, + logit1: torch.Tensor, + labels: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Return the loss of the probe on the contrast pair `(x0, x1)`.""" + loss = self.unsupervised_loss(logit0, logit1) + + if labels is not None: + # If labels are provided, use them to compute a supervised loss + num_labels = len(labels) + label_frac = num_labels / len(logit0) + assert label_frac <= 1.0, "Too many labels provided" + p0 = logit0[:num_labels].sigmoid() + p1 = logit1[:num_labels].sigmoid() + preds = 0.5 * (p0 + (1 - p1)) + loss += bce(preds.squeeze(-1), labels.type_as(preds)) + + return loss + def validate_data(self, data): assert len(data) == 2 and data[0].shape == data[1].shape def fit( self, - data: tuple[torch.Tensor, torch.Tensor], + contrast_pair: tuple[torch.Tensor, torch.Tensor], + labels: Optional[torch.Tensor] = None, + *, lr: float = 1e-2, num_epochs: int = 1000, num_tries: int = 10, @@ -114,22 +142,24 @@ def fit( verbose: bool = False, weight_decay: float = 0.01, ) -> float: - self.validate_data(data) + self.validate_data(contrast_pair) if verbose: print(f"Fitting CCS probe; {num_epochs=}, {num_tries=}, {lr=}") # Record the best acc, loss, and params found so far best_loss = torch.inf best_state: dict[str, torch.Tensor] = {} # State dict of the best run - x0, x1 = data + x0, x1 = contrast_pair for _ in range(num_tries): self.reset_parameters() if optimizer == "lbfgs": - loss = self.train_loop_lbfgs(x0, x1, num_epochs, weight_decay) + loss = self.train_loop_lbfgs(x0, x1, labels, num_epochs, weight_decay) elif optimizer == "adam": - loss = self.train_loop_adam(x0, x1, lr, num_epochs, weight_decay) + loss = self.train_loop_adam( + x0, x1, labels, lr, num_epochs, weight_decay + ) else: raise ValueError(f"Optimizer {optimizer} is not supported") @@ -140,18 +170,21 @@ def fit( best_loss = loss best_state = deepcopy(self.state_dict()) + if not math.isfinite(best_loss): + raise RuntimeError("Got NaN/infinite loss during training") + self.load_state_dict(best_state) return best_loss @torch.no_grad() def score( self, - data: tuple[torch.Tensor, torch.Tensor], + contrast_pair: tuple[torch.Tensor, torch.Tensor], labels: torch.Tensor, ) -> EvalResult: - self.validate_data(data) + self.validate_data(contrast_pair) - logit0, logit1 = map(self, data) + logit0, logit1 = map(self, contrast_pair) p0, p1 = logit0.sigmoid(), logit1.sigmoid() pred_probs = 0.5 * (p0 + (1 - p1)) @@ -165,13 +198,21 @@ def score( raw_acc = raw_preds.eq(labels.reshape(-1)).float().mean() return EvalResult( - loss=self.loss(logit0, logit1).item(), + loss=self.loss(logit0, logit1, labels).item(), acc=torch.max(raw_acc, 1 - raw_acc).item(), cal_acc=torch.max(cal_acc, 1 - cal_acc).item(), auroc=max(auroc, 1 - auroc), ) - def train_loop_adam(self, x0, x1, lr: float, num_epochs: int, wd: float) -> float: + def train_loop_adam( + self, + x0, + x1, + labels: Optional[torch.Tensor], + lr: float, + num_epochs: int, + wd: float, + ) -> float: """Adam train loop, returning the final loss. Modifies params in-place.""" optimizer = torch.optim.AdamW( self.parameters(), @@ -183,15 +224,16 @@ def train_loop_adam(self, x0, x1, lr: float, num_epochs: int, wd: float) -> floa for _ in range(num_epochs): optimizer.zero_grad() - logit0, logit1 = self(x0), self(x1) - loss = self.loss(logit0, logit1) + loss = self.loss(self(x0), self(x1), labels) loss.backward() optimizer.step() return float(loss) - def train_loop_lbfgs(self, x0, x1, max_iter: int, l2_penalty: float) -> float: + def train_loop_lbfgs( + self, x0, x1, labels: Optional[torch.Tensor], max_iter: int, l2_penalty: float + ) -> float: """LBFGS train loop, returning the final loss. Modifies params in-place.""" optimizer = torch.optim.LBFGS( @@ -208,8 +250,7 @@ def closure(): nonlocal loss optimizer.zero_grad() - logit0, logit1 = self(x0), self(x1) - loss = self.loss(logit0, logit1) + loss = self.loss(self(x0), self(x1), labels) regularizer = 0.0 # We explicitly add L2 regularization to the loss, since LBFGS @@ -218,7 +259,10 @@ def closure(): regularizer += l2_penalty * param.norm() ** 2 / 2 regularized = loss + regularizer - regularized.backward() + if regularized.isfinite(): + regularized.backward() + else: + print("Got NaN loss; skipping backward pass") return float(regularized) diff --git a/elk/training/losses.py b/elk/training/losses.py index 93bbeb23..5bce6ea1 100644 --- a/elk/training/losses.py +++ b/elk/training/losses.py @@ -1,43 +1,33 @@ -from torch.nn.functional import binary_cross_entropy_with_logits as bce_logits +from torch import Tensor import math import torch -def bernoulli_js( - logit0: torch.Tensor, logit1: torch.Tensor, base: float = 2.0, alpha: float = 1.0 -): - """Jensen-Shannon divergence between Bernoulli distributions. - - Note that by default we use the base 2 logarithm, so the value is measured in bits. - This ensures the divergence is in the range [0, 1]. - """ - # Arithmetic mixture of the two distributions. For numerical stability, we - # do the operation in log space. - log_p_bar = torch.stack([logit0, logit1]).logsumexp(0) - math.log(2) - - H_bar = bce_logits(log_p_bar, log_p_bar, reduction="none") - H0 = bce_logits(logit0, logit0, reduction="none") - H1 = bce_logits(logit1, logit1, reduction="none") - - nats = H_bar - alpha * (H0 + H1) / 2 - return nats / math.log(base) +def H(p: Tensor) -> Tensor: + """Entropy of Bernoulli distribution(s) with success probability `p`.""" + return torch.nn.functional.binary_cross_entropy(p, p) def js_loss( - logit0: torch.Tensor, - logit1: torch.Tensor, - alpha: float = 0.5, + logit0: Tensor, + logit1: Tensor, + confidence: float = 0.0, base: float = 2.0, -) -> torch.Tensor: - """Consistency and confidence loss based on the Jensen-Shannon divergence.""" - return bernoulli_js(logit0, logit1, alpha=alpha, base=base).mean() +) -> Tensor: + """Consistency and confidence loss based on the Jensen-Shannon divergence. + + Note that by default we use the base 2 logarithm, so the value is measured in bits. + This ensures the divergence is in the range [0, 1].""" + p0, neg_p1 = logit0.sigmoid(), 1 - logit1.sigmoid() + nats = (1 + confidence) * H((p0 + neg_p1) / 2) - (H(p0) + H(neg_p1)) / 2 + return nats / math.log(base) -def ccs_squared_loss(logit0: torch.Tensor, logit1: torch.Tensor) -> torch.Tensor: +def ccs_squared_loss(logit0: Tensor, logit1: Tensor) -> Tensor: """CCS loss from original paper, with squared differences between probabilities.""" p0, p1 = logit0.sigmoid(), logit1.sigmoid() - consistency = torch.mean((p0 - (1 - p1)) ** 2, dim=0) - confidence = torch.mean(torch.min(p0, p1) ** 2, dim=0) + consistency = p0.sub(1 - p1).square().mean() + confidence = torch.min(p0, p1).square().mean() return consistency + confidence diff --git a/elk/training/parser.py b/elk/training/parser.py index b62b6210..b86213a8 100644 --- a/elk/training/parser.py +++ b/elk/training/parser.py @@ -27,9 +27,15 @@ def add_train_args(parser: ArgumentParser): "--init", type=str, default="default", - choices=("default", "spherical"), + choices=("default", "spherical", "zero"), help="Initialization for CCS probe.", ) + parser.add_argument( + "--label-frac", + type=float, + default=0.0, + help="Fraction of labeled data to use for training.", + ) parser.add_argument( "--loss", type=str, diff --git a/elk/training/train.py b/elk/training/train.py index 7af2ab81..65457a85 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -1,20 +1,18 @@ -import csv -import pickle -import random - -import numpy as np -import torch -from sklearn.linear_model import LogisticRegression -from sklearn.metrics import accuracy_score, roc_auc_score -from tqdm.auto import tqdm - from ..files import elk_cache_dir from .ccs import CCS from .parser import get_training_parser from .preprocessing import load_hidden_states, normalize +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import accuracy_score, roc_auc_score +from tqdm.auto import tqdm +import csv +import numpy as np +import pickle +import random +import torch -@torch.autocast("cuda", enabled=torch.cuda.is_available()) +# @torch.autocast("cuda", enabled=torch.cuda.is_available()) def train(args): # Reproducibility np.random.seed(args.seed) @@ -65,25 +63,25 @@ def train(args): ) for train_h, val_h in pbar: - # TODO: Once we implement cross-validation for CCS, we should benchmark against - # LogisticRegressionCV here. - pbar.set_description("Fitting LR") - lr_model = LogisticRegression(max_iter=10000, n_jobs=1, C=0.1) - lr_model.fit(train_h, train_labels) + x0, x1 = train_h.to(args.device).float().chunk(2, dim=-1) + val_x0, val_x1 = val_h.to(args.device).float().chunk(2, dim=-1) - lr_preds = lr_model.predict_proba(val_h)[:, 1] - lr_acc = accuracy_score(val_labels, lr_preds > 0.5) - lr_auroc = roc_auc_score(val_labels, lr_preds) + train_labels_aug = train_labels + [1 - label for label in train_labels] + val_labels_aug = val_labels + [1 - label for label in val_labels] pbar.set_description("Fitting CCS") - x0, x1 = train_h.to(args.device).chunk(2, dim=-1) - val_x0, val_x1 = val_h.to(args.device).chunk(2, dim=-1) - ccs_model = CCS( in_features=x0.shape[-1], device=args.device, init=args.init, loss=args.loss ) + if args.label_frac: + num_labels = round(args.label_frac * len(train_labels)) + labels = torch.tensor(train_labels[:num_labels], device=args.device) + else: + labels = None + train_loss = ccs_model.fit( - data=(x0, x1), + contrast_pair=(x0, x1), + labels=labels, num_tries=args.num_tries, optimizer=args.optimizer, weight_decay=args.weight_decay, @@ -92,7 +90,21 @@ def train(args): (val_x0, val_x1), torch.tensor(val_labels, device=args.device), ) - pbar.set_postfix(ccs_auroc=val_result.auroc, lr_auroc=lr_auroc) + pbar.set_postfix(train_loss=train_loss, ccs_auroc=val_result.auroc) + + # TODO: Once we implement cross-validation for CCS, we should benchmark against + # LogisticRegressionCV here. + pbar.set_description("Fitting LR") + lr_model = LogisticRegression(max_iter=10_000) + lr_model.fit(torch.cat([x0, x1]).cpu(), train_labels_aug) + + lr_preds = lr_model.predict_proba(torch.cat([val_x0, val_x1]).cpu())[:, 1] + lr_acc = accuracy_score(val_labels_aug, lr_preds > 0.5) + lr_auroc = roc_auc_score(val_labels_aug, lr_preds) + pbar.set_postfix( + train_loss=train_loss, ccs_auroc=val_result.auroc, lr_auroc=lr_auroc + ) + stats = [train_loss, *val_result, lr_auroc, lr_acc] writer.writerow([L - pbar.n] + [f"{s:.4f}" for s in stats]) diff --git a/elk/utils.py b/elk/utils.py index df1c1789..2299f857 100644 --- a/elk/utils.py +++ b/elk/utils.py @@ -1,4 +1,24 @@ +from torch import Tensor from typing import Callable, Mapping, TypeVar +import torch.distributed as dist + + +def maybe_all_cat(x: Tensor) -> Tensor: + if not dist.is_initialized(): + return x + + buffer = x.new_empty([dist.get_world_size() * x.shape[0], *x.shape[1:]]) + dist.all_gather_into_tensor(buffer, x) + return buffer + + +def maybe_all_gather_lists(lst: list) -> list: + if not dist.is_initialized(): + return lst + + lists = [[] for _ in range(dist.get_world_size())] + dist.all_gather_object(lists, lst) + return sum(lists, []) TreeType = TypeVar("TreeType") From 421996c5935e8251eeabd7249523c7c6b60a834c Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Tue, 14 Feb 2023 00:07:23 +0000 Subject: [PATCH 02/15] Full DDP support --- elk/__main__.py | 79 ++++++++++++++++------------- elk/extraction/__init__.py | 2 + elk/extraction/extraction_main.py | 11 ++-- elk/training/__init__.py | 0 elk/training/ccs.py | 40 ++++++++++----- elk/training/parser.py | 6 ++- elk/training/preprocessing.py | 4 +- elk/training/train.py | 84 +++++++++++++++++-------------- elk/utils.py | 21 +++++--- 9 files changed, 151 insertions(+), 96 deletions(-) create mode 100644 elk/training/__init__.py diff --git a/elk/__main__.py b/elk/__main__.py index 8bfc04fb..4f764e0a 100644 --- a/elk/__main__.py +++ b/elk/__main__.py @@ -1,16 +1,15 @@ -from elk.files import args_to_uuid, elk_cache_dir from .extraction.extraction_main import run as run_extraction -from .extraction.parser import ( - add_saveable_args, - add_unsaveable_args, - get_extraction_parser, -) -from .training.parser import add_train_args, get_training_parser +from .extraction.parser import get_extraction_parser +from .training.parser import get_training_parser from .training.train import train from argparse import ArgumentParser +from contextlib import nullcontext, redirect_stdout +from elk.files import args_to_uuid, elk_cache_dir from pathlib import Path from transformers import AutoConfig, PretrainedConfig import json +import os +import torch.distributed as dist def run(): @@ -45,12 +44,6 @@ def run(): ) args = parser.parse_args() - # Default to CUDA iff available - if args.device is None: - import torch - - args.device = "cuda" if torch.cuda.is_available() else "cpu" - if model := getattr(args, "model", None): config_path = Path(__file__).parent / "default_config.json" with open(config_path, "r") as f: @@ -72,29 +65,47 @@ def run(): elif args.layer_stride > 1: args.layers = list(range(0, num_layers, args.layer_stride)) - for key in list(vars(args).keys()): - print("{}: {}".format(key, vars(args)[key])) - - # TODO: Implement the rest of the CLI - if args.command == "extract": - run_extraction(args) - elif args.command == "train": - train(args) - elif args.command == "elicit": - args.name = args_to_uuid(args) - cache_dir = elk_cache_dir() / args.name - if not cache_dir.exists(): + # Support both distributed and non-distributed training + local_rank = os.environ.get("LOCAL_RANK") + if local_rank is not None: + dist.init_process_group("nccl") + local_rank = int(local_rank) + + # Default to CUDA iff available + if args.device is None: + import torch + + if not torch.cuda.is_available(): + args.device = "cpu" + else: + args.device = f"cuda:{local_rank or 0}" + + # Prevent printing from processes other than the first one + with redirect_stdout(None) if local_rank != 0 else nullcontext(): + for key in list(vars(args).keys()): + print("{}: {}".format(key, vars(args)[key])) + + # TODO: Implement the rest of the CLI + if args.command == "extract": run_extraction(args) + elif args.command == "train": + train(args) + elif args.command == "elicit": + args.name = args_to_uuid(args) + cache_dir = elk_cache_dir() / args.name + + if not cache_dir.exists(): + run_extraction(args) + else: + print( + f"Cache dir \033[1m{cache_dir}\033[0m exists, " + "skipping extraction of hidden states" + ) # bold + train(args) + elif args.command == "eval": + raise NotImplementedError else: - print( - f"Cache dir \033[1m{cache_dir}\033[0m exists, " - "skip extraction of hidden states" - ) # bold - train(args) - elif args.command == "eval": - raise NotImplementedError - else: - raise ValueError(f"Unknown command {args.command}") + raise ValueError(f"Unknown command {args.command}") if __name__ == "__main__": diff --git a/elk/extraction/__init__.py b/elk/extraction/__init__.py index e69de29b..e574c053 100644 --- a/elk/extraction/__init__.py +++ b/elk/extraction/__init__.py @@ -0,0 +1,2 @@ +from .extraction import extract_hiddens +from .prompt_collator import PromptCollator diff --git a/elk/extraction/extraction_main.py b/elk/extraction/extraction_main.py index 38b8fe0b..20e585a3 100644 --- a/elk/extraction/extraction_main.py +++ b/elk/extraction/extraction_main.py @@ -1,7 +1,7 @@ from .extraction import extract_hiddens, PromptCollator from ..files import args_to_uuid, elk_cache_dir from ..training.preprocessing import silence_datasets_messages -from ..utils import maybe_all_cat, maybe_all_gather_lists +from ..utils import maybe_all_cat from transformers import AutoModel, AutoTokenizer import json import torch @@ -32,7 +32,7 @@ def extract(args, split: str): raise ValueError(f"Unknown prompt strategy: {args.prompts}") items = [ - (features.cpu(), labels) + (features, labels) for features, labels in extract_hiddens( model, tokenizer, @@ -48,10 +48,13 @@ def extract(args, split: str): with open(save_dir / f"{split}_hiddens.pt", "wb") as f: hidden_batches, label_batches = zip(*items) hiddens = maybe_all_cat(torch.cat(hidden_batches)) # type: ignore - labels = maybe_all_gather_lists(sum(label_batches, [])) # type: ignore + + # Moving labels to GPU just to be able to use maybe_all_cat + labels = torch.tensor(sum(label_batches, []), device=hiddens.device) + labels = maybe_all_cat(labels) # type: ignore if rank == 0: - torch.save((hiddens, labels), f) + torch.save((hiddens.cpu(), labels.cpu()), f) # AutoModel should do the right thing here in nearly all cases. We don't actually # care what head the model has, since we are just extracting hidden states. diff --git a/elk/training/__init__.py b/elk/training/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/elk/training/ccs.py b/elk/training/ccs.py index d9f96c9e..d28045af 100644 --- a/elk/training/ccs.py +++ b/elk/training/ccs.py @@ -1,4 +1,5 @@ from .losses import ccs_squared_loss, js_loss +from ..utils import maybe_ddp_wrap, maybe_all_cat from copy import deepcopy from pathlib import Path from sklearn.metrics import roc_auc_score @@ -25,12 +26,13 @@ def __init__( *, activation: Type[nn.Module] = nn.GELU, bias: bool = True, - device: str = "cuda", + device: Optional[str] = None, hidden_size: Optional[int] = None, init: Literal["default", "spherical", "zero"] = "zero", loss: Literal["js", "squared"] = "squared", num_layers: int = 1, pre_ln: bool = False, + supervised_weight: float = 0.0, ): super().__init__() @@ -66,6 +68,7 @@ def __init__( self.init = init self.device = device self.unsupervised_loss = js_loss if loss == "js" else ccs_squared_loss + self.supervised_weight = supervised_weight def reset_parameters(self): # Mathematically equivalent to the unusual initialization scheme used in the @@ -115,15 +118,22 @@ def loss( """Return the loss of the probe on the contrast pair `(x0, x1)`.""" loss = self.unsupervised_loss(logit0, logit1) + # If labels are provided, use them to compute a supervised loss if labels is not None: - # If labels are provided, use them to compute a supervised loss num_labels = len(labels) - label_frac = num_labels / len(logit0) - assert label_frac <= 1.0, "Too many labels provided" + assert num_labels <= len(logit0), "Too many labels provided" p0 = logit0[:num_labels].sigmoid() p1 = logit1[:num_labels].sigmoid() - preds = 0.5 * (p0 + (1 - p1)) - loss += bce(preds.squeeze(-1), labels.type_as(preds)) + + alpha = self.supervised_weight + preds = p0.add(1 - p1).mul(0.5).squeeze(-1) + bce_loss = bce(preds, labels.type_as(preds)) + loss = alpha * bce_loss + (1 - alpha) * loss + + elif self.supervised_weight > 0: + raise ValueError( + "Supervised weight > 0 but no labels provided to compute loss" + ) return loss @@ -188,8 +198,11 @@ def score( p0, p1 = logit0.sigmoid(), logit1.sigmoid() pred_probs = 0.5 * (p0 + (1 - p1)) + pred_probs = maybe_all_cat(pred_probs) + labels = maybe_all_cat(labels) + # Calibrated accuracy - cal_thresh = pred_probs.float().quantile(labels.float().mean()) + cal_thresh = pred_probs.float().quantile(labels.float.mean()) cal_preds = pred_probs.gt(cal_thresh).squeeze(1).to(int) raw_preds = pred_probs.gt(0.5).squeeze(1).to(int) @@ -214,8 +227,10 @@ def train_loop_adam( wd: float, ) -> float: """Adam train loop, returning the final loss. Modifies params in-place.""" + + probe = maybe_ddp_wrap(self) optimizer = torch.optim.AdamW( - self.parameters(), + probe.parameters(), lr=lr, weight_decay=wd, ) @@ -224,7 +239,7 @@ def train_loop_adam( for _ in range(num_epochs): optimizer.zero_grad() - loss = self.loss(self(x0), self(x1), labels) + loss = self.loss(probe(x0), probe(x1), labels) loss.backward() optimizer.step() @@ -236,8 +251,9 @@ def train_loop_lbfgs( ) -> float: """LBFGS train loop, returning the final loss. Modifies params in-place.""" + probe = maybe_ddp_wrap(self) optimizer = torch.optim.LBFGS( - self.parameters(), + probe.parameters(), line_search_fn="strong_wolfe", max_iter=max_iter, tolerance_change=torch.finfo(x0.dtype).eps, @@ -250,12 +266,12 @@ def closure(): nonlocal loss optimizer.zero_grad() - loss = self.loss(self(x0), self(x1), labels) + loss = self.loss(probe(x0), probe(x1), labels) regularizer = 0.0 # We explicitly add L2 regularization to the loss, since LBFGS # doesn't have a weight_decay parameter - for param in self.parameters(): + for param in probe.parameters(): regularizer += l2_penalty * param.norm() ** 2 / 2 regularized = loss + regularizer diff --git a/elk/training/parser.py b/elk/training/parser.py index b86213a8..2374cac9 100644 --- a/elk/training/parser.py +++ b/elk/training/parser.py @@ -1,5 +1,4 @@ from argparse import ArgumentParser -from ..extraction.parser import add_saveable_args def get_training_parser(name=True) -> ArgumentParser: @@ -57,6 +56,11 @@ def add_train_args(parser: ArgumentParser): help="Optimizer for CCS. Should be adam or lbfgs.", ) parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--skip-baseline", + action="store_true", + help="Skip training the logistic regression baseline.", + ) parser.add_argument( "--weight-decay", type=float, diff --git a/elk/training/preprocessing.py b/elk/training/preprocessing.py index fd589d5f..e8fd1423 100644 --- a/elk/training/preprocessing.py +++ b/elk/training/preprocessing.py @@ -37,8 +37,10 @@ def normalize( return train_hiddens, val_hiddens -def load_hidden_states(path: Path): +def load_hidden_states(path: Path) -> tuple[torch.Tensor, torch.Tensor]: hiddens, labels = torch.load(path, map_location="cpu") + assert isinstance(hiddens, torch.Tensor) + assert isinstance(labels, torch.Tensor) # Concatenate the positive and negative examples together. return hiddens.flatten(start_dim=-2), labels diff --git a/elk/training/train.py b/elk/training/train.py index 65457a85..572c1b30 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -10,10 +10,14 @@ import pickle import random import torch +import torch.distributed as dist -# @torch.autocast("cuda", enabled=torch.cuda.is_available()) def train(args): + rank = dist.get_rank() if dist.is_initialized() else 0 + if dist.is_initialized() and not args.skip_baseline and rank == 0: + print("Skipping LR baseline during distributed training.") + # Reproducibility np.random.seed(args.seed) random.seed(args.seed) @@ -30,12 +34,16 @@ def train(args): assert len(set(train_labels)) > 1 assert len(set(val_labels)) > 1 - assert isinstance(val_hiddens, torch.Tensor) - assert isinstance(train_hiddens, torch.Tensor) - train_hiddens, val_hiddens = normalize( train_hiddens, val_hiddens, args.normalization ) + if dist.is_initialized(): + world_size = dist.get_world_size() + train_hiddens = train_hiddens.chunk(world_size)[rank] + train_labels = train_labels.chunk(world_size)[rank] + + val_hiddens = val_hiddens.chunk(world_size)[rank] + val_labels = val_labels.chunk(world_size)[rank] ccs_models = [] lr_models = [] @@ -47,27 +55,24 @@ def train(args): val_layers.reverse() train_layers.reverse() - pbar = tqdm(zip(train_layers, val_layers), total=L, unit="layer") + cols = ["layer", "train_loss", "loss", "acc", "cal_acc", "auroc"] + if not args.skip_baseline: + cols += ["lr_auroc", "lr_acc"] + writer = csv.writer(open(cache_dir / "eval.csv", "w")) - writer.writerow( - [ - "layer", - "train_loss", - "loss", - "acc", - "cal_acc", - "auroc", - "lr_auroc", - "lr_acc", - ] - ) + writer.writerow(cols) + pbar = tqdm(zip(train_layers, val_layers), total=L, unit="layer") for train_h, val_h in pbar: + # Note: currently we're just upcasting to float32 so we don't have to deal with + # grad scaling (which isn't supported for LBFGS), while the hidden states are + # saved in float16 to save disk space. In the future we could try to use mixed + # precision training in at least some cases. x0, x1 = train_h.to(args.device).float().chunk(2, dim=-1) val_x0, val_x1 = val_h.to(args.device).float().chunk(2, dim=-1) - train_labels_aug = train_labels + [1 - label for label in train_labels] - val_labels_aug = val_labels + [1 - label for label in val_labels] + train_labels_aug = torch.cat([train_labels, 1 - train_labels]) + val_labels_aug = torch.cat([val_labels, 1 - val_labels]) pbar.set_description("Fitting CCS") ccs_model = CCS( @@ -91,34 +96,37 @@ def train(args): torch.tensor(val_labels, device=args.device), ) pbar.set_postfix(train_loss=train_loss, ccs_auroc=val_result.auroc) + stats = [train_loss, *val_result] + + if not args.skip_baseline and not dist.is_initialized(): + # TODO: Once we implement cross-validation for CCS, we should benchmark + # against LogisticRegressionCV here. + pbar.set_description("Fitting LR") + lr_model = LogisticRegression(max_iter=10_000) + lr_model.fit(torch.cat([x0, x1]).cpu(), train_labels_aug) + + lr_preds = lr_model.predict_proba(torch.cat([val_x0, val_x1]).cpu())[:, 1] + lr_acc = accuracy_score(val_labels_aug, lr_preds > 0.5) + lr_auroc = roc_auc_score(val_labels_aug, lr_preds) + pbar.set_postfix( + train_loss=train_loss, ccs_auroc=val_result.auroc, lr_auroc=lr_auroc + ) + lr_models.append(lr_model) + stats += [lr_auroc, lr_acc] - # TODO: Once we implement cross-validation for CCS, we should benchmark against - # LogisticRegressionCV here. - pbar.set_description("Fitting LR") - lr_model = LogisticRegression(max_iter=10_000) - lr_model.fit(torch.cat([x0, x1]).cpu(), train_labels_aug) - - lr_preds = lr_model.predict_proba(torch.cat([val_x0, val_x1]).cpu())[:, 1] - lr_acc = accuracy_score(val_labels_aug, lr_preds > 0.5) - lr_auroc = roc_auc_score(val_labels_aug, lr_preds) - pbar.set_postfix( - train_loss=train_loss, ccs_auroc=val_result.auroc, lr_auroc=lr_auroc - ) - - stats = [train_loss, *val_result, lr_auroc, lr_acc] writer.writerow([L - pbar.n] + [f"{s:.4f}" for s in stats]) - - lr_models.append(lr_model) ccs_models.append(ccs_model) ccs_models.reverse() lr_models.reverse() path = elk_cache_dir() / args.name - with open(path / "lr_models.pkl", "wb") as file: - pickle.dump(lr_models, file) + if rank == 0: + torch.save(ccs_models, path / "ccs_models.pt") - torch.save(ccs_models, path / "ccs_models.pt") + if lr_models and rank == 0: + with open(path / "lr_models.pkl", "wb") as file: + pickle.dump(lr_models, file) if __name__ == "__main__": diff --git a/elk/utils.py b/elk/utils.py index 2299f857..9b2070eb 100644 --- a/elk/utils.py +++ b/elk/utils.py @@ -1,6 +1,8 @@ from torch import Tensor -from typing import Callable, Mapping, TypeVar +from torch.nn.parallel import DistributedDataParallel as DDP +from typing import cast, Callable, Mapping, TypeVar import torch.distributed as dist +import torch.nn as nn def maybe_all_cat(x: Tensor) -> Tensor: @@ -12,13 +14,20 @@ def maybe_all_cat(x: Tensor) -> Tensor: return buffer -def maybe_all_gather_lists(lst: list) -> list: +def maybe_all_reduce(x: Tensor) -> Tensor: if not dist.is_initialized(): - return lst + return x + + dist.all_reduce(x, op=dist.ReduceOp.SUM) + x /= dist.get_world_size() + return x + + +def maybe_ddp_wrap(model: nn.Module) -> nn.Module: + if not dist.is_initialized(): + return model - lists = [[] for _ in range(dist.get_world_size())] - dist.all_gather_object(lists, lst) - return sum(lists, []) + return DDP(model, device_ids=[dist.get_rank()]) TreeType = TypeVar("TreeType") From 2cadb28b33fb12b7f2f012f06b22faba692d1d06 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Tue, 14 Feb 2023 06:29:56 +0000 Subject: [PATCH 03/15] Bug fixes --- .vscode/launch.json | 68 ----------------------------------------- elk/__main__.py | 32 ++++++++----------- elk/default_config.json | 42 ------------------------- elk/training/ccs.py | 39 +++++++++++------------ elk/training/train.py | 57 +++++++++++++++++++--------------- elk/utils.py | 2 +- 6 files changed, 62 insertions(+), 178 deletions(-) delete mode 100644 .vscode/launch.json delete mode 100644 elk/default_config.json diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index f1b576e8..00000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,68 +0,0 @@ -{ - // Use IntelliSense to learn about possible attributes. - // Hover to view descriptions of existing attributes. - // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 - "version": "0.2.0", - "configurations": [ - { - "name": "Python: Current File", - "type": "python", - "request": "launch", - "program": "${file}", - "console": "integratedTerminal", - "justMyCode": true, - "args" : [ - "--model", "deberta-v2-xxlarge-mnli", - "--dataset","imdb", - "--prefix", "confusion", - "--device", "cuda", - "--num-data", "1000" - ] - }, - { - "name": "Python: Generation Cuda", - "type": "python", - "request": "launch", - "module": "elk.extraction_main", - "console": "integratedTerminal", - "justMyCode": true, - "args" : [ - "--model", "deberta-v2-xxlarge-mnli", - "--dataset","imdb", - "--prefix", "normal", - "--device", "cuda", - "--num-data", "1000" - ] - }, - { - "name": "Python: Evaluation Cuda", - "type": "python", - "request": "launch", - "module": "elk.evaluate", - "console": "integratedTerminal", - "justMyCode": true, - "args" : [ - "--model", "deberta-v2-xxlarge-mnli", - "--dataset","imdb", - "--prefix", "normal", - "--device", "cuda", - "--num-data", "1000" - ] - }, - { - "name": "Python: Training Cuda", - "type": "python", - "request": "launch", - "module": "elk.train.py", - "console": "integratedTerminal", - "justMyCode": true, - "args" : [ - "--model", "deberta-v2-xxlarge-mnli", - "--dataset","imdb", - "--prefix", "normal", - "--device", "cuda", - "--num-data", "1000" - ] - } - ] -} diff --git a/elk/__main__.py b/elk/__main__.py index 4f764e0a..2d9dd5c1 100644 --- a/elk/__main__.py +++ b/elk/__main__.py @@ -4,10 +4,8 @@ from .training.train import train from argparse import ArgumentParser from contextlib import nullcontext, redirect_stdout -from elk.files import args_to_uuid, elk_cache_dir -from pathlib import Path +from elk.files import args_to_uuid from transformers import AutoConfig, PretrainedConfig -import json import os import torch.distributed as dist @@ -45,14 +43,7 @@ def run(): args = parser.parse_args() if model := getattr(args, "model", None): - config_path = Path(__file__).parent / "default_config.json" - with open(config_path, "r") as f: - default_config = json.load(f) - model_shortcuts = default_config["model_shortcuts"] - - # Dereference shortcut - args.model = model_shortcuts.get(model, model) - config = AutoConfig.from_pretrained(args.model) + config = AutoConfig.from_pretrained(model) assert isinstance(config, PretrainedConfig) num_layers = getattr(config, "num_layers", config.num_hidden_layers) @@ -92,16 +83,17 @@ def run(): train(args) elif args.command == "elicit": args.name = args_to_uuid(args) - cache_dir = elk_cache_dir() / args.name - - if not cache_dir.exists(): + try: + train(args) + except (EOFError, FileNotFoundError): run_extraction(args) - else: - print( - f"Cache dir \033[1m{cache_dir}\033[0m exists, " - "skipping extraction of hidden states" - ) # bold - train(args) + + # Ensure the extraction is finished before starting training + if dist.is_initialized(): + dist.barrier() + + train(args) + elif args.command == "eval": raise NotImplementedError else: diff --git a/elk/default_config.json b/elk/default_config.json deleted file mode 100644 index f4b4a3a7..00000000 --- a/elk/default_config.json +++ /dev/null @@ -1,42 +0,0 @@ -{ - "datasets": [ - "imdb", - "amazon-polarity", - "ag-news", - "dbpedia-14", - "copa", - "rte", - "boolq", - "qnli", - "piqa", - "story-cloze" - ], - "model_shortcuts" : { - "t5-11b": "google/t5-11b", - "unifiedqa-t5-11b": "allenai/unifiedqa-t5-11b", - "T0pp": "bigscience/T0pp", - "gpt-j-6B": "EleutherAI/gpt-j-6B", - "deberta-v2-xxlarge-mnli": "microsoft/deberta-v2-xxlarge-mnli" - }, - "prefix": [ - "normal", - "confusion", - "confusion2", - "confusion3", - "confusion4", - "confusion6", - "confusion7", - "iamincorrect", - "dadnotrust", - "dadisincorrect", - "teachernoimitate" - ], - "models_layer_num": { - "t5-11b": 25, - "unifiedqa-t5-11b": 25, - "T0pp": 25, - "gpt-j-6B": 29, - "roberta-large-mnli": 25, - "deberta-v2-xxlarge-mnli": 49 - } -} diff --git a/elk/training/ccs.py b/elk/training/ccs.py index d28045af..35a629c1 100644 --- a/elk/training/ccs.py +++ b/elk/training/ccs.py @@ -1,5 +1,5 @@ from .losses import ccs_squared_loss, js_loss -from ..utils import maybe_ddp_wrap, maybe_all_cat +from ..utils import maybe_ddp_wrap, maybe_all_cat, maybe_all_reduce from copy import deepcopy from pathlib import Path from sklearn.metrics import roc_auc_score @@ -202,18 +202,18 @@ def score( labels = maybe_all_cat(labels) # Calibrated accuracy - cal_thresh = pred_probs.float().quantile(labels.float.mean()) - cal_preds = pred_probs.gt(cal_thresh).squeeze(1).to(int) - raw_preds = pred_probs.gt(0.5).squeeze(1).to(int) + cal_thresh = pred_probs.float().quantile(labels.float().mean()) + cal_preds = pred_probs.gt(cal_thresh).squeeze(1).to(torch.int) + raw_preds = pred_probs.gt(0.5).squeeze(1).to(torch.int) - auroc = float(roc_auc_score(labels.tolist(), pred_probs.tolist())) + auroc = 0.0 # float(roc_auc_score(labels.cpu(), pred_probs.cpu())) cal_acc = cal_preds.eq(labels.reshape(-1)).float().mean() raw_acc = raw_preds.eq(labels.reshape(-1)).float().mean() return EvalResult( - loss=self.loss(logit0, logit1, labels).item(), - acc=torch.max(raw_acc, 1 - raw_acc).item(), - cal_acc=torch.max(cal_acc, 1 - cal_acc).item(), + loss=self.loss(logit0, logit1), + acc=torch.max(raw_acc, 1 - raw_acc), + cal_acc=torch.max(cal_acc, 1 - cal_acc), auroc=max(auroc, 1 - auroc), ) @@ -229,18 +229,13 @@ def train_loop_adam( """Adam train loop, returning the final loss. Modifies params in-place.""" probe = maybe_ddp_wrap(self) - optimizer = torch.optim.AdamW( - probe.parameters(), - lr=lr, - weight_decay=wd, - ) + optimizer = torch.optim.AdamW(probe.parameters(), lr=lr, weight_decay=wd) loss = torch.inf for _ in range(num_epochs): optimizer.zero_grad() loss = self.loss(probe(x0), probe(x1), labels) - loss.backward() optimizer.step() @@ -251,9 +246,8 @@ def train_loop_lbfgs( ) -> float: """LBFGS train loop, returning the final loss. Modifies params in-place.""" - probe = maybe_ddp_wrap(self) optimizer = torch.optim.LBFGS( - probe.parameters(), + self.parameters(), line_search_fn="strong_wolfe", max_iter=max_iter, tolerance_change=torch.finfo(x0.dtype).eps, @@ -266,19 +260,20 @@ def closure(): nonlocal loss optimizer.zero_grad() - loss = self.loss(probe(x0), probe(x1), labels) + loss = self.loss(self(x0), self(x1), labels) regularizer = 0.0 # We explicitly add L2 regularization to the loss, since LBFGS # doesn't have a weight_decay parameter - for param in probe.parameters(): + for param in self.parameters(): regularizer += l2_penalty * param.norm() ** 2 / 2 regularized = loss + regularizer - if regularized.isfinite(): - regularized.backward() - else: - print("Got NaN loss; skipping backward pass") + regularized.backward() + + for p in self.parameters(): + if p.grad is not None: + maybe_all_reduce(p.grad) return float(regularized) diff --git a/elk/training/train.py b/elk/training/train.py index 572c1b30..c5f5341a 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -1,6 +1,5 @@ from ..files import elk_cache_dir from .ccs import CCS -from .parser import get_training_parser from .preprocessing import load_hidden_states, normalize from sklearn.linear_model import LogisticRegression from sklearn.metrics import accuracy_score, roc_auc_score @@ -55,15 +54,14 @@ def train(args): val_layers.reverse() train_layers.reverse() - cols = ["layer", "train_loss", "loss", "acc", "cal_acc", "auroc"] - if not args.skip_baseline: - cols += ["lr_auroc", "lr_acc"] - - writer = csv.writer(open(cache_dir / "eval.csv", "w")) - writer.writerow(cols) + iterator = zip(train_layers, val_layers) + pbar = None + if rank == 0: + pbar = tqdm(iterator, total=L, unit="layer") + iterator = pbar - pbar = tqdm(zip(train_layers, val_layers), total=L, unit="layer") - for train_h, val_h in pbar: + statistics = [] + for i, (train_h, val_h) in enumerate(iterator): # Note: currently we're just upcasting to float32 so we don't have to deal with # grad scaling (which isn't supported for LBFGS), while the hidden states are # saved in float16 to save disk space. In the future we could try to use mixed @@ -74,13 +72,14 @@ def train(args): train_labels_aug = torch.cat([train_labels, 1 - train_labels]) val_labels_aug = torch.cat([val_labels, 1 - val_labels]) - pbar.set_description("Fitting CCS") + if pbar: + pbar.set_description("Fitting CCS") ccs_model = CCS( in_features=x0.shape[-1], device=args.device, init=args.init, loss=args.loss ) if args.label_frac: num_labels = round(args.label_frac * len(train_labels)) - labels = torch.tensor(train_labels[:num_labels], device=args.device) + labels = train_labels[:num_labels].to(args.device) else: labels = None @@ -93,28 +92,31 @@ def train(args): ) val_result = ccs_model.score( (val_x0, val_x1), - torch.tensor(val_labels, device=args.device), + val_labels.to(args.device), ) - pbar.set_postfix(train_loss=train_loss, ccs_auroc=val_result.auroc) + if pbar: + pbar.set_postfix(train_loss=train_loss, ccs_auroc=val_result.auroc) stats = [train_loss, *val_result] if not args.skip_baseline and not dist.is_initialized(): # TODO: Once we implement cross-validation for CCS, we should benchmark # against LogisticRegressionCV here. - pbar.set_description("Fitting LR") + if pbar: + pbar.set_description("Fitting LR") lr_model = LogisticRegression(max_iter=10_000) lr_model.fit(torch.cat([x0, x1]).cpu(), train_labels_aug) lr_preds = lr_model.predict_proba(torch.cat([val_x0, val_x1]).cpu())[:, 1] lr_acc = accuracy_score(val_labels_aug, lr_preds > 0.5) lr_auroc = roc_auc_score(val_labels_aug, lr_preds) - pbar.set_postfix( - train_loss=train_loss, ccs_auroc=val_result.auroc, lr_auroc=lr_auroc - ) + if pbar: + pbar.set_postfix( + train_loss=train_loss, ccs_auroc=val_result.auroc, lr_auroc=lr_auroc + ) lr_models.append(lr_model) stats += [lr_auroc, lr_acc] - writer.writerow([L - pbar.n] + [f"{s:.4f}" for s in stats]) + statistics.append(stats) ccs_models.append(ccs_model) ccs_models.reverse() @@ -122,13 +124,18 @@ def train(args): path = elk_cache_dir() / args.name if rank == 0: - torch.save(ccs_models, path / "ccs_models.pt") + cols = ["layer", "train_loss", "loss", "acc", "cal_acc", "auroc"] + if not args.skip_baseline: + cols += ["lr_auroc", "lr_acc"] - if lr_models and rank == 0: - with open(path / "lr_models.pkl", "wb") as file: - pickle.dump(lr_models, file) + with open(cache_dir / "eval.csv", "w") as f: + writer = csv.writer(f) + writer.writerow(cols) + for i, stats in enumerate(statistics): + writer.writerow([L - i] + [f"{s:.4f}" for s in stats]) -if __name__ == "__main__": - args = get_training_parser().parse_args() - train(args) + torch.save(ccs_models, path / "ccs_models.pt") + if lr_models: + with open(path / "lr_models.pkl", "wb") as file: + pickle.dump(lr_models, file) diff --git a/elk/utils.py b/elk/utils.py index 9b2070eb..23afc067 100644 --- a/elk/utils.py +++ b/elk/utils.py @@ -1,6 +1,6 @@ from torch import Tensor from torch.nn.parallel import DistributedDataParallel as DDP -from typing import cast, Callable, Mapping, TypeVar +from typing import Callable, Mapping, TypeVar import torch.distributed as dist import torch.nn as nn From b3fa294a09508777270ebbaacd8759bf0ea85c1a Mon Sep 17 00:00:00 2001 From: Fabien Roger Date: Tue, 14 Feb 2023 13:16:44 +0000 Subject: [PATCH 04/15] Reverse layer stride --- elk/__main__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/elk/__main__.py b/elk/__main__.py index 8bfc04fb..814ecf41 100644 --- a/elk/__main__.py +++ b/elk/__main__.py @@ -70,7 +70,9 @@ def run(): "Cannot use both --layers and --layer-stride. Please use only one." ) elif args.layer_stride > 1: - args.layers = list(range(0, num_layers, args.layer_stride)) + # the last layer is often the most interesting + # layers = [..., num_layers - 1 - layer_stride, num_layers - 1] + args.layers = list(range(num_layers - 1, -1, -args.layer_stride)).reverse() for key in list(vars(args).keys()): print("{}: {}".format(key, vars(args)[key])) From 5eb2811fd76dd143b297496be932f2be21a79239 Mon Sep 17 00:00:00 2001 From: Fabien Roger Date: Tue, 14 Feb 2023 13:31:07 +0000 Subject: [PATCH 05/15] Save one file per layer --- elk/extraction/extraction_main.py | 24 +++++++++++++++++++----- elk/training/preprocessing.py | 9 +++++++-- elk/training/train.py | 4 ++-- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/elk/extraction/extraction_main.py b/elk/extraction/extraction_main.py index f85bc391..527e1944 100644 --- a/elk/extraction/extraction_main.py +++ b/elk/extraction/extraction_main.py @@ -1,3 +1,4 @@ +from pathlib import Path from .extraction import extract_hiddens, PromptCollator from ..files import args_to_uuid, elk_cache_dir from ..training.preprocessing import silence_datasets_messages @@ -42,11 +43,16 @@ def extract(args, split: str): ] save_dir.mkdir(parents=True, exist_ok=True) - with open(save_dir / f"{split}_hiddens.pt", "wb") as f: - hidden_batches, label_batches = zip(*items) - hiddens = torch.cat(hidden_batches) # type: ignore - labels = sum(label_batches, []) - torch.save((hiddens, labels), f) + hidden_batches, label_batches = zip(*items) + hiddens = torch.cat(hidden_batches) # type: ignore + labels = sum(label_batches, []) + + for layer in args.layers: + hiddens_at_l = hiddens[:, layer, :, :] + with open(get_hiddens_path(save_dir, split, layer), "wb") as f: + torch.save(hiddens_at_l, f) + with open(get_labels_path(save_dir, split), "wb") as f: + torch.save(labels, f) # AutoModel should do the right thing here in nearly all cases. We don't actually # care what head the model has, since we are just extracting hidden states. @@ -80,3 +86,11 @@ def extract(args, split: str): with open(save_dir / "model_config.json", "w") as f: json.dump(model.config.to_dict(), f) + + +def get_hiddens_path(dir: Path, split: str, layer: int): + return dir / f"{split}_hiddens_l{layer}.pt" + + +def get_labels_path(dir: Path, split: str): + return dir / f"{split}_labels.pt" diff --git a/elk/training/preprocessing.py b/elk/training/preprocessing.py index fd589d5f..2e644f1d 100644 --- a/elk/training/preprocessing.py +++ b/elk/training/preprocessing.py @@ -3,6 +3,8 @@ from typing import Literal import torch +from elk.extraction.extraction_main import get_hiddens_path, get_labels_path + def normalize( train_hiddens: torch.Tensor, @@ -37,8 +39,11 @@ def normalize( return train_hiddens, val_hiddens -def load_hidden_states(path: Path): - hiddens, labels = torch.load(path, map_location="cpu") +def load_hidden_states(dir: Path, split: str, layers: list[int]): + labels = torch.load(get_labels_path(dir, split)) + hiddens_list = [torch.load(get_hiddens_path(dir, split, layer)) for layer in layers] + + hiddens = torch.stack(hiddens_list, dim=1) # Concatenate the positive and negative examples together. return hiddens.flatten(start_dim=-2), labels diff --git a/elk/training/train.py b/elk/training/train.py index 7af2ab81..6f8ba1d1 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -24,10 +24,10 @@ def train(args): # load the hidden states extracted from the model cache_dir = elk_cache_dir() / args.name train_hiddens, train_labels = load_hidden_states( - path=cache_dir / "train_hiddens.pt" + dir=cache_dir, split="train", layers=args.layers ) val_hiddens, val_labels = load_hidden_states( - path=cache_dir / "validation_hiddens.pt" + dir=cache_dir, split="validation", layers=args.layers ) assert len(set(train_labels)) > 1 assert len(set(val_labels)) > 1 From 962fd829dbb8fde503e6f2400e37371bb4a4a7b2 Mon Sep 17 00:00:00 2001 From: Fabien Roger Date: Tue, 14 Feb 2023 13:38:33 +0000 Subject: [PATCH 06/15] Add automatic layer extraction to main --- elk/__main__.py | 72 ++++++++++++++++++++++++++++++++++--------------- 1 file changed, 50 insertions(+), 22 deletions(-) diff --git a/elk/__main__.py b/elk/__main__.py index 814ecf41..af6c576e 100644 --- a/elk/__main__.py +++ b/elk/__main__.py @@ -1,5 +1,5 @@ from elk.files import args_to_uuid, elk_cache_dir -from .extraction.extraction_main import run as run_extraction +from .extraction.extraction_main import get_hiddens_path, run as run_extraction from .extraction.parser import ( add_saveable_args, add_unsaveable_args, @@ -45,6 +45,47 @@ def run(): ) args = parser.parse_args() + normalize_args_inplace(args) + + for key in list(vars(args).keys()): + print("{}: {}".format(key, vars(args)[key])) + + # TODO: Implement the rest of the CLI + if args.command == "extract": + run_extraction(args) + elif args.command == "train": + train(args) + elif args.command == "elicit": + # Extract the hidden states if they're not already there + args.name = args_to_uuid(args) + cache_dir = elk_cache_dir() / args.name + missing_layers = find_missing_layers(args) + if missing_layers: + if cache_dir.exists(): + print( + f"Found cache dir \033[1m{cache_dir}\033[0m" + f" but it's missing layers {', '.join(missing_layers)}" + ) + + old_layers = args.layers + args.layers = missing_layers + run_extraction(args) + args.layers = old_layers + else: + print( + f"Cache dir \033[1m{cache_dir}\033[0m exists, " + "skip extraction of hidden states" + ) # bold + + # Train the probes + train(args) + elif args.command == "eval": + raise NotImplementedError + else: + raise ValueError(f"Unknown command {args.command}") + + +def normalize_args_inplace(args): # Default to CUDA iff available if args.device is None: import torch @@ -74,29 +115,16 @@ def run(): # layers = [..., num_layers - 1 - layer_stride, num_layers - 1] args.layers = list(range(num_layers - 1, -1, -args.layer_stride)).reverse() - for key in list(vars(args).keys()): - print("{}: {}".format(key, vars(args)[key])) - # TODO: Implement the rest of the CLI - if args.command == "extract": - run_extraction(args) - elif args.command == "train": - train(args) - elif args.command == "elicit": - args.name = args_to_uuid(args) +def find_missing_layers(args): + missing_layers = [] + for layer in args.layers: cache_dir = elk_cache_dir() / args.name - if not cache_dir.exists(): - run_extraction(args) - else: - print( - f"Cache dir \033[1m{cache_dir}\033[0m exists, " - "skip extraction of hidden states" - ) # bold - train(args) - elif args.command == "eval": - raise NotImplementedError - else: - raise ValueError(f"Unknown command {args.command}") + train_layer_path = get_hiddens_path(cache_dir, "train", layer) + validation_layer_path = get_hiddens_path(cache_dir, "validation", layer) + if not train_layer_path.exists() or not validation_layer_path.exists(): + missing_layers.append(layer) + return missing_layers if __name__ == "__main__": From f5c191bc427d083c846ba1bdb8b598dc57585211 Mon Sep 17 00:00:00 2001 From: Fabien Roger Date: Tue, 14 Feb 2023 13:54:32 +0000 Subject: [PATCH 07/15] Fix bugs --- elk/__main__.py | 18 ++++++++++++++++-- elk/extraction/extraction_main.py | 10 +--------- elk/files.py | 8 ++++++++ elk/training/preprocessing.py | 2 +- 4 files changed, 26 insertions(+), 12 deletions(-) diff --git a/elk/__main__.py b/elk/__main__.py index af6c576e..ccb5c347 100644 --- a/elk/__main__.py +++ b/elk/__main__.py @@ -1,5 +1,6 @@ -from elk.files import args_to_uuid, elk_cache_dir -from .extraction.extraction_main import get_hiddens_path, run as run_extraction +from typing import Optional +from elk.files import args_to_uuid, elk_cache_dir, get_hiddens_path +from .extraction.extraction_main import run as run_extraction from .extraction.parser import ( add_saveable_args, add_unsaveable_args, @@ -114,6 +115,19 @@ def normalize_args_inplace(args): # the last layer is often the most interesting # layers = [..., num_layers - 1 - layer_stride, num_layers - 1] args.layers = list(range(num_layers - 1, -1, -args.layer_stride)).reverse() + else: + assert ( + args.name is not None + ) # If the model is not provided, it means we are using the name + config = json.load(open(elk_cache_dir() / args.name / "model_config.json", "r")) + num_layers = config.get("num_layers", config.get("num_hidden_layers")) + + args.layers = normalized_layers(args.layers, num_layers) + + +def normalized_layers(layers: Optional[list[int]], num_layers: int) -> list[int]: + layers = layers or list(range(num_layers)) + return [layer if layer >= 0 else num_layers + layer for layer in layers] def find_missing_layers(args): diff --git a/elk/extraction/extraction_main.py b/elk/extraction/extraction_main.py index 527e1944..e8d5d1ca 100644 --- a/elk/extraction/extraction_main.py +++ b/elk/extraction/extraction_main.py @@ -1,6 +1,6 @@ from pathlib import Path from .extraction import extract_hiddens, PromptCollator -from ..files import args_to_uuid, elk_cache_dir +from ..files import args_to_uuid, elk_cache_dir, get_hiddens_path, get_labels_path from ..training.preprocessing import silence_datasets_messages from transformers import AutoModel, AutoTokenizer import json @@ -86,11 +86,3 @@ def extract(args, split: str): with open(save_dir / "model_config.json", "w") as f: json.dump(model.config.to_dict(), f) - - -def get_hiddens_path(dir: Path, split: str, layer: int): - return dir / f"{split}_hiddens_l{layer}.pt" - - -def get_labels_path(dir: Path, split: str): - return dir / f"{split}_labels.pt" diff --git a/elk/files.py b/elk/files.py index 0ac1a3fc..82fb7b96 100644 --- a/elk/files.py +++ b/elk/files.py @@ -24,3 +24,11 @@ def elk_cache_dir() -> Path: cache_dir.mkdir(parents=True, exist_ok=True) return cache_dir + + +def get_hiddens_path(dir: Path, split: str, layer: int): + return dir / f"{split}_hiddens_l{layer}.pt" + + +def get_labels_path(dir: Path, split: str): + return dir / f"{split}_labels.pt" diff --git a/elk/training/preprocessing.py b/elk/training/preprocessing.py index 2e644f1d..cf9a96dc 100644 --- a/elk/training/preprocessing.py +++ b/elk/training/preprocessing.py @@ -3,7 +3,7 @@ from typing import Literal import torch -from elk.extraction.extraction_main import get_hiddens_path, get_labels_path +from elk.files import get_hiddens_path, get_labels_path def normalize( From bd34f962e5c8c56dabd49a9a67472ea8a497571e Mon Sep 17 00:00:00 2001 From: Fabien Roger Date: Tue, 14 Feb 2023 15:52:18 +0000 Subject: [PATCH 08/15] Add .item() where necessary --- elk/training/ccs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/elk/training/ccs.py b/elk/training/ccs.py index 35a629c1..90ee5062 100644 --- a/elk/training/ccs.py +++ b/elk/training/ccs.py @@ -211,9 +211,9 @@ def score( raw_acc = raw_preds.eq(labels.reshape(-1)).float().mean() return EvalResult( - loss=self.loss(logit0, logit1), - acc=torch.max(raw_acc, 1 - raw_acc), - cal_acc=torch.max(cal_acc, 1 - cal_acc), + loss=self.loss(logit0, logit1).item(), + acc=torch.max(raw_acc, 1 - raw_acc).item(), + cal_acc=torch.max(cal_acc, 1 - cal_acc).item(), auroc=max(auroc, 1 - auroc), ) From 522b95acd360dc7345b4a7f1837427f343d5fb6e Mon Sep 17 00:00:00 2001 From: Fabien Roger Date: Tue, 14 Feb 2023 15:53:40 +0000 Subject: [PATCH 09/15] Use maybe all gather instead of cat --- elk/extraction/extraction_main.py | 8 ++++---- elk/training/ccs.py | 6 +++--- elk/utils.py | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/elk/extraction/extraction_main.py b/elk/extraction/extraction_main.py index 20e585a3..4ed9af79 100644 --- a/elk/extraction/extraction_main.py +++ b/elk/extraction/extraction_main.py @@ -1,7 +1,7 @@ from .extraction import extract_hiddens, PromptCollator from ..files import args_to_uuid, elk_cache_dir from ..training.preprocessing import silence_datasets_messages -from ..utils import maybe_all_cat +from ..utils import maybe_all_gather from transformers import AutoModel, AutoTokenizer import json import torch @@ -47,11 +47,11 @@ def extract(args, split: str): with open(save_dir / f"{split}_hiddens.pt", "wb") as f: hidden_batches, label_batches = zip(*items) - hiddens = maybe_all_cat(torch.cat(hidden_batches)) # type: ignore + hiddens = maybe_all_gather(torch.cat(hidden_batches)) # type: ignore - # Moving labels to GPU just to be able to use maybe_all_cat + # Moving labels to GPU just to be able to use maybe_all_gather labels = torch.tensor(sum(label_batches, []), device=hiddens.device) - labels = maybe_all_cat(labels) # type: ignore + labels = maybe_all_gather(labels) # type: ignore if rank == 0: torch.save((hiddens.cpu(), labels.cpu()), f) diff --git a/elk/training/ccs.py b/elk/training/ccs.py index 90ee5062..a7750eef 100644 --- a/elk/training/ccs.py +++ b/elk/training/ccs.py @@ -1,5 +1,5 @@ from .losses import ccs_squared_loss, js_loss -from ..utils import maybe_ddp_wrap, maybe_all_cat, maybe_all_reduce +from ..utils import maybe_ddp_wrap, maybe_all_gather, maybe_all_reduce from copy import deepcopy from pathlib import Path from sklearn.metrics import roc_auc_score @@ -198,8 +198,8 @@ def score( p0, p1 = logit0.sigmoid(), logit1.sigmoid() pred_probs = 0.5 * (p0 + (1 - p1)) - pred_probs = maybe_all_cat(pred_probs) - labels = maybe_all_cat(labels) + pred_probs = maybe_all_gather(pred_probs) + labels = maybe_all_gather(labels) # Calibrated accuracy cal_thresh = pred_probs.float().quantile(labels.float().mean()) diff --git a/elk/utils.py b/elk/utils.py index 23afc067..29d96ed3 100644 --- a/elk/utils.py +++ b/elk/utils.py @@ -5,7 +5,7 @@ import torch.nn as nn -def maybe_all_cat(x: Tensor) -> Tensor: +def maybe_all_gather(x: Tensor) -> Tensor: if not dist.is_initialized(): return x From 7f1b5af20034a79a5a763e73eef9189898141b99 Mon Sep 17 00:00:00 2001 From: Fabien Roger Date: Tue, 14 Feb 2023 15:55:04 +0000 Subject: [PATCH 10/15] Shuffle before sharding --- elk/extraction/prompt_collator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/elk/extraction/prompt_collator.py b/elk/extraction/prompt_collator.py index 0439dc9d..3735e944 100644 --- a/elk/extraction/prompt_collator.py +++ b/elk/extraction/prompt_collator.py @@ -53,10 +53,11 @@ def __init__( raise ValueError(f"Dataset {path}/{name} has only one label") if max_examples: self.dataset = self.dataset.select(range(max_examples)) + + self.dataset = self.dataset.shuffle(seed=seed) if dist.is_initialized(): self.dataset = self.dataset.shard(dist.get_world_size(), dist.get_rank()) - self.dataset = self.dataset.shuffle(seed=seed) self.label_column = label_column self.prompter = DatasetTemplates(path, subset_name=name) # type: ignore self.rng = Random(seed) From 73f4fd56d71d93d2b94f4f72bac38062e92e95c1 Mon Sep 17 00:00:00 2001 From: Fabien Roger Date: Tue, 14 Feb 2023 15:59:36 +0000 Subject: [PATCH 11/15] Fix typing problems --- elk/extraction/extraction.py | 2 +- elk/extraction/prompt_collator.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/elk/extraction/extraction.py b/elk/extraction/extraction.py index f3a3eaab..2309637a 100644 --- a/elk/extraction/extraction.py +++ b/elk/extraction/extraction.py @@ -9,7 +9,7 @@ import torch.distributed as dist -@torch.autocast("cuda", enabled=torch.cuda.is_available()) +@torch.autocast("cuda", enabled=torch.cuda.is_available()) # type: ignore @torch.no_grad() def extract_hiddens( model: PreTrainedModel, diff --git a/elk/extraction/prompt_collator.py b/elk/extraction/prompt_collator.py index 3735e944..586faa78 100644 --- a/elk/extraction/prompt_collator.py +++ b/elk/extraction/prompt_collator.py @@ -1,9 +1,10 @@ from dataclasses import dataclass -from datasets import DatasetDict, load_dataset +from datasets import DatasetDict, load_dataset # type: ignore from promptsource.templates import DatasetTemplates from random import Random from typing import Literal, Optional import numpy as np +from torch.utils.data import Dataset import torch.distributed as dist @@ -19,7 +20,7 @@ def to_string(self, answer_idx: int, sep: str = "\n") -> str: return f"{self.question}{sep}{self.answers[answer_idx]}" -class PromptCollator: +class PromptCollator(Dataset): def __init__( self, path: str, From 21272cf91aa8b0373a2bf8ffdb6d67a146613300 Mon Sep 17 00:00:00 2001 From: Nora Belrose Date: Tue, 14 Feb 2023 16:10:10 +0000 Subject: [PATCH 12/15] Revert Fabien's change to shuffle/shard ordering --- elk/extraction/prompt_collator.py | 3 +-- elk/training/ccs.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/elk/extraction/prompt_collator.py b/elk/extraction/prompt_collator.py index 586faa78..58c96a86 100644 --- a/elk/extraction/prompt_collator.py +++ b/elk/extraction/prompt_collator.py @@ -54,11 +54,10 @@ def __init__( raise ValueError(f"Dataset {path}/{name} has only one label") if max_examples: self.dataset = self.dataset.select(range(max_examples)) - - self.dataset = self.dataset.shuffle(seed=seed) if dist.is_initialized(): self.dataset = self.dataset.shard(dist.get_world_size(), dist.get_rank()) + self.dataset = self.dataset.shuffle(seed=seed) self.label_column = label_column self.prompter = DatasetTemplates(path, subset_name=name) # type: ignore self.rng = Random(seed) diff --git a/elk/training/ccs.py b/elk/training/ccs.py index a7750eef..6ab50b61 100644 --- a/elk/training/ccs.py +++ b/elk/training/ccs.py @@ -206,7 +206,7 @@ def score( cal_preds = pred_probs.gt(cal_thresh).squeeze(1).to(torch.int) raw_preds = pred_probs.gt(0.5).squeeze(1).to(torch.int) - auroc = 0.0 # float(roc_auc_score(labels.cpu(), pred_probs.cpu())) + auroc = float(roc_auc_score(labels.cpu(), pred_probs.cpu())) cal_acc = cal_preds.eq(labels.reshape(-1)).float().mean() raw_acc = raw_preds.eq(labels.reshape(-1)).float().mean() From b1e30fbff55fbd312c59292e565fd698a8e9660f Mon Sep 17 00:00:00 2001 From: Fabien Roger Date: Tue, 14 Feb 2023 16:25:32 +0000 Subject: [PATCH 13/15] Fix problems --- elk/extraction/prompt_collator.py | 15 +++++++++++---- elk/training/train.py | 7 ++++--- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/elk/extraction/prompt_collator.py b/elk/extraction/prompt_collator.py index 7002dd96..9eb58238 100644 --- a/elk/extraction/prompt_collator.py +++ b/elk/extraction/prompt_collator.py @@ -1,8 +1,8 @@ from dataclasses import dataclass -from datasets import DatasetDict, load_dataset # type: ignore +from datasets import DatasetDict, load_dataset, Dataset as HFDataset # type: ignore from promptsource.templates import DatasetTemplates from random import Random -from typing import Literal, Optional +from typing import Literal, Optional, cast import numpy as np from torch.utils.data import Dataset import torch.distributed as dist @@ -50,10 +50,17 @@ def __init__( print("No validation split found, using test split instead") split = "test" - self.dataset = data[split] + self.dataset: HFDataset = data[split] if balance: - self.dataset = undersample(self.dataset, seed, label_column) + self.dataset = cast( + HFDataset, + undersample( + self.dataset, + seed, + label_column, + ), + ) # type: ignore self.labels, counts = np.unique(self.dataset[label_column], return_counts=True) self.label_fracs = counts / counts.sum() diff --git a/elk/training/train.py b/elk/training/train.py index c88ec1d5..7a76a7a4 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -69,9 +69,6 @@ def train(args): x0, x1 = train_h.to(args.device).float().chunk(2, dim=-1) val_x0, val_x1 = val_h.to(args.device).float().chunk(2, dim=-1) - train_labels_aug = torch.cat([train_labels, 1 - train_labels]) - val_labels_aug = torch.cat([val_labels, 1 - val_labels]) - if pbar: pbar.set_description("Fitting CCS") ccs_model = CCS( @@ -101,6 +98,10 @@ def train(args): if not args.skip_baseline and not dist.is_initialized(): # TODO: Once we implement cross-validation for CCS, we should benchmark # against LogisticRegressionCV here. + + train_labels_aug = torch.cat([train_labels, 1 - train_labels]).cpu() + val_labels_aug = torch.cat([val_labels, 1 - val_labels]).cpu() + if pbar: pbar.set_description("Fitting LR") lr_model = LogisticRegression(max_iter=10_000) From 47a6c769a13a889fef2abd032c1cfa42da30c26f Mon Sep 17 00:00:00 2001 From: Fabien Roger Date: Tue, 14 Feb 2023 16:27:33 +0000 Subject: [PATCH 14/15] Send labels to cpu before Linear Regression --- elk/training/train.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/elk/training/train.py b/elk/training/train.py index c5f5341a..1979f4a4 100644 --- a/elk/training/train.py +++ b/elk/training/train.py @@ -69,9 +69,6 @@ def train(args): x0, x1 = train_h.to(args.device).float().chunk(2, dim=-1) val_x0, val_x1 = val_h.to(args.device).float().chunk(2, dim=-1) - train_labels_aug = torch.cat([train_labels, 1 - train_labels]) - val_labels_aug = torch.cat([val_labels, 1 - val_labels]) - if pbar: pbar.set_description("Fitting CCS") ccs_model = CCS( @@ -101,6 +98,10 @@ def train(args): if not args.skip_baseline and not dist.is_initialized(): # TODO: Once we implement cross-validation for CCS, we should benchmark # against LogisticRegressionCV here. + + train_labels_aug = torch.cat([train_labels, 1 - train_labels]).cpu() + val_labels_aug = torch.cat([val_labels, 1 - val_labels]).cpu() + if pbar: pbar.set_description("Fitting LR") lr_model = LogisticRegression(max_iter=10_000) From 6d9d92fc8a80bb33a3eeab113e4e3b61a8fb5292 Mon Sep 17 00:00:00 2001 From: Fabien Roger Date: Tue, 14 Feb 2023 16:36:06 +0000 Subject: [PATCH 15/15] Use prompt collator from the merged branch --- elk/extraction/prompt_collator.py | 26 +++----------------------- 1 file changed, 3 insertions(+), 23 deletions(-) diff --git a/elk/extraction/prompt_collator.py b/elk/extraction/prompt_collator.py index 00035c06..58c96a86 100644 --- a/elk/extraction/prompt_collator.py +++ b/elk/extraction/prompt_collator.py @@ -1,14 +1,12 @@ from dataclasses import dataclass -from datasets import DatasetDict, load_dataset, Dataset as HFDataset # type: ignore +from datasets import DatasetDict, load_dataset # type: ignore from promptsource.templates import DatasetTemplates from random import Random -from typing import Literal, Optional, cast +from typing import Literal, Optional import numpy as np from torch.utils.data import Dataset import torch.distributed as dist -from elk.extraction.dataset_preprocessing import undersample - @dataclass class Prompt: @@ -33,7 +31,6 @@ def __init__( max_examples: int = 0, seed: int = 42, strategy: Literal["all", "randomize"] = "randomize", - balance: bool = False, ): data = load_dataset(path, name) assert isinstance(data, DatasetDict) @@ -50,26 +47,9 @@ def __init__( print("No validation split found, using test split instead") split = "test" - self.dataset: HFDataset = data[split] - - if balance: - self.dataset = cast( - HFDataset, - undersample( - self.dataset, - seed, - label_column, - ), - ) # type: ignore - + self.dataset = data[split] self.labels, counts = np.unique(self.dataset[label_column], return_counts=True) self.label_fracs = counts / counts.sum() - - print(f"Class balance '{split}': {[f'{x:.2%}' for x in self.label_fracs]}") - pivot, *rest = self.label_fracs - if not all(x == pivot for x in rest): - print("Use arg --balance to force class balance") - if len(self.labels) < 2: raise ValueError(f"Dataset {path}/{name} has only one label") if max_examples: