Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add length and padding keyworks to DistributedSampler #28841

Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Address comments
  • Loading branch information
Thiago Crepaldi committed Oct 29, 2019
commit 77c2046853e9d3bd7b9e4240845b18359b9d8abf
20 changes: 14 additions & 6 deletions torch/utils/data/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ class DistributedSampler(Sampler):
If `None`, length is calculated as `len(dataset)` (default: `None`)
Must be greater than 0 when :attr:`dataset` is a `IterableDataset`
padding (bool, optional): If `True`, the returned lists will be padded to have the same length
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
(default: `True`)
Padding is done by adding (duplicate) indices from the beggining of the index list
into the end of it in a circular fashion. (default: `True`)

.. note::
Regardless of :attr:`padding` value, :meth:`__len__` will return `ceil(dataset length / num_replicas)`
"""

def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, length=None, padding=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IterableDataset can still implement __len__. I think it makes more sense for this sampler to assume that the dataset has __len__ than having an explicit input argument.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ssnl if the dataset can implement len then it makes more sense to extend Dataset as opposed to IterableDataset

This PR tries to add the ability to do distributed training when the number of samples are unknown. IterableDataset allows this concept, but for distributed training, the sampler needs some hints on how many chunks this unknown number of samples are split at

Expand All @@ -42,9 +46,15 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, length=N
self.rank = rank
self.epoch = 0

self._dataset_length = length
if not isinstance(length, int) or length <= 0:
if dataset is None and length is None:
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError("Either :attr:`dataset` or :attr:`length` must be specified (not `None`)")
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(length, int) and length > 0:
self._dataset_length = length
elif length is None:
self._dataset_length = len(dataset)
else:
raise RuntimeError("When specified, :attr:`length` must be a strictly positive integer")

self.num_samples = int(math.ceil(self._dataset_length * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
Expand Down Expand Up @@ -73,9 +83,7 @@ def __iter__(self):
return iter(indices)

def __len__(self):
if self.padding:
return self.num_samples
return self._dataset_length
return self.num_samples
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved

def set_rank(self, rank):
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
assert rank >= 0, 'rank must be >= 0'
Expand Down