Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC, Tracker] DataLoader improvements #41292

Open
8 tasks
VitalyFedyunin opened this issue Jul 11, 2020 · 19 comments
Open
8 tasks

[RFC, Tracker] DataLoader improvements #41292

VitalyFedyunin opened this issue Jul 11, 2020 · 19 comments
Labels
module: dataloader Related to torch.utils.data.DataLoader and Sampler triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@VitalyFedyunin
Copy link
Contributor

VitalyFedyunin commented Jul 11, 2020

This issue created to track all current problems of DataLoader (and related components such as DataSet, Sampler, Transforms). It is focused on what we want to archive. Implementation details will follow after prioritizing.

Known problems

Documentation

  • Clear Initialization process doc. Some users getting confused at what time DataSet / DataLoader getting initialized. The order must be clear and understandable.

  • Clear Multiprocessing / Multithreading docs. With recipes (torchaudio is a good example). Especially covering cases when threading meets forking and when cuda meets forking.

Benchmarking

  • To make sure that performance improvements are trackable and we are not introducing slow downs we need to write and document benchmark scripts / methodology.

  • Add documentation how to measure performance of data loading, so users can identify bottlenecks.

  • Ability to benchmark CPU vs GPU Transforms.

Improvements

Considering Improvements

  • C++ implementation of DataLoader components for performance/c++ parity

cc @ssnl
cc @vadimkantorov

@VitalyFedyunin VitalyFedyunin added module: dataloader Related to torch.utils.data.DataLoader and Sampler triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jul 11, 2020
vincentqb referenced this issue in pytorch/audio Jul 16, 2020
* new dataset format.
* add basic test.
* files for testing.
* serialization using torch.
* add diskcache.
* adding deprecation warnings.
* removing legacy.
* warning about transforms.
* detecting file format using reader.
@vadimkantorov
Copy link
Contributor

Some random comments:

  1. The worker process lifecycle should be crystal-clear. Some external libraries need to do extra initialization (and sometimes teardown) once: free(): invalid pointer | data loader + torchaudio + SoxEffectsChain audio#271 (comment).

  2. For people not knowledgeable in linux process/thread model, it would be good to have some recipes in docs on those init things (torchaudio is a good example)

  3. Also it would be good to have in docs a rundown of common issues with tqdm/librosa/opencv related to forking/multi-threading/omp/cuda init.

  4. Another thing is that samplers may wish to support saving/restoring there state (maybe already done). This is tricky to get done, as it should be very clear what actually gets saved.

  5. Another thing is that with more transforms implemented as general PyTorch ops working on CUDA, it becomes less clear what ops should be done on CPU and if doing the transforms within the data-loader benefits from multi-threading at all. Maybe some fresh benchmarks for vision workloads would be nice.

  6. I heard that at some point training on ImageNet on Linux was tricky because OS disk cache went into some bad state and a manual periodic reset of disk cache was required. Is it still relevant?

  7. Is CPU process pinning useful? (especially for academic machines where a single node with a few GPUs may be shared by a few people, all running data loaders with many threads) I saw people using it. Some docs / recommendations on this would be great

  8. It would be good to have some advice/recipes on measuring perf of data loading (and maybe measuring some of the issues above)

@VitalyFedyunin
Copy link
Contributor Author

Thanks @vadimkantorov for detailed reply. I've updated issue with it. See some comments inline.

Some random comments:

  1. The worker process lifecycle should be crystal-clear. Some external libraries need to do extra initialization (and sometimes teardown) once: pytorch/audio#271 (comment).

Captured.

  1. For people not knowledgeable in linux process/thread model, it would be good to have some recipes in docs on those init things (torchaudio is a good example)

Do you have url of good example?

  1. Also it would be good to have in docs a rundown of common issues with tqdm/librosa/opencv related to forking/multi-threading/omp/cuda init.

Agree.

  1. Another thing is that samplers may wish to support saving/restoring there state (maybe already done). This is tricky to get done, as it should be very clear what actually gets saved.

