Skip to content

Commit

Permalink
Update parallel writing
Browse files Browse the repository at this point in the history
  • Loading branch information
wpreimes committed Jan 14, 2024
1 parent ce01f6e commit 6bdb7bf
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 77 deletions.
1 change: 0 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ dependencies:
- pynetcf>=0.5.0
- more_itertools
- smecv_grid
- sharedmem
- tqdm
# Optional, for documentation and testing
- nbconvert
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ install_requires =
pyresample
tqdm
more_itertools
sharedmem
# The usage of test_requires is discouraged, see `Dependency Management` docs
# tests_require = pytest; pytest-cov
# Require a specific Python version, e.g. Python 2.7 or >= 3.4
Expand Down
133 changes: 59 additions & 74 deletions src/repurpose/img2ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pygeogrids.netcdf as grid2nc
import pandas as pd
from pygeobase.object_base import Image
import sharedmem as mem
from multiprocessing import Manager


class Img2TsError(Exception):
Expand Down Expand Up @@ -194,13 +194,7 @@ def __init__(self,
self.orthogonal = None # to be set when reading data
self.timekey = None # to be set when reading data

# Multiprocessing only used when n_proc > 1 chosen
if platform.system().lower() != "linux" and n_proc != 1:
warnings.warn("Parallel processing is for now only supported "
"on Linux systems. Setting `n_proc=1`.")
self.n_proc = 1
else:
self.n_proc = n_proc
self.n_proc = n_proc

def _read_image(self, date, target_grid):
"""
Expand Down Expand Up @@ -344,6 +338,8 @@ def _write_orthogonal(self,
lats=cell_lats,
attributes=self.ts_attributes)

dataout.close()

def _write_non_orthogonal(self,
cell: int,
target_grid: CellGrid,
Expand Down Expand Up @@ -430,59 +426,6 @@ def _write_non_orthogonal(self,
attributes=self.ts_attributes,
dates_direct=True)

def _calc_cell(self, cell, img_stack_dict, timestamps, target_grid):
"""
Select time series cell data from global stack and write to netcdf
files.
Parameters
----------
cell: int
Cell number in the target grid
img_stack_dict: dict[str, mem.anonymousmemmap]
Dict containing the global image stacks to convert. Shared
between processes.
timestamps: numpy.ndarray
Array of datetime objects with same size as second dimension of
data arrays.
target_grid: CellGrid
Target points for resampling and storing the time series on.
"""""
# look where in the subset the data is
cell_index = np.where(
cell == target_grid.activearrcell)[0]

if cell_index.size == 0:
raise Img2TsError('cell not found in grid subset')

celldata = {}

for key in img_stack_dict.keys():
# rename variable in output dataset
if self.variable_rename is None:
var_new_name = str(key)
else:
var_new_name = self.variable_rename[key]

data = np.swapaxes(
img_stack_dict.get(key)[:, cell_index], 0, 1)

# change dtypes of output time series
if self.ts_dtypes is not None:
if type(self.ts_dtypes) == dict:
output_dtype = self.ts_dtypes[key]
else:
output_dtype = self.ts_dtypes
data = data.astype(output_dtype)

celldata[var_new_name] = data

if self.orthogonal:
self._write_orthogonal(cell, target_grid, celldata, timestamps)
elif not self.orthogonal:
# time information is contained in `celldata`
self._write_non_orthogonal(cell, target_grid, celldata)

def calc(self):
"""
Iterate through all images of the image stack and extract temporal
Expand All @@ -496,34 +439,76 @@ def calc(self):
for img_stack_dict, timestamps in self.img_bulk():
# ==================================================================
start_time = datetime.now()
cells = self.target_grid.get_cells()

# temporally drop grids, due to issue when pickling them...
target_grid = self.target_grid
input_grid = self.input_grid
self.target_grid = None
self.input_grid = None

from numpy.ctypeslib import as_ctypes
cells = target_grid.activearrcell

keys = list(img_stack_dict.keys())
for key in keys:
# rename variable in output dataset
if self.variable_rename is None:
var_new_name = str(key)
else:
var_new_name = self.variable_rename[key]

# change dtypes of output time series
if self.ts_dtypes is not None:
if type(self.ts_dtypes) == dict:
output_dtype = self.ts_dtypes[key]
else:
output_dtype = self.ts_dtypes
img_stack_dict[key] = img_stack_dict[key].astype(
output_dtype)

if var_new_name != key:
img_stack_dict[var_new_name] = img_stack_dict[key]
del img_stack_dict[key]

ITER_KWARGS = {'cell': [], 'celldata': []}

for cell in np.unique(target_grid.activearrcell):
cell_idx = np.where(cells == cell)[0]

if self.n_proc > 1:
# shared image stack between parallel processes
for k, v in img_stack_dict.items():
img_stack_dict[k] = mem.full_like(v, v)
if len(cell_idx) == 0:
continue

ITER_KWARGS['cell'].append(cell)

celldata = {}
for k in img_stack_dict.keys():
celldata[k] = np.swapaxes(
np.atleast_2d(img_stack_dict[k])[:, cell_idx], 0, 1)
img_stack_dict[k] = np.delete(img_stack_dict[k], cell_idx,
axis=1)

cells = np.delete(cells, cell_idx)

ITER_KWARGS['celldata'].append(celldata)

STATIC_KWARGS = {'target_grid': target_grid}

if self.orthogonal:
STATIC_KWARGS['timestamps'] = timestamps
FUNC = self._write_orthogonal
else:
# time information is contained in `celldata`
FUNC = self._write_non_orthogonal

parallel_process_async(
self._calc_cell,
ITER_KWARGS={'cell': cells},
STATIC_KWARGS={
'img_stack_dict': img_stack_dict,
'timestamps': timestamps,
'target_grid': target_grid,
},
FUNC=FUNC,
ITER_KWARGS=ITER_KWARGS,
STATIC_KWARGS=STATIC_KWARGS,
log_path=os.path.join(self.outputpath, '000_log'),
loglevel="INFO",
n_proc=self.n_proc,
show_progress_bars=False,
)

self.target_grid = target_grid
self.input_grid = input_grid

Expand Down
1 change: 0 additions & 1 deletion tests/test_img2ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,4 +298,3 @@ def test_img2ts_ortho_daily_no_resampling_missing_day():
assert dates_should == list(ts['time'])
nptest.assert_allclose(ds.dataset.variables['location_id'][:],
np.array([0, 1, 2, 3]))

0 comments on commit 6bdb7bf

Please sign in to comment.