Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] RWKV LM #207

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Reset extraction
  • Loading branch information
Kyle1668 committed Apr 22, 2023
commit 69194eedfae517c5571de5f59748d740ed6048c3
203 changes: 107 additions & 96 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from dataclasses import InitVar, dataclass
from itertools import islice
from typing import Any, Iterable, Literal
from warnings import filterwarnings

import torch
from datasets import (
Array2D,
Array3D,
ClassLabel,
DatasetDict,
DownloadMode,
Features,
Sequence,
SplitDict,
Expand All @@ -20,23 +22,23 @@
)
from simple_parsing import Serializable, field
from torch import Tensor
from transformers import AutoConfig, AutoTokenizer, GPT2TokenizerFast
from transformers import AutoConfig, PreTrainedModel
from transformers.modeling_outputs import Seq2SeqLMOutput

from ..promptsource import DatasetTemplates
from ..utils import (
assert_type,
convert_span,
float32_to_int16,
infer_label_column,
infer_num_classes,
instantiate_model,
instantiate_tokenizer,
is_autoregressive,
select_train_val_splits,
select_usable_devices,
)
from .balanced_sampler import BalancedSampler
from .generator import _GeneratorBuilder
from .prompt_loading import PromptConfig, load_prompts
from ..rwkv_lm.rwkv_hf import RWKVConfig


@dataclass
Expand All @@ -58,6 +60,7 @@ class Extract(Serializable):
layers: tuple[int, ...] = ()
layer_stride: InitVar[int] = 1
token_loc: Literal["first", "last", "mean"] = "last"
use_encoder_states: bool = False

def __post_init__(self, layer_stride: int):
if self.layers and layer_stride > 1:
Expand Down Expand Up @@ -85,7 +88,7 @@ def explode(self) -> list["Extract"]:
return copies


