Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched Dataloader #26957

Open
gkossakowski opened this issue Sep 27, 2019 · 9 comments
Open

Batched Dataloader #26957

gkossakowski opened this issue Sep 27, 2019 · 9 comments
Labels
feature A request for a proper, new feature. module: dataloader Related to torch.utils.data.DataLoader and Sampler triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@gkossakowski
Copy link

gkossakowski commented Sep 27, 2019

🚀 Feature

Add a mode to Dataset that enables fetching data in batches instead of item-by-item.

Motivation

If model training takes relatively small individual examples as an input, like in the case of training on tabular data, the python interpreter overhead of fetching data becomes so large that hinders training performance. In other words, training becomes CPU-bound (even with multiprocessing enabled).

This came up in a real scenario of the StarSpace model from FAIR.

Pitch

Add an optional __getbatch__ method to the Dataset that's analogous to __getitem__ but takes a collection of indices as an input. Make the Dataloader aware of BatchedDataset. Once the Dataloader recognizes that the __getbatch__ is present, that method is used for fetching data, one batch at the time.

As a result, the user receives an ability to pass data in batch end-to-end and avoid the high cost (per byte read) of python interpreter.

I implemented a variant of batch loading for aforementioned StarSpace model and got the training down from 5.5 days to under 24 hours. The person who originally implemented it used standard PyTorch data loading abstractions and fall into the trap of low performance.

This is a type of issue anybody working on e.g. tabular data will be running into. Unfortunately, there's no natural way out given current PyTorch abstractions.

Alternatives

Implement this on top of existing abstractions by "smuggling" batches values wrapped as a single value and unwrapping them in a custom collate function. The code, that I provide below, is fairly subtle and a bit hacky (abusing current abstractions). The code is fully functional and used in production, though.

Edit: I found also this: #19228 which a different way of implementing what I need. The downside of IterableDataset is that it essentially throws through the window the nice decomposition into Dataset, Sampler and Dataloader. Suddenly, you're responsible for implementing all of the logic. Having said that, this is a big improvement over my rather hacky solution I posted below.

cc @ssnl

@gkossakowski
Copy link
Author

gkossakowski commented Sep 27, 2019

An implementation of BatchedData and surrounding classes that smuggles batches through regular PyTorch abstractions:

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import Sampler
from collections import namedtuple
from itertools import islice
import abc
import math

# BatchedSample and BatchedData below are wrappers
# that we use to enforce various invariants throughout
# the code and shield against unwanted PyTorch's behaviors
# we want to smuggle a batch of values as a single value and
# remain in control of unwrapping
# if we were passing "naked" collections, we risk stepping into
# PyTorch's logic that handles batch sampling but singleton fetching
# of values from a dataset in PyTorch's DataLoader class
# the end result of unwrapping too early (by DataLoader class) is a
# crash due to dimensions of tensors not lining up
# yes, this is a delicate point but we avoid all traps by hiding the
# fact that we're passing batches of indices/values by wrapping them

BatchedSample = namedtuple('BatchedSample', ['payload'])

BatchedData = namedtuple('BatchedData', ['payload'])


class BatchedDataset(Dataset, object):
    """
    A specialized variant of a Dataset that expects a batch of indices
    passed in instead of a single numerical index.

    It has one abstract method `_get_batch` that receives unwrapped indices
    and is expected to be implemented by concrete subclasses of this class
    that return specific data in batches.
    """

    __metaclass__ = abc.ABCMeta

    def __init__(self, batch_size, unbatched_length):
        self.batch_size = batch_size
        n_frac = unbatched_length / float(batch_size)
        batched_length = int(math.ceil(n_frac))
        self.batched_length = batched_length

    def __len__(self):
        return self.batched_length

    def __getitem__(self, batched_indices):
        if not isinstance(batched_indices, BatchedSample):
            raise TypeError('Parameter `batched_indices` must be of type \
                BatchedSample, got %s' % (type(batched_indices)))
        raw_indices = batched_indices.payload
        batch = self._get_batch(raw_indices)
        return BatchedData(batch)

    @abc.abstractmethod
    def _get_batch(self, indices):
        pass


def split_every(n, iterable):
    i = iter(iterable)
    piece = list(islice(i, n))
    while piece:
        yield piece
        piece = list(islice(i, n))