Missed state of Sampler. Captured as requirement.

  1. Another thing is that with more transforms implemented as general PyTorch ops working on CUDA, it becomes less clear what ops should be done on CPU and if doing the transforms within the data-loader benefits from multi-threading at all. Maybe some fresh benchmarks for vision workloads would be nice.

Captured

  1. I heard that at some point training on ImageNet on Linux was tricky because OS disk cache went into some bad state and a manual periodic reset of disk cache was required. Is it still relevant?

Haven't heard of it. Maybe someone else have info.

  1. Is CPU process pinning useful? (especially for academic machines where a single node with a few GPUs may be shared by a few people, all running data loaders with many threads) I saw people using it. Some docs / recommendations on this would be great

I haven't seen feature requests about pinning. How much of it must be a feature in compare to taskset?

  1. It would be good to have some advice/recipes on measuring perf of data loading (and maybe measuring some of the issues above)

Captured.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jul 16, 2020

Do you have url of good example?

I guess (1) and (2) are the same things :) I never finally tried doing atexit in pytorch/audio#271 (comment). torchaudio currently requires initialization to use sox: https://pytorch.org/audio/#torchaudio.initialize_sox. Unitl very recently, ffmpeg required initialization as well. I'm never sure if we must do this initialization in every worker thread or initialization in main process is enough

Haven't heard of it. Maybe someone else have info.

If I'm not mistaken, I heard about this as a battle story from @szagoruyko who shared this advice received from @soumith (again if I'm not mixing things up), maybe 4 years ago

I haven't seen feature requests about pinning. How much of it must be a feature in compare to taskset?

Yeah, I saw people using taskset sometimes (and that's what I meant by pinning), but I never saw anyone discussing what are benefits and when you have to use it

@VitalyFedyunin VitalyFedyunin changed the title [Draft, RFC, Tracker] DataLoader improvements [RFC, Tracker] DataLoader improvements Jul 16, 2020
@rwightman
Copy link

Some great ideas, improvements listed here. Balancing IO and CPU utilization is a big challenge with the current DataLoader design.

Chunking the data into shards (the webdata idea, tfrecords, riegeli, etc) can help significantly with the IO, but there is huge room for improvement with even individual files if the IO is dispatched and coalesced efficiently.

I've spent time thinking about this related to past experience essentially building real-time video storage subsystems, moving the IO layer down to C++ could be a big win. It's really hard to dispatch enough IO requests simultaneously to fill sata queues from synchonous Python calls like PIL.open(). From C/C++ it's fairly straight forward to flood controllers with io calls using reactor like i/o multiplexing (select()/poll()) or proactor like (windows overlapped I/O). Maybe some of this could be done in Python, but Python threading and async interfaces have always struck me as being quite limiting and generally 'not worth the effort' given the ever present GIL and the near impossibility of writing low context switch code (no mutexes, no mallocs) in Python.

However, the above makes the boundary between the Python and C++ potentially quite awkard. Once you have the data hot from the IO, it's best to do some decoding and possibly augmentation while you have it, but then that could result in a loss of the Python transform pipeline flexibility.

@elistevens
Copy link
Contributor

I'd like to have an option where a batch is collected from multiple loader processes, rather than all from a single process. The general use case is:

  • Have many files on disk, each file containing data for multiple samples. Files must be loaded, parsed, and stored in RAM in their entirety before any samples can be extracted.
  • Loading and parsing a file is expensive.
  • Extracting each sample after that is cheap.
  • Each loader process can only keep a handful of files in RAM at a time (fewer than the batch size).

Currently, this situation results in batches that have samples from a handful of files (the next batch from the next loader's handful), rather than from entire set of files that can be kept in RAM, which reduces the variation inside a batch.

The workaround is to extract and cache samples ahead of time, but that makes experimenting with the details of what a sample is expensive.

@rwightman
Copy link

@elistevens sounds equivalent to the fairly common case you run into where you chunk data with some sort of compression .... parquet files, video codecs with IDR frame intervals, etc ... although sometimes you can at least start decoding with partial data (video) or skip some data because there is internal structure or external metadata about seek points, packets, etc

multi-threading in the same memory space is usually much more effective there than trying to do it across process boundaries with IPC or shared memory

@elistevens
Copy link
Contributor

In my actual use case it was CT scans converted into .mhd files read via https://simpleitk.readthedocs.io/en/master/IO.html but I suspect it's a general usage pattern.

Not impossible to work around, obviously, but the current API doesn't help.

@VitalyFedyunin
Copy link
Contributor Author

@rwightman @elistevens thank you for the input, we realize that both multi-processing and multi-threading are equally important for the framework and will look for options to make them interchangeable.

@vadimkantorov
Copy link
Contributor

One more related dataloader issue: #22924 about "Make DataLoader return usable traces in the case of Ctrl+C and similar OS signals."

@vadimkantorov
Copy link
Contributor

#13246 this would be a huge problem if the processes are kept around and not recreated

@tmbdev
Copy link

tmbdev commented Aug 5, 2020

We've developed an open source framework called Tensorcom (based on the WebDataset I/O system, though it can also be used with existing Dataset implementations). It can be used instead of DataLoader.

Worker processes are started up independently of the main training job; they can be started up before or after, as client or servers, and batching can happen either in the workers or the main job.

Worker processes don't have to run on the same machine as the deep learning job, and for hyperparameter searches, a single set of worker processes can broadcast training samples to a large number of training jobs.

Making worker processes explicit and separate from the main training job has a number of advantages: it simplifies testing and debugging, it allows performance tuning separate from the deep learning job, you can adjust the number of workers while the DL job is running, it makes I/O pipelines independent of the specific deep learning job, and it works really well with K8s (where worker processes and DL jobs can just be configured and scaled as separate ReplicaSets or StatefulSets).

Tensorcom can be used as a replacement for DataLoader (and the Tensorcom workers can be started up by the DL job), but more importantly, it can take advantage of RDMA and direct-to-GPU hardware where available.

We're currently still working on WebDataset integration into PyTorch. Tensorcom works very well and efficiently, but it's still largely undocumented and the APIs may still change. I just wanted to give a heads up.

@ppwwyyxx
Copy link
Contributor

Some thoughts on batching:
We have the need to create data loaders that does not do batching. This is doable now but not very elegant as batching seems to be a pretty fundamental assumption (in the batch_sampler argument).
Also, the batch_sampler argument of torch.utils.data.DataLoader is supposed to return indices, and that's a limitation on how batching can be done. We often do dynamic batching that's based on the size/length of the possibly transformed data. In this case the batching can no longer be pre-determined as a static group of indices.

These issues can be resolved to some extent today using IterableDataset but hopefully the new design can take these issues into account and make fewer assumptions on batching.

@VitalyFedyunin
Copy link
Contributor Author

@ppwwyyxx do you have an example or preudocode of how you calculate dynamic batch size based on size/length?

@ppwwyyxx
Copy link
Contributor

In order to only put samples that have similar sizes into a batch, we first created a dataloader that's not batched, i.e. for d in data_loader returns individual (transformed) samples. Then we do

buckets = defaultdict(list)
for d in data_loader:
    bucket = buckets[compute_bucket_id(d)]  # determine which bucket based on d's size
    bucket.append(d)
    if len(bucket) == batch_size:
        yield collate_fn(bucket)
        del bucket[:] 

@tmbdev
Copy link

tmbdev commented Aug 18, 2020

We have the need to create data loaders that does not do batching. This is doable now but not very elegant as batching seems to be a pretty fundamental assumption (in the batch_sampler argument).

Batching serves two purposes: it's needed for deep learning, but it's also needed for efficient data transfer from workers to the main process.

These issues can be resolved to some extent today using IterableDataset but hopefully the new design can take these issues into account and make fewer assumptions on batching.

I think a redesign should be based on a general pipeline abstraction; batching and unbatching just become pipeline stages. There are initial sources of data, and a primitive for running multiple data pipelines in parallel and combining the results.

We have tried to provide something like that in WebDataset (while still staying somewhat close to the existing Dataset/DataLoader framework).

So, a regular multi-worker loader with shuffling across batches would look like (note the use of batching for efficient data transfers from workers to the main process):

dataset = wds.Dataset(url).decode().to_tuple("png", "cls").batched(64)
loader = wds.MultiDataset(dataset, workers=4).unbatched().shuffle(1000).batched(64)

For variable-sized batching, you would use something like:

def bucket_batching(samples):
    ... as above ...

dataset = wds.Dataset(url).decode().to_tuple("png", "cls").batched(64, combine_tensors=False)
loader = wds.MultiDataset(dataset, workers=4).unbatched().shuffle(1000).pipe(bucket_batching)

For multi-node distributed preprocessing with buckets, you can write:

# CPU node
dataset = wds.Dataset(url).decode().to_tuple("png", "cls").shuffle(1000).pipe(bucket_batching)
tensorcom.Connection("zpush:https://gpunode:5678").serve(dataset)
# GPU node
loader = tensorcom.Connection("zpull:https://*:5678")

All of this essentially works today in WebDataset and Tensorcom (arguments may be slightly different from these examples in the released versions).

@hudeven
Copy link

hudeven commented Aug 25, 2020

Some feedbacks from PyText use case:

  1. similar to @ppwwyyxx mentioned, PyText also need custom batcher on transformed data with IterableDataset. We typically 1) transform sentence in to a sequence of tokens 2) sort by sequence length 3) bundle the sentences with similar length into same batch 4) pad the batch. Currently, we made custom iterator with our batcher. As we are migrating to Dataset/Dataloader abstraction, it would be great to have the custom batcher(re-batch after transform) support in pytorch https://github.com/facebookresearch/pytext/blob/master/pytext/data/data.py#L86

  2. @VitalyFedyunin do you plan to build a unified Transform interface for both vision and text? We are building some text domain Transforms mimicking torchvision's Transform. It might worth to have a common abstraction such that we could composite the data pipelines for MultiModality use case easily(fused model with text, image, video features)

  3. We have internal OnboxDataloader supporting checkpointing data status, auto data shardding with Pytorch Elastic Trainer etc. It has some overlap with WebDataset. It would be good to have a unified data API between internal and OSS, such that we can switch between them seamlessly.

