Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add length and padding keyworks to DistributedSampler #28841

Conversation

thiagocrepaldi
Copy link
Collaborator

@thiagocrepaldi thiagocrepaldi commented Oct 29, 2019

Current implementation of DistributedSampler is ideal for distributed
training using map datasets, as they fit in memory and have known size.
However, it doesn't support distributed training using IterableDataset
datasets, as these classes do not implement __len__.
To fix that, a length keyword was added to DistributedSampler, which
has precedence when set.

An extra padding=True parameter was also added was give finer control
on whether the (returned) index list should be padded by the sampler.
This is useful for preventing duplicate reading on IterableDataset
datasets that do not fit in memory or which data reading or transformation
are expensive.

Finally, set_rank method was added to ease distributed training, allowing DataLoader's worker processes to register themselves on DistributedSampler instances through worker_init_fn method. This is useful when worker processes want to change the sampling behavior based on not only in the process rank, but also on their worker ID. IterableDataset documentation mentions this issue, but the examples on it only handle dataset with known size, which is not always the case for IterableDataset dataset.

There is no back compatibility with this change.

Current implementation of `DistributedSampler` is ideal for distributed
training using map datasets, as they fit in memory and have known size.
However, it doesn't support distributed training using `IterableDataset`
datasets, as these classes do not implement `__len__`.
To fix that, a `length` keyword was added to `DistributedSampler`, which
has precedence when set.

An extra `padding=True` parameter was also added was give finer control
on whether the (returned) index list should be padded by the sampler.
This is useful for preventing duplicate reading on `IterableDataset`
datasets that do not fit in memory or which data reading or transformation
are expensive.

Finally, set_rank method was added, similarly the existing `set_epoch`,
to ease distributed training. When `DataLoader` is created with
`num_workers` > 0 and `dataset` is an instance of `ChunkDataset`,
a copy of `DistributedSampler` on each worker needs to be configured
with their new rank.

There is no back compatibility with this change.
@vincentqb
Copy link
Contributor

CC @cpuhrsch for #26547.

@vincentqb vincentqb added the module: dataloader Related to torch.utils.data.DataLoader and Sampler label Oct 29, 2019
@thiagocrepaldi
Copy link
Collaborator Author

Out of curiosity, are the *.pyi files automatically generated? If so, which tool is used?

@vincentqb
Copy link
Contributor

Out of curiosity, are the *.pyi files automatically generated? If so, which tool is used?

Manually :)

@cpuhrsch
Copy link
Contributor

Thanks for splitting this out!

torch/utils/data/distributed.py Show resolved Hide resolved
torch/utils/data/distributed.py Outdated Show resolved Hide resolved
@thiagocrepaldi
Copy link
Collaborator Author

@cpuhrsch Did you have a chance to check my last replies regarding using dataset for both int/Dataset types (or adding a new length keyword) and why set_rank is needed?

@thiagocrepaldi
Copy link
Collaborator Author

ping @fmassa :)

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

Sorry for the delay in reviewing.

I don't think that DistributedSampler should necessarily work with IterableDataset. The indices returned by the DistributedSampler are by definition useless to the general IterableDataset, and trying to use it to fit a particular use-case would add unnecessary constraints to IterableDataset.

My understanding is that this PR is trying to accomplish two things:

  1. add some meta-information to the DistributedSampler to work on IterableDatasets
  2. (potentially) simplify the use of IterableDatasets in distributed mode (not present in this PR).

The concept of a sampler is not valid for IterableDataset in general -- we have no guarantees on the order of the examples that will be returned. I believe we should keep this as is in general.
Does this mean that no IterableDataset can know ahead of time of its iteration order? No, they can, but this is a special case, and should be handled by the application. I don't think this should live in PyTorch.

But then, how to make it easier for users to write their own IterableDataset that works on distributed?

What are the things we need to keep in mind in DDP in this case?

  • sampler.set_epoch, so that we can split the dataset between different machines
  • in particular, we need to handle the idx of the worker ourselves.

How can we accomplish both, without having to change the APIs of DistributedSampler nor DataLoader?

Here is one example.

class MyIterableDataset(IterableDataset):
    def __init__(self):
        # each dataset, when constructed, know its rank
        # and they are constructed once per (GPU) process
        self.rank = dist.get_rank()
        # how many are we?
        self.world_size = dist.get_world_size()
        # have a counter on how many epochs have passed
        # can be incremented by get_chunk_iterator
        # if the user wants different shuffle per epochs
        self.epoch = 0

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        chunk_id = self.rank
        total_chunks = self.world_size
        if worker_info is not None:
            chunk_id = self.rank * worker_info.num_workers + worker_info.id
            total_chunks = self.world_size * worker_info.num_workers

        # now return the chunks in the dataset accordingly
        return iter(self.get_chunk_iterator(chunk_id, total_chunks))

    def get_chunk_iterator(self, chunk_id, total_chunks):
        # user implements this
        pass

Then, all the logic specific to your application stays restricted to your Dataset implementation, via the get_chunk_iterator, which can handle buffering / shuffling / etc, and can also increment the epoch counter if the user wants.

I might be missing something, but let me know if the above implementation doesn't address your use-cases.

@cpuhrsch
Copy link
Contributor

cpuhrsch commented Dec 4, 2019

@fmassa - samplers are definitely relevant to IterableDatasets - there are various sampling techniques that apply to streams such as https://en.wikipedia.org/wiki/Reservoir_sampling