class BatchingSamplerWrapper(Sampler):
    """
    A class that takes a regular data sampler and turns it into
    one that returns sampled indices in batches. This class is
    useful for testing and quick prototyping.

    For the actual sampler implementation, you should implement
    batched sampling directly to avoid creating a performance
    bottleneck in sampling codepath.
    """

    def __init__(self, orig_sampler, batch_size):
        self.orig_sampler = orig_sampler
        self.batch_size = batch_size

    def __iter__(self):
        raw_batched_samples = split_every(self.batch_size, iter(self.orig_sampler))
        for batch in raw_batched_samples:
            yield BatchedSample(batch)

    def __len__(self):
        n_frac = len(self.orig_sampler) / float(self.batch_size)
        return int(math.ceil(n_frac))


def dataloader_for_batched_dataset(batched_dataset, pin_memory=True, sampler=None,
                                   batch_sampler=None):
    """
    This method ties three pieces together:
      1. Sampler that is returning a batch of indices wrapped in BatchedSample
      2. BatchedDataset that unwraps BatchedSample, fetches requested pieces of
         data and returns them wrapped in BatchedData
      3. `unwrap_batch_collate_fn` collate function (defined below) that unwraps
         data from `BatchedData` and returns it

    This process is orchestrated by PyTorch's original Dataloader class.
    As a result, we smuggle batches of data as single values through PyTorch's
    APIs that are designed for single values instead of batches.

    Parameters
    ----------
    batched_dataset: BatchedDataset
    pin_memory: boolean, default True
        Parameter passed to torch.utils.data.DataLoader
    sampler: Sampler
        An instance of Sampler (regular, unbatched)
    batch_sampler: Sampler
        An instance of Sampler that returns batches of indices wrapped in BatchedSample
    """

    def unwrap_batch_collate_fn(wrapped_batch):
        # gkk: where is my pattern matching?? :sob:
        if not isinstance(wrapped_batch, list):
            raise TypeError('Parameter `wrapped_batch` must be of type \
                BatchedSample, got %s' % (type(wrapped_batch)))
        if not len(wrapped_batch) == 1:
            raise ValueError('`wrapped_batch` should have length 1 but has \
                 %d ' % (len(wrapped_batch)))
        wrapped_batch = wrapped_batch[0]
        if not isinstance(wrapped_batch, BatchedData):
            raise TypeError('Parameter `wrapped_batch` must be of type \
                BatchedData, got %s' % (type(wrapped_batch)))
        collated = wrapped_batch.payload
        return collated

    if not isinstance(batched_dataset, BatchedDataset):
        raise TypeError('Parameter `batched_dataset` must be of type \
            BatchedDataset, got %s' % (type(batched_dataset)))

    assert sampler or batch_sampler, 'Either `sampler` or `batch_sampler` must be specified'
    assert not (sampler and batch_sampler), \
        "sampler and batch_sampler can't be specified at the same time"

    batch_size = batched_dataset.batch_size
    # not sure if `downstream` is the best word but couldn't think of a better one
    if sampler:
        downstream_batch_sampler = BatchingSamplerWrapper(sampler, batch_size)
    else:
        downstream_batch_sampler = batch_sampler

    loader = DataLoader(batched_dataset, sampler=downstream_batch_sampler,
                        collate_fn=unwrap_batch_collate_fn, pin_memory=pin_memory)
    return loader

A simple unit test for the above:

class ConstDataset(Dataset):
    """Serves passed data as is."""

    def __init__(self, data):
        super(ConstDataset, self).__init__()
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

def test_batching_wrappers():
    """
    Checks if the bevaior of the triple: batched dataset, batched sampler and
    the logic in dataloader_for_batched_dataset that ties the two together
    ends up loading the same data as a regular (unbatched) dataloading.
    """
    batch_size = 3
    descending = list(reversed(range(10)))
    dataset = ConstDataset(map(torch.tensor, descending))

    class ConstBatchDataset(BatchedDataset):
        def __init__(self, data, batch_size):
            unbatched_length = len(data)
            super(ConstBatchDataset, self).__init__(batch_size, unbatched_length)
            self.data = data

        def _get_batch(self, indices):
            print('indices are %s' % indices)
            batch = torch.take(self.data, torch.tensor(indices))
            print('batch is %s' % batch)
            return batch

    batch_dataset = ConstBatchDataset(torch.tensor(descending), batch_size)

    sampler = SequentialSampler(dataset)

    loader = dataloader_for_batched_dataset(batch_dataset, sampler=sampler)

    sampled_data = list(iter(loader))

    regular_loader = DataLoader(dataset, batch_size=batch_size)
    sampled_data2 = list(iter(regular_loader))
    print(sampled_data)
    print(sampled_data2)

    # pytorch doesn't support standard comparison for tensors so we have
    # to iterate over elements of a list individually and call torch.equal
    # on each one
    assert len(sampled_data) == len(sampled_data2)
    for x, y in zip(sampled_data, sampled_data2):
        assert torch.equal(x, y)

