Skip to content

Commit

Permalink
[Train] Ensure local HF Datasets are split (ray-project#34581)
Browse files Browse the repository at this point in the history
Signed-off-by: Antoni Baum <[email protected]>
  • Loading branch information
Yard1 committed Apr 25, 2023
1 parent fb58279 commit b04154c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
15 changes: 11 additions & 4 deletions python/ray/train/huggingface/_huggingface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_train_dataloader(self):
data_loader = super().get_train_dataloader()
if isinstance(
data_loader.dataset, transformers.trainer.IterableDatasetShard
):
) and getattr(data_loader.dataset.dataset, "_do_not_split", False):
# Default Trainer.get_train_dataloader will wrap the dataset in
# IterableDatasetShard, which will perform additional sharding on top
# of the already sharded dataset. By setting those two attributes,
Expand Down Expand Up @@ -75,8 +75,10 @@ def __iter__(self):
yield (0, {k: v for k, v in row.as_pydict().items()})


def process_dataset_for_hf(dataset: DataIterator) -> "IterableDataset":
"""Converts a Datastream into a HF IterableDataset."""
def process_dataset_for_hf(
dataset: DataIterator, disable_transformers_splitting: bool = False
) -> "IterableDataset":
"""Converts a Ray Dataset into a HF IterableDataset."""
hf_iterable = RayDatasetHFIterable(dataset)

iterable_dataset = datasets.iterable_dataset.IterableDataset(
Expand All @@ -90,6 +92,9 @@ def process_dataset_for_hf(dataset: DataIterator) -> "IterableDataset":
dataset_length = None

iterable_dataset = maybe_add_length(iterable_dataset, dataset_length)
# Trigger logic in `wrap_transformers_trainer` to disable built-in
# HuggingFace splitting, as we have already split the dataset ourselves.
iterable_dataset._do_not_split = disable_transformers_splitting
return iterable_dataset


Expand All @@ -99,7 +104,9 @@ def process_datasets(
) -> Tuple["IterableDataset", "IterableDataset"]:
"""Convert Ray train and validation to HF IterableDatasets."""
if train_dataset:
train_torch_dataset = process_dataset_for_hf(train_dataset)
train_torch_dataset = process_dataset_for_hf(
train_dataset, disable_transformers_splitting=True
)
else:
train_torch_dataset = None

Expand Down
13 changes: 7 additions & 6 deletions python/ray/train/huggingface/huggingface_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,18 @@ class HuggingFaceTrainer(TorchTrainer):
shards, with each Actor training on a single shard.
All the other datasets will not be split.
Please note that if you use a custom ``transformers.Trainer`` subclass,
the ``get_train_dataloader`` method will be wrapped around to disable
sharding by ``transformers.IterableDatasetShard``, as the dataset will
already be sharded on the Ray AIR side.
You can also provide ``datasets.Dataset`` object or other dataset objects
allowed by ``transformers.Trainer`` directly in the ``trainer_init_per_worker``
function, without specifying the ``datasets`` dict. It is recommended to initialize
those objects inside the function, as otherwise they will be serialized and passed
to the function, which may lead to long runtime and memory issues with large
amounts of data.
Please note that if you use a custom ``transformers.Trainer`` subclass,
the ``get_train_dataloader`` method will be wrapped around to disable
sharding by ``transformers.IterableDatasetShard``, as the dataset will
already be sharded on the Ray AIR side.
amounts of data. In this case, the training dataset will be split
automatically by Transformers.
HuggingFace loggers will be automatically disabled, and the ``local_rank``
argument in ``TrainingArguments`` will be automatically set. Please note
Expand Down

0 comments on commit b04154c

Please sign in to comment.