-
Notifications
You must be signed in to change notification settings - Fork 21.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add length and padding keyworks to DistributedSampler #28841
Add length and padding keyworks to DistributedSampler #28841
Conversation
Current implementation of `DistributedSampler` is ideal for distributed training using map datasets, as they fit in memory and have known size. However, it doesn't support distributed training using `IterableDataset` datasets, as these classes do not implement `__len__`. To fix that, a `length` keyword was added to `DistributedSampler`, which has precedence when set. An extra `padding=True` parameter was also added was give finer control on whether the (returned) index list should be padded by the sampler. This is useful for preventing duplicate reading on `IterableDataset` datasets that do not fit in memory or which data reading or transformation are expensive. Finally, set_rank method was added, similarly the existing `set_epoch`, to ease distributed training. When `DataLoader` is created with `num_workers` > 0 and `dataset` is an instance of `ChunkDataset`, a copy of `DistributedSampler` on each worker needs to be configured with their new rank. There is no back compatibility with this change.
Out of curiosity, are the |
Manually :) |
Thanks for splitting this out! |
@cpuhrsch Did you have a chance to check my last replies regarding using |
ping @fmassa :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
Sorry for the delay in reviewing.
I don't think that DistributedSampler
should necessarily work with IterableDataset
. The indices returned by the DistributedSampler
are by definition useless to the general IterableDataset
, and trying to use it to fit a particular use-case would add unnecessary constraints to IterableDataset
.
My understanding is that this PR is trying to accomplish two things:
- add some meta-information to the
DistributedSampler
to work onIterableDatasets
- (potentially) simplify the use of
IterableDatasets
in distributed mode (not present in this PR).
The concept of a sampler is not valid for IterableDataset
in general -- we have no guarantees on the order of the examples that will be returned. I believe we should keep this as is in general.
Does this mean that no IterableDataset
can know ahead of time of its iteration order? No, they can, but this is a special case, and should be handled by the application. I don't think this should live in PyTorch.
But then, how to make it easier for users to write their own IterableDataset
that works on distributed?
What are the things we need to keep in mind in DDP in this case?
sampler.set_epoch
, so that we can split the dataset between different machines- in particular, we need to handle the
idx
of the worker ourselves.
How can we accomplish both, without having to change the APIs of DistributedSampler
nor DataLoader
?
Here is one example.
class MyIterableDataset(IterableDataset):
def __init__(self):
# each dataset, when constructed, know its rank
# and they are constructed once per (GPU) process
self.rank = dist.get_rank()
# how many are we?
self.world_size = dist.get_world_size()
# have a counter on how many epochs have passed
# can be incremented by get_chunk_iterator
# if the user wants different shuffle per epochs
self.epoch = 0
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
chunk_id = self.rank
total_chunks = self.world_size
if worker_info is not None:
chunk_id = self.rank * worker_info.num_workers + worker_info.id
total_chunks = self.world_size * worker_info.num_workers
# now return the chunks in the dataset accordingly
return iter(self.get_chunk_iterator(chunk_id, total_chunks))
def get_chunk_iterator(self, chunk_id, total_chunks):
# user implements this
pass
Then, all the logic specific to your application stays restricted to your Dataset implementation, via the get_chunk_iterator
, which can handle buffering / shuffling / etc, and can also increment the epoch counter if the user wants.
I might be missing something, but let me know if the above implementation doesn't address your use-cases.
@fmassa - samplers are definitely relevant to IterableDatasets - there are various sampling techniques that apply to streams such as https://en.wikipedia.org/wiki/Reservoir_sampling |
@cpuhrsch I believe this should be an implementation detail of the specialization of the |
Thanks for the comment, Francisco! I agree that snippet work in some cases, but note the amount of boiler plate that was needed just to enable distributed training. More code would be needed for cases where As @cpuhrsch mentioned, there are several use cases for sampling within |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
After reading through the thread, I am still a bit confused how this helps distributed sampler for IterableDataset
s. Is this supposed to directly work already, or will there be follow-up patches?
The current PyTorch Sampler
is different from the "sampling" in methods like reservoir sampling, in the sense that it only specifies the "indices" to sampler, rather than whether a sample should be kept. Due to this, in DataLoader, we currently disallow using IterableDataset
with a custom sampler
or batch_sampler
. An example snippet would be really useful!
Finally, this needs a test to get in.
""" | ||
|
||
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): | ||
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, length=None, padding=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IterableDataset
can still implement __len__
. I think it makes more sense for this sampler to assume that the dataset has __len__
than having an explicit input argument.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ssnl if the dataset can implement len
then it makes more sense to extend Dataset
as opposed to IterableDataset
This PR tries to add the ability to do distributed training when the number of samples are unknown. IterableDataset
allows this concept, but for distributed training, the sampler needs some hints on how many chunks this unknown number of samples are split at
Closing this for now |
Hi def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset): could you assist me how I can implement it for my case? |
Current implementation of
DistributedSampler
is ideal for distributedtraining using map datasets, as they fit in memory and have known size.
However, it doesn't support distributed training using
IterableDataset
datasets, as these classes do not implement
__len__
.To fix that, a
length
keyword was added toDistributedSampler
, whichhas precedence when set.
An extra
padding=True
parameter was also added was give finer controlon whether the (returned) index list should be padded by the sampler.
This is useful for preventing duplicate reading on
IterableDataset
datasets that do not fit in memory or which data reading or transformation
are expensive.
Finally,
set_rank
method was added to ease distributed training, allowingDataLoader
's worker processes to register themselves onDistributedSampler
instances throughworker_init_fn
method. This is useful when worker processes want to change the sampling behavior based on not only in the process rank, but also on their worker ID. IterableDataset documentation mentions this issue, but the examples on it only handle dataset with known size, which is not always the case forIterableDataset
dataset.There is no back compatibility with this change.