@fmassa
Copy link
Member

fmassa commented Dec 4, 2019

@cpuhrsch I believe this should be an implementation detail of the specialization of the IterableDataset , this doesn't mean that the concept of sampler that we have should support IterableDataset.

@thiagocrepaldi
Copy link
Collaborator Author

thiagocrepaldi commented Dec 4, 2019

Hi,

Sorry for the delay in reviewing.

I don't think that DistributedSampler should necessarily work with IterableDataset. The indices returned by the DistributedSampler are by definition useless to the general IterableDataset, and trying to use it to fit a particular use-case would add unnecessary constraints to IterableDataset.

My understanding is that this PR is trying to accomplish two things:

1. add some meta-information to the `DistributedSampler` to work on `IterableDatasets`

2. (potentially) simplify the use of `IterableDatasets` in distributed mode (not present in this PR).

The concept of a sampler is not valid for IterableDataset in general -- we have no guarantees on the order of the examples that will be returned. I believe we should keep this as is in general.
Does this mean that no IterableDataset can know ahead of time of its iteration order? No, they can, but this is a special case, and should be handled by the application. I don't think this should live in PyTorch.

But then, how to make it easier for users to write their own IterableDataset that works on distributed?

What are the things we need to keep in mind in DDP in this case?

* `sampler.set_epoch`, so that we can split the dataset between different machines

* in particular, we need to handle the `idx` of the worker ourselves.

How can we accomplish both, without having to change the APIs of DistributedSampler nor DataLoader?

Here is one example.

class MyIterableDataset(IterableDataset):
    def __init__(self):
        # each dataset, when constructed, know its rank
        # and they are constructed once per (GPU) process
        self.rank = dist.get_rank()
        # how many are we?
        self.world_size = dist.get_world_size()
        # have a counter on how many epochs have passed
        # can be incremented by get_chunk_iterator
        # if the user wants different shuffle per epochs
        self.epoch = 0

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        chunk_id = self.rank
        total_chunks = self.world_size
        if worker_info is not None:
            chunk_id = self.rank * worker_info.num_workers + worker_info.id
            total_chunks = self.world_size * worker_info.num_workers

        # now return the chunks in the dataset accordingly
        return iter(self.get_chunk_iterator(chunk_id, total_chunks))

    def get_chunk_iterator(self, chunk_id, total_chunks):
        # user implements this
        pass

Then, all the logic specific to your application stays restricted to your Dataset implementation, via the get_chunk_iterator, which can handle buffering / shuffling / etc, and can also increment the epoch counter if the user wants.

I might be missing something, but let me know if the above implementation doesn't address your use-cases.

Thanks for the comment, Francisco!

I agree that snippet work in some cases, but note the amount of boiler plate that was needed just to enable distributed training. More code would be needed for cases where total_chunks !=self.world_size * worker_info.num_workers and some more to support non-distributed scenarios too. Reservoir sampling and other techniques to distribute load accross servers wouldnt be easily supported. Also, distributed training using regular datasets doesnt need all that boilerplate as the sampler coordinates sample selection. That is where this PR tries to fit in by achieving feature parity between both types of dataset and decreasing the amount of boilerplate code related to distributed training infra structure.

As @cpuhrsch mentioned, there are several use cases for sampling within IterableDataset, including but not limited to Reservoir sampling that was requested #28743. Reservoir sampler could be implemented from the new DistributedSampler and it would be interesting as it allows sampling dataset with unknown size without duplicates in an efficient way (that by sampler might be a separate PR - but lets drop that idea for now :))

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can move on with this PR, although I would have loved a review from either @apaszke or @ssnl

Copy link
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's wait on @apaszke or @ssnl

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

After reading through the thread, I am still a bit confused how this helps distributed sampler for IterableDatasets. Is this supposed to directly work already, or will there be follow-up patches?

The current PyTorch Sampler is different from the "sampling" in methods like reservoir sampling, in the sense that it only specifies the "indices" to sampler, rather than whether a sample should be kept. Due to this, in DataLoader, we currently disallow using IterableDataset with a custom sampler or batch_sampler. An example snippet would be really useful!

Finally, this needs a test to get in.

"""

def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, length=None, padding=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IterableDataset can still implement __len__. I think it makes more sense for this sampler to assume that the dataset has __len__ than having an explicit input argument.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ssnl if the dataset can implement len then it makes more sense to extend Dataset as opposed to IterableDataset

This PR tries to add the ability to do distributed training when the number of samples are unknown. IterableDataset allows this concept, but for distributed training, the sampler needs some hints on how many chunks this unknown number of samples are split at

@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 9, 2020
@thiagocrepaldi
Copy link
Collaborator Author

Closing this for now

@rabeehkarimimahabadi
Copy link

Hi
I have multiple large-scale datasets and I know their length in advance, I need to train the models on TPUs and need to write distributed sampler for this case, to be able to use pytorch XLA, I really appreciate providing me with some examples, how I can implement distributed sampler for iterative datasets.
In the non-iterative case this is implemented as follows

def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
if xm.xrt_world_size() <= 1:
return RandomSampler(dataset)
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())

could you assist me how I can implement it for my case?
thanks. I really appreciate some help on this.
Best
Rabeeh

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dataloader Related to torch.utils.data.DataLoader and Sampler open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants