Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python dataloader Improvements #13023

Open
2 of 7 tasks
ssnl opened this issue Oct 23, 2018 · 6 comments
Open
2 of 7 tasks

Python dataloader Improvements #13023

ssnl opened this issue Oct 23, 2018 · 6 comments
Assignees
Labels
module: dataloader Related to torch.utils.data.DataLoader and Sampler triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ssnl
Copy link
Collaborator

ssnl commented Oct 23, 2018

@goldsborough and I are planning a series of improvements to dataloader in both C++ and Python API. This issue mainly focuses on the planned changes for the Python API.

  • Iterator Serialization ([feature request] Savable data loader/iterator #11813)

    Dataloader iterator is not picklable currently, due to the multiprocessing and multithreading attributes it has. We should make it picklable as long as the dataset and the Sampler iterator is picklable, e.g., the __getstate__ could be

      def __getstate__(self):
          return (self.loader, self.sampler_iter, self.base_seed)

    We will also make the iterator of provided samplers serializable.

  • Examples of Bulk Loading

    The current dataloding API seems to suggest that dataloader is mainly suited for creating batches from random reads from the dataset. However, it supports bulk loading very well. For instance, this gist implements sharded/chunked bulk loading in just 40 lines. We will improve the documentation to include examples of such cases.

  • Worker Load Configurations

    We currently balance the load of workers by keeping the #tasks per worker balanced. This could be a problem if the workload is not very even for the tasks. We should make this optional instead. Additionally, the max task number (currently 2 * num_workers) should also become configurable.

  • Expose Sampler iterator ([feature request] [PyTorch] Dynamic Samplers. #7359)

    This would enable dynamic updates to the Sampler iterator states, e.g., dynamic reweighting of the samples. The API may be loader_iter.get_sampler_iter(). Since we always prefetch some number of batches, we also need to augment the existing document to reflect that this iterator may be ahead of the latest return value of the data loader iterator.

    Edit: As @apaszke pointed out below, it is possible to allow for strict consistency by providing a interface to flush the pipeline and ask sampler iterator to give new indices basing on the updated state. But that design needs further consideration and we don't plan to do until there is immediate need.

  • More Flexible worker_init_fn

    Currently, worker_init_fn only takes in a single worker_idx argument, making it very difficult to initialize dataset object of each worker in a different way. Furthermore, it is impossible for the dataset.__getitem__ in workers to communicate with the Sampler iterator to fetch more indices, or update the iterator state. We plan to augment it's input argument to include a wider range of objects it can access, without being BC breaking, and being future-proof.

    I haven't given much thought to the API design of this. But for a proof-of-concept, the API could be a get_worker_init_fn_arg argument which would be called in main process, takes in a data_loader_iter_info "struct", containing fields referencing the dataset and the sampler_iter (and maybe more), and returns a serializable to be fed in as an additional argument of worker_init_fn in worker processes. Please let me know if you have suggestions!

  • Iterator-style Dataset

    We don't necessarily need to have a sampler. By allowing an iterator style Dataset (rather than a stateless mapping), the workers can do interesting things like backfilling. This is entirely supported as of today, but we will make it nicer.

  • Bridging C++ and Python DataLoaders

    We will be providing a simple way to convert a C++ DataLoader into a Python one, with the same API as the existing Python DataLoader.

Our Plan to Make These Happen

I (@ssnl) will be focusing on the first four items while @goldsborough will implement the fifth. In addition to these, @goldsborough is also adding a bunch of exciting features into the C++ DataLoader to allow for greater flexibility (e.g., see #12960 #12999).

Let us know your thoughts and suggestions!

cc @soumith @fmassa @apaszke

@apaszke
Copy link
Contributor

apaszke commented Oct 30, 2018

A few questions regarding your points:

Iterator Serialization

Do you expect the iterator to resume from the beginning or from the particular place at which it was serialized? If you thought about the first case, why would be serialize the iterator instead of the data loader?

Examples of Bulk Loading

Yes, please. We might want to add code like that to the core as well.

Expose Sampler iterator

Sounds reasonable. The only downside I see with this approach is that it's impossible to do that while having the nice for loop syntax, since Python calls iter for you in that case, but I guess that's fine. I'm not sure if returning self from __iter__ is a standard practice or not.

NB implementing a flag that enforces strict consistency is possible too (we could flush the pipeline, and retry with new sample indices), but it's not all that simple, so I'd wait until it's really needed before we proceed.

More flexible worker_init_fn

I don't think I fully understand the API. It seems to me that it would be much simpler to just accept an extra argument in the data loader, which would be a list of length equal to num_workers, where every element is a tuple of args to append to init_worker_fn.

Bridging C++ and Python DataLoaders

That is currently very hard because they have different semantics (C++ data loaders can only have a single iterator at any time, as they heavily rely on mutation of internal state). IMHO we should make C++ iterators behave exactly like those in Python.

@goldsborough is also adding a bunch of exciting features into the C++ DataLoader to allow for greater flexibility (e.g., see #12960 #12999)

I'd really like us to discuss those changes more before we introduce them, because currently we're making Python and C++ worlds more and more different, which contradicts the previous point.

@ssnl
Copy link
Collaborator Author

ssnl commented Oct 31, 2018

Do you expect the iterator to resume from the beginning or from the particular place at which it was serialized? If you thought about the first case, why would be serialize the iterator instead of the data loader?

The second. However, considering that we cache out-of-order batches, this may not work perfectly.

I'm not sure if returning self from iter is a standard practice or not.

It's used a lot, but not standard so we can't rely on that.

NB implementing a flag that enforces strict consistency is possible too (we could flush the pipeline, and retry with new sample indices), but it's not all that simple, so I'd wait until it's really needed before we proceed.

I agree.

just accept an extra argument in the data loader, which would be a list of length equal to num_workers, where every element is a tuple of args to append to init_worker_fn.

This solves only one of the two problems. The worker_init_fn you described still doesn't have access to the dataset object, which it may need to configure differently for different workers.

@ssnl ssnl self-assigned this Oct 31, 2018
@apaszke
Copy link
Contributor

apaszke commented Nov 1, 2018

The second. However, considering that we cache out-of-order batches, this may not work perfectly.

Exactly. Those are very very stateful objects, and do a lot of preprocessing in the background, so I'm not sure if that's even feasible and useful in many cases. I'd err on the side of simplicity and avoid implementing that.

It's used a lot, but not standard so we can't rely on that.

I'm not sure if I understand. Can you please elaborate?

The worker_init_fn you described still doesn't have access to the dataset object ...

Can you also please elaborate on the API and maybe post an example? I'm not exactly sure how would someone use it based on your description.

@ssnl
Copy link
Collaborator Author

ssnl commented Nov 2, 2018

Exactly. Those are very very stateful objects, and do a lot of preprocessing in the background, so I'm not sure if that's even feasible and useful in many cases. I'd err on the side of simplicity and avoid implementing that.

Being able to pause & continue is still very useful, especially when you have huge amount of samples per epoch. Since we are already able to serialize tensors, maybe this isn't too big a problem. I'll investigate and see.

It's used a lot, but not standard so we can't rely on that.
I'm not sure if I understand. Can you please elaborate?

I was talking about that returning self in __iter__ can't be relied upon. And yes, you can't use the for loop syntax in this case, but I would say that if you need advanced options like this, asking them to use the iterator object explicitly is reasonable :).

Can you also please elaborate on the API and maybe post an example? I'm not exactly sure how would someone use it based on your description.

We abandoned that idea. Now the new idea is to introduce a function maybe called torch.utils.data.get_worker_id which returns worker id in worker processes, and errors in main process.

facebook-github-bot pushed a commit that referenced this issue Dec 4, 2018
Summary:
As I am working on tasks in #13023, I realized how unreadable the code is because all functions to be run in multiprocessing must be at top global level. Adding more functionalities to `dataloader.py` will only make things worse.

So in this PR, I refactor `dataloader.py` and move much of it into `data._utils`. E.g., the `_worker_loop` and related methods are now in `data._utils.worker`, signal handling code in `data._utils.signal_handling`, collating code in `data._utils.collate`, etc. This split, IMHO, makes code much clearer. I will base my future changes to DataLoader on top of this.

No functionality is changed, except that  I added `torch._six.queue`.
Pull Request resolved: #14668

Reviewed By: soumith

Differential Revision: D13289919

Pulled By: ailzhang

fbshipit-source-id: d701bc7bb48f5dd7b163b5be941a9d27eb277a4c
facebook-github-bot pushed a commit to facebookresearch/ReAgent that referenced this issue Dec 19, 2018
Summary:
Same as #14668, and was approved there.

ailzhang , please apply this patch to Horizon's `data_streamer.py`: https://gist.github.com/SsnL/020fdb3d6b7016d81b6ba1d04cc41459 Thank you!

Below is the original description at #14668:

As I am working on tasks in pytorch/pytorch#13023, I realized how unreadable the code is because all functions to be run in multiprocessing must be at top global level. Adding more functionalities to `dataloader.py` will only make things worse.

So in this PR, I refactor `dataloader.py` and move much of it into `data._utils`. E.g., the `_worker_loop` and related methods are now in `data._utils.worker`, signal handling code in `data._utils.signal_handling`, collating code in `data._utils.collate`, etc. This split, IMHO, makes code much clearer. I will base my future changes to DataLoader on top of this.

No functionality is changed, except that  I added `torch._six.queue`.
Pull Request resolved: pytorch/pytorch#15331

Reviewed By: yf225

Differential Revision: D13503120

Pulled By: ailzhang

fbshipit-source-id: 94df16b4d80ad1102c437cde0d5a2e62cffe1f8e
facebook-github-bot pushed a commit that referenced this issue Dec 19, 2018
Summary:
Same as #14668, and was approved there.

ailzhang , please apply this patch to Horizon's `data_streamer.py`: https://gist.github.com/SsnL/020fdb3d6b7016d81b6ba1d04cc41459 Thank you!

Below is the original description at #14668:

As I am working on tasks in #13023, I realized how unreadable the code is because all functions to be run in multiprocessing must be at top global level. Adding more functionalities to `dataloader.py` will only make things worse.

So in this PR, I refactor `dataloader.py` and move much of it into `data._utils`. E.g., the `_worker_loop` and related methods are now in `data._utils.worker`, signal handling code in `data._utils.signal_handling`, collating code in `data._utils.collate`, etc. This split, IMHO, makes code much clearer. I will base my future changes to DataLoader on top of this.

No functionality is changed, except that  I added `torch._six.queue`.
Pull Request resolved: #15331

Reviewed By: yf225

Differential Revision: D13503120

Pulled By: ailzhang

fbshipit-source-id: 94df16b4d80ad1102c437cde0d5a2e62cffe1f8e
XavierGeerinck pushed a commit to XavierGeerinck/Horizon that referenced this issue Jan 27, 2019
Summary:
Same as #14668, and was approved there.

ailzhang , please apply this patch to Horizon's `data_streamer.py`: https://gist.github.com/SsnL/020fdb3d6b7016d81b6ba1d04cc41459 Thank you!

Below is the original description at #14668:

As I am working on tasks in pytorch/pytorch#13023, I realized how unreadable the code is because all functions to be run in multiprocessing must be at top global level. Adding more functionalities to `dataloader.py` will only make things worse.

So in this PR, I refactor `dataloader.py` and move much of it into `data._utils`. E.g., the `_worker_loop` and related methods are now in `data._utils.worker`, signal handling code in `data._utils.signal_handling`, collating code in `data._utils.collate`, etc. This split, IMHO, makes code much clearer. I will base my future changes to DataLoader on top of this.

No functionality is changed, except that  I added `torch._six.queue`.
Pull Request resolved: pytorch/pytorch#15331

Reviewed By: yf225

Differential Revision: D13503120

Pulled By: ailzhang

fbshipit-source-id: 94df16b4d80ad1102c437cde0d5a2e62cffe1f8e
facebook-github-bot pushed a commit to facebookresearch/ReAgent that referenced this issue Jun 21, 2019
Summary:
This is a modified version of pytorch/pytorch#14705 since commit structure for that PR is quite messy.

1. Add `IterableDataset`.
3. So we have 2 data loader mods: `Iterable` and `Map`.

    1. `Iterable` if the `dataset` is an instance of `IterableDataset`
    2. `Map` o.w.

3. Add better support for non-batch loading (i.e., `batch_size=None` and `batch_sampler=None`). This is useful in doing things like bulk loading.
3. Refactor `DataLoaderIter` into two classes, `_SingleProcessDataLoaderIter` and `_MultiProcessingDataLoaderIter`. Rename some methods to be more generic, e.g., `get_batch` -> `get_data`.
4. Add `torch.utils.data.get_worker_info` which returns worker information in a worker proc (e.g., worker id, dataset obj copy, etc.) and can be used in `IterableDataset.__iter__` and `worker_init_fn` to do per-worker configuration.
5. Add `ChainDataset`, which is the analog of `ConcatDataset` for `IterableDataset`.
7. Import torch.utils.data in `torch/__init__.py`
9. data loader examples and documentations
10. Use `get_worker_info` to detect whether we are in a worker process in `default_collate`

Closes pytorch/pytorch#17909, pytorch/pytorch#18096, pytorch/pytorch#19946, and some of pytorch/pytorch#13023
Pull Request resolved: pytorch/pytorch#19228

Reviewed By: bddppq

Differential Revision: D15058152

fbshipit-source-id: 9e081a901a071d7e4502b88054a34b450ab5ddde
pull bot pushed a commit to Pandinosaurus/pytorch that referenced this issue Jun 21, 2019
Summary:
This is a modified version of pytorch#14705 since commit structure for that PR is quite messy.

1. Add `IterableDataset`.
3. So we have 2 data loader mods: `Iterable` and `Map`.

    1. `Iterable` if the `dataset` is an instance of `IterableDataset`
    2. `Map` o.w.

3. Add better support for non-batch loading (i.e., `batch_size=None` and `batch_sampler=None`). This is useful in doing things like bulk loading.
3. Refactor `DataLoaderIter` into two classes, `_SingleProcessDataLoaderIter` and `_MultiProcessingDataLoaderIter`. Rename some methods to be more generic, e.g., `get_batch` -> `get_data`.
4. Add `torch.utils.data.get_worker_info` which returns worker information in a worker proc (e.g., worker id, dataset obj copy, etc.) and can be used in `IterableDataset.__iter__` and `worker_init_fn` to do per-worker configuration.
5. Add `ChainDataset`, which is the analog of `ConcatDataset` for `IterableDataset`.
7. Import torch.utils.data in `torch/__init__.py`
9. data loader examples and documentations
10. Use `get_worker_info` to detect whether we are in a worker process in `default_collate`

Closes pytorch#17909, pytorch#18096, pytorch#19946, and some of pytorch#13023
Pull Request resolved: pytorch#19228

Reviewed By: bddppq

Differential Revision: D15058152

fbshipit-source-id: 9e081a901a071d7e4502b88054a34b450ab5ddde
iotamudelta pushed a commit to ROCm/pytorch that referenced this issue Jun 21, 2019
Summary:
This is a modified version of pytorch#14705 since commit structure for that PR is quite messy.

1. Add `IterableDataset`.
3. So we have 2 data loader mods: `Iterable` and `Map`.

    1. `Iterable` if the `dataset` is an instance of `IterableDataset`
    2. `Map` o.w.

3. Add better support for non-batch loading (i.e., `batch_size=None` and `batch_sampler=None`). This is useful in doing things like bulk loading.
3. Refactor `DataLoaderIter` into two classes, `_SingleProcessDataLoaderIter` and `_MultiProcessingDataLoaderIter`. Rename some methods to be more generic, e.g., `get_batch` -> `get_data`.
4. Add `torch.utils.data.get_worker_info` which returns worker information in a worker proc (e.g., worker id, dataset obj copy, etc.) and can be used in `IterableDataset.__iter__` and `worker_init_fn` to do per-worker configuration.
5. Add `ChainDataset`, which is the analog of `ConcatDataset` for `IterableDataset`.
7. Import torch.utils.data in `torch/__init__.py`
9. data loader examples and documentations
10. Use `get_worker_info` to detect whether we are in a worker process in `default_collate`

Closes pytorch#17909, pytorch#18096, pytorch#19946, and some of pytorch#13023
Pull Request resolved: pytorch#19228

Reviewed By: bddppq

Differential Revision: D15058152

fbshipit-source-id: 9e081a901a071d7e4502b88054a34b450ab5ddde
@zou3519 zou3519 added module: dataloader Related to torch.utils.data.DataLoader and Sampler triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jul 8, 2019
@byronyi
Copy link

byronyi commented Jun 12, 2020

Most parts of the improvement plan seem to be stalling while still in great need. @ssnl has left FB and at best works on PyTorch part time.

Any updates from the team?

@VitalyFedyunin
Copy link
Contributor

@byronyi we are taking DataLoader as high priority task for this half, I will publish new DL improvements RFC approx next week.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dataloader Related to torch.utils.data.DataLoader and Sampler triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants