Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add length and padding keyworks to DistributedSampler #28841

Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add length and padding keyworks to DistributedSampler
Current implementation of `DistributedSampler` is ideal for distributed
training using map datasets, as they fit in memory and have known size.
However, it doesn't support distributed training using `IterableDataset`
datasets, as these classes do not implement `__len__`.
To fix that, a `length` keyword was added to `DistributedSampler`, which
has precedence when set.

An extra `padding=True` parameter was also added was give finer control
on whether the (returned) index list should be padded by the sampler.
This is useful for preventing duplicate reading on `IterableDataset`
datasets that do not fit in memory or which data reading or transformation
are expensive.

Finally, set_rank method was added, similarly the existing `set_epoch`,
to ease distributed training. When `DataLoader` is created with
`num_workers` > 0 and `dataset` is an instance of `ChunkDataset`,
a copy of `DistributedSampler` on each worker needs to be configured
with their new rank.

There is no back compatibility with this change.
  • Loading branch information
Thiago Crepaldi committed Oct 29, 2019
commit 44b631ebf0152615c0628f0360ddeb990cfdc405
41 changes: 29 additions & 12 deletions torch/utils/data/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,26 @@ class DistributedSampler(Sampler):

It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
process can pass a `DistributedSampler` instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.

.. note::
Dataset is assumed to be of constant size.

Arguments:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
shuffle (optional): If true (default), sampler will shuffle the indices
dataset (Dataset): Dataset used for sampling. It can be `None` if :attr:`length` is specified
num_replicas (int, optional): Number of processes participating in
distributed training (default: `None`)
rank (int, optional): Rank of the current process within :attr:`num_replicas` (default: `None`)
shuffle (bool, optional): If `True` sampler will shuffle the indices (default: `True`)
length (int, optional): length of `dataset`
If `None`, length is calculated as `len(dataset)` (default: `None`)
Must be greater than 0 when :attr:`dataset` is a `IterableDataset`
padding (bool, optional): If `True`, the returned lists will be padded to have the same length
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
(default: `True`)
"""

def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, length=None, padding=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IterableDataset can still implement __len__. I think it makes more sense for this sampler to assume that the dataset has __len__ than having an explicit input argument.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ssnl if the dataset can implement len then it makes more sense to extend Dataset as opposed to IterableDataset

This PR tries to add the ability to do distributed training when the number of samples are unknown. IterableDataset allows this concept, but for distributed training, the sampler needs some hints on how many chunks this unknown number of samples are split at

if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
Expand All @@ -36,23 +41,30 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))

self._dataset_length = length
if not isinstance(length, int) or length <= 0:
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
self._dataset_length = len(dataset)

self.num_samples = int(math.ceil(self._dataset_length * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.padding = padding

def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist()
indices = torch.randperm(self._dataset_length, generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
indices = list(range(self._dataset_length))


# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
if self.padding:
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size

# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
Expand All @@ -63,5 +75,10 @@ def __iter__(self):
def __len__(self):
return self.num_samples
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved

def set_rank(self, rank):
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
assert rank >= 0, 'rank must be >= 0'
assert rank < self.num_replicas, 'rank must < num_replicas'
self.rank = rank

def set_epoch(self, epoch):
self.epoch = epoch
5 changes: 3 additions & 2 deletions torch/utils/data/distributed.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ from . import Sampler, Dataset

T_co = TypeVar('T_co', covariant=True)
class DistributedSampler(Sampler[T_co]):
def __init__(self, dataset: Dataset, num_replicas: Optional[int]=..., rank: Optional[int]=...): ...
def __iter__(self) -> Iterator[int]: ...
def __init__(self, dataset: Dataset, num_replicas: Optional[int]=..., rank: Optional[int]=..., shuffle: Optional[bol]=..., length: Optional[int]=..., padding: Optional[bool]=...): ...
def __iter__(self) -> Iterable[int]: ...
def __len__(self) -> int: ...
def set_rank(self, rank: int) -> None: ...
def set_epoch(self, epoch: int) -> None: ...