Skip to content

Commit

Permalink
Use joblib for parallel processing
Browse files Browse the repository at this point in the history
  • Loading branch information
wpreimes committed May 2, 2024
1 parent c3f5442 commit f6c2345
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 64 deletions.
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies:
- more_itertools
- smecv_grid
- tqdm
- joblib
# Optional, for documentation and testing
- nbconvert
- sphinx_rtd_theme
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ install_requires =
pyresample
tqdm
more_itertools
joblib
# The usage of test_requires is discouraged, see `Dependency Management` docs
# tests_require = pytest; pytest-cov
# Require a specific Python version, e.g. Python 2.7 or >= 3.4
Expand Down
197 changes: 137 additions & 60 deletions src/repurpose/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,29 @@
import warnings
import os

if 'numpy' in sys.modules:
warnings.warn("Numpy is already imported. Environment variables set in "
"repurpose.utils wont have any effect!")

# Note: Must be set BEFORE the first numpy import!!
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'
os.environ['OMP_NUM_THREADS'] = '1'
os.environ['MKL_DYNAMIC'] = 'FALSE'
os.environ['OPENBLAS_NUM_THREADS'] = '1'
# if 'numpy' in sys.modules:
# warnings.warn("Numpy is already imported. Please make sure "
# "`repurpose.process` is imported before numpy to avoid "
# "numpy multi-threading.")
#
# # Note: Must be set BEFORE the first numpy import!!
# os.environ['MKL_NUM_THREADS'] = '1'
# os.environ['NUMEXPR_NUM_THREADS'] = '1'
# os.environ['OMP_NUM_THREADS'] = '1'
# os.environ['MKL_DYNAMIC'] = 'FALSE'
# os.environ['OPENBLAS_NUM_THREADS'] = '1'

import numpy as np
from tqdm import tqdm
import logging
from multiprocessing import Pool
from datetime import datetime
import sys
from pathlib import Path
from typing import List
from typing import List, Any
from glob import glob
from joblib import Parallel, delayed, parallel_config
from logging.handlers import QueueHandler, QueueListener
from multiprocessing import Manager


class ImageBaseConnection:
Expand Down Expand Up @@ -102,8 +105,8 @@ def read(self, timestamp, **kwargs):


def rootdir() -> Path:
return Path(os.path.join(os.path.dirname(
os.path.abspath(__file__)))).parents[1]
p = str(os.path.join(os.path.dirname(os.path.abspath(__file__))))
return Path(p).parents[1]


def idx_chunks(idx, n=-1):
Expand All @@ -123,6 +126,63 @@ def idx_chunks(idx, n=-1):
for i in range(0, len(idx.values), n):
yield idx[i:i + n]

class ProgressParallel(Parallel):
def __init__(self, use_tqdm=True, total=None, desc="",
*args, **kwargs) -> None:
"""
Joblib parallel with progress bar
"""
self._use_tqdm = use_tqdm
self._total = total
self._desc = desc
super().__init__(*args, **kwargs)

def __call__(self, *args, **kwargs):
"""
Wraps progress bar around function calls
"""
with tqdm(
disable=not self._use_tqdm, total=self._total, desc=self._desc
) as self._pbar:
return Parallel.__call__(self, *args, **kwargs)
def print_progress(self):
"""
Updated the progress bar after each successful call
"""
if self._total is None:
self._pbar.total = self.n_dispatched_tasks
self._pbar.n = self.n_completed_tasks
self._pbar.refresh()

def configure_worker_logger(log_queue, log_level):
worker_logger = logging.getLogger('worker')
if not worker_logger.hasHandlers():
h = QueueHandler(log_queue)
worker_logger.addHandler(h)
worker_logger.setLevel(log_level)
return worker_logger

def run_with_error_handling(FUNC,
ignore_errors=False, log_queue=None, log_level="WARNING",
**kwargs) -> Any:

if log_queue is not None:
logger = configure_worker_logger(log_queue, log_level)
logger_name = logger.name
kwargs['logger_name'] = logger_name
else:
logger = logging.getLogger()

r = None

try:
r = FUNC(**kwargs)
except Exception as e:
if ignore_errors:
logger.error(f"Error: {e}")
else:
raise e
return r

