Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Complimentary gluon DataLoader improvements #13606

Merged
merged 5 commits into from
Dec 14, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions docs/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
* MXNET_CPU_NNPACK_NTHREADS
- Values: Int ```(default=4)```
- The number of threads used for NNPACK. NNPACK package aims to provide high-performance implementations of some layers for multi-core CPUs. Checkout [NNPACK](http:https://mxnet.io/faq/nnpack.html) to know more about it.
* MXNET_MP_WORKER_NTHREADS
- Values: Int ```(default=1)```
- The number of scheduling threads on CPU given to multiprocess workers. Enlarge this number allows more operators to run in parallel in individual workers but please consider reducing the overall `num_workers` to avoid thread contention (not available on Windows).
* MXNET_MP_OPENCV_NUM_THREADS
- Values: Int ```(default=0)```
- The number of OpenCV execution threads given to multiprocess workers. OpenCV multithreading is disabled if `MXNET_MP_OPENCV_NUM_THREADS` < 1 (default). Enlarge this number may boost the performance of individual workers when executing underlying OpenCV functions but please consider reducing the overall `num_workers` to avoid thread contention (not available on Windows).

## Memory Options

Expand Down Expand Up @@ -99,10 +105,10 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
* MXNET_KVSTORE_REDUCTION_NTHREADS
- Values: Int ```(default=4)```
- The number of CPU threads used for summing up big arrays on a single machine
- This will also be used for `dist_sync` kvstore to sum up arrays from different contexts on a single machine.
- This does not affect summing up of arrays from different machines on servers.
- This will also be used for `dist_sync` kvstore to sum up arrays from different contexts on a single machine.
- This does not affect summing up of arrays from different machines on servers.
- Summing up of arrays for `dist_sync_device` kvstore is also unaffected as that happens on GPUs.

* MXNET_KVSTORE_BIGARRAY_BOUND
- Values: Int ```(default=1000000)```
- The minimum size of a "big array".
Expand Down Expand Up @@ -166,7 +172,7 @@ When USE_PROFILER is enabled in Makefile or CMake, the following environments ca

* MXNET_CUDNN_AUTOTUNE_DEFAULT
- Values: 0, 1, or 2 ```(default=1)```
- The default value of cudnn auto tuning for convolution layers.
- The default value of cudnn auto tuning for convolution layers.
- Value of 0 means there is no auto tuning to pick the convolution algo
- Performance tests are run to pick the convolution algo when value is 1 or 2
- Value of 1 chooses the best algo in a limited workspace
Expand All @@ -190,12 +196,12 @@ When USE_PROFILER is enabled in Makefile or CMake, the following environments ca
* MXNET_HOME
- Data directory in the filesystem for storage, for example when downloading gluon models.
- Default in *nix is .mxnet APPDATA/mxnet in windows.

* MXNET_MKLDNN_ENABLED
- Values: 0, 1 ```(default=1)```
- Flag to enable or disable MKLDNN accelerator. On by default.
- Only applies to mxnet that has been compiled with MKLDNN (```pip install mxnet-mkl``` or built from source with ```USE_MKLDNN=1```)

* MXNET_MKLDNN_CACHE_NUM
- Values: Int ```(default=-1)```
- Flag to set num of elements that MKLDNN cache can hold. Default is -1 which means cache size is unbounded. Should only be set if your model has variable input shapes, as cache size may grow unbounded. The number represents the number of items in the cache and is proportional to the number of layers that use MKLDNN and different input shape.
Expand Down
40 changes: 30 additions & 10 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import multiprocessing
import multiprocessing.queues
from multiprocessing.reduction import ForkingPickler
from multiprocessing.pool import ThreadPool
import threading
import numpy as np

Expand Down Expand Up @@ -384,8 +385,9 @@ def _worker_initializer(dataset):
global _worker_dataset
_worker_dataset = dataset

def _worker_fn(samples, batchify_fn):
def _worker_fn(samples, batchify_fn, dataset=None):
"""Function for processing data in worker process."""
# pylint: disable=unused-argument
# it is required that each worker process has to fork a new MXIndexedRecordIO handle
# preserving dataset as global variable can save tons of overhead and is safe in new process
global _worker_dataset
Expand All @@ -394,10 +396,14 @@ def _worker_fn(samples, batchify_fn):
ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(batch)
return buf.getvalue()

def _thread_worker_fn(samples, batchify_fn, dataset):
"""Threadpool worker function for processing data."""
return batchify_fn([dataset[i] for i in samples])

class _MultiWorkerIter(object):
"""Internal multi-worker iterator for DataLoader."""
def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
worker_fn=_worker_fn, prefetch=0):
worker_fn=_worker_fn, prefetch=0, dataset=None):
self._worker_pool = worker_pool
self._batchify_fn = batchify_fn
self._batch_sampler = batch_sampler
Expand All @@ -407,6 +413,7 @@ def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
self._iter = iter(self._batch_sampler)
self._worker_fn = worker_fn
self._pin_memory = pin_memory
self._dataset = dataset
# pre-fetch
for _ in range(prefetch):
self._push_next()
Expand All @@ -419,7 +426,8 @@ def _push_next(self):
r = next(self._iter, None)
if r is None:
return
async_ret = self._worker_pool.apply_async(self._worker_fn, (r, self._batchify_fn))
async_ret = self._worker_pool.apply_async(
self._worker_fn, (r, self._batchify_fn, self._dataset))
self._data_buffer[self._sent_idx] = async_ret
self._sent_idx += 1

Expand All @@ -432,7 +440,7 @@ def __next__(self):
assert self._rcvd_idx < self._sent_idx, "rcvd_idx must be smaller than sent_idx"
assert self._rcvd_idx in self._data_buffer, "fatal error with _push_next, rcvd_idx missing"
ret = self._data_buffer.pop(self._rcvd_idx)
batch = pickle.loads(ret.get())
batch = pickle.loads(ret.get()) if self._dataset is None else ret.get()
if self._pin_memory:
batch = _as_in_context(batch, context.cpu_pinned())
batch = batch[0] if len(batch) == 1 else batch
Expand Down Expand Up @@ -498,12 +506,18 @@ def default_batchify_fn(data):
but will consume more shared_memory. Using smaller number may forfeit the purpose of using
multiple worker processes, try reduce `num_workers` in this case.
By default it defaults to `num_workers * 2`.
thread_pool : bool, default False
If ``True``, use threading pool instead of multiprocessing pool. Using threadpool
can avoid shared memory usage. If `DataLoader` is more IO bounded or GIL is not a killing
problem, threadpool version may achieve better performance than multiprocessing.

"""
def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
last_batch=None, batch_sampler=None, batchify_fn=None,
num_workers=0, pin_memory=False, prefetch=None):
num_workers=0, pin_memory=False, prefetch=None, thread_pool=False):
self._dataset = dataset
self._pin_memory = pin_memory
self._thread_pool = thread_pool

if batch_sampler is None:
if batch_size is None:
Expand All @@ -529,8 +543,11 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
self._worker_pool = None
self._prefetch = max(0, int(prefetch) if prefetch is not None else 2 * self._num_workers)
if self._num_workers > 0:
self._worker_pool = multiprocessing.Pool(
self._num_workers, initializer=_worker_initializer, initargs=[self._dataset])
if self._thread_pool:
self._worker_pool = ThreadPool(self._num_workers)
else:
self._worker_pool = multiprocessing.Pool(
self._num_workers, initializer=_worker_initializer, initargs=[self._dataset])
if batchify_fn is None:
if num_workers > 0:
self._batchify_fn = default_mp_batchify_fn
Expand All @@ -551,14 +568,17 @@ def same_process_iter():

# multi-worker
return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler,
pin_memory=self._pin_memory, worker_fn=_worker_fn,
prefetch=self._prefetch)
pin_memory=self._pin_memory,
worker_fn=_thread_worker_fn if self._thread_pool else _worker_fn,
prefetch=self._prefetch,
dataset=self._dataset if self._thread_pool else None)

def __len__(self):
return len(self._batch_sampler)

def __del__(self):
if self._worker_pool:
# manually terminate due to a bug that pool is not automatically terminated on linux
# manually terminate due to a bug that pool is not automatically terminated
# https://bugs.python.org/issue34172
assert isinstance(self._worker_pool, multiprocessing.pool.Pool)
self._worker_pool.terminate()
8 changes: 5 additions & 3 deletions src/initialize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,13 @@ class LibraryInitializer {
Engine::Get()->Start();
},
[]() {
// Make children single threaded since they are typically workers
dmlc::SetEnv("MXNET_CPU_WORKER_NTHREADS", 1);
// Conservative thread management for multiprocess workers
const size_t mp_worker_threads = dmlc::GetEnv("MXNET_MP_WORKER_NTHREADS", 1);
dmlc::SetEnv("MXNET_CPU_WORKER_NTHREADS", mp_worker_threads);
dmlc::SetEnv("OMP_NUM_THREADS", 1);
#if MXNET_USE_OPENCV && !__APPLE__
cv::setNumThreads(0); // disable opencv threading
const size_t mp_cv_num_threads = dmlc::GetEnv("MXNET_MP_OPENCV_NUM_THREADS", 0);
cv::setNumThreads(mp_cv_num_threads); // disable opencv threading
#endif // MXNET_USE_OPENCV
engine::OpenMP::Get()->set_enabled(false);
Engine::Get()->Start();
Expand Down
7 changes: 4 additions & 3 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,10 @@ def __getitem__(self, key):
@with_seed()
def test_multi_worker():
data = Dataset()
loader = gluon.data.DataLoader(data, batch_size=1, num_workers=5)
for i, batch in enumerate(loader):
assert (batch.asnumpy() == i).all()
for thread_pool in [True, False]:
loader = gluon.data.DataLoader(data, batch_size=1, num_workers=5, thread_pool=thread_pool)
for i, batch in enumerate(loader):
assert (batch.asnumpy() == i).all()

class _Dummy(Dataset):
"""Dummy dataset for randomized shape arrays."""
Expand Down