-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add temp_write_split * Note in the docs * Updated docstring * Add file_handling to api * Expand test coverage * Finish test coverage * Fix
- Loading branch information
Showing
7 changed files
with
171 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,8 @@ dependencies: | |
- xgcm | ||
- pip | ||
- cf_xarray | ||
- cftime | ||
- zarr | ||
- pip: | ||
- codecov | ||
- pytest-cov | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import xarray as xr | ||
|
||
try: | ||
from fastprogress.fastprogress import progress_bar | ||
|
||
fastprogress = 1 | ||
except ImportError: | ||
fastprogress = None | ||
|
||
|
||
import shutil | ||
|
||
|
||
def temp_write_split( | ||
ds_in, | ||
folder, | ||
method="dimension", | ||
dim="time", | ||
split_interval=40, | ||
zarr_write_kwargs={}, | ||
zarr_read_kwargs={}, | ||
file_name_pattern="temp_write_split", | ||
verbose=False, | ||
): | ||
"""[summary] | ||
Parameters | ||
---------- | ||
ds_in : xr.Dataset | ||
input | ||
folder : pathlib.Path | ||
Target folder for temporary files | ||
method : str, optional | ||
Defines if the temporary files are split by an increment along a certain | ||
dimension("dimension") or by the variables of the dataset ("variables"), | ||
by default "dimension" | ||
dim : str, optional | ||
Dimension to split along (only relevant for `method="dimension"`), by default "time" | ||
split_interval : int, optional | ||
Steps along `dim` for each temporary file (only relevant for `method="dimension"`), by default 40 | ||
zarr_write_kwargs : dict, optional | ||
Kwargs parsed to `xr.to_zarr()`, by default {} | ||
zarr_read_kwargs : dict, optional | ||
Kwargs parsed to `xr.open_zarr()`, by default {} | ||
file_name_pattern : str, optional | ||
Pattern used to name the temporary files, by default "temp_write_split" | ||
verbose : bool, optional | ||
Activates printing, by default False | ||
Returns | ||
------- | ||
ds_out : xr.Dataset | ||
reloaded dataset, with value identical to `ds_in` | ||
flist : list | ||
List of paths to temporary datasets written. | ||
""" | ||
|
||
zarr_write_kwargs.setdefault("consolidated", False) | ||
zarr_read_kwargs.setdefault("use_cftime", True) | ||
zarr_read_kwargs.setdefault("consolidated", False) | ||
|
||
flist = [] | ||
if method == "dimension": | ||
split_points = list(range(0, len(ds_in[dim]), split_interval)) + [None] | ||
if verbose: | ||
print(f" Split indicies: {split_points}") | ||
|
||
nsi = len(split_points) - 1 | ||
if fastprogress: | ||
progress = progress_bar(range(nsi)) | ||
else: | ||
progress = range(nsi) | ||
|
||
for si in progress: | ||
fname = folder.joinpath(f"{file_name_pattern}_{si}.zarr") | ||
if fname.exists(): | ||
shutil.rmtree(fname) | ||
ds_in.isel({dim: slice(split_points[si], split_points[si + 1])}).to_zarr( | ||
fname, **zarr_write_kwargs | ||
) | ||
flist.append(fname) | ||
ds_out = xr.concat( | ||
[xr.open_zarr(f, **zarr_read_kwargs) for f in flist], dim=dim | ||
) | ||
elif method == "variables": | ||
# move all coords to data variables to avoid doubling up the writing for expensive (time resolved) coords | ||
reset_coords = [co for co in ds_in.coords if co not in ds_in.dims] | ||
ds_in = ds_in.reset_coords(reset_coords) | ||
|
||
variables = list(ds_in.data_vars) | ||
if verbose: | ||
print(variables) | ||
for var in variables: | ||
fname = folder.joinpath(f"{file_name_pattern}_{var}.zarr") | ||
if fname.exists(): | ||
shutil.rmtree( | ||
fname | ||
) # can I just overwrite with zarr? This can take long! | ||
ds_in[var].to_dataset(name=var).to_zarr(fname, **zarr_write_kwargs) | ||
flist.append(fname) | ||
ds_out = xr.merge([xr.open_zarr(f, **zarr_read_kwargs) for f in flist]) | ||
ds_out = ds_out.set_coords(reset_coords) | ||
else: | ||
raise ValueError(f"Method '{method}' not recognized.") | ||
return ds_out, flist |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import xarray as xr | ||
import numpy as np | ||
import pytest | ||
import pathlib | ||
from xarrayutils.file_handling import temp_write_split | ||
|
||
|
||
@pytest.mark.parametrize("dask", [True, False]) | ||
@pytest.mark.parametrize("verbose", [True, False]) | ||
@pytest.mark.parametrize("already_exists", [True, False]) | ||
@pytest.mark.parametrize("method", ["dimension", "variables", "wrong"]) | ||
def test_temp_write_split(dask, method, verbose, already_exists, tmpdir): | ||
folder = tmpdir.mkdir("sub") | ||
folder = pathlib.Path(folder) | ||
|
||
# create test dataset | ||
data = np.random.rand() | ||
time = xr.cftime_range("1850", freq="1AS", periods=12) | ||
ds = xr.DataArray(data, dims=["x", "y", "time"], coords={"time": time}).to_dataset( | ||
name="data" | ||
) | ||
if dask: | ||
ds = ds.chunk({"time": 1}) | ||
|
||
# write a manual copy (with wrong data) to test the erasing | ||
(ds.isel(time=0) + 100).to_zarr( | ||
folder.joinpath("temp_write_split_0.zarr"), consolidated=True | ||
) | ||
|
||
if method == "wrong": | ||
with pytest.raises(ValueError): | ||
temp_write_split( | ||
ds, | ||
folder, | ||
method=method, | ||
split_interval=1, | ||
) | ||
else: | ||
ds_reloaded, filelist = temp_write_split( | ||
ds, | ||
folder, | ||
method=method, | ||
verbose=verbose, | ||
split_interval=1, | ||
) | ||
xr.testing.assert_allclose(ds, ds_reloaded) |