Skip to content

Commit

Permalink
Merge pull request caikit#156 from alex-jw-brooks/experiment_sharding…
Browse files Browse the repository at this point in the history
…_rebase

Experiment sharding rebase
  • Loading branch information
gkumbhat committed Sep 12, 2023
2 parents 091e271 + 8a0395a commit ad2aca2
Show file tree
Hide file tree
Showing 12 changed files with 603 additions and 238 deletions.
3 changes: 3 additions & 0 deletions caikit_nlp/config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ source_prompt_base: ""

# Whether or not to purge TGIS prompts on model deletion
unload_tgis_prompt_artifacts: false
# Torchrun elastic launch configuration, e.g., for fine tuning on multiple GPUs
master_addr: localhost
master_port: 29550

runtime:
library: caikit_nlp
2 changes: 1 addition & 1 deletion caikit_nlp/modules/text_generation/peft_prompt_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ def _get_data_loaders_from_stream(
(
tokenize_function,
requires_unwrapping,
) = base_model.build_task_tokenize_function(
) = base_model.build_task_tokenize_closure(
tokenizer, max_source_length, max_target_length, verbalizer, task_ids=0
)
mapped_stream = train_stream.map(tokenize_function)
Expand Down
303 changes: 235 additions & 68 deletions caikit_nlp/modules/text_generation/text_generation_local.py

Large diffs are not rendered by default.

78 changes: 63 additions & 15 deletions caikit_nlp/resources/pretrained_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# Standard
from abc import ABC, abstractmethod
from collections.abc import Mapping
from typing import Callable, List, Optional, Tuple, Type, Union
import json
import os
Expand All @@ -31,12 +32,13 @@

# First Party
from caikit import get_config
from caikit.core.data_model import DataStream
from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver
from caikit.core.toolkit import error_handler
import alog

# Local
from ...data_model import PromptOutputModelType
from ...data_model import GenerationTrainRecord, PromptOutputModelType
from ...toolkit.data_type_utils import get_torch_dtype, str_to_torch_dtype

log = alog.use_channel("HFRBAS")
Expand All @@ -50,6 +52,12 @@ class PretrainedModelBase(ABC, ModuleBase):
_MODEL_ARTIFACTS_CONFIG_KEY = "model_artifacts"
_LEFT_PAD_MODEL_TYPES = ("gpt", "opt", "bloom")

@classmethod
@property
def REQUIRES_TOKEN_UNWRAPPING(cls) -> str:
"""Most models don't need token unwrapping from their tokenizer closures"""
return False

## Abstract Interface ######################################################

@classmethod
Expand All @@ -73,7 +81,6 @@ def SUPPORTED_MODEL_TYPES(cls) -> str:
"""All classes must indicate the model types supported by the resource"""

## Shared Implementation ###################################################

def __init__(
self,
tokenizer: AutoTokenizer,
Expand Down Expand Up @@ -125,12 +132,14 @@ def bootstrap(
Args:
model_name (str)
The name/path of the HF sequence classifier model
tokenizer_name (Optional[str])
tokenizer_name (Optional[str]
The name/path of the HF tokenizer model (matches model_name if
not given)
not given) or an instance of a loaded tokenizer.
NOTE: If a loaded tokenizer is provided, and it doesn't have
a pad token ID, the pad token ID will be set to the EOS token ID.
padding_side (Optional[str])
The padding side for the tokenizer. Found by convention if not
given.
given. This value is only used if a tokenizer needs to be loaded.
torch_dtype: (Optional[Union[torch.dtype, str]])
Data type to load the model as; if no value is provided, we pull
torch_dtype from config.
Expand Down Expand Up @@ -174,7 +183,10 @@ def bootstrap(
tokenizer_name,
local_files_only=not get_config().allow_downloads,
padding_side=padding_side,
# We can't disable use_fast otherwise unit test fails
# use_fast=False,
)

if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id

Expand Down Expand Up @@ -265,7 +277,6 @@ def get_trainer(
trainer_arguments = {
"train_dataset": train_dataset,
"data_collator": data_collator,
"tokenizer": self._tokenizer,
"optimizers": optimizers,
"eval_dataset": eval_dataset,
}
Expand Down Expand Up @@ -302,20 +313,58 @@ def get_num_transformers_submodules(
"""Return number of applicable transformer submodules"""
return 1

@staticmethod
def decompose_example_io(example: Union[GenerationTrainRecord, Mapping]):
"""Given an example, which might be a number of supported types,
extract the input / output texts. Depending on the manner in which
the sample is being leveraged, this might be a raw data model object,
a dict, or some other mappable, e.g., a HF dataset LazyRow.
args:
example: Union[GenerationTrainRecord, Mapping]
Objects whose input / output we want to retrieve.
Returns:
Tuple[str, str]
Input & Output strings.
"""
if isinstance(example, GenerationTrainRecord):
return example.input, example.output
# TODO: probably a good idea to add some error handling here;
# For now, we don't since situations in which we call this
# internally should generally enforce this as true, e.g.,
# hf datasets created out of data model objects.
return example["input"], example["output"]

@classmethod
def build_task_tokenize_closure(cls, *args, **kwargs) -> Tuple[Callable, bool]:
"""Builds tokenizer closure which can be mapped over train streams to process
data which can then be easily passed to a DataLoader for different model types.
This is largely for convenience if we want a closure that can be applied
without having to carry around other parameters.
"""

def tokenize_wrapper(example: GenerationTrainRecord):
return cls.tokenize_function(example, *args, **kwargs)

return (tokenize_wrapper, cls.REQUIRES_TOKEN_UNWRAPPING)

@classmethod
@abstractmethod
def build_task_tokenize_function(
def tokenize_function(
cls,
example: Union[GenerationTrainRecord, Mapping],
tokenizer: "AutoTokenizer",
max_source_length: int,
max_target_length: int,
verbalizer: str,
verbalizer: Union[None, str] = None,
task_ids: Union[None, int] = None,
) -> Tuple[Callable, bool]:
"""Builds tokenizer functions which can be mapped over train streams to process
data which can then be easily passed to a DataLoader for different model types.
) -> Union["BatchEncoding", DataStream["BatchEncoding"]]:
"""Tokenizes a generation training record.
Args:
Union[GenerationTrainRecord, Mapping]
Example data model object / mapping to be tokenized.
tokenizer: AutoTokenizer
Model tokenizer to be used in preprocessing, i.e., when we iterate over our data.
max_source_length: int
Expand All @@ -325,14 +374,13 @@ def build_task_tokenize_function(
verbalizer: str
Verbalizer template to be used for formatting data. This template may use brackets
to indicate where fields from the data model TrainGenerationRecord must be rendered.
If no verbalizer is provided, the source text is used as the rendered result.
task_ids: Union[None, int]
Task id corresponding particular task for multi-task prompt tuning.
NOTE: Only required for MPT (Multi-task prompt tuning)
Default: None
Returns:
Tuple(Callable, bool)
Mappable tokenize function to be applied to a training stream and bool indicating
whether or not the stream needs to be unwrapped, i.e., each sample yields a stream
of 1+ samples.
BatchEncoding | DataStream[BatchEncoding]
encoded tokenization output corresponding to the input example.
"""
160 changes: 93 additions & 67 deletions caikit_nlp/resources/pretrained_model/hf_auto_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@
Huggingface auto causal LM resource type
"""
# Standard
from collections.abc import Mapping
from copy import deepcopy
from typing import Callable, Tuple, Union
from typing import Union

# Third Party
from transformers import AutoModelForCausalLM, DataCollatorForLanguageModeling
from transformers import (
AutoModelForCausalLM,
BatchEncoding,
DataCollatorForLanguageModeling,
)
from transformers.models.auto import modeling_auto

# First Party
Expand All @@ -45,66 +50,60 @@ class HFAutoCausalLM(PretrainedModelBase):
TASK_TYPE = "CAUSAL_LM"
PROMPT_OUTPUT_TYPES = [PromptOutputModelType.DECODER]
MAX_NUM_TRANSFORMERS = 1
REQUIRES_TOKEN_UNWRAPPING = True

@staticmethod
def build_task_tokenize_function(
@classmethod
def tokenize_function(
cls,
example: Union[GenerationTrainRecord, Mapping],
tokenizer: "AutoTokenizer",
max_source_length: int,
max_target_length: int,
verbalizer: str,
verbalizer: Union[None, str] = None,
task_ids: Union[None, int] = None,
) -> Tuple[Callable, bool]:
"""Builds tokenizer functions which can be mapped over train streams to process
data which can then be easily passed to a DataLoader for CausalLM models.
) -> DataStream["BatchEncoding"]:
"""Tokenization function to be used for causallm training; this function consumes a
GenerationTrainRecord object and applies the verbalizer to it followed by
the model tokenizer. Due to the nature of our training data with src/target seqs,
each sample yields one example per token in the target sequence.
Args:
tokenizer: AutoTokenizer
Model tokenizer to be used in preprocessing, i.e., when we iterate over our data.
max_source_length: int
Max length of sequences being considered.
max_target_length: int
Max length of target sequences being predicted.
verbalizer: str
Verbalizer template to be used for formatting data. This template may use brackets
to indicate where fields from the data model TrainGenerationRecord must be rendered.
task_ids: Union[None, int]
Task id corresponding particular task for multi-task prompt tuning.
NOTE: Only required for MPT (Multi-task prompt tuning)
Default: None
example: GenerationTrainRecord | Mapping
Training data model object to convert a form we can learn on, or a Mapping
that has keys input/output.
Returns:
Tuple(Callable, bool)
Mappable tokenize function to be applied to a training stream and bool indicating
whether or not the stream needs to be unwrapped, i.e., each sample yields a stream
of 1+ samples.
DataStream[transformers.tokenization_utils_base.BatchEncoding]
stream of encoded tokenization output corresponding to the input example.
"""

def tokenize_function_language_model(
example: GenerationTrainRecord,
) -> "BatchEncoding":
"""Tokenization function to be used for causallm training; this function consumes a
GenerationTrainRecord object and applies the verbalizer to it followed by
the model tokenizer. Due to the nature of our training data with src/target seqs,
each sample yields one example per token in the target sequence.
Args:
example: GenerationTrainRecord
Training data model object to convert a form we can learn on.
Returns:
transformers.tokenization_utils_base.BatchEncoding
encoded tokenization output corresponding to the input example.
"""

# Render the verbalizer template with the attributes of this data model example
source = render_verbalizer(verbalizer, example)

source_ids = tokenizer(
source, max_length=max_source_length, truncation=True
)
target_ids = tokenizer(
example.output, max_length=max_target_length, truncation=True
# Extract the source & target from our provided inputs
source, target = cls.decompose_example_io(example)
# Determine if our mapped inputs are in batched mode or not
batched_mode = isinstance(source, list) and isinstance(target, list)

# TODO: Handle batched verbalizer stuff!
if batched_mode and verbalizer is not None:
raise NotImplementedError(
"Verbalizer rendering not implemented for batch mode"
)
source = (
source if verbalizer is None else render_verbalizer(verbalizer, example)
)

# HACK: We shouldn't have to pad here, but the causal LM data collator dynamic padding
# does not appear to be playing nicely with the Huggingface trainer / torch fsdp...
source_ids = tokenizer(source, max_length=max_source_length, truncation=True)
target_ids = tokenizer(target, max_length=max_target_length, truncation=True)
if batched_mode:
num_target_samples = []
for idx, _ in enumerate(source_ids.input_ids):
source_ids["input_ids"][idx] = (
source_ids.input_ids[idx] + target_ids.input_ids[idx]
)
num_target_samples.append(len(target_ids.input_ids[idx]))
if task_ids is not None:
source_ids["task_ids"][idx] = task_ids
else:
source_ids["input_ids"] = source_ids.input_ids + target_ids.input_ids
# Here, we need to yield and manipulate the attention mask to attend
# to the input seq + the tokens we have seen so far...
Expand All @@ -113,22 +112,49 @@ def tokenize_function_language_model(
if task_ids is not None:
source_ids["task_ids"] = task_ids

def generator_func():
for idx in range(num_target_samples):
# This may not actually be needed, but for now we do it, since the underlying
# data may be referenced in multiple places, and the data will be dynamically
# padded by the LM collator
s = deepcopy(source_ids)
s["attention_mask"] = (
s["attention_mask"]
+ [1] * (idx + 1)
+ [0] * (num_target_samples - idx - 1)
)
yield s

return DataStream(generator_func)

return (tokenize_function_language_model, True)
# This is disgusting! TODO:
# - Consolidate batched [generator] vs. non-batched behavior [batch encoded lists]
# - Make all attention mask logic common, etc.
def single_generator_func():
for idx in range(num_target_samples):
# This may not actually be needed, but for now we do it, since the underlying
# data may be referenced in multiple places, and the data will be dynamically
# padded by the LM collator
s = deepcopy(source_ids)
s["attention_mask"] = (
s["attention_mask"]
+ [1] * (idx + 1)
+ [0] * (num_target_samples - idx - 1)
)
yield s

def get_batched_output():
# Initialize the batch encoding key lists as empty
batch_encoding = BatchEncoding()
for k in source_ids:
batch_encoding[k] = []
# Flatten the batch and add everything individually...
for batch_idx in range(len(source_ids.input_ids)):
# Consider every output text for this entry in the batch
for idx in range(num_target_samples[batch_idx]):
# Create the batch encoding dict directly and populate the keys
# from the corresponding entry inside of the batch...
for key in source_ids:
if key != "attention_mask":
batch_encoding[key].append(source_ids[key][batch_idx])
else:
# Handle the attention mask for this entry...
attn_mask = (
source_ids["attention_mask"][batch_idx]
+ [1] * (idx + 1)
+ [0] * (num_target_samples[batch_idx] - idx - 1)
)
batch_encoding["attention_mask"].append(attn_mask)
return batch_encoding

if batched_mode:
return get_batched_output()
return DataStream(single_generator_func)

def _get_data_collator(self, **kwargs):
"""Function to return appropriate data collator based on resource.
Expand Down
Loading

0 comments on commit ad2aca2

Please sign in to comment.