-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ChunkDataset API proposal #26547
ChunkDataset API proposal #26547
Conversation
Ping.. |
Thanks for the PR! Could you add some tests and usage examples as part of the docstrings? EDIT: We still need to take a closer look at the design and API. |
Linking to #24915 |
ChunkDataset is also useful when the dataset is out-of-memory. |
(second level of sampling) (default: `True`) | ||
""" | ||
|
||
def __init__(self, chunk_sampler, chunk_reader, shuffle_cache=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cpuhrsch We still need to think about the API design. Not sure it's conventional to have an Sampler
as the argument in the constructor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @zhangguanheng66 and thanks for the feedback!
Your concern makes total sense. The decision for adding the sampler at the ChunkDataset was due to some factors:
- When using
IterableDataset
, PytorchDataLoader
doesnt use any user-provided sampler. Instead, it creates a hidden_InfiniteConstantSampler
that always returnNone
. With that approach, we dont have support fromDataLoader
to help coordinating dataset fragmentation between workers. It is up to theIterableDataset
, this is where ourDistributedChunkSampler
played its role. From theIterableDataset
docstring, we have the following:
each worker process will have a different copy of the dataset object, so it is often desired to configure each copy independently to avoid having duplicate data returned from the workers
It also suggests using
dataset's :meth:
__iter__
method or the :class:~torch.utils.data.DataLoader
's :attr:worker_init_fn
option to modify each copy's behavior
but this methos on their own don't coordination between workers.
- Also, as you mentioned before,
ChunkDataset
can be used to prevent out-of-memory issues by splitting dataset in several smaller pieces and caching only a few at a time. As shown by the backing article, good randomization is kept and it is very close to randomizing the whole dataset and just returning a batch from it. The advantage on this approach is that different sampling can be implemented and used onChunkDataset
without changing the API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @thiagocrepaldi that a sampler is needed in here. ChunkDataset
is effectively a transformer on the chunk_reader
(which should really be a regular dataset and not a dedicated subclass). Also, iterable datasets should never have a reset
method! At each call to __iter__
you should create a completely fresh class with a new instance of the sampler (possibly shuffled), and with a new iteration state.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree we could move the self._chunk_sampler_iter = iter(self.chunk_sampler)
from reset to def __iter__
. However, as the sampler
is needed inside the ChunkDataset
, we need a way to inject epoch
into the dataset's sampler (users may wish to set a new seed every epoch). Today this is done through reset
, but we can rename it to set_epoch
and move the iterator part to def __iter__()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this need information about epoch
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be easier to pass that order as an order (i.e. as a list or such) before construction?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cpuhrsch That is a good suggestion, but this requires the user know about the number of epochs before hand, which is not a bad thing either. Typically, in training the entire dataset is consumed in an epoch, but for large datasets an epoch could be only a part of the whole dataset (e.g. one few gigabyte file is an epoch and the dataset is many such files). In either case, I think this is a good suggestion and perhaps we can get rid of the reset method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although it works, it seems a bit odd because none of Pytorch's Sampler
s or Dataset
s know how many epochs they are going to be used for during initialization. That is more of a DataLoader
domain.
Resetting sampler state after each epoch, on the other hand, is something that we already do during training.
By renaming reset
to set_epoch
, the training loop would change from a typical (before this PR)
for epoch in range(num_epochs):
train_sampler.set_epoch(epoch)
to
for epoch in range(num_epochs):
train_dataset.set_epoch(epoch)
which should look familiar to the user
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, you can use worker_init_fn for DataLoader to set the seed for each worker initially.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
worker_init_fn
runs just once, so it wouldnt solve the issue of changing seed for every epoch
Yes, that is one of the good features. We can handle huge datasets without loading everything in memory, but keeping good randomization anyway |
Thanks for the feedback. The main reason that I didnt add unit tests and examples in the docstring was the fact that API changes would cause many changes just to keep examples/unit test working. But I agree an usage example is very useful for understanding the API, so I will first add as a comment here and a full example in the repo next. |
@cpuhrsch Below there is an example on how # Loading the training data
num_replicas = size * max(1, args.num_workers) # num_replicas must be >= 1
train_reader = MNISTCSVChunkDataReader(chunk_files=train_csv_files)
train_chunk_sampler = torch.utils.data.DistributedChunkSampler(rank=rank,
num_replicas=num_replicas,
num_chunks=len(train_csv_files),
shuffle=args.shuffle)
train_dataset = torch.utils.data.ChunkDataset(chunk_sampler=train_chunk_sampler,
chunk_reader=train_reader,
shuffle_cache=args.shuffle)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=args.train_batch_size,
collate_fn=collate_fn,
num_workers=args.num_workers,
**kwargs)
# Loading the test data
test_reader = MNISTCSVChunkDataReader(chunk_files=test_csv_files)
test_chunk_sampler = torch.utils.data.DistributedChunkSampler(rank=rank,
num_replicas=num_replicas,
num_chunks=len(test_csv_files),
shuffle=args.shuffle)
test_dataset = torch.utils.data.ChunkDataset(chunk_sampler=test_chunk_sampler,
chunk_reader=test_reader,
shuffle_cache=args.shuffle)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=args.test_batch_size,
collate_fn=collate_fn,
num_workers=args.num_workers,
**kwargs)
for epoch in range(1, args.epochs + 1):
# Reset before each epoch
train_dataset.reset(epoch)
train(args, model, device, train_loader, optimizer, epoch, ...)
test(args, model, device, test_loader, ...) By using pandas, we could implement class MNISTCSVChunkDataReader(torch.utils.data.ChunkDataReader):
r"""Reads chunk of MNIST CSV data for the specified chunk index."""
def __init__(self, chunk_files):
super(MNISTCSVChunkDataReader, self).__init__()
assert isinstance(chunk_files, list), 'chunk_files must be a `list`'
assert len(chunk_files) > 0, 'chunk_files must contain at least one chunk'
self.chunk_files = chunk_files
def __call__(self, index):
r"""
Returns a `tuple(data, target)` or `None`, where
`data` and `target` are `numpy.ndarray((batch_size, actual_data))`
"""
assert isinstance(index, int), 'index must be a `int`'
assert index < len(
self.chunk_files), 'index must be < `len(chunk_files)`'
csv = pd.read_csv(self.chunk_files[index])
data = csv.loc[:, csv.columns != "label"].values.astype(np.uint8).reshape(-1, 28, 28)
target = csv.label.values.astype(np.uint8)
return list(zip(data, target)) |
Ping |
@@ -65,3 +65,112 @@ def __len__(self): | |||
|
|||
def set_epoch(self, epoch): | |||
self.epoch = epoch | |||
|
|||
class ChunkDataReader(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We really shouldn't have ChunkDataReader
s. This class is the same thing as an iterable dataset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This implementation hides (iterable) dataset implementation from the user along with coordinating data loading amongst workers (including multiple nodes). This PRimplements ChunkDataset
so that future users forget about it. All the user will have to do is implementing a data format specific function (ChunkDataReader
) that can handle chunks
, wherever chunks are in their scenario. Current IterableDataset
works perfectly for non-distributed scenarios, but requires lots of boiler plate to accomplish the distributed version. That is the extension we are proposing: makes IterableDataset
easier/more powerful for distributed training with less coding. This comment tries to clarify how the API would be used.
In the scenario that you are describing (removing chunk data reader) and using iterable dataset, how do you propose each worker would read, load and shuffle only the slices of dataset they strictly need?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You said that IterableDataset
would be harder to deal with, but I don't fully understand why? What's so special about this class? How do you handle replication of the state in a distributed fashion that you couldn't reproduce with a regular dataset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh sorry, I didn't notice that it's actually stateless. In that case why couldn't it be a regular map-style dataset?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The original design changed a bit - or at least will change as soon as #28841 lands.
ChunkDataReader
class will be replaced by a new method on ChunkDataset(IterableDataset)
class. This method will be responsible for fetching data for a specific chunk index.
DistributedChunkSampler
is also being eliminated after #28841.
There are some subtle changes that can be found in the discussion below, but the core is the same: creating a ChunkDataset
(which is IterableDataset
) that can be used in distributed training using a distributed sampler and which also saves users from boiler plate code to handle dataset (with unknown size) reading in each worker.
torch/utils/data/distributed.py
Outdated
r"""This sampler introduces distributed sampling without padding and dataset dependency. | ||
|
||
This sampler is very similar to the `DistributedSampler`, however without padding and | ||
the dependency on the size of the dataset. With two levels of sampling, the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is this independent of the dataset size? You're still asking for the number of chunks. I don't see why would we need this sampler instead of only having DistributedSampler
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Existing DistributedSampler
depends on len(dataset)
, which is not available on IterableDataset
- when dataset size is known, MapDataset
can be used instead. DistributedSampler
also pads the index list, which is undesired for chunking, as it would result in different workers reading the portion of dataset again.
In the first prototype, we modified DistributedSampler
instead of creating a new sampler, but the final design wasn't good. One of the reasons was that DistributedSampler
would need to have the existing dataset
changed to optional and a new optional argument num_chunks
would be needed to replace len(dataset)
. That leads to a constructor with all parameters being optional, with the undesired behavior of an invalid instance if the user didnt specify either dataset
or num_chunks
. The second issue was that lots of conditionals to disable padding and switching between len(dataset)
and num_chunks
, not only doubling the size the class, but also changing a specialized behavior to a generic one that could confuse the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This still seems to be dependent on the length of the dataset through num_chunks
, no?
- Is
num_chunks
the length of theIterableDataset
? If so, then the latter can be patched to have that as a length, or just changed to aDataset
. - What if
DistributedSampler
simply took thelen(dataset)
as parameter directly? That's the only way the dataset is used withinDistributedSampler
, right? (That would be BC-breaking of course, unless we do something like detecting when a dataset is passed, and save the length instead.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I might be missing here, but the original idea of
IterableDataset
was to represent data 1) without direct indexing of examples and 2) without usinglen(dataset)
, which can cover lots of scenarios where computing length can be expensive (database) or impossible (a stream dataset).IterableDataset
guarantees thatDataLoader
can iterate over it for data and when it is exhausted, an exception will be raised, signalingDataloader
to stop. This aspect is interesting because it lets space for future extensions, such asInfiniteDataset
that would never be able to return its length but still would be compatible withIterableDataset
.
By following this definition, we disabled __len__
for ChunkDataset(IterableDataset)
:
def __len__(self):
# `IterableDataset` classes have unknown dataset size
raise NotImplementedError
In this PR, num_chunks
means how many parts the dataset was divided into, but it doesn't mean how many examples there are in the dataset. num_chunks
is only used by DistributedChunkSampler
while ChunkDataset
keeps blind regarding its length. For MNIST training dataset, for example, there could be 12 chunks (files) with 4615 samples each and a final chunk with 4620 examples.
- If breaking BC is not an issue,
len(dataset)
would definately work. The only other minor issue we would have to address is adding an optional boolean flag that disables padding to make the list of indices evenly divisible. ChangingDistributedSampler
ctor fromdef __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
todef __init__(self, size, num_replicas=None, rank=None, shuffle=True, padding=True):
does the trick. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I might be missing here, but the original idea of
IterableDataset
was to represent data 1) without direct indexing of examples and 2) without usinglen(dataset)
That is correct. I just wanted to make sure the meaning and use of num_chunks
was clear :)
- If breaking BC is not an issue,
len(dataset)
would definately work. The only other minor issue we would have to address is adding an optional boolean flag that disables padding to make the list of indices evenly divisible. ChangingDistributedSampler
ctor fromdef __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
todef __init__(self, size, num_replicas=None, rank=None, shuffle=True, padding=True):
does the trick. What do you think?
Which line would you disable in distributed.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new flag would disable from
pytorch/torch/utils/data/distributed.py
Line 53 in 9705d60
# add extra samples to make it evenly divisible |
pytorch/torch/utils/data/distributed.py
Line 55 in 9705d60
assert len(indices) == self.total_size |
when the user toggles it. Distributed data loading doesnt require dataset chunks to be equally divided between workers, especially if that incurs in reading and applying heavy transformations on duplicate chunks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to be very clear here: By "distributed" do you mean multiple machines or multiple processes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ChunkDataset is flexible to work with any combination among single node + single process, single node + multiple processes and multiple nodes + multiple processes. The latter is the most challenging and the one we wanted to make easier. In fact, the whole idea was that users could start with single node + single process during development and when it is ready for production, with just minor changes, it would scale to the multiple nodes + multiple processes
(second level of sampling) (default: `True`) | ||
""" | ||
|
||
def __init__(self, chunk_sampler, chunk_reader, shuffle_cache=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with @thiagocrepaldi that a sampler is needed in here. ChunkDataset
is effectively a transformer on the chunk_reader
(which should really be a regular dataset and not a dedicated subclass). Also, iterable datasets should never have a reset
method! At each call to __iter__
you should create a completely fresh class with a new instance of the sampler (possibly shuffled), and with a new iteration state.
In this example, we do know the length of the dataset through these csv files, no? I'm assuming this was just a toy example? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This snippet assumes
train_csv_files
andtest_csv_files
are list of files with portions of MNIST in CSV format for training and testIn this example, we do know the length of the dataset through these csv files, no? I'm assuming this was just a toy example?
In this toy example, we do know the size of the dataset, but IterableDataset
can be used in many ways where calculating the size is either impossible or too expensive.
Just to reiterate, breaking a dataset into chunks as proposed in this PR aims to easily distribute data loading in several workers without redundancies and keeping satisfactory randomness instead of 1) loading the entire dataset on each worker and using only a small part of it or 2) letting the user the burden of writing clever scripts to do this themselves.
@thiagocrepaldi - How would you implement this in the simplest way possible without using the Dataset abstraction? From what I understand (based on Algorithm 3 in Section 1 of [1]) that would be (for text files that fit into memory)
chunk_list here could be an IterableDataset and you want to draw from these chunks in a distributed fashion? I assume the main motivation here is to parallelize "process_data" because it is expensive? This might also include the "next(chunk)" call above, if the underlying Iterator has to do some kind of expensive operation to get the next datapoint. Of course there are two types of expensive: 1) IO heavy 2) Compute heavy. I typically expect drawing a new line of text to be IO heavy, hence can easily benefit from multithreading since the underlying system call (or equivalent) suspends execution and waits on some kind of response. I then expect process_data to be compute heavy. It might decode a list of images, apply some transformations (normalization, rotations etc.) and then concatenate them all together to form a batch. Does this capture your situation? [1] http:https://martin.zinkevich.org/publications/nips2010.pdf |
@cpuhrsch, the text example is an overly simplified version. While, your code segment captures the basic idea, the key challenge @thiagocrepaldi is trying to solve is the data loading to support data parallel(DP) training. There are two aspects to consider with any datasets. (1) preloading, (2) transforms. Preloading hides any storage latencies, and having parallelism for transforms is a must for expensive transforms. Current PyTorch dataloader architecture has most of these features and Iterable concept is a good one, but still it is not trivial to use for DP training. |
@jaliyae - By "preloading" I assume you mean readahead using some kind of buffer? Where in this PR is that implemented? This sounds generic enough that we could it split it out via sort of Buffer class that wraps an InterableDataset and uses multithreading to read ahead to fill up a buffer and anticipate future "next"s. |
CC @mrshenli |
@cpuhrsch, yes preloading is buffering or caching. The ChunkDataset has this cache (shuffle_cache) |
As a first measure of simplification: What is the downside of instantiating a list of Datasets instead of having a ChunkDataReader and passing that to ChunkDataset? This would already follow the ideas ChainDataset more closely. The difference between ChunkDataset and ChainDataset then is that we're keeping track of which chunk a datapoint originates from (for purposes of shuffling) as opposed to ChainDataset which would throw them all together linearly. If this is ok, we can get rid of ChunkDataReader. |
If I understood your suggestion right, in this new scenario, users are responsible for doing data reading scheduling for distributed training. They would have to use rank IDs, world_size and worker ID to calculate the proper dataset slice/chunk for each worker and then inject that into ChunkDataset. That is simple for single node with single worker training, but when we scale this to several nodes with several nodes each, things get a bit messy. This PR hides all distributed data reading, by requiring a generic chunk data reader that can return data based on an arbitrary chunk id. |
Sure, no problem! The stateful dataloader feature discussed in that issue shouldnt affect or be affected by this PR. There would be issues if we turned to the other approach that required DataLoader to be destroyed/instantiated every epoch. |
@thiagocrepaldi - One thing that worries me here is that, even though these components are separate, they can't be used in other contents except in this particular combination. That means they're highly coupled. This is something we'd usually want to avoid. In particular, as far as I can tell, this sampler is unique in that it's based to the Dataset and not to the Dataloader. |
@cpuhrsch That makes sense and maybe we can work this out with some changes:
In this new solution, we would introduce a single class (aka What do you think? |
@thiagocrepaldi - that sounds good! We should talk more about your proposed extension of DistributedSampler however. Could you expand a bit on your ideas around expanding the DistributedSample, i.e. the specifications of pad_indices and size? |
Reusing
|
@thiagocrepaldi - to maintain API consistency I'd then add a length keyword. That could also be used for map-style datasets. I'd then also require the user to pass the IterableDataset (of course they could just pass None and nothing would happen) just in case in the future we want to make use of additional properties of the dataset this sampler is used for. The padding keyword also seems like a natural extension. So I'd agree with adding these two flags. Can you create a separate PR that does this and add your reasoning to it in context of this PR? It seems like a pretty small change. I'm assuming once you have those two keywords this PR would significantly reduce in length? |
I can definitely create a separate PR for the |
Current implementation of `DistributedSampler` is ideal for distributed training using map datasets, as they fit in memory and have known size. However, it doesn't support distributed training using `IterableDataset` datasets, as these classes do not implement `__len__`. To fix that, a `length` keyword was added to `DistributedSampler`, which has precedence when set. An extra `padding=True` parameter was also added was give finer control on whether the (returned) index list should be padded by the sampler. This is useful for preventing duplicate reading on `IterableDataset` datasets that do not fit in memory or which data reading or transformation are expensive. Finally, set_rank method was added, similarly the existing `set_epoch`, to ease distributed training. When `DataLoader` is created with `num_workers` > 0 and `dataset` is an instance of `ChunkDataset`, a copy of `DistributedSampler` on each worker needs to be configured with their new rank. There is no back compatibility with this change.
@cpuhrsch The new PR is #28841. Let me know what you think! I can create a new version of this PR with all changes we discussed as soon as the sampler extension is merged |
@fmassa This is an example on how to use ChunkDataset. It might help you understand how (...)
# Loading the training data
num_replicas = size * max(1, args.num_workers)
train_reader = MNISTCSVChunkDataReader(chunk_files=train_csv_files)
train_chunk_sampler = torch.utils.data.DistributedSampler(dataset=None,
rank=rank,
num_replicas=num_replicas,
length=len(train_csv_files),
shuffle=args.shuffle)
train_dataset = torch.utils.data.ChunkDataset(chunk_sampler=train_chunk_sampler,
chunk_reader=train_reader,
shuffle_cache=args.shuffle)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=args.train_batch_size,
collate_fn=collate_fn,
num_workers=args.num_workers,
**kwargs)
# Loading the test data
test_reader = MNISTCSVChunkDataReader(chunk_files=test_csv_files)
test_chunk_sampler = torch.utils.data.DistributedSampler(dataset=None,
rank=rank,
num_replicas=num_replicas,
length=len(test_csv_files),
shuffle=args.shuffle)
test_dataset = torch.utils.data.ChunkDataset(chunk_sampler=test_chunk_sampler,
chunk_reader=test_reader,
shuffle_cache=args.shuffle)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=args.test_batch_size,
collate_fn=collate_fn,
num_workers=args.num_workers,
**kwargs)
# Model and optimizer
(...)
# Train and testing
for epoch in range(1, args.epochs + 1):
# Reset before each epoch
train_dataset.reset(epoch)
test_dataset.reset(epoch)
train(model, train_loader, ...)
test(model, test_loader, ...) By using pandas, we could implement class MNISTCSVChunkDataReader(torch.utils.data.ChunkDataReader):
r"""Reads chunk of MNIST CSV data for the specified chunk index."""
def __init__(self, chunk_files):
super(MNISTCSVChunkDataReader, self).__init__()
assert isinstance(chunk_files, list), 'chunk_files must be a `list`'
assert len(chunk_files) > 0, 'chunk_files must contain at least one chunk'
self.chunk_files = chunk_files
def __call__(self, index):
r"""
Returns a `tuple(data, target)` or `None`, where
`data` and `target` are `numpy.ndarray((batch_size, actual_data))`
"""
assert isinstance(index, int), 'index must be a `int`'
assert index < len(
self.chunk_files), 'index must be < `len(chunk_files)`'
csv = pd.read_csv(self.chunk_files[index])
data = csv.loc[:, csv.columns != "label"].values.astype(np.uint8).reshape(-1, 28, 28)
target = csv.label.values.astype(np.uint8)
return list(zip(data, target)) |
8362a13
to
203f43e
Compare
Sorry for bumping this. What's the status of this PR? Blocked on #28841? One concern I have is that in the multi-GPU training setup, the sampler is allowed to assign chunks to ranks non-evenly, and besides, number of examples can vary from chunk to chunk (fair to expect and I believe the opposite is never asserted). This means some ranks will have more examples to process than the others. But during training, batches are drawn by all ranks synchronously at the same rate, meaning some ranks will exhaust their shards earlier than the others, reset, and continue to train on the next epoch, while the others are still training on the current one. So, undesirably, elements from different epochs will be "mixed", and moreover, given fixed number of epochs, the "fast" ranks eventually finish training and exit their processes; the "slow" ranks will be stuck because they still have data to process but allreduce is now impossible. The current |
Any update? |
Closing this for now |
@thiagocrepaldi Is there any further plan to reopen this issue later? Maybe ChunkDataset is a good counterpart of shared tfrecords. |
I wonder is there any other good solution for handling iterable datasets in the DDP scenario? |
@summelon maybe u can check this RFC. Currenly u can use webdataset as an alternative. |
ChunkDataset API proposal
Problem to be solved
A typical data loading in PyTorch assumes all the data is accessible to every participating process. Randomization is performed by the sampler with the knowledge of total length of the dataset. While this approach is simpler and natural to scenarios such as a directory full of images, it does not map well to situations where a large dataset with unknown size is available in collection of files or a single large file. The global randomization incurs many disk seeks and the user needs to carefully partition data to support distributed training. Manually splitting the data, distribute amongst computing units without duplicates and performing efficient shuffling are not strictly related to training models, but are still important. We often implement similar boiler plate code in different projects, leading to increase in development time.
Proposed solution
The proposed
ChunkDataset
is a stateful dataset that supports hierarchical sampling and efficient reading through chunks. Achunk
, in this context, could be a file, such as audio or image, section of a file in the case of a large text-file, a folder, a URL, or any other abstraction that allows data to be segmented roughly the same size.Unlike regular datasets,
ChunkDataset
implements two levels of sampling, i.e. hierarchical sampling, to operate. In the first level, achunk
is selected based on a sampling strategy and second, a sample is selected from thechunk
using another or similar sampling strategy. The hierarchical sampling approach adopted here provides satisfactory randomness and is inspired by the following paper.By using ChunkDataset API, tasks such as splitting data between computing units with proper randomization become trivial. All user has to do is to provide a
ChunkDataReader
implementation that reads a chunk, instantiate aDistributedChunkSampler
with the desired shuffling strategy and finally putting all together in aChunkDataSet
instance. Once this dataset is passed to PyTorchDataLoader
, every worker will learn its correct rank, reads their pieces of data and continue on the regularDataloader
flow.Brief discussion on API
ChunkDataReader class
In order to perform reading of a particular chunk chosen by
DistributedChunkSampler
, the user has to implement a reader class that extendsChunkDataReader
:DistributedChunkSampler class
DistributedChunkSampler
is already implemented and the user only needs to instantiate it and inject intoChunkDataset
.Similarly to
DistributedSampler
,DistributedChunkSampler
takes :attr:num_replicas, :attr:rank
and :attr:shuffle
on its constructor to specify the number of processes participating in the distributed training, the current rank of a process and the shuffling strategy. One main difference between two samplers is that becauseDistributedChunkSampler
operates onIterableDataset
with unknown size, it takes :attr:num_chunks
as input to draw indices as opposed toDistributedSampler
:attr:dataset
parameter. Another important difference between both samplers is thatDistributedSampler
performs padding on its generated indices, which can't be done for chunks to prevent duplicate reading on different workers.The
DistributedChunkSampler
public API is:ChunkDataset class
ChunkDataset
is already implemented and the user only needs to instantiate it and inject into PyTorchDataLoader
.As mentioned before,
ChunkDataset
is anIterableDataset
implementation, which focus on representing a dataset with unknown size. Once it is passed in to PyTorchDataLoader
, it iterates over the dataset until it is exhausted. At this point, an exception is raised and reading is gracefully finished.ChunkDataset
must bereset
after each epoch to reset the internal state of the sampler and to optionally improve shuffling by injectingepoch
.The
ChunkDataset
public API is:This PR builds on the original C++ ChunkDataset API and IterableDataset