Skip to content

Commit

Permalink
Add hacks for batch encoded causal lm tokenization
Browse files Browse the repository at this point in the history
Signed-off-by: Alex-Brooks <[email protected]>
  • Loading branch information
alex-jw-brooks committed Sep 12, 2023
1 parent 0f1d48d commit 033f314
Showing 1 changed file with 55 additions and 11 deletions.
66 changes: 55 additions & 11 deletions caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Callable, Tuple, Union

# Third Party
from transformers import BatchEncoding
from transformers import AutoModelForCausalLM, DataCollatorForLanguageModeling
from transformers.models.auto import modeling_auto

Expand Down Expand Up @@ -72,8 +73,14 @@ def tokenize_function(
DataStream[transformers.tokenization_utils_base.BatchEncoding]
stream of encoded tokenization output corresponding to the input example.
"""
# Only render the verbalizer if one is provided
# Extract the soruce & target from our provided inputs
source, target = cls.decompose_example_io(example)
# Determine if our mapped inputs are in batched mode or not
batched_mode = isinstance(source, list) and isinstance(target, list)

# TODO: Handle batched verbalizer stuff!
if batched_mode and verbalizer is not None:
raise NotImplementedError("Verbalizer rendering not implemented for batch mode")
source = source if verbalizer is None else render_verbalizer(verbalizer, example)

source_ids = tokenizer(
Expand All @@ -82,15 +89,26 @@ def tokenize_function(
target_ids = tokenizer(
target, max_length=max_target_length, truncation=True
)
source_ids["input_ids"] = source_ids.input_ids + target_ids.input_ids
# Here, we need to yield and manipulate the attention mask to attend
# to the input seq + the tokens we have seen so far...
num_target_samples = len(target_ids.input_ids)

if task_ids is not None:
source_ids["task_ids"] = task_ids

def generator_func():
if batched_mode:
num_target_samples = []
for idx in range(len(source_ids.input_ids)):
source_ids["input_ids"][idx] = source_ids.input_ids[idx] + target_ids.input_ids[idx]
num_target_samples.append(len(target_ids.input_ids[idx]))
if task_ids is not None:
source_ids["task_ids"][idx] = task_ids
else:
source_ids["input_ids"] = source_ids.input_ids + target_ids.input_ids
# Here, we need to yield and manipulate the attention mask to attend
# to the input seq + the tokens we have seen so far...
num_target_samples = len(target_ids.input_ids)

if task_ids is not None:
source_ids["task_ids"] = task_ids

# This is disgusting! TODO:
# - Consolidate batched [generator] vs. non-batched behavior [batch encoded lists]
# - Make all attention mask logic common, etc.
def single_generator_func():
for idx in range(num_target_samples):
# This may not actually be needed, but for now we do it, since the underlying
# data may be referenced in multiple places, and the data will be dynamically
Expand All @@ -103,7 +121,33 @@ def generator_func():
)
yield s

return DataStream(generator_func)
def get_batched_output():
# Initialize the batch encoding key lists as empty
batch_encoding = BatchEncoding()
for k in source_ids:
batch_encoding[k] = []
# Flatten the batch and add everything individually...
for batch_idx in range(len(source_ids.input_ids)):
# Consider every output text for this entry in the batch
for idx in range(num_target_samples[batch_idx]):
# Create the batch encoding dict directly and populate the keys
# from the corresponding entry inside of the batch...
for key in source_ids:
if key != "attention_mask":
batch_encoding[key].append(source_ids[key][batch_idx])
else:
# Handle the attention mask for this entry...
attn_mask = (
source_ids["attention_mask"][batch_idx]
+ [1] * (idx + 1)
+ [0] * (num_target_samples[batch_idx] - idx - 1)
)
batch_encoding["attention_mask"].append(attn_mask)
return batch_encoding

if batched_mode:
return get_batched_output()
return DataStream(single_generator_func)


def _get_data_collator(self, **kwargs):
Expand Down

0 comments on commit 033f314

Please sign in to comment.