@torch.no_grad()
@torch.inference_mode()
def extract_hiddens(
cfg: "Extract",
*,
Expand All @@ -99,135 +102,135 @@ def extract_hiddens(

# Silence datasets logging messages from all but the first process
if rank != 0:
filterwarnings("ignore")
logging.disable(logging.CRITICAL)

ds_names = cfg.prompts.datasets
p_cfg = cfg.prompts
ds_names = p_cfg.datasets
assert len(ds_names) == 1, "Can only extract hiddens from one dataset at a time."

prompt_ds = load_prompts(
ds_names[0],
split_type=split_type,
stream=cfg.prompts.stream,
rank=rank,
world_size=world_size,
) # this dataset is already sharded, but hasn't been truncated to max_examples

model = instantiate_model(
cfg.model, torch_dtype="auto" if device != "cpu" else torch.float32
).to(device)
tokenizer = None
tokenizer = instantiate_tokenizer(
cfg.model, truncation_side="left", verbose=rank == 0
)

if cfg.model.startswith("RWKV"):
tokenizer = GPT2TokenizerFast(tokenizer_file='/home/kyle/repos/elk/elk/rwkv_lm/20B_tokenizer.json')
else:
tokenizer = AutoTokenizer.from_pretrained(
cfg.model, truncation_side="left", verbose=False
)
is_enc_dec = model.config.is_encoder_decoder
if is_enc_dec and cfg.use_encoder_states:
assert hasattr(model, "get_encoder") and callable(model.get_encoder)
model = assert_type(PreTrainedModel, model.get_encoder())
is_enc_dec = False

has_lm_preds = is_autoregressive(model.config)
has_lm_preds = is_autoregressive(model.config, not cfg.use_encoder_states)
if has_lm_preds and rank == 0:
print("Model has language model head, will store predictions.")

prompt_ds = load_prompts(
ds_names[0],
label_column=p_cfg.label_columns[0] if p_cfg.label_columns else None,
num_classes=p_cfg.num_classes,
split_type=split_type,
stream=p_cfg.stream,
rank=rank,
world_size=world_size,
)

# Iterating over questions
layer_indices = cfg.layers or tuple(range(model.config.num_hidden_layers))

global_max_examples = cfg.prompts.max_examples[0 if split_type == "train" else 1]
global_max_examples = p_cfg.max_examples[0 if split_type == "train" else 1]
# break `max_examples` among the processes roughly equally
max_examples = global_max_examples // world_size
# the last process gets the remainder (which is usually small)
if rank == world_size - 1:
max_examples += global_max_examples % world_size

for example in islice(BalancedSampler(prompt_ds), max_examples):
for example in islice(prompt_ds, max_examples):
num_variants = len(example["prompts"])
num_choices = len(example["prompts"][0])

hidden_dict = {
f"hidden_{layer_idx}": torch.empty(
num_variants,
2, # contrast pair
num_choices,
model.config.hidden_size,
device=device,
dtype=torch.int16,
)
for layer_idx in layer_indices
}
lm_preds = torch.empty(
lm_logits = torch.empty(
num_variants,
2, # contrast pair
num_choices,
device=device,
dtype=torch.float32,
)
text_inputs = []
text_questions = []

# Iterate over variants
for i, record in enumerate(example["prompts"]):
variant_inputs = []
variant_questions = []

# Iterate over answers
for j, choice in enumerate(record):
text = choice["text"]

# TODO: Do something smarter than "rindex" here. Really we want to
# get the span of the answer directly from Jinja, but that doesn't
# seem possible. This approach may fail for complex templates.
answer_start = text.rindex(choice["answer"])
text = choice["question"]

# Only feed question, not the answer, to the encoder for enc-dec models
if model.config.is_encoder_decoder:
# TODO: Maybe make this more generic for complex templates?
text = text[:answer_start].rstrip()
target = choice["answer"]
else:
target = None

# Record the EXACT string we fed to the model
variant_inputs.append(text)
# inputs = None
# if cfg.model.startswith("RWKV"):
# inputs = tokenizer(
# text,
# return_offsets_mapping=True,
# text_target=target, # type: ignore[arg-type]
# truncation=True,
# )
# else:
inputs = tokenizer(
target = choice["answer"] if is_enc_dec else None

# Record the EXACT question we fed to the model
variant_questions.append(text)
encoding = tokenizer(
text,
return_offsets_mapping=True,
add_special_tokens=False,
return_tensors="pt",
text_target=target, # type: ignore[arg-type]
truncation=True,
)
).to(device)
input_ids = assert_type(Tensor, encoding.input_ids)

if is_enc_dec:
answer = assert_type(Tensor, encoding.labels)
else:
encoding2 = tokenizer(
choice["answer"],
add_special_tokens=False,
return_tensors="pt",
).to(device)
answer = assert_type(Tensor, encoding2.input_ids)

# The offset_mapping is a sorted list of (start, end) tuples. We locate
# the start of the answer in the tokenized sequence with binary search.
offsets = inputs.pop("offset_mapping") if cfg.model.startswith("RWKV") else inputs.pop("offset_mapping").squeeze().tolist()
inputs = inputs if cfg.model.startswith("RWKV") else inputs.to(device)
input_ids = torch.cat([input_ids, answer], dim=-1)
if max_len := tokenizer.model_max_length:
input_ids = input_ids[..., -max_len:]

# Run the forward pass
outputs = model(**inputs) if cfg.model.startswith("RWKV") else model(**inputs, output_hidden_states=True)
# Make sure we only pass the arguments that the model expects
inputs = dict(input_ids=input_ids)
if is_enc_dec:
inputs["labels"] = answer

with torch.autocast("cuda", enabled=torch.cuda.is_available()):
outputs = model(**inputs, output_hidden_states=True)

# Compute the log probability of the answer tokens if available
if has_lm_preds:
start, end = convert_span(
offsets, (answer_start, answer_start + len(choice["answer"]))
)
log_p = outputs.logits[..., start - 1 : end - 1, :].log_softmax(
dim=-1
)
tokens = inputs.input_ids[..., start:end, None]
lm_preds[i, j] = log_p.gather(-1, tokens).sum()
answer_len = answer.shape[-1]

log_p = outputs.logits[..., -answer_len:, :].log_softmax(dim=-1)
tokens = answer[..., None]
lm_logits[i, j] = log_p.gather(-1, tokens).sum()

elif isinstance(outputs, Seq2SeqLMOutput):
# The cross entropy loss is averaged over tokens, so we need to
# multiply by the length to get the total log probability.
length = inputs.labels.shape[-1]
lm_preds[i, j] = -assert_type(Tensor, outputs.loss) * length
length = encoding.labels.shape[-1]
lm_logits[i, j] = -assert_type(Tensor, outputs.loss) * length

hiddens = outputs if cfg.model.startswith("RWKV") else (
hiddens = (
outputs.get("decoder_hidden_states") or outputs["hidden_states"]
)
# First element of list is the input embeddings
hiddens = hiddens if cfg.model.startswith("RWKV") else hiddens[1:]
hiddens = hiddens[1:]

# Throw out layers we don't care about
hiddens = [hiddens[i] for i in layer_indices]
Expand All @@ -245,17 +248,16 @@ def extract_hiddens(
for layer_idx, hidden in zip(layer_indices, hiddens):
hidden_dict[f"hidden_{layer_idx}"][i, j] = float32_to_int16(hidden)

text_inputs.append(variant_inputs)
text_questions.append(variant_questions)

out_record: dict[str, Any] = dict(
label=example["label"],
variant_ids=example["template_names"],
text_inputs=text_inputs,
text_questions=text_questions,
**hidden_dict,
)
if has_lm_preds:
# We only need the probability of the positive example since this is binary
out_record["model_preds"] = lm_preds.softmax(dim=-1)[..., 1]
out_record["model_logits"] = lm_logits

yield out_record

Expand All @@ -266,7 +268,11 @@ def _extraction_worker(**kwargs):


def extract(
cfg: "Extract", num_gpus: int = -1, min_gpu_mem: int | None = None
cfg: "Extract",
*,
disable_cache: bool = False,
num_gpus: int = -1,
min_gpu_mem: int | None = None,
) -> DatasetDict:
"""Extract hidden states from a model and return a `DatasetDict` containing them."""

Expand All @@ -292,15 +298,18 @@ def get_splits() -> SplitDict:
dataset_name=available_splits.dataset_name,
)

model_cfg = None
if cfg.model.startswith("RWKV"):
model_cfg = RWKVConfig()
else:
model_cfg = AutoConfig.from_pretrained(cfg.model)
model_cfg = AutoConfig.from_pretrained(cfg.model)

ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ")
info = get_dataset_config_info(ds_name, config_name or None)

ds_features = assert_type(Features, info.features)
label_col = (
cfg.prompts.label_columns[0]
if cfg.prompts.label_columns
else infer_label_column(ds_features)
)
num_classes = cfg.prompts.num_classes or infer_num_classes(ds_features[label_col])
num_variants = cfg.prompts.num_variants
if num_variants < 0:
prompter = DatasetTemplates(ds_name, config_name)
Expand All @@ -309,7 +318,7 @@ def get_splits() -> SplitDict:
layer_cols = {
f"hidden_{layer}": Array3D(
dtype="int16",
shape=(num_variants, 2, model_cfg.hidden_size),
shape=(num_variants, num_classes, model_cfg.hidden_size),
)
for layer in cfg.layers or range(model_cfg.num_hidden_layers)
}
Expand All @@ -318,21 +327,20 @@ def get_splits() -> SplitDict:
Value(dtype="string"),
length=num_variants,
),
"label": ClassLabel(names=["neg", "pos"]),
"text_inputs": Sequence(
"label": Value(dtype="int64"),
"text_questions": Sequence(
Sequence(
Value(dtype="string"),
length=2,
),
length=num_variants,
),
}

# Only add model_preds if the model is an autoregressive model
if is_autoregressive(model_cfg):
other_cols["model_preds"] = Sequence(
Value(dtype="float32"),
length=num_variants,
# Only add model_logits if the model is an autoregressive model
if is_autoregressive(model_cfg, not cfg.use_encoder_states):
other_cols["model_logits"] = Array2D(
shape=(num_variants, num_classes),
dtype="float32",
)

devices = select_usable_devices(num_gpus, min_memory=min_gpu_mem)
Expand Down Expand Up @@ -361,7 +369,10 @@ def get_splits() -> SplitDict:

ds = dict()
for split, builder in builders.items():
builder.download_and_prepare(num_proc=len(devices))
builder.download_and_prepare(
download_mode=DownloadMode.FORCE_REDOWNLOAD if disable_cache else None,
num_proc=len(devices),
)
ds[split] = builder.as_dataset(split=split)

return DatasetDict(ds)
return DatasetDict(ds)