Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunk Dataset API #26545

Closed

Conversation

thiagocrepaldi
Copy link
Collaborator

ChunkDataset API proposal

Problem to be solved

A typical data loading in PyTorch assumes all the data is accessible to every participating process. Randomization is performed by the sampler with the knowledge of total length of the dataset. While this approach is simpler and natural to scenarios such as a directory full of images, it does not map well to situations where a large dataset with unknown size is available in collection of files or a single large file. The global randomization incurs many disk seeks and the user needs to carefully partition data to support distributed training. Manually splitting the data, distribute amongst computing units without duplicates and performing efficient shuffling are not strictly related to training models, but are still important. We often implement similar boiler plate code in different projects, leading to increase in development time.

Proposed solution

The proposed ChunkDataset is a stateful dataset that supports hierarchical sampling and efficient reading through chunks. A chunk, in this context, could be a file, such as audio or image, section of a file in the case of a large text-file, a folder, a URL, or any other abstraction that allows data to be segmented roughly the same size.

Unlike regular datasets, ChunkDataset implements two levels of sampling, i.e. hierarchical sampling, to operate. In the first level, a chunk is selected based on a sampling strategy and second, a sample is selected from the chunk using another or similar sampling strategy. The hierarchical sampling approach adopted here provides satisfactory randomness and is inspired by the following paper.

By using ChunkDataset API, tasks such as splitting data between computing units with proper randomization become trivial. All user has to do is to provide a ChunkDataReader implementation that reads a chunk, instantiate a DistributedChunkSampler with the desired shuffling strategy and finally putting all together in a ChunkDataSet instance. Once this dataset is passed to PyTorch DataLoader, every worker will learn its correct rank, reads their pieces of data and continue on the regular Dataloader flow.

Brief discussion on API

ChunkDataReader class

In order to perform reading of a particular chunk chosen by DistributedChunkSampler, the user has to implement a reader class that extends ChunkDataReader:

class ChunkDataReader(object):
    def __init__(self):
        r"""The reader is initialized here"""

    def __call__(self, index):
        r"""Returns `list(samples)` for a given :attr:`index"""

DistributedChunkSampler class

DistributedChunkSampler is already implemented and the user only needs to instantiate it and inject into ChunkDataset.

Similarly to DistributedSampler, DistributedChunkSampler takes :attr:num_replicas, :attr:rank and :attr:shuffle on its constructor to specify the number of processes participating in the distributed training, the current rank of a process and the shuffling strategy. One main difference between two samplers is that because DistributedChunkSampler operates on IterableDataset with unknown size, it takes :attr:num_chunks as input to draw indices as opposed to DistributedSampler :attr:dataset parameter. Another important difference between both samplers is that DistributedSampler performs padding on its generated indices, which can't be done for chunks to prevent duplicate reading on different workers.

The DistributedChunkSampler public API is:

class DistributedChunkSampler(Sampler):
    def __init__(self, num_replicas, rank=0, num_chunks=0, shuffle=False):
        r"""Returns a new DistributedChunkSampler instance"""

    def set_rank(self, rank):
        r"""Set rank for the current sampler instance"""

    def set_epoch(self, epoch):
        r"""Set epoch for the current sampler instance"""

ChunkDataset class

ChunkDataset is already implemented and the user only needs to instantiate it and inject into PyTorch DataLoader.

As mentioned before, ChunkDataset is an IterableDataset implementation, which focus on representing a dataset with unknown size. Once it is passed in to PyTorch DataLoader, it iterates over the dataset until it is exhausted. At this point, an exception is raised and reading is gracefully finished.

ChunkDataset must be reset after each epoch to reset the internal state of the sampler and to optionally improve shuffling by injecting epoch.

The ChunkDataset public API is:

class ChunkDataset(IterableDataset):
    r"""Dataset which uses hierarchical sampling"""

    def __init__(self, chunk_sampler, chunk_reader, shuffle_cache=True):
        r"""Returns a new ChunkDataset instance"""

    def __iter__(self):
        r"""Returns an Iterator for batching"""

    def __next__(self):
        r"""Returns the next value in the Iterator"""

    def reset(self, epoch=None):
        r"""Resets internal state of ChunkDataset"""

ps: This PR builds on IterableDataset and the original C++ implementation for ChunkDataset API

@pytorchbot pytorchbot added module: dataloader Related to torch.utils.data.DataLoader and Sampler module: typing Related to mypy type annotations labels Sep 20, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dataloader Related to torch.utils.data.DataLoader and Sampler module: typing Related to mypy type annotations
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants