-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batched Dataloader #26957
Comments
An implementation of from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import Sampler
from collections import namedtuple
from itertools import islice
import abc
import math
# BatchedSample and BatchedData below are wrappers
# that we use to enforce various invariants throughout
# the code and shield against unwanted PyTorch's behaviors
# we want to smuggle a batch of values as a single value and
# remain in control of unwrapping
# if we were passing "naked" collections, we risk stepping into
# PyTorch's logic that handles batch sampling but singleton fetching
# of values from a dataset in PyTorch's DataLoader class
# the end result of unwrapping too early (by DataLoader class) is a
# crash due to dimensions of tensors not lining up
# yes, this is a delicate point but we avoid all traps by hiding the
# fact that we're passing batches of indices/values by wrapping them
BatchedSample = namedtuple('BatchedSample', ['payload'])
BatchedData = namedtuple('BatchedData', ['payload'])
class BatchedDataset(Dataset, object):
"""
A specialized variant of a Dataset that expects a batch of indices
passed in instead of a single numerical index.
It has one abstract method `_get_batch` that receives unwrapped indices
and is expected to be implemented by concrete subclasses of this class
that return specific data in batches.
"""
__metaclass__ = abc.ABCMeta
def __init__(self, batch_size, unbatched_length):
self.batch_size = batch_size
n_frac = unbatched_length / float(batch_size)
batched_length = int(math.ceil(n_frac))
self.batched_length = batched_length
def __len__(self):
return self.batched_length
def __getitem__(self, batched_indices):
if not isinstance(batched_indices, BatchedSample):
raise TypeError('Parameter `batched_indices` must be of type \
BatchedSample, got %s' % (type(batched_indices)))
raw_indices = batched_indices.payload
batch = self._get_batch(raw_indices)
return BatchedData(batch)
@abc.abstractmethod
def _get_batch(self, indices):
pass
def split_every(n, iterable):
i = iter(iterable)
piece = list(islice(i, n))
while piece:
yield piece
piece = list(islice(i, n))
class BatchingSamplerWrapper(Sampler):
"""
A class that takes a regular data sampler and turns it into
one that returns sampled indices in batches. This class is
useful for testing and quick prototyping.
For the actual sampler implementation, you should implement
batched sampling directly to avoid creating a performance
bottleneck in sampling codepath.
"""
def __init__(self, orig_sampler, batch_size):
self.orig_sampler = orig_sampler
self.batch_size = batch_size
def __iter__(self):
raw_batched_samples = split_every(self.batch_size, iter(self.orig_sampler))
for batch in raw_batched_samples:
yield BatchedSample(batch)
def __len__(self):
n_frac = len(self.orig_sampler) / float(self.batch_size)
return int(math.ceil(n_frac))
def dataloader_for_batched_dataset(batched_dataset, pin_memory=True, sampler=None,
batch_sampler=None):
"""
This method ties three pieces together:
1. Sampler that is returning a batch of indices wrapped in BatchedSample
2. BatchedDataset that unwraps BatchedSample, fetches requested pieces of
data and returns them wrapped in BatchedData
3. `unwrap_batch_collate_fn` collate function (defined below) that unwraps
data from `BatchedData` and returns it
This process is orchestrated by PyTorch's original Dataloader class.
As a result, we smuggle batches of data as single values through PyTorch's
APIs that are designed for single values instead of batches.
Parameters
----------
batched_dataset: BatchedDataset
pin_memory: boolean, default True
Parameter passed to torch.utils.data.DataLoader
sampler: Sampler
An instance of Sampler (regular, unbatched)
batch_sampler: Sampler
An instance of Sampler that returns batches of indices wrapped in BatchedSample
"""
def unwrap_batch_collate_fn(wrapped_batch):
# gkk: where is my pattern matching?? :sob:
if not isinstance(wrapped_batch, list):
raise TypeError('Parameter `wrapped_batch` must be of type \
BatchedSample, got %s' % (type(wrapped_batch)))
if not len(wrapped_batch) == 1:
raise ValueError('`wrapped_batch` should have length 1 but has \
%d ' % (len(wrapped_batch)))
wrapped_batch = wrapped_batch[0]
if not isinstance(wrapped_batch, BatchedData):
raise TypeError('Parameter `wrapped_batch` must be of type \
BatchedData, got %s' % (type(wrapped_batch)))
collated = wrapped_batch.payload
return collated
if not isinstance(batched_dataset, BatchedDataset):
raise TypeError('Parameter `batched_dataset` must be of type \
BatchedDataset, got %s' % (type(batched_dataset)))
assert sampler or batch_sampler, 'Either `sampler` or `batch_sampler` must be specified'
assert not (sampler and batch_sampler), \
"sampler and batch_sampler can't be specified at the same time"
batch_size = batched_dataset.batch_size
# not sure if `downstream` is the best word but couldn't think of a better one
if sampler:
downstream_batch_sampler = BatchingSamplerWrapper(sampler, batch_size)
else:
downstream_batch_sampler = batch_sampler
loader = DataLoader(batched_dataset, sampler=downstream_batch_sampler,
collate_fn=unwrap_batch_collate_fn, pin_memory=pin_memory)
return loader A simple unit test for the above: class ConstDataset(Dataset):
"""Serves passed data as is."""
def __init__(self, data):
super(ConstDataset, self).__init__()
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def test_batching_wrappers():
"""
Checks if the bevaior of the triple: batched dataset, batched sampler and
the logic in dataloader_for_batched_dataset that ties the two together
ends up loading the same data as a regular (unbatched) dataloading.
"""
batch_size = 3
descending = list(reversed(range(10)))
dataset = ConstDataset(map(torch.tensor, descending))
class ConstBatchDataset(BatchedDataset):
def __init__(self, data, batch_size):
unbatched_length = len(data)
super(ConstBatchDataset, self).__init__(batch_size, unbatched_length)
self.data = data
def _get_batch(self, indices):
print('indices are %s' % indices)
batch = torch.take(self.data, torch.tensor(indices))
print('batch is %s' % batch)
return batch
batch_dataset = ConstBatchDataset(torch.tensor(descending), batch_size)
sampler = SequentialSampler(dataset)
loader = dataloader_for_batched_dataset(batch_dataset, sampler=sampler)
sampled_data = list(iter(loader))
regular_loader = DataLoader(dataset, batch_size=batch_size)
sampled_data2 = list(iter(regular_loader))
print(sampled_data)
print(sampled_data2)
# pytorch doesn't support standard comparison for tensors so we have
# to iterate over elements of a list individually and call torch.equal
# on each one
assert len(sampled_data) == len(sampled_data2)
for x, y in zip(sampled_data, sampled_data2):
assert torch.equal(x, y) |
Actually batch loading is supported if you set |
@ssnl yes, the IterableDataset helps, thanks for implementing it. I ran into this issue around 1.0 Pytorch release when this abstraction didn't exist and only now I got around to filing a ticket. After filing a ticket I saw your PR and dropped this comment:
Do you see "falling of the utilities cliff" and being on your own when switching to IterableDataset a problem? |
Sorry for not being clear. I actually did not mean |
@ssnl, ah, I misunderstood you. I wrongly assumed you talk about #19228 and |
@gkossakowski - were your issues resolved? I'm currently evaluating whether we need a wider support for more complex data pipelines via some set of new low-level abstractions. |
@ssnl, how would I disable automatic batching but use a Use-case:
I'm using a |
Using the ideas described by @ssnl above I managed to very simply put together a batched dataloader, no custom classes classes or "smuggling" of values required. Hopefully this is useful to those that end up here looking for a ready made solution. https://gist.github.com/GCBallesteros/05ca61456f7cf0a319d5736c553cde19 Disclaimer: Considering how simple this is in relation to what is proposed at the top I may be missing something fundamental (please point it out if that is the case) but it's working for my needs 🤷🏽♂️. Please test more thoroughly than my two asserts before putting into prod. |
@GCBallesteros without looking too closely at your code, I think you're benefiting from API extensions added after this issue has been filed. The fact that this is now easy to do and the solution composes with rest of Pytoroch's data utilities shows this issue has been addressed and can be closed. |
🚀 Feature
Add a mode to
Dataset
that enables fetching data in batches instead of item-by-item.Motivation
If model training takes relatively small individual examples as an input, like in the case of training on tabular data, the python interpreter overhead of fetching data becomes so large that hinders training performance. In other words, training becomes CPU-bound (even with multiprocessing enabled).
This came up in a real scenario of the StarSpace model from FAIR.
Pitch
Add an optional
__getbatch__
method to theDataset
that's analogous to__getitem__
but takes a collection of indices as an input. Make theDataloader
aware ofBatchedDataset
. Once theDataloader
recognizes that the__getbatch__
is present, that method is used for fetching data, one batch at the time.As a result, the user receives an ability to pass data in batch end-to-end and avoid the high cost (per byte read) of python interpreter.
I implemented a variant of batch loading for aforementioned StarSpace model and got the training down from 5.5 days to under 24 hours. The person who originally implemented it used standard PyTorch data loading abstractions and fall into the trap of low performance.
This is a type of issue anybody working on e.g. tabular data will be running into. Unfortunately, there's no natural way out given current PyTorch abstractions.
Alternatives
Implement this on top of existing abstractions by "smuggling" batches values wrapped as a single value and unwrapping them in a custom collate function. The code, that I provide below, is fairly subtle and a bit hacky (abusing current abstractions). The code is fully functional and used in production, though.
Edit: I found also this: #19228 which a different way of implementing what I need. The downside of IterableDataset is that it essentially throws through the window the nice decomposition into Dataset, Sampler and Dataloader. Suddenly, you're responsible for implementing all of the logic. Having said that, this is a big improvement over my rather hacky solution I posted below.
cc @ssnl
The text was updated successfully, but these errors were encountered: