Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ELLMo Integration #180

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ elk/trained/*
nohup.out
.idea
*.pkl
*.ipynb
.vscode/launch.json

# scripts for experiments in progress
my_*.sh
Expand Down
5 changes: 3 additions & 2 deletions elk/extraction/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from simple_parsing import Serializable, field
from torch import Tensor
from transformers import AutoConfig, PreTrainedModel
from transformers import PreTrainedModel
from transformers.modeling_outputs import Seq2SeqLMOutput

from ..promptsource import DatasetTemplates
Expand All @@ -31,6 +31,7 @@
float32_to_int16,
infer_label_column,
infer_num_classes,
instantiate_config,
instantiate_model,
instantiate_tokenizer,
is_autoregressive,
Expand Down Expand Up @@ -298,7 +299,7 @@ def get_splits() -> SplitDict:
dataset_name=available_splits.dataset_name,
)

model_cfg = AutoConfig.from_pretrained(cfg.model)
model_cfg = instantiate_config(cfg.model)

ds_name, _, config_name = cfg.prompts.datasets[0].partition(" ")
info = get_dataset_config_info(ds_name, config_name or None)
Expand Down
82 changes: 82 additions & 0 deletions elk/rnn/elmo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import tensorflow as tf
import tensorflow_hub as hub
import torch
from transformers import (
AutoTokenizer,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizer,
)


class ElmoConfig(PretrainedConfig):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.hidden_size = 1024
self.num_hidden_layers = 3
self.is_encoder_decoder = False
self.architectures = ["Elmo"]


class ElmoTokenizer(PreTrainedTokenizer):
""" "
The ELMo tokenizer is a wrapper around the GPT-2 tokenizer since much of the extraction
pipeline depends on the input being tensors. The ELMo TF implementaiton takes a string
input, so the tensors are decoded within the TfElmoModel instance.
"""

def __init__(self):
self.internal_tokenizer = AutoTokenizer.from_pretrained("gpt2")
self.model_max_length = self.internal_tokenizer.model_max_length

def __call__(
self,
text=None,
return_tensors=None,
truncation=None,
return_offsets_mapping=None,
text_target=None,
add_special_tokens=None,
):
return self.internal_tokenizer(
text=text,
add_special_tokens=add_special_tokens,
return_tensors=return_tensors,
text_target=text_target,
truncation=truncation,
)


class TfElmoModel(PreTrainedModel):
"""A HF wrappper around the Tensorflow ELMo model"""

def __init__(self):
super().__init__(config=ElmoConfig())
self.internal_tokenizer = AutoTokenizer.from_pretrained("gpt2")
self.elmo_model = hub.load("https://tfhub.dev/google/elmo/3").signatures[
"default"
]

def forward(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
labels=None,
output_hidden_states=None,
):
nl_inputs = [
self.internal_tokenizer.decode(sequence_tensor)
for sequence_tensor in input_ids
]
embeddings = self.elmo_model(tf.constant(nl_inputs))
return {
"hidden_states": [
torch.tensor(embeddings["word_emb"].numpy()),
torch.tensor(embeddings["lstm_outputs1"].numpy()),
torch.tensor(embeddings["lstm_outputs2"].numpy()),
torch.tensor(embeddings["elmo"].numpy()),
]
}
7 changes: 6 additions & 1 deletion elk/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
select_train_val_splits,
)
from .gpu_utils import select_usable_devices
from .hf_utils import instantiate_model, instantiate_tokenizer, is_autoregressive
from .hf_utils import (
instantiate_config,
instantiate_model,
instantiate_tokenizer,
is_autoregressive,
)
from .math_util import batch_cov, cov_mean_fused, stochastic_round_constrained
from .tree_utils import pytree_map
from .typing import assert_type, float32_to_int16, int16_to_float32
Expand Down
16 changes: 16 additions & 0 deletions elk/utils/hf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
PreTrainedTokenizerBase,
)

from ..rnn.elmo import ElmoConfig, ElmoTokenizer, TfElmoModel

# Ordered by preference
_DECODER_ONLY_SUFFIXES = [
"CausalLM",
Expand All @@ -19,6 +21,9 @@

def instantiate_model(model_str: str, **kwargs) -> PreTrainedModel:
"""Instantiate a model string with the appropriate `Auto` class."""
if model_str.startswith("elmo"):
return TfElmoModel()

model_cfg = AutoConfig.from_pretrained(model_str)
archs = model_cfg.architectures
if not isinstance(archs, list):
Expand All @@ -37,6 +42,9 @@ def instantiate_model(model_str: str, **kwargs) -> PreTrainedModel:

def instantiate_tokenizer(model_str: str, **kwargs) -> PreTrainedTokenizerBase:
"""Instantiate a tokenizer, using the fast one iff it exists."""
if model_str == "elmo":
return ElmoTokenizer()

try:
return AutoTokenizer.from_pretrained(model_str, use_fast=True, **kwargs)
except Exception as e:
Expand All @@ -46,6 +54,14 @@ def instantiate_tokenizer(model_str: str, **kwargs) -> PreTrainedTokenizerBase:
return AutoTokenizer.from_pretrained(model_str, use_fast=False, **kwargs)


def instantiate_config(model_str: str, **kwargs) -> PretrainedConfig:
"""Instantiate a config."""
if model_str == "elmo":
return ElmoConfig()

return AutoConfig.from_pretrained(model_str, **kwargs)


def is_autoregressive(model_cfg: PretrainedConfig, include_enc_dec: bool) -> bool:
"""Check if a model config is autoregressive."""
archs = model_cfg.architectures
Expand Down
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ dependencies = [
"sentencepiece==0.1.97",
# We upstreamed bugfixes for Literal types in 0.1.1
"simple-parsing>=0.1.1",
# Version 1.11 introduced Fully Sharded Data Parallel, which we plan to use soon
"torch>=1.11.0",
# Version 1.11 introduced Fully Sharded Data Parallel, which we plan to use soon. 1.13 supports A40 cards.
"torch>=1.13.0",
# Doesn't really matter but versions < 4.0 are very very old (pre-2016)
"tqdm>=4.0.0",
# 4.0 introduced the breaking change of using return_dict=True by default
Expand All @@ -43,6 +43,9 @@ dev = [
"pytest",
"pyright",
"scikit-learn",
# We use an impmentation of ELMo from Tensorflow hub. These are only required when eliciting from ELMo.
"tensorflow",
"tensorflow_hub"
]

[project.scripts]
Expand Down