@oleg-kachan
Copy link

oleg-kachan commented Sep 3, 2020

What is the current status of the support of NVIDIA DirectStorage/RTX IO direct data loading into GPU by PyTorch?

@AntreasAntoniou
Copy link

Just adding my own use case for why I think this is a must have feature:

Imagine you are trying to scale up your batch size by some factor N, and you can only increase the number of your CPUs but not so much your RAM.

The default behaviour of DataLoader is to spin up num_workers processes, each of which samples a full batch before returning it. The issue that this causes is that in order for me to benefit from the increased availability of CPUs to scale my batch size up efficiently, I need to basically convert a situation where:

batch_size=B, num_workers=W

to

batch_size=NxB, num_workers=NxW.

This will ensure that I am able to properly utilize the extra CPUs I can get access to. However, in doing so I am also effectively increasing the memory that I need to use from B x W to N x B x N x W, which is problematic since I have access to more CPUs, but not additional RAM. And in most cases, such as on cloud compute, the factor of CPU to RAM increase is usually equal, therefore, an increase of N^2 complexity for memory could be very problematic.

I was wondering if there is a way to change the default behaviour of the DataLoader from a situation where instead of each worker loading a full batch, it instead has a pool of workers sampling single samples each which are then integrated into a batch once they reach a certain number. That way I could do something like what I need above, without needing all that extra memory, and potentially saving memory as well.

Please let me know your thoughts.

@VitalyFedyunin
Copy link
Contributor Author

This functionality is going to be supported using DataPipes and the new DataLoader, likely 1.12 release will cover such use-case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dataloader Related to torch.utils.data.DataLoader and Sampler triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

9 participants