-
Notifications
You must be signed in to change notification settings - Fork 21.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[wip, test CI] Add IterableDataset #14705
Conversation
6ae6c8c
to
2ccf231
Compare
Can you please elaborate on the final design of iterable datasets? As I've mentioned in some PRs to C++ Dataset API, I think that our current DataLoader story is completely mismatched with the desire to have mutable datasets (including iterable datasets). In particular, you can't simply replicate the dataset, and have all workers produce the same samples, because that would be silly. Similarly, the Sampler is completely useless in this case. I really think we should work on alternative solutions for this API. |
def __init__(self, sizes_for_all_workers): | ||
self.sizes_for_all_workers = sizes_for_all_workers | ||
|
||
def __iter__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@apaszke Basically you can config each dataset replica differently in two different ways using torch.utils.data.get_worker_info()
which returns the worker's id
, seed
and dataset
replica:
- In
__iter__
for an iterable dataset (e.g., this example). - In
worker_init_fn
.
This requires user to write their dataset code with multiprocessing data loading in mind. But I think this is a reasonable requirement because (1) there is no general way to split an iterator across multiple processes (2) if people want to do sharding / bulk loading, how to split work among workers should be part of the design.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haven't read the new multiprocess iterator, but I had a few questions
@@ -70,47 +79,61 @@ class DataLoader(object): | |||
__initialized = False | |||
|
|||
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, | |||
batch_sampler=None, num_workers=0, collate_fn=default_collate, | |||
batch_sampler=None, num_workers=0, | |||
convert_fn=_utils.collate.default_convert, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can't the conversion be handled by collate_fn
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With collate_fn
, the input is assumed to be a list of many data samples, and you want to collate each field in the data structure of each element. E.g., [([1, 2], [3, 4]), ([5, 6], [7, 8)]
should be come (tensor([[1, 2], [5, 6]]), tensor([[3, 4], [7, 8]]))
. It does both elementwise conversion and collation.
However, when the dataset is an iterable, we want to only do conversion by considering the entire input as a single data sample. E.g., [([1, 2], [3, 4]), ([5, 6], [7, 8)]
should be come a single tensor([[[1, 2], [5, 6]], [[3, 4], [7, 8]]])
.
else: | ||
return batch | ||
return data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the rename?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because data loader can be used to load individual samples as well, and this PR provides better support for that use case. batch
seems to assume that the data is always batched.
torch/utils/data/dataloader.py
Outdated
if self.mode == DataLoaderMode.Map: | ||
data = self.convert_fn(self.dataset[index]) | ||
else: | ||
# mode == DataLoaderMode.MapWithBatchedRead: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need both Map
and MapWithBatchedRead
, when we only had a single mode before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because previous there is no easy way to to unbatch loading from a map-like dataset. One has to use batch_size=1
and provide a custom collate_fn which is just lambda x: x[0]
. That feels counterintuitive to me. With all other changes in this PR, this was a simple addition, so I made it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another reason is to provide better support for the case where the sampler provides a list of indices, and the dataset does bulk loading.
@apaszke There is no new multiprocessing iterator design. I just refactored it into two classes, each handling single or multi proc loading. The logic is exactly the same. |
485927e
to
2e264a0
Compare
@pytorchbot retest this please |
Add ChainDataset (the analog of ConcatDataset but for IterableDataset) and tests Add doc entries Make torch.utils.data.* usable without another import after import torch Make IterableDataset return NotImplemented in __len__ so fallback of some functions work
edd2913
to
0c40b91
Compare
5236cb5
to
da54df7
Compare
5cb90f8
to
74ea324
Compare
close in favor of #19228 |
Summary: This is a modified version of pytorch/pytorch#14705 since commit structure for that PR is quite messy. 1. Add `IterableDataset`. 3. So we have 2 data loader mods: `Iterable` and `Map`. 1. `Iterable` if the `dataset` is an instance of `IterableDataset` 2. `Map` o.w. 3. Add better support for non-batch loading (i.e., `batch_size=None` and `batch_sampler=None`). This is useful in doing things like bulk loading. 3. Refactor `DataLoaderIter` into two classes, `_SingleProcessDataLoaderIter` and `_MultiProcessingDataLoaderIter`. Rename some methods to be more generic, e.g., `get_batch` -> `get_data`. 4. Add `torch.utils.data.get_worker_info` which returns worker information in a worker proc (e.g., worker id, dataset obj copy, etc.) and can be used in `IterableDataset.__iter__` and `worker_init_fn` to do per-worker configuration. 5. Add `ChainDataset`, which is the analog of `ConcatDataset` for `IterableDataset`. 7. Import torch.utils.data in `torch/__init__.py` 9. data loader examples and documentations 10. Use `get_worker_info` to detect whether we are in a worker process in `default_collate` Closes pytorch/pytorch#17909, pytorch/pytorch#18096, pytorch/pytorch#19946, and some of pytorch/pytorch#13023 Pull Request resolved: pytorch/pytorch#19228 Reviewed By: bddppq Differential Revision: D15058152 fbshipit-source-id: 9e081a901a071d7e4502b88054a34b450ab5ddde
Summary: This is a modified version of pytorch#14705 since commit structure for that PR is quite messy. 1. Add `IterableDataset`. 3. So we have 2 data loader mods: `Iterable` and `Map`. 1. `Iterable` if the `dataset` is an instance of `IterableDataset` 2. `Map` o.w. 3. Add better support for non-batch loading (i.e., `batch_size=None` and `batch_sampler=None`). This is useful in doing things like bulk loading. 3. Refactor `DataLoaderIter` into two classes, `_SingleProcessDataLoaderIter` and `_MultiProcessingDataLoaderIter`. Rename some methods to be more generic, e.g., `get_batch` -> `get_data`. 4. Add `torch.utils.data.get_worker_info` which returns worker information in a worker proc (e.g., worker id, dataset obj copy, etc.) and can be used in `IterableDataset.__iter__` and `worker_init_fn` to do per-worker configuration. 5. Add `ChainDataset`, which is the analog of `ConcatDataset` for `IterableDataset`. 7. Import torch.utils.data in `torch/__init__.py` 9. data loader examples and documentations 10. Use `get_worker_info` to detect whether we are in a worker process in `default_collate` Closes pytorch#17909, pytorch#18096, pytorch#19946, and some of pytorch#13023 Pull Request resolved: pytorch#19228 Reviewed By: bddppq Differential Revision: D15058152 fbshipit-source-id: 9e081a901a071d7e4502b88054a34b450ab5ddde
Summary: This is a modified version of pytorch#14705 since commit structure for that PR is quite messy. 1. Add `IterableDataset`. 3. So we have 2 data loader mods: `Iterable` and `Map`. 1. `Iterable` if the `dataset` is an instance of `IterableDataset` 2. `Map` o.w. 3. Add better support for non-batch loading (i.e., `batch_size=None` and `batch_sampler=None`). This is useful in doing things like bulk loading. 3. Refactor `DataLoaderIter` into two classes, `_SingleProcessDataLoaderIter` and `_MultiProcessingDataLoaderIter`. Rename some methods to be more generic, e.g., `get_batch` -> `get_data`. 4. Add `torch.utils.data.get_worker_info` which returns worker information in a worker proc (e.g., worker id, dataset obj copy, etc.) and can be used in `IterableDataset.__iter__` and `worker_init_fn` to do per-worker configuration. 5. Add `ChainDataset`, which is the analog of `ConcatDataset` for `IterableDataset`. 7. Import torch.utils.data in `torch/__init__.py` 9. data loader examples and documentations 10. Use `get_worker_info` to detect whether we are in a worker process in `default_collate` Closes pytorch#17909, pytorch#18096, pytorch#19946, and some of pytorch#13023 Pull Request resolved: pytorch#19228 Reviewed By: bddppq Differential Revision: D15058152 fbshipit-source-id: 9e081a901a071d7e4502b88054a34b450ab5ddde
Add
IterableDataset
.Support non batched loading of traditional map-like dataset. This is useful in doing bulk loading.
So we have three data loader mods:
Iterable
(newly added),Map
(newly added), andMapWithBatchedRead
(old).Iterable
if thedataset
is an instance ofIterableDataset
Map
ifbatch_size
isNone
MapWithBatchedRead
is chosen otherwise.Refactor
DataLoaderIter
into two classes,_SingleProcessDataLoaderIter
and_MultiProcessingDataLoaderIter
. Rename some methods to be more generic, e.g.,get_batch
->get_data
.Add
torch.utils.data.get_worker_info
which returns worker information in a worker proc (e.g., worker id, dataset obj copy, etc.) and can be used inIterableDataset.__iter__
andworker_init_fn
to do per-worker configuration.Add
ChainDataset
, which is the analog ofConcatDataset
forIterableDataset
.Add
convert_fn
, which is meant to convert each loaded data into tensors. For theMapWithBatchedRead
mode, fetched data is first converted usingconvert_fn
and then collated. This shouldn't be much slower (if any) than the old approach of only usingcollate_fn
.