def parallel_process_async(
FUNC,
Expand All @@ -136,8 +196,11 @@ def parallel_process_async(
log_filename=None,
loglevel="WARNING",
verbose=False,
progress_bar_label="Processed"
) -> List:
progress_bar_label="Processed",
backend="loky",
sharedmem=False,
parallel_kwargs=None,
) -> list:
"""
Applies the passed function to all elements of the passed iterables.
Parallel function calls are processed ASYNCHRONOUSLY (ie order of
Expand Down Expand Up @@ -178,15 +241,21 @@ def parallel_process_async(
loglevel: str, optional (default: "WARNING")
Log level to use for logging. Must be one of
["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"].
verbose: float, optional (default: False)
verbose: bool, optional (default: False)
Print all logging messages to stdout, useful for debugging.
progress_bar_label: str, optional (default: "Processed")
Label to use for the progress bar.
backend: Literal["threading", "multiprocessing", "loky"] = "loky"
The backend to use for parallel execution (if n_proc > 1).
Defaults to "loky". See joblib docs for more info.
sharedmem: bool, optional (default:True)
Activate shared memory option (slow)
Returns
-------
results: List
List of return values from each function call
results: list or None
List of return values from each function call or None if no return
values are found.
"""
if activate_logging:
logger = logging.getLogger()
Expand All @@ -213,7 +282,7 @@ def parallel_process_async(
if log_file:
os.makedirs(os.path.dirname(log_file), exist_ok=True)
logging.basicConfig(
filename=log_file,
filename=str(log_file),
level=loglevel.upper(),
format="%(levelname)s %(asctime)s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
Expand Down Expand Up @@ -246,48 +315,53 @@ def parallel_process_async(
kws.update(STATIC_KWARGS)
process_kwargs.append(kws)

if show_progress_bars:
pbar = tqdm(total=len(process_kwargs), desc=progress_bar_label)
else:
pbar = None

results = []

def update(r) -> None:
if r is not None:
results.append(r)
if pbar is not None:
pbar.update()

def error(e) -> None:
if logger is not None:
logging.error(e)
if not ignore_errors:
raise e
if pbar is not None:
pbar.update()

if n_proc == 1:
logging.info("Processing metadata with {} process.".format(n_proc))
results = []
if show_progress_bars:
pbar = tqdm(total=len(process_kwargs), desc=progress_bar_label)
else:
pbar = None
for kwargs in process_kwargs:
try:
r = FUNC(**kwargs)
update(r)
except Exception as e:
error(e)
r = run_with_error_handling(FUNC, ignore_errors, **kwargs)
if r is not None:
results.append(r)
if pbar is not None:
pbar.update()
else:
with Pool(n_proc) as pool:
for kwds in process_kwargs:
pool.apply_async(
FUNC,
kwds=kwds,
callback=update,
error_callback=error,
)
pool.close()
pool.join()

if pbar is not None:
pbar.close()
logging.info(f"Processing metadata with {n_proc} processes.")
if logger is not None:
m = Manager()
q = m.Queue()
listener = QueueListener(q, *logger.handlers)
listener.start()
log_level = logger.getEffectiveLevel()
else:
q = None
log_level = None
listener = None

with parallel_config(backend=backend, inner_max_num_threads=1):
results: list = ProgressParallel(
use_tqdm=show_progress_bars,
n_jobs=n_proc,
verbose=0,
total=len(process_kwargs),
desc=progress_bar_label,
require='sharedmem' if sharedmem else None,
return_as="list",
**parallel_kwargs or dict(),
)(delayed(run_with_error_handling)(
FUNC, ignore_errors,
log_queue=q,
log_level=log_level,
**kwargs)
for kwargs in process_kwargs)

results = [r for r in results if r is not None]

if listener is not None:
listener.stop()

if logger is not None:
if verbose:
Expand All @@ -299,4 +373,7 @@ def error(e) -> None:
handler.close()
handlers.clear()

return results
if len(results) == 0:
return None
else:
return results
14 changes: 10 additions & 4 deletions tests/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
import tempfile
import logging
import pytest

from repurpose.process import parallel_process_async, idx_chunks

Expand All @@ -24,12 +25,17 @@ def func(x: int, p: int):
logging.info(f'x={x}, p={p}')
return x**p

def test_apply_to_elements():

@pytest.mark.parametrize("n_proc,backend", [
("1", None), # backend doesn't matter in this case
("2", "threading"), ("2", "multiprocessing"), ("2", "loky")
])
def test_apply_to_elements(n_proc, backend):
iter_kwargs = {'x': [1, 2, 3, 4]}
static_kwargs = {'p': 2}
with tempfile.TemporaryDirectory() as log_path:
res = parallel_process_async(
func, iter_kwargs, static_kwargs, n_proc=1,
func, iter_kwargs, static_kwargs, n_proc=int(n_proc),
show_progress_bars=False, verbose=False, loglevel="DEBUG",
ignore_errors=True, log_path=log_path)
assert sorted(res) == [1, 4, 9, 16]
ignore_errors=True, log_path=log_path, backend=backend)
assert res == [1, 4, 9, 16]

0 comments on commit f6c2345

Please sign in to comment.