Skip to content

Commit

Permalink
bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamScherlis committed Jun 3, 2024
1 parent a3051ce commit 57d1c6e
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 39 deletions.
88 changes: 70 additions & 18 deletions w2s/model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, Union

from peft import LoraConfig, get_peft_model
import torch
from peft import LoraConfig, TaskType, get_peft_model
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
LlamaForSequenceClassification,
MistralForSequenceClassification,
Qwen2ForSequenceClassification,
)

from w2s.utils import assert_type

# Works for Llama, Mistral, and Qwen architectures
DEFAULT_LORA_MODULES = [
"gate_proj",
Expand All @@ -21,15 +21,17 @@
"v_proj",
"o_proj",
]
DEFAULT_ARCHS = [
LlamaForSequenceClassification,
MistralForSequenceClassification,
Qwen2ForSequenceClassification,
]


@dataclass
class ModelConfig:
class PredictorConfig(ABC):
@abstractmethod
def to_dict(self) -> dict:
...


@dataclass
class ModelConfig(PredictorConfig):
name: str
enable_lora: bool
lora_modules: Optional[List[str]] = None
Expand All @@ -38,17 +40,30 @@ def to_dict(self):
return vars(self)


class AutoCastingScore(torch.nn.Module):
def __init__(
self, score: torch.nn.Linear, output_dtype: torch.dtype = torch.bfloat16
):
super().__init__()
# make a leaf tensor with the same data as score
self.weight = torch.nn.Parameter(score.weight.to(torch.float32).data)
self.output_dtype = output_dtype

def forward(self, hiddens):
return torch.nn.functional.linear(
hiddens.to(self.weight.dtype), self.weight, None
).to(self.output_dtype)


def init_model_and_tokenizer(cfg: ModelConfig):
model = AutoModelForSequenceClassification.from_pretrained(
cfg.name, torch_dtype="auto", device_map={"": "cuda"}
)

if cfg.lora_modules is None and cfg.enable_lora:
cfg.lora_modules = DEFAULT_LORA_MODULES
if not any(isinstance(model, arch) for arch in DEFAULT_ARCHS):
warnings.warn(
"Using default LORA modules for an architecture that is not Llama, Mistral, or Qwen"
)
cfg.lora_modules = MODEL_REGISTRY.get(cfg.name, {}).get(
"lora_modules", DEFAULT_LORA_MODULES
)

tokenizer = AutoTokenizer.from_pretrained(cfg.name)
if tokenizer.pad_token_id is None:
Expand All @@ -59,7 +74,22 @@ def init_model_and_tokenizer(cfg: ModelConfig):
model.config.problem_type = "single_label_classification"

if cfg.enable_lora:
lora_cfg = LoraConfig(target_modules=cfg.lora_modules)
lora_cfg = LoraConfig(
target_modules=cfg.lora_modules, task_type=TaskType.SEQ_CLS
)

# NOTE: adding task_type causes dtype errors, but is necessary for proper module saving
# and for making the lm head trainable, so we need to wrap it in an AutoCastingScore
for attr in ["score", "classifier"]:
if hasattr(model, attr):
setattr(
model,
attr,
AutoCastingScore(getattr(model, attr), output_dtype=model.dtype),
)
break
else:
raise ValueError("Could not find classifier head in model.")
model = get_peft_model(model, lora_cfg)

# put all the trainable (e.g. LoRA) parameters in float32
Expand All @@ -68,3 +98,25 @@ def init_model_and_tokenizer(cfg: ModelConfig):
p.data = p.data.float()

return model, tokenizer


# TODO: make a legitimate model registry
# for now we just have a map from model name to learning rate and lora modules
MODEL_REGISTRY = {
"meta-llama/Meta-Llama-3-8B": {
"lr": 8e-5,
"lora_modules": DEFAULT_LORA_MODULES,
},
"mistralai/Mistral-7B-v0.1": {
"lr": 8e-5,
"lora_modules": DEFAULT_LORA_MODULES,
},
"gemma/gemma-7b": {
"lr": 8e-5,
"lora_modules": DEFAULT_LORA_MODULES,
},
"Qwen/Qwen1.5-0.5B": {
"lr": 5e-4,
"lora_modules": DEFAULT_LORA_MODULES,
},
}
15 changes: 8 additions & 7 deletions w2s/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@ class TopoProbeConfig(ProbeConfig):
"topo": TopoProbeConfig,
}

PROBES = {
"knn": KnnProbe,
"logreg": LogisticProbe,
"topo": TopoProbe,
}


class Probe:
def __init__(self, config: ProbeConfig):
Expand Down Expand Up @@ -113,4 +107,11 @@ def filter(self, acts, labels, contamination):
if not self.config.modified:
return topofilter(acts, labels, contamination, k_cc=self.k_cc, k_zeta=self.k_zeta)
else:
return super().filter(acts, labels, contamination)
return super().filter(acts, labels, contamination)


PROBES = {
"knn": KnnProbe,
"logreg": LogisticProbe,
"topo": TopoProbe,
}
18 changes: 9 additions & 9 deletions w2s/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@


@dataclass
class SFTConfig(Serializable): # TODO: what is this for??
class SFTConfig(Serializable):
# name of the model to train
weak_model_name: str
strong_model_name: str
weak_model_name: str = "Qwen/Qwen1.5-0.5B"
strong_model_name: str = "meta-llama/Meta-Llama-3-8B"
# name of the dataset to use
dataset: str
n_epochs: float = 2
n_train: int = 20_000
n_val: int = 500
n_test: int = 1_000
dataset: str = "boolq"
n_epochs: float = 3
n_train: int = 10_000
n_val: int = 1_000
n_test: int = 5_000
# when "train", it uses the training set to generate predictions
# otherwise it uses n_predict held out examples
n_predict: Union[literal("train"), int] = 0
Expand All @@ -33,7 +33,7 @@ class SFTConfig(Serializable): # TODO: what is this for??
n_warmup_steps: int = 40 # 2 / (1 - 0.95) = 40
eval_every: int = 100 # steps
save_every: int = 100 # steps
save_total_limit: Optional[int] = None
save_total_limit: Optional[int] = 1
weight_decay: float = 0.1
weak_lr: float = 5e-4
strong_lr: float = 8e-5
Expand Down
25 changes: 22 additions & 3 deletions w2s/sft_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import gc
from pathlib import Path
from enum import StrEnum

import pynvml
import torch
Expand All @@ -11,9 +10,29 @@
from w2s.utils import assert_type


# simple_parsing doesn't like typing.Literal so I rolled my own
# simple_parsing doesn't like typing.Literal (pre-3.12) so I rolled my own
# note: parens, not brackets
literal = lambda *args: StrEnum("option", args)

# Python 3.11 version:
# literal = lambda *args: StrEnum("option", args)

# Python 3.10 version:
def literal(s: str):
return type(f'LiteralString_{s}', (LiteralString,), {"value": s})


class LiteralString():
value = ""

def __init__(self, value):
if value != self.value:
raise ValueError(f"Invalid value {value!r} is not literally {self.value!r}")

def __str__(self):
return self.value

def __eq__(self, other):
return self.value == other


@torch.no_grad()
Expand Down
4 changes: 2 additions & 2 deletions w2s/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any, Type, TypeVar, cast
from w2s.sft_config import LossConfig
from simple_parsing import Serializable

T = TypeVar("T")

Expand Down Expand Up @@ -48,7 +48,7 @@ def flatten_dict(d: dict, parent_key: str = "", sep: str = "_") -> dict:
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, sep=sep).items())
elif isinstance(v, LossConfig):
elif isinstance(v, Serializable): # can't use LossConfig, etc to avoid circular import
items.extend(flatten_dict(v.to_dict(), new_key, sep=sep).items())
else:
items.append((new_key, v))
Expand Down

0 comments on commit 57d1c6e

Please sign in to comment.