@pbelevich pbelevich added module: dataloader Related to torch.utils.data.DataLoader and Sampler feature A request for a proper, new feature. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Sep 27, 2019
@ssnl
Copy link
Collaborator

ssnl commented Sep 27, 2019

Actually batch loading is supported if you set batch_size=None. See https://pytorch.org/docs/stable/data.html#disable-automatic-batching for details.

@gkossakowski
Copy link
Author

gkossakowski commented Sep 27, 2019

@ssnl yes, the IterableDataset helps, thanks for implementing it. I ran into this issue around 1.0 Pytorch release when this abstraction didn't exist and only now I got around to filing a ticket.

After filing a ticket I saw your PR and dropped this comment:

Edit: I found also this: #19228 which a different way of implementing what I need. The downside of IterableDataset is that it essentially throws through the window the nice decomposition into Dataset, Sampler and Dataloader. Suddenly, you're responsible for implementing all of the logic. Having said that, this is a big improvement over my rather hacky solution I posted below.

Do you see "falling of the utilities cliff" and being on your own when switching to IterableDataset a problem?

@ssnl
Copy link
Collaborator

ssnl commented Oct 1, 2019

Sorry for not being clear. I actually did not mean IterableDataset, but setting batch_size=None and using a sampler that yields a collection of indices at a time. In that way, your dataset.__getitem__ will receive a collection of indices, and the collate_fn will only convert np arrays to tensors (no collating anything into batches). This behavior seems aligning with what you are proposing. The link https://pytorch.org/docs/stable/data.html#disable-automatic-batching should have more details.

@gkossakowski
Copy link
Author

@ssnl, ah, I misunderstood you. I wrongly assumed you talk about #19228 and IterableDataset. The batch_size=None + yield of a collection of indices is interesting option. Let me check if that works for me some time later this week and I will go back to you. Aside: the docs make this particular combo of batch_size and sampler settings easy to miss.

@cpuhrsch
Copy link
Contributor

@gkossakowski - were your issues resolved? I'm currently evaluating whether we need a wider support for more complex data pipelines via some set of new low-level abstractions.

@sampepose
Copy link
Contributor

@ssnl, how would I disable automatic batching but use a BatchSampler?

Use-case:

  • Infinite stream of data for training, shuffled
  • Training in minibatches of size N
  • Minibatch data is grouped (for example 2 groups for images: portrait and landscape)
  • Need to prefetch an entire minibatch of data from DB backend at once (this is a minibatch that has ALREADY been grouped)

I'm using a BatchSampler now to handle the grouping (portrait vs landscape). It just keeps iterating and caching indices in-memory until it can return a full group. See https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/samplers/grouped_batch_sampler.py

@GCBallesteros
Copy link

Using the ideas described by @ssnl above I managed to very simply put together a batched dataloader, no custom classes classes or "smuggling" of values required. Hopefully this is useful to those that end up here looking for a ready made solution.

https://gist.github.com/GCBallesteros/05ca61456f7cf0a319d5736c553cde19

Disclaimer: Considering how simple this is in relation to what is proposed at the top I may be missing something fundamental (please point it out if that is the case) but it's working for my needs 🤷🏽‍♂️. Please test more thoroughly than my two asserts before putting into prod.

@gkossakowski
Copy link
Author

@GCBallesteros without looking too closely at your code, I think you're benefiting from API extensions added after this issue has been filed. The fact that this is now easy to do and the solution composes with rest of Pytoroch's data utilities shows this issue has been addressed and can be closed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: dataloader Related to torch.utils.data.DataLoader and Sampler triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants