Skip to content

Commit

Permalink
Add IterableDataset (#19228)
Browse files Browse the repository at this point in the history
Summary:
This is a modified version of pytorch/pytorch#14705 since commit structure for that PR is quite messy.

1. Add `IterableDataset`.
3. So we have 2 data loader mods: `Iterable` and `Map`.

    1. `Iterable` if the `dataset` is an instance of `IterableDataset`
    2. `Map` o.w.

3. Add better support for non-batch loading (i.e., `batch_size=None` and `batch_sampler=None`). This is useful in doing things like bulk loading.
3. Refactor `DataLoaderIter` into two classes, `_SingleProcessDataLoaderIter` and `_MultiProcessingDataLoaderIter`. Rename some methods to be more generic, e.g., `get_batch` -> `get_data`.
4. Add `torch.utils.data.get_worker_info` which returns worker information in a worker proc (e.g., worker id, dataset obj copy, etc.) and can be used in `IterableDataset.__iter__` and `worker_init_fn` to do per-worker configuration.
5. Add `ChainDataset`, which is the analog of `ConcatDataset` for `IterableDataset`.
7. Import torch.utils.data in `torch/__init__.py`
9. data loader examples and documentations
10. Use `get_worker_info` to detect whether we are in a worker process in `default_collate`

Closes pytorch/pytorch#17909, pytorch/pytorch#18096, pytorch/pytorch#19946, and some of pytorch/pytorch#13023
Pull Request resolved: pytorch/pytorch#19228

Reviewed By: bddppq

Differential Revision: D15058152

fbshipit-source-id: 9e081a901a071d7e4502b88054a34b450ab5ddde
  • Loading branch information
ssnl authored and facebook-github-bot committed Jun 21, 2019
1 parent 87d3519 commit 00e078b
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions ml/rl/readers/data_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _worker_loop(

def _pin_memory_loop(in_queue, out_queue, done_event, pin_memory, device_id):
"""
This is copied from dataloader. It uses a different `pin_memory_batch()`.
This is copied from dataloader. It uses a different `pin_memory()`.
It'd probably be best to merge.
"""
if pin_memory:
Expand All @@ -126,14 +126,14 @@ def _pin_memory_loop(in_queue, out_queue, done_event, pin_memory, device_id):
idx, batch = r
try:
if pin_memory:
batch = pin_memory_batch(batch)
batch = pin_memory(batch)
except Exception:
out_queue.put((idx, ExceptionWrapper(sys.exc_info())))
else:
out_queue.put((idx, batch))


def pin_memory_batch(batch):
def pin_memory(batch):
"""
This is ripped off from dataloader. The only difference is that it preserves
the type of Mapping so that the OrderedDict is maintained.
Expand All @@ -144,13 +144,13 @@ def pin_memory_batch(batch):
return batch
elif isinstance(batch, NamedTuple) or hasattr(batch, "_asdict"):
return type(batch)(
**{name: pin_memory_batch(value) for name, value in batch._asdict().items()}
**{name: pin_memory(value) for name, value in batch._asdict().items()}
)
elif isinstance(batch, collections.Mapping):
# NB: preserving OrderedDict
return type(batch)((k, pin_memory_batch(sample)) for k, sample in batch.items())
return type(batch)((k, pin_memory(sample)) for k, sample in batch.items())
elif isinstance(batch, collections.Sequence):
return [pin_memory_batch(sample) for sample in batch]
return [pin_memory(sample) for sample in batch]
else:
return batch

Expand Down Expand Up @@ -249,7 +249,7 @@ def __next__(self):
if self.num_workers == 0: # same-process loading
batch = next(self.data_reader_iter) # May raise StopIteration
if self.pin_memory:
batch = pin_memory_batch(batch)
batch = pin_memory(batch)
return batch

if self.batches_outstanding == 0:
Expand Down

0 comments on commit 00e078b

Please sign in to comment.