Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add length and padding keyworks to DistributedSampler #28841

Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletions torch/utils/data/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,30 @@ class DistributedSampler(Sampler):

It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
process can pass a `DistributedSampler` instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.

.. note::
Dataset is assumed to be of constant size.

Arguments:
dataset: Dataset used for sampling.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
shuffle (optional): If true (default), sampler will shuffle the indices
dataset (Dataset): Dataset used for sampling. It can be `None` if :attr:`length` is specified
num_replicas (int, optional): Number of processes participating in
distributed training (default: `None`)
rank (int, optional): Rank of the current process within :attr:`num_replicas` (default: `None`)
shuffle (bool, optional): If `True` sampler will shuffle the indices (default: `True`)
length (int, optional): length of `dataset`
If `None`, length is calculated as `len(dataset)` (default: `None`)
Must be greater than 0 when :attr:`dataset` is a `IterableDataset`
padding (bool, optional): If `True`, the returned lists will be padded to have the same length
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
Padding is done by adding (duplicate) indices from the beggining of the index list
into the end of it in a circular fashion. (default: `True`)

.. note::
Regardless of :attr:`padding` value, :meth:`__len__` will return `ceil(dataset length / num_replicas)`
"""

def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True, length=None, padding=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IterableDataset can still implement __len__. I think it makes more sense for this sampler to assume that the dataset has __len__ than having an explicit input argument.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ssnl if the dataset can implement len then it makes more sense to extend Dataset as opposed to IterableDataset

This PR tries to add the ability to do distributed training when the number of samples are unknown. IterableDataset allows this concept, but for distributed training, the sampler needs some hints on how many chunks this unknown number of samples are split at

if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
Expand All @@ -36,23 +45,36 @@ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))

if dataset is None and length is None:
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError("Either :attr:`dataset` or :attr:`length` must be specified (not `None`)")
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved

if isinstance(length, int) and length > 0:
self._dataset_length = length
elif length is None:
self._dataset_length = len(dataset)
else:
raise RuntimeError("When specified, :attr:`length` must be a strictly positive integer")

self.num_samples = int(math.ceil(self._dataset_length * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.padding = padding

def __iter__(self):
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist()
indices = torch.randperm(self._dataset_length, generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
indices = list(range(self._dataset_length))


# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
if self.padding:
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size

# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
Expand All @@ -63,5 +85,10 @@ def __iter__(self):
def __len__(self):
return self.num_samples
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved

def set_rank(self, rank):
thiagocrepaldi marked this conversation as resolved.
Show resolved Hide resolved
assert rank >= 0, 'rank must be >= 0'
assert rank < self.num_replicas, 'rank must < num_replicas'
self.rank = rank

def set_epoch(self, epoch):
self.epoch = epoch
5 changes: 3 additions & 2 deletions torch/utils/data/distributed.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ from . import Sampler, Dataset

T_co = TypeVar('T_co', covariant=True)
class DistributedSampler(Sampler[T_co]):
def __init__(self, dataset: Dataset, num_replicas: Optional[int]=..., rank: Optional[int]=...): ...
def __iter__(self) -> Iterator[int]: ...
def __init__(self, dataset: Dataset, num_replicas: Optional[int]=..., rank: Optional[int]=..., shuffle: Optional[bol]=..., length: Optional[int]=..., padding: Optional[bool]=...): ...
def __iter__(self) -> Iterable[int]: ...
def __len__(self) -> int: ...
def set_rank(self, rank: int) -> None: ...
def set_epoch(self, epoch: int) -> None: ...