Skip to content

Commit

Permalink
More file handling functions. (#95)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbusecke committed Mar 31, 2021
1 parent 131faed commit 7b70f15
Show file tree
Hide file tree
Showing 2 changed files with 279 additions and 7 deletions.
191 changes: 191 additions & 0 deletions xarrayutils/file_handling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import pathlib
import warnings
import functools
import numpy as np
import xarray as xr


try:
from fastprogress.fastprogress import progress_bar

Expand Down Expand Up @@ -103,3 +108,189 @@ def temp_write_split(
else:
raise ValueError(f"Method '{method}' not recognized.")
return ds_out, flist


def maybe_create_folder(path):
p = pathlib.Path(path)
if not p.exists():
p.mkdir(parents=True, exist_ok=True)
else:
warnings.warn(f"Folder {path} does already exist.", UserWarning)
return p


def total_nested_size(nested):
"""Calculate the size of a nested dict full of xarray objects
Parameters
----------
nested : dict
Input dictionary. Can have arbitrary nesting levels
Returns
-------
float
total size in bytes
"""
size = []

def _size(obj):
if not (isinstance(obj, xr.Dataset) or isinstance(obj, xr.DataArray)):
return {k: _size(v) for k, v in obj.items()}
else:
size.append(obj.nbytes)

_size(nested)

return np.sum(np.array(size))


def _maybe_pathlib(path):
if not isinstance(path, pathlib.PosixPath):
path = pathlib.Path(path)
return path


def _file_iszarr(path):
if ".nc" in str(path):
zarr = False
elif ".zarr" in str(path):
zarr = True
return zarr


def file_exist_check(filepath, check_zarr_consolidated_complete=True):
"""Check if a file exists, with some extra checks for zarr files
Parameters
----------
filepath : path
path to the file to check
check_zarr_consolidated_complete : bool, optional
Check if .zmetadata file was written (consolidated metadata), by default True
"""
filepath = _maybe_pathlib(filepath)

zarr = _file_iszarr(filepath)

basic_check = filepath.exists()
if zarr and check_zarr_consolidated_complete:
check = filepath.joinpath(".zmetadata").exists()
else:
check = True

return check and basic_check


def checkpath(func):
@functools.wraps(func)
def wrapper_checkpath(*args, **kwargs):
ds = args[0]
path = _maybe_pathlib(args[1])

# Do something before
overwrite = kwargs.pop("overwrite", False)
check_zarr_consolidated_complete = kwargs.pop(
"check_zarr_consolidated_complete", False
)
reload_saved = kwargs.pop("reload_saved", True)
write_kwargs = kwargs.pop("write_kwargs", {})
load_kwargs = kwargs.pop("load_kwargs", {})

load_kwargs.setdefault("use_cftime", True)
load_kwargs.setdefault("consolidated", True)
write_kwargs.setdefault("consolidated", load_kwargs["consolidated"])

zarr = _file_iszarr(path)
check = file_exist_check(
path, check_zarr_consolidated_complete=check_zarr_consolidated_complete
)

# check for the consolidated stuff... or just rewrite it?
if check and not overwrite:
print(f"File [{str(path)}] already exists. Skipping.")
else:
# the file might still exist (inclomplete) and then needs to be removed.
if path.exists():
print(f"Removing file {str(path)}")
if zarr:
shutil.rmtree(path)
else:
path.unlink()

func(ds, path, **write_kwargs)

# Do something after
ds_out = ds
if reload_saved:
print(f"$ Reloading file")
consolidated = load_kwargs.pop("consolidated")
if not zarr:
ds_out = xr.open_dataset(str(path), **load_kwargs)
else:
ds_out = xr.open_zarr(
str(path), consolidated=consolidated, **load_kwargs
)

return ds_out

return wrapper_checkpath


@checkpath
def write(
ds,
path,
print_size=True,
consolidated=True,
**kwargs,
):
"""Convenience function to save large datasets.
Performs the following additional steps (compared to e.g. xr.to_netcdf() or xr.to_zarr())
1. Checks for existing files (with special checks for zarr files)
2. Handles existing files via `overwrite` argument.
3. Checks attributes for incompatible values
4. Optional: Prints size of saved dataset
4. Optional: Returns the saved dataset loaded from disk (e.g. for quality control)
Parameters
----------
ds : xr.Dataset
Input dataset
path : pathlib.Path
filepath to save to. Ending determines the output type (`.nc` for netcdf, `.zarr` for zarr)
print_size : bool, optional
If true prints the size of the dataset before saving, by default True
reload_saved : bool, optional
If true the returned datasets is opened from the written file,
otherwise the input is returned, by default True
open_kwargs : dict
Arguments passed to the reloading function (either xr.open_dataset or xr.open_zarr based on filename)
write_kwargs : dict
Arguments passed to the writing function (either xr.to_netcdf or xr.to_zarr based on filename)
overwrite : bool, optional
If True, overwrite existing files, by default False
check_zarr_consolidated_complete: bool, optional
If True check if `.zmetadata` is present in zarr store, and overwrite if not present, by default False
Returns
-------
xr.Dataset
Returns either the unmodified input dataset or a reloaded version from the written file
"""

for k, v in ds.attrs.items():
if isinstance(v, xr.Dataset) or isinstance(v, xr.DataArray):
raise RuntimeError(f"Found an attrs ({k}) in with xarray values:{v}.")

zarr = _file_iszarr(path)

if print_size:
print(f"$ Saving {ds.nbytes/1e9}GB to {path}")

if zarr:
ds.to_zarr(path, consolidated=consolidated, **kwargs)
else:
ds.to_netcdf(path, **kwargs)
95 changes: 88 additions & 7 deletions xarrayutils/test/test_file_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,33 @@
import numpy as np
import pytest
import pathlib
from xarrayutils.file_handling import temp_write_split
from xarrayutils.file_handling import (
temp_write_split,
maybe_create_folder,
total_nested_size,
write,
)


@pytest.fixture
def ds():
data = np.random.rand()
time = xr.cftime_range("1850", freq="1AS", periods=12)
ds = xr.DataArray(data, dims=["x", "y", "time"], coords={"time": time}).to_dataset(
name="data"
)
return ds


@pytest.mark.parametrize("dask", [True, False])
@pytest.mark.parametrize("verbose", [True, False])
@pytest.mark.parametrize("already_exists", [True, False])
@pytest.mark.parametrize("method", ["dimension", "variables", "wrong"])
def test_temp_write_split(dask, method, verbose, already_exists, tmpdir):
def test_temp_write_split(ds, dask, method, verbose, already_exists, tmpdir):
folder = tmpdir.mkdir("sub")
folder = pathlib.Path(folder)

# create test dataset
data = np.random.rand()
time = xr.cftime_range("1850", freq="1AS", periods=12)
ds = xr.DataArray(data, dims=["x", "y", "time"], coords={"time": time}).to_dataset(
name="data"
)
if dask:
ds = ds.chunk({"time": 1})

Expand All @@ -44,3 +54,74 @@ def test_temp_write_split(dask, method, verbose, already_exists, tmpdir):
split_interval=1,
)
xr.testing.assert_allclose(ds, ds_reloaded)


@pytest.mark.parametrize("sub", ["sub", "nested/sub/path"])
def test_maybe_create_folder(sub, tmp_path):
folder = pathlib.Path(tmp_path)
subfolder = folder.joinpath(sub)

maybe_create_folder(subfolder)

assert subfolder.exists()

with pytest.warns(UserWarning):
maybe_create_folder(subfolder)


def test_total_nested_size(ds):

# create a bunch of broadcasted copies of a dataset
a = ds.copy(deep=True).expand_dims(new=2)
b = ds.copy(deep=True).expand_dims(new=5)
c = ds.copy(deep=True).expand_dims(new=4, new_new=10)

# nest them into a dict
nested_dict = {"experiment_a": a, "experiment_b": {"label_x": b, "label_y": c}}
size_nested = total_nested_size(nested_dict)

assert size_nested == np.sum(np.array([i.nbytes for i in [a, b, c]]))


@pytest.mark.parametrize("strpath", [True, False])
@pytest.mark.parametrize("reload_saved", [True, False])
@pytest.mark.parametrize("overwrite", [True, False])
@pytest.mark.parametrize("filetype", [".nc", ".zarr"])
def test_write(ds, strpath, reload_saved, overwrite, filetype, tmpdir):
def _load(path):
if filetype == ".nc":
return xr.open_dataset(path, use_cftime=True)
else:
return xr.open_zarr(path, use_cftime=True)

folder = pathlib.Path(tmpdir)
path = folder.joinpath("file" + filetype)
if strpath:
path_write = str(path)
else:
path_write = path

write(ds, path)
assert path.exists()
xr.testing.assert_allclose(ds, _load(path))

# create modified
ds_modified = ds * 4
dummy = write(
ds_modified, path_write, overwrite=overwrite, reload_saved=reload_saved
)

if not overwrite:
# this should not overwrite
xr.testing.assert_allclose(_load(path_write), ds)

else:
# this should
xr.testing.assert_allclose(_load(path_write), ds_modified)

# check the reloaded file
dummy = dummy.load()
if reload_saved:
xr.testing.assert_allclose(dummy, _load(path_write))
else:
xr.testing.assert_allclose(dummy, ds_modified)

0 comments on commit 7b70f15

Please sign in to comment.