diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index a5ddceebf11..ff3e8ab7e63 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -72,8 +72,7 @@ jobs: runs-on: "ubuntu-latest" needs: detect-ci-trigger # temporarily skipping due to https://github.com/pydata/xarray/issues/6551 - # if: needs.detect-ci-trigger.outputs.triggered == 'false' - if: false + if: needs.detect-ci-trigger.outputs.triggered == 'false' defaults: run: shell: bash -l {0} @@ -105,10 +104,10 @@ jobs: - name: Install mypy run: | python -m pip install mypy - python -m mypy --install-types --non-interactive - name: Run mypy - run: python -m mypy + run: | + python -m mypy --install-types --non-interactive min-version-policy: name: Minimum Version Policy diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d4ce6ebc523..6d6c94ff88f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: hooks: - id: isort - repo: https://github.com/asottile/pyupgrade - rev: v2.32.0 + rev: v2.32.1 hooks: - id: pyupgrade args: @@ -46,7 +46,7 @@ repos: # - id: velin # args: ["--write", "--compact"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.942 + rev: v0.950 hooks: - id: mypy # Copied from setup.cfg diff --git a/asv_bench/asv.conf.json b/asv_bench/asv.conf.json index 3e4137cf807..5de6d6a4f76 100644 --- a/asv_bench/asv.conf.json +++ b/asv_bench/asv.conf.json @@ -58,6 +58,8 @@ // "pip+emcee": [""], // emcee is only available for install with pip. // }, "matrix": { + "setuptools_scm[toml]": [""], // GH6609 + "setuptools_scm_git_archive": [""], // GH6609 "numpy": [""], "pandas": [""], "netcdf4": [""], @@ -65,6 +67,8 @@ "bottleneck": [""], "dask": [""], "distributed": [""], + "flox": [""], + "numpy_groupies": [""], "sparse": [""] }, diff --git a/asv_bench/benchmarks/groupby.py b/asv_bench/benchmarks/groupby.py index fa93ce9e8b5..490c2ccbd4c 100644 --- a/asv_bench/benchmarks/groupby.py +++ b/asv_bench/benchmarks/groupby.py @@ -13,6 +13,7 @@ def setup(self, *args, **kwargs): { "a": xr.DataArray(np.r_[np.repeat(1, self.n), np.repeat(2, self.n)]), "b": xr.DataArray(np.arange(2 * self.n)), + "c": xr.DataArray(np.arange(2 * self.n)), } ) self.ds2d = self.ds1d.expand_dims(z=10) @@ -50,10 +51,11 @@ class GroupByDask(GroupBy): def setup(self, *args, **kwargs): requires_dask() super().setup(**kwargs) - self.ds1d = self.ds1d.sel(dim_0=slice(None, None, 2)).chunk({"dim_0": 50}) - self.ds2d = self.ds2d.sel(dim_0=slice(None, None, 2)).chunk( - {"dim_0": 50, "z": 5} - ) + + self.ds1d = self.ds1d.sel(dim_0=slice(None, None, 2)) + self.ds1d["c"] = self.ds1d["c"].chunk({"dim_0": 50}) + self.ds2d = self.ds2d.sel(dim_0=slice(None, None, 2)) + self.ds2d["c"] = self.ds2d["c"].chunk({"dim_0": 50, "z": 5}) self.ds1d_mean = self.ds1d.groupby("b").mean() self.ds2d_mean = self.ds2d.groupby("b").mean() diff --git a/asv_bench/benchmarks/polyfit.py b/asv_bench/benchmarks/polyfit.py new file mode 100644 index 00000000000..429ffa19baa --- /dev/null +++ b/asv_bench/benchmarks/polyfit.py @@ -0,0 +1,38 @@ +import numpy as np + +import xarray as xr + +from . import parameterized, randn, requires_dask + +NDEGS = (2, 5, 20) +NX = (10**2, 10**6) + + +class Polyval: + def setup(self, *args, **kwargs): + self.xs = {nx: xr.DataArray(randn((nx,)), dims="x", name="x") for nx in NX} + self.coeffs = { + ndeg: xr.DataArray( + randn((ndeg,)), dims="degree", coords={"degree": np.arange(ndeg)} + ) + for ndeg in NDEGS + } + + @parameterized(["nx", "ndeg"], [NX, NDEGS]) + def time_polyval(self, nx, ndeg): + x = self.xs[nx] + c = self.coeffs[ndeg] + xr.polyval(x, c).compute() + + @parameterized(["nx", "ndeg"], [NX, NDEGS]) + def peakmem_polyval(self, nx, ndeg): + x = self.xs[nx] + c = self.coeffs[ndeg] + xr.polyval(x, c).compute() + + +class PolyvalDask(Polyval): + def setup(self, *args, **kwargs): + requires_dask() + super().setup(*args, **kwargs) + self.xs = {k: v.chunk({"x": 10000}) for k, v in self.xs.items()} diff --git a/ci/install-upstream-wheels.sh b/ci/install-upstream-wheels.sh index 96a39ccd20b..ff5615c17c6 100755 --- a/ci/install-upstream-wheels.sh +++ b/ci/install-upstream-wheels.sh @@ -15,6 +15,7 @@ conda uninstall -y --force \ pint \ bottleneck \ sparse \ + flox \ h5netcdf \ xarray # to limit the runtime of Upstream CI @@ -47,4 +48,5 @@ python -m pip install \ git+https://github.com/pydata/sparse \ git+https://github.com/intake/filesystem_spec \ git+https://github.com/SciTools/nc-time-axis \ + git+https://github.com/dcherian/flox \ git+https://github.com/h5netcdf/h5netcdf diff --git a/ci/requirements/all-but-dask.yml b/ci/requirements/all-but-dask.yml index cb9ec8d3bc5..e20ec2016ed 100644 --- a/ci/requirements/all-but-dask.yml +++ b/ci/requirements/all-but-dask.yml @@ -13,6 +13,7 @@ dependencies: - cfgrib - cftime - coveralls + - flox - h5netcdf - h5py - hdf5 diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index 6c389c22ce6..634140fe84b 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -10,6 +10,7 @@ dependencies: - cftime - dask-core - distributed + - flox - fsspec!=2021.7.0 - h5netcdf - h5py diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 516c964afc7..d37bb7dc44a 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -12,6 +12,7 @@ dependencies: - cftime - dask-core - distributed + - flox - fsspec!=2021.7.0 - h5netcdf - h5py diff --git a/ci/requirements/min-all-deps.yml b/ci/requirements/min-all-deps.yml index 76e2b28093d..34879af730b 100644 --- a/ci/requirements/min-all-deps.yml +++ b/ci/requirements/min-all-deps.yml @@ -10,46 +10,46 @@ dependencies: - python=3.8 - boto3=1.13 - bottleneck=1.3 - # cartopy 0.18 conflicts with pynio - - cartopy=0.17 + - cartopy=0.19 - cdms2=3.1 - cfgrib=0.9 - - cftime=1.2 + - cftime=1.4 - coveralls - - dask-core=2.30 - - distributed=2.30 - - h5netcdf=0.8 - - h5py=2.10 - # hdf5 1.12 conflicts with h5py=2.10 + - dask-core=2021.04 + - distributed=2021.04 + - flox=0.5 + - h5netcdf=0.11 + - h5py=3.1 + # hdf5 1.12 conflicts with h5py=3.1 - hdf5=1.10 - hypothesis - iris=2.4 - lxml=4.6 # Optional dep of pydap - - matplotlib-base=3.3 + - matplotlib-base=3.4 - nc-time-axis=1.2 # netcdf follows a 1.major.minor[.patch] convention # (see https://github.com/Unidata/netcdf4-python/issues/1090) # bumping the netCDF4 version is currently blocked by #4491 - netcdf4=1.5.3 - - numba=0.51 - - numpy=1.18 + - numba=0.53 + - numpy=1.19 - packaging=20.0 - - pandas=1.1 - - pint=0.16 + - pandas=1.2 + - pint=0.17 - pip - pseudonetcdf=3.1 - pydap=3.2 - - pynio=1.5 + # - pynio=1.5.5 - pytest - pytest-cov - pytest-env - pytest-xdist - - rasterio=1.1 - - scipy=1.5 + - rasterio=1.2 + - scipy=1.6 - seaborn=0.11 - - sparse=0.11 + - sparse=0.12 - toolz=0.11 - typing_extensions=3.7 - - zarr=2.5 + - zarr=2.8 - pip: - numbagg==0.1 diff --git a/doc/getting-started-guide/installing.rst b/doc/getting-started-guide/installing.rst index 0668853946f..faa0fba5dd3 100644 --- a/doc/getting-started-guide/installing.rst +++ b/doc/getting-started-guide/installing.rst @@ -7,9 +7,9 @@ Required dependencies --------------------- - Python (3.8 or later) -- `numpy `__ (1.18 or later) +- `numpy `__ (1.19 or later) - `packaging `__ (20.0 or later) -- `pandas `__ (1.1 or later) +- `pandas `__ (1.2 or later) .. _optional-dependencies: diff --git a/doc/internals/extending-xarray.rst b/doc/internals/extending-xarray.rst index 2951ce10f21..f8b61d12a2f 100644 --- a/doc/internals/extending-xarray.rst +++ b/doc/internals/extending-xarray.rst @@ -92,8 +92,8 @@ on ways to write new accessors and the philosophy behind the approach, see To help users keep things straight, please `let us know `_ if you plan to write a new accessor -for an open source library. In the future, we will maintain a list of accessors -and the libraries that implement them on this page. +for an open source library. Existing open source accessors and the libraries +that implement them are available in the list on the :ref:`ecosystem` page. To make documenting accessors with ``sphinx`` and ``sphinx.ext.autosummary`` easier, you can use `sphinx-autosummary-accessors`_. diff --git a/doc/user-guide/dask.rst b/doc/user-guide/dask.rst index 5110a970390..56717f5306e 100644 --- a/doc/user-guide/dask.rst +++ b/doc/user-guide/dask.rst @@ -84,7 +84,7 @@ argument to :py:func:`~xarray.open_dataset` or using the In this example ``latitude`` and ``longitude`` do not appear in the ``chunks`` dict, so only one chunk will be used along those dimensions. It is also -entirely equivalent to opening a dataset using :py:meth:`~xarray.open_dataset` +entirely equivalent to opening a dataset using :py:func:`~xarray.open_dataset` and then chunking the data using the ``chunk`` method, e.g., ``xr.open_dataset('example-data.nc').chunk({'time': 10})``. @@ -95,13 +95,21 @@ use :py:func:`~xarray.open_mfdataset`:: This function will automatically concatenate and merge datasets into one in the simple cases that it understands (see :py:func:`~xarray.combine_by_coords` -for the full disclaimer). By default, :py:meth:`~xarray.open_mfdataset` will chunk each +for the full disclaimer). By default, :py:func:`~xarray.open_mfdataset` will chunk each netCDF file into a single Dask array; again, supply the ``chunks`` argument to control the size of the resulting Dask arrays. In more complex cases, you can -open each file individually using :py:meth:`~xarray.open_dataset` and merge the result, as -described in :ref:`combining data`. Passing the keyword argument ``parallel=True`` to :py:meth:`~xarray.open_mfdataset` will speed up the reading of large multi-file datasets by +open each file individually using :py:func:`~xarray.open_dataset` and merge the result, as +described in :ref:`combining data`. Passing the keyword argument ``parallel=True`` to +:py:func:`~xarray.open_mfdataset` will speed up the reading of large multi-file datasets by executing those read tasks in parallel using ``dask.delayed``. +.. warning:: + + :py:func:`~xarray.open_mfdataset` called without ``chunks`` argument will return + dask arrays with chunk sizes equal to the individual files. Re-chunking + the dataset after creation with ``ds.chunk()`` will lead to an ineffective use of + memory and is not recommended. + You'll notice that printing a dataset still shows a preview of array values, even if they are actually Dask arrays. We can do this quickly with Dask because we only need to compute the first few values (typically from the first block). @@ -224,6 +232,7 @@ disk. available memory. .. note:: + For more on the differences between :py:meth:`~xarray.Dataset.persist` and :py:meth:`~xarray.Dataset.compute` see this `Stack Overflow answer `_ and the `Dask documentation `_. @@ -236,6 +245,11 @@ sizes of Dask arrays is done with the :py:meth:`~xarray.Dataset.chunk` method: rechunked = ds.chunk({"latitude": 100, "longitude": 100}) +.. warning:: + + Rechunking an existing dask array created with :py:func:`~xarray.open_mfdataset` + is not recommended (see above). + You can view the size of existing chunks on an array by viewing the :py:attr:`~xarray.Dataset.chunks` attribute: @@ -295,8 +309,7 @@ each block of your xarray object, you have three options: ``apply_ufunc`` ~~~~~~~~~~~~~~~ -Another option is to use xarray's :py:func:`~xarray.apply_ufunc`, which can -automate `embarrassingly parallel +:py:func:`~xarray.apply_ufunc` automates `embarrassingly parallel `__ "map" type operations where a function written for processing NumPy arrays should be repeatedly applied to xarray objects containing Dask arrays. It works similarly to @@ -542,18 +555,20 @@ larger chunksizes. Optimization Tips ----------------- -With analysis pipelines involving both spatial subsetting and temporal resampling, Dask performance can become very slow in certain cases. Here are some optimization tips we have found through experience: +With analysis pipelines involving both spatial subsetting and temporal resampling, Dask performance +can become very slow or memory hungry in certain cases. Here are some optimization tips we have found +through experience: -1. Do your spatial and temporal indexing (e.g. ``.sel()`` or ``.isel()``) early in the pipeline, especially before calling ``resample()`` or ``groupby()``. Grouping and resampling triggers some computation on all the blocks, which in theory should commute with indexing, but this optimization hasn't been implemented in Dask yet. (See `Dask issue #746 `_). +1. Do your spatial and temporal indexing (e.g. ``.sel()`` or ``.isel()``) early in the pipeline, especially before calling ``resample()`` or ``groupby()``. Grouping and resampling triggers some computation on all the blocks, which in theory should commute with indexing, but this optimization hasn't been implemented in Dask yet. (See `Dask issue #746 `_). More generally, ``groupby()`` is a costly operation and does not (yet) perform well on datasets split across multiple files (see :pull:`5734` and linked discussions there). 2. Save intermediate results to disk as a netCDF files (using ``to_netcdf()``) and then load them again with ``open_dataset()`` for further computations. For example, if subtracting temporal mean from a dataset, save the temporal mean to disk before subtracting. Again, in theory, Dask should be able to do the computation in a streaming fashion, but in practice this is a fail case for the Dask scheduler, because it tries to keep every chunk of an array that it computes in memory. (See `Dask issue #874 `_) -3. Specify smaller chunks across space when using :py:meth:`~xarray.open_mfdataset` (e.g., ``chunks={'latitude': 10, 'longitude': 10}``). This makes spatial subsetting easier, because there's no risk you will load chunks of data referring to different chunks (probably not necessary if you follow suggestion 1). +3. Specify smaller chunks across space when using :py:meth:`~xarray.open_mfdataset` (e.g., ``chunks={'latitude': 10, 'longitude': 10}``). This makes spatial subsetting easier, because there's no risk you will load subsets of data which span multiple chunks. On individual files, prefer to subset before chunking (suggestion 1). + +4. Chunk as early as possible, and avoid rechunking as much as possible. Always pass the ``chunks={}`` argument to :py:func:`~xarray.open_mfdataset` to avoid redundant file reads. -4. Using the h5netcdf package by passing ``engine='h5netcdf'`` to :py:meth:`~xarray.open_mfdataset` - can be quicker than the default ``engine='netcdf4'`` that uses the netCDF4 package. +5. Using the h5netcdf package by passing ``engine='h5netcdf'`` to :py:meth:`~xarray.open_mfdataset` can be quicker than the default ``engine='netcdf4'`` that uses the netCDF4 package. -5. Some dask-specific tips may be found `here `_. +6. Some dask-specific tips may be found `here `_. -6. The dask `diagnostics `_ can be - useful in identifying performance bottlenecks. +7. The dask `diagnostics `_ can be useful in identifying performance bottlenecks. diff --git a/doc/user-guide/terminology.rst b/doc/user-guide/terminology.rst index 1876058323e..c8cfdd5133d 100644 --- a/doc/user-guide/terminology.rst +++ b/doc/user-guide/terminology.rst @@ -27,7 +27,7 @@ complete examples, please consult the relevant documentation.* Variable A `NetCDF-like variable - `_ + `_ consisting of dimensions, data, and attributes which describe a single array. The main functional difference between variables and numpy arrays is that numerical operations on variables implement array broadcasting diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 4882402073c..c9ee52f3da0 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -41,10 +41,37 @@ New Features - Allow passing chunks in ``**kwargs`` form to :py:meth:`Dataset.chunk`, :py:meth:`DataArray.chunk`, and :py:meth:`Variable.chunk`. (:pull:`6471`) By `Tom Nicholas `_. +- Expose `inline_array` kwarg from `dask.array.from_array` in :py:func:`open_dataset`, :py:meth:`Dataset.chunk`, + :py:meth:`DataArray.chunk`, and :py:meth:`Variable.chunk`. (:pull:`6471`) + By `Tom Nicholas `_. +- :py:meth:`xr.polyval` now supports :py:class:`Dataset` and :py:class:`DataArray` args of any shape, + is faster and requires less memory. (:pull:`6548`) + By `Michael Niklas `_. +- Improved overall typing. Breaking changes ~~~~~~~~~~~~~~~~ +- PyNIO support is now untested. The minimum versions of some dependencies were changed: + + =============== ===== ==== + Package Old New + =============== ===== ==== + cftime 1.2 1.4 + dask 2.30 2021.4 + distributed 2.30 2021.4 + h5netcdf 0.8 0.11 + matplotlib-base 3.3 3.4 + numba 0.51 0.53 + numpy 1.18 1.19 + pandas 1.1 1.2 + pint 0.16 0.17 + rasterio 1.1 1.2 + scipy 1.5 1.6 + sparse 0.11 0.12 + zarr 2.5 2.8 + =============== ===== ==== + - The Dataset and DataArray ``rename*`` methods do not implicitly add or drop indexes. (:pull:`5692`). By `Benoît Bovy `_. @@ -54,6 +81,10 @@ Breaking changes - Xarray's ufuncs have been removed, now that they can be replaced by numpy's ufuncs in all supported versions of numpy. By `Maximilian Roos `_. +- :py:meth:`xr.polyval` now uses the ``coord`` argument directly instead of its index coordinate. + (:pull:`6548`) + By `Michael Niklas `_. + Deprecations ~~~~~~~~~~~~ @@ -92,6 +123,12 @@ Bug fixes :pull:`6489`). By `Spencer Clark `_. - Dark themes are now properly detected in Furo-themed Sphinx documents (:issue:`6500`, :pull:`6501`). By `Kevin Paul `_. +- :py:meth:`isel` with `drop=True` works as intended with scalar :py:class:`DataArray` indexers. + (:issue:`6554`, :pull:`6579`) + By `Michael Niklas `_. +- Fixed silent overflow issue when decoding times encoded with 32-bit and below + unsigned integer data types (:issue:`6589`, :pull:`6598`). By `Spencer Clark + `_. Documentation ~~~~~~~~~~~~~ @@ -107,6 +144,11 @@ Performance - GroupBy binary operations are now vectorized. Previously this involved looping over all groups. (:issue:`5804`,:pull:`6160`) By `Deepak Cherian `_. +- Substantially improved GroupBy operations using `flox `_. + This is auto-enabled when ``flox`` is installed. Use ``xr.set_options(use_flox=False)`` to use + the old algorithm. (:issue:`4473`, :issue:`4498`, :issue:`659`, :issue:`2237`, :pull:`271`). + By `Deepak Cherian `_,`Anderson Banihirwe `_, + `Jimmy Westling `_. Internal Changes ~~~~~~~~~~~~~~~~ diff --git a/setup.cfg b/setup.cfg index 05b202810b4..f5dd4dde810 100644 --- a/setup.cfg +++ b/setup.cfg @@ -75,8 +75,8 @@ zip_safe = False # https://mypy.readthedocs.io/en/latest/installed_packages.htm include_package_data = True python_requires = >=3.8 install_requires = - numpy >= 1.18 - pandas >= 1.1 + numpy >= 1.19 + pandas >= 1.2 packaging >= 20.0 [options.extras_require] @@ -98,6 +98,7 @@ accel = scipy bottleneck numbagg + flox parallel = dask[complete] diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 9967b0a08c0..95f8dbc6eaf 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -17,7 +17,7 @@ import numpy as np -from .. import backends, coding, conventions +from .. import backends, conventions from ..core import indexing from ..core.combine import ( _infer_concat_order_from_positions, @@ -35,7 +35,7 @@ try: from dask.delayed import Delayed except ImportError: - Delayed = None + Delayed = None # type: ignore DATAARRAY_NAME = "__xarray_dataarray_name__" @@ -274,6 +274,7 @@ def _chunk_ds( engine, chunks, overwrite_encoded_chunks, + inline_array, **extra_tokens, ): from dask.base import tokenize @@ -292,6 +293,7 @@ def _chunk_ds( overwrite_encoded_chunks=overwrite_encoded_chunks, name_prefix=name_prefix, token=token, + inline_array=inline_array, ) return backend_ds._replace(variables) @@ -303,6 +305,7 @@ def _dataset_from_backend_dataset( chunks, cache, overwrite_encoded_chunks, + inline_array, **extra_tokens, ): if not isinstance(chunks, (int, dict)) and chunks not in {None, "auto"}: @@ -320,6 +323,7 @@ def _dataset_from_backend_dataset( engine, chunks, overwrite_encoded_chunks, + inline_array, **extra_tokens, ) @@ -346,6 +350,7 @@ def open_dataset( concat_characters=None, decode_coords=None, drop_variables=None, + inline_array=False, backend_kwargs=None, **kwargs, ): @@ -430,6 +435,12 @@ def open_dataset( A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or inconsistent values. + inline_array: bool, optional + How to include the array in the dask task graph. + By default(``inline_array=False``) the array is included in a task by + itself, and each chunk refers to that task by its key. With + ``inline_array=True``, Dask will instead inline the array directly + in the values of the task graph. See :py:func:`dask.array.from_array`. backend_kwargs: dict Additional keyword arguments passed on to the engine open function, equivalent to `**kwargs`. @@ -505,6 +516,7 @@ def open_dataset( chunks, cache, overwrite_encoded_chunks, + inline_array, drop_variables=drop_variables, **decoders, **kwargs, @@ -526,6 +538,7 @@ def open_dataarray( concat_characters=None, decode_coords=None, drop_variables=None, + inline_array=False, backend_kwargs=None, **kwargs, ): @@ -613,6 +626,12 @@ def open_dataarray( A variable or list of variables to exclude from being parsed from the dataset. This may be useful to drop variables with problems or inconsistent values. + inline_array: bool, optional + How to include the array in the dask task graph. + By default(``inline_array=False``) the array is included in a task by + itself, and each chunk refers to that task by its key. With + ``inline_array=True``, Dask will instead inline the array directly + in the values of the task graph. See :py:func:`dask.array.from_array`. backend_kwargs: dict Additional keyword arguments passed on to the engine open function, equivalent to `**kwargs`. @@ -660,6 +679,7 @@ def open_dataarray( chunks=chunks, cache=cache, drop_variables=drop_variables, + inline_array=inline_array, backend_kwargs=backend_kwargs, use_cftime=use_cftime, decode_timedelta=decode_timedelta, @@ -1277,28 +1297,40 @@ def _validate_region(ds, region): ) -def _validate_datatypes_for_zarr_append(dataset): - """DataArray.name and Dataset keys must be a string or None""" +def _validate_datatypes_for_zarr_append(zstore, dataset): + """If variable exists in the store, confirm dtype of the data to append is compatible with + existing dtype. + """ + + existing_vars = zstore.get_variables() - def check_dtype(var): + def check_dtype(vname, var): if ( - not np.issubdtype(var.dtype, np.number) - and not np.issubdtype(var.dtype, np.datetime64) - and not np.issubdtype(var.dtype, np.bool_) - and not coding.strings.is_unicode_dtype(var.dtype) - and not var.dtype == object + vname not in existing_vars + or np.issubdtype(var.dtype, np.number) + or np.issubdtype(var.dtype, np.datetime64) + or np.issubdtype(var.dtype, np.bool_) + or var.dtype == object ): - # and not re.match('^bytes[1-9]+$', var.dtype.name)): + # We can skip dtype equality checks under two conditions: (1) if the var to append is + # new to the dataset, because in this case there is no existing var to compare it to; + # or (2) if var to append's dtype is known to be easy-to-append, because in this case + # we can be confident appending won't cause problems. Examples of dtypes which are not + # easy-to-append include length-specified strings of type `|S*` or ` "Dataset": raise NotImplementedError() + def _flox_reduce( + self, + dim: Union[None, Hashable, Sequence[Hashable]], + **kwargs, + ) -> "Dataset": + raise NotImplementedError() + def count( self, dim: Union[None, Hashable, Sequence[Hashable]] = None, @@ -2021,13 +2035,23 @@ def count( Data variables: da (labels) int64 1 2 2 """ - return self.reduce( - duck_array_ops.count, - dim=dim, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="count", + dim=dim, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.count, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def all( self, @@ -2095,13 +2119,23 @@ def all( Data variables: da (labels) bool False True True """ - return self.reduce( - duck_array_ops.array_all, - dim=dim, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="all", + dim=dim, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_all, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def any( self, @@ -2169,13 +2203,23 @@ def any( Data variables: da (labels) bool True True True """ - return self.reduce( - duck_array_ops.array_any, - dim=dim, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="any", + dim=dim, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_any, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def max( self, @@ -2259,14 +2303,25 @@ def max( Data variables: da (labels) float64 nan 2.0 3.0 """ - return self.reduce( - duck_array_ops.max, - dim=dim, - skipna=skipna, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="max", + dim=dim, + skipna=skipna, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def min( self, @@ -2350,14 +2405,25 @@ def min( Data variables: da (labels) float64 nan 2.0 1.0 """ - return self.reduce( - duck_array_ops.min, - dim=dim, - skipna=skipna, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="min", + dim=dim, + skipna=skipna, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def mean( self, @@ -2445,14 +2511,25 @@ def mean( Data variables: da (labels) float64 nan 2.0 2.0 """ - return self.reduce( - duck_array_ops.mean, - dim=dim, - skipna=skipna, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="mean", + dim=dim, + skipna=skipna, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def prod( self, @@ -2557,15 +2634,27 @@ def prod( Data variables: da (labels) float64 nan 4.0 3.0 """ - return self.reduce( - duck_array_ops.prod, - dim=dim, - skipna=skipna, - min_count=min_count, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="prod", + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def sum( self, @@ -2670,15 +2759,27 @@ def sum( Data variables: da (labels) float64 nan 4.0 4.0 """ - return self.reduce( - duck_array_ops.sum, - dim=dim, - skipna=skipna, - min_count=min_count, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="sum", + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def std( self, @@ -2780,15 +2881,27 @@ def std( Data variables: da (labels) float64 nan 0.0 1.414 """ - return self.reduce( - duck_array_ops.std, - dim=dim, - skipna=skipna, - ddof=ddof, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="std", + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def var( self, @@ -2890,15 +3003,27 @@ def var( Data variables: da (labels) float64 nan 0.0 2.0 """ - return self.reduce( - duck_array_ops.var, - dim=dim, - skipna=skipna, - ddof=ddof, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="var", + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def median( self, @@ -2997,7 +3122,7 @@ def median( class DatasetResampleReductions: - __slots__ = () + _obj: "Dataset" def reduce( self, @@ -3011,6 +3136,13 @@ def reduce( ) -> "Dataset": raise NotImplementedError() + def _flox_reduce( + self, + dim: Union[None, Hashable, Sequence[Hashable]], + **kwargs, + ) -> "Dataset": + raise NotImplementedError() + def count( self, dim: Union[None, Hashable, Sequence[Hashable]] = None, @@ -3077,13 +3209,23 @@ def count( Data variables: da (time) int64 1 3 1 """ - return self.reduce( - duck_array_ops.count, - dim=dim, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="count", + dim=dim, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.count, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def all( self, @@ -3151,13 +3293,23 @@ def all( Data variables: da (time) bool True True False """ - return self.reduce( - duck_array_ops.array_all, - dim=dim, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="all", + dim=dim, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_all, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def any( self, @@ -3225,13 +3377,23 @@ def any( Data variables: da (time) bool True True True """ - return self.reduce( - duck_array_ops.array_any, - dim=dim, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="any", + dim=dim, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_any, + dim=dim, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def max( self, @@ -3315,14 +3477,25 @@ def max( Data variables: da (time) float64 1.0 3.0 nan """ - return self.reduce( - duck_array_ops.max, - dim=dim, - skipna=skipna, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="max", + dim=dim, + skipna=skipna, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def min( self, @@ -3406,14 +3579,25 @@ def min( Data variables: da (time) float64 1.0 1.0 nan """ - return self.reduce( - duck_array_ops.min, - dim=dim, - skipna=skipna, - numeric_only=False, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="min", + dim=dim, + skipna=skipna, + numeric_only=False, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + numeric_only=False, + keep_attrs=keep_attrs, + **kwargs, + ) def mean( self, @@ -3501,14 +3685,25 @@ def mean( Data variables: da (time) float64 1.0 2.0 nan """ - return self.reduce( - duck_array_ops.mean, - dim=dim, - skipna=skipna, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="mean", + dim=dim, + skipna=skipna, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def prod( self, @@ -3613,15 +3808,27 @@ def prod( Data variables: da (time) float64 nan 6.0 nan """ - return self.reduce( - duck_array_ops.prod, - dim=dim, - skipna=skipna, - min_count=min_count, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="prod", + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def sum( self, @@ -3726,15 +3933,27 @@ def sum( Data variables: da (time) float64 nan 6.0 nan """ - return self.reduce( - duck_array_ops.sum, - dim=dim, - skipna=skipna, - min_count=min_count, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="sum", + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def std( self, @@ -3836,15 +4055,27 @@ def std( Data variables: da (time) float64 nan 1.0 nan """ - return self.reduce( - duck_array_ops.std, - dim=dim, - skipna=skipna, - ddof=ddof, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="std", + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def var( self, @@ -3946,15 +4177,27 @@ def var( Data variables: da (time) float64 nan 1.0 nan """ - return self.reduce( - duck_array_ops.var, - dim=dim, - skipna=skipna, - ddof=ddof, - numeric_only=True, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="var", + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + numeric_only=True, + keep_attrs=keep_attrs, + **kwargs, + ) def median( self, @@ -4053,7 +4296,7 @@ def median( class DataArrayGroupByReductions: - __slots__ = () + _obj: "DataArray" def reduce( self, @@ -4067,6 +4310,13 @@ def reduce( ) -> "DataArray": raise NotImplementedError() + def _flox_reduce( + self, + dim: Union[None, Hashable, Sequence[Hashable]], + **kwargs, + ) -> "DataArray": + raise NotImplementedError() + def count( self, dim: Union[None, Hashable, Sequence[Hashable]] = None, @@ -4128,12 +4378,21 @@ def count( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.count, - dim=dim, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="count", + dim=dim, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.count, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) def all( self, @@ -4196,12 +4455,21 @@ def all( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.array_all, - dim=dim, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="all", + dim=dim, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_all, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) def any( self, @@ -4264,12 +4532,21 @@ def any( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.array_any, - dim=dim, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="any", + dim=dim, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_any, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) def max( self, @@ -4346,13 +4623,23 @@ def max( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.max, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="max", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def min( self, @@ -4429,13 +4716,23 @@ def min( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.min, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="min", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def mean( self, @@ -4516,13 +4813,23 @@ def mean( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.mean, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="mean", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def prod( self, @@ -4618,14 +4925,25 @@ def prod( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.prod, - dim=dim, - skipna=skipna, - min_count=min_count, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="prod", + dim=dim, + skipna=skipna, + min_count=min_count, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + keep_attrs=keep_attrs, + **kwargs, + ) def sum( self, @@ -4721,14 +5039,25 @@ def sum( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.sum, - dim=dim, - skipna=skipna, - min_count=min_count, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="sum", + dim=dim, + skipna=skipna, + min_count=min_count, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + keep_attrs=keep_attrs, + **kwargs, + ) def std( self, @@ -4821,14 +5150,25 @@ def std( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.std, - dim=dim, - skipna=skipna, - ddof=ddof, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="std", + dim=dim, + skipna=skipna, + ddof=ddof, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + keep_attrs=keep_attrs, + **kwargs, + ) def var( self, @@ -4921,14 +5261,25 @@ def var( Coordinates: * labels (labels) object 'a' 'b' 'c' """ - return self.reduce( - duck_array_ops.var, - dim=dim, - skipna=skipna, - ddof=ddof, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="var", + dim=dim, + skipna=skipna, + ddof=ddof, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + keep_attrs=keep_attrs, + **kwargs, + ) def median( self, @@ -5019,7 +5370,7 @@ def median( class DataArrayResampleReductions: - __slots__ = () + _obj: "DataArray" def reduce( self, @@ -5033,6 +5384,13 @@ def reduce( ) -> "DataArray": raise NotImplementedError() + def _flox_reduce( + self, + dim: Union[None, Hashable, Sequence[Hashable]], + **kwargs, + ) -> "DataArray": + raise NotImplementedError() + def count( self, dim: Union[None, Hashable, Sequence[Hashable]] = None, @@ -5094,12 +5452,21 @@ def count( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.count, - dim=dim, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="count", + dim=dim, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.count, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) def all( self, @@ -5162,12 +5529,21 @@ def all( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.array_all, - dim=dim, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="all", + dim=dim, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_all, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) def any( self, @@ -5230,12 +5606,21 @@ def any( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.array_any, - dim=dim, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="any", + dim=dim, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.array_any, + dim=dim, + keep_attrs=keep_attrs, + **kwargs, + ) def max( self, @@ -5312,13 +5697,23 @@ def max( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.max, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="max", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.max, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def min( self, @@ -5395,13 +5790,23 @@ def min( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.min, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="min", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.min, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def mean( self, @@ -5482,13 +5887,23 @@ def mean( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.mean, - dim=dim, - skipna=skipna, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="mean", + dim=dim, + skipna=skipna, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.mean, + dim=dim, + skipna=skipna, + keep_attrs=keep_attrs, + **kwargs, + ) def prod( self, @@ -5584,14 +5999,25 @@ def prod( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.prod, - dim=dim, - skipna=skipna, - min_count=min_count, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="prod", + dim=dim, + skipna=skipna, + min_count=min_count, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.prod, + dim=dim, + skipna=skipna, + min_count=min_count, + keep_attrs=keep_attrs, + **kwargs, + ) def sum( self, @@ -5687,14 +6113,25 @@ def sum( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.sum, - dim=dim, - skipna=skipna, - min_count=min_count, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="sum", + dim=dim, + skipna=skipna, + min_count=min_count, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.sum, + dim=dim, + skipna=skipna, + min_count=min_count, + keep_attrs=keep_attrs, + **kwargs, + ) def std( self, @@ -5787,14 +6224,25 @@ def std( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.std, - dim=dim, - skipna=skipna, - ddof=ddof, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="std", + dim=dim, + skipna=skipna, + ddof=ddof, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.std, + dim=dim, + skipna=skipna, + ddof=ddof, + keep_attrs=keep_attrs, + **kwargs, + ) def var( self, @@ -5887,14 +6335,25 @@ def var( Coordinates: * time (time) datetime64[ns] 2001-01-31 2001-04-30 2001-07-31 """ - return self.reduce( - duck_array_ops.var, - dim=dim, - skipna=skipna, - ddof=ddof, - keep_attrs=keep_attrs, - **kwargs, - ) + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="var", + dim=dim, + skipna=skipna, + ddof=ddof, + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.var, + dim=dim, + skipna=skipna, + ddof=ddof, + keep_attrs=keep_attrs, + **kwargs, + ) def median( self, diff --git a/xarray/core/_typed_ops.pyi b/xarray/core/_typed_ops.pyi index e23b5848ff7..e5b3c9112c7 100644 --- a/xarray/core/_typed_ops.pyi +++ b/xarray/core/_typed_ops.pyi @@ -21,7 +21,7 @@ from .variable import Variable try: from dask.array import Array as DaskArray except ImportError: - DaskArray = np.ndarray + DaskArray = np.ndarray # type: ignore # DatasetOpsMixin etc. are parent classes of Dataset etc. # Because of https://github.com/pydata/xarray/issues/5755, we redefine these. Generally diff --git a/xarray/core/common.py b/xarray/core/common.py index 3db9b1cfa0c..75518716870 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -158,6 +158,10 @@ def _repr_html_(self): return f"
{escape(repr(self))}
" return formatting_html.array_repr(self) + def __format__(self: Any, format_spec: str) -> str: + # we use numpy: scalars will print fine and arrays will raise + return self.values.__format__(format_spec) + def _iter(self: Any) -> Iterator[Any]: for n in range(len(self)): yield self[n] @@ -450,7 +454,7 @@ def assign_coords(self, coords=None, **coords_kwargs): Examples -------- - Convert longitude coordinates from 0-359 to -180-179: + Convert `DataArray` longitude coordinates from 0-359 to -180-179: >>> da = xr.DataArray( ... np.random.rand(4), @@ -490,6 +494,54 @@ def assign_coords(self, coords=None, **coords_kwargs): >>> _ = da.assign_coords({"lon_2": ("lon", lon_2)}) + Note the same method applies to `Dataset` objects. + + Convert `Dataset` longitude coordinates from 0-359 to -180-179: + + >>> temperature = np.linspace(20, 32, num=16).reshape(2, 2, 4) + >>> precipitation = 2 * np.identity(4).reshape(2, 2, 4) + >>> ds = xr.Dataset( + ... data_vars=dict( + ... temperature=(["x", "y", "time"], temperature), + ... precipitation=(["x", "y", "time"], precipitation), + ... ), + ... coords=dict( + ... lon=(["x", "y"], [[260.17, 260.68], [260.21, 260.77]]), + ... lat=(["x", "y"], [[42.25, 42.21], [42.63, 42.59]]), + ... time=pd.date_range("2014-09-06", periods=4), + ... reference_time=pd.Timestamp("2014-09-05"), + ... ), + ... attrs=dict(description="Weather-related data"), + ... ) + >>> ds + + Dimensions: (x: 2, y: 2, time: 4) + Coordinates: + lon (x, y) float64 260.2 260.7 260.2 260.8 + lat (x, y) float64 42.25 42.21 42.63 42.59 + * time (time) datetime64[ns] 2014-09-06 2014-09-07 ... 2014-09-09 + reference_time datetime64[ns] 2014-09-05 + Dimensions without coordinates: x, y + Data variables: + temperature (x, y, time) float64 20.0 20.8 21.6 22.4 ... 30.4 31.2 32.0 + precipitation (x, y, time) float64 2.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 2.0 + Attributes: + description: Weather-related data + >>> ds.assign_coords(lon=(((ds.lon + 180) % 360) - 180)) + + Dimensions: (x: 2, y: 2, time: 4) + Coordinates: + lon (x, y) float64 -99.83 -99.32 -99.79 -99.23 + lat (x, y) float64 42.25 42.21 42.63 42.59 + * time (time) datetime64[ns] 2014-09-06 2014-09-07 ... 2014-09-09 + reference_time datetime64[ns] 2014-09-05 + Dimensions without coordinates: x, y + Data variables: + temperature (x, y, time) float64 20.0 20.8 21.6 22.4 ... 30.4 31.2 32.0 + precipitation (x, y, time) float64 2.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 2.0 + Attributes: + description: Weather-related data + Notes ----- Since ``coords_kwargs`` is a dictionary, the order of your arguments diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 1834622d96e..81b5e3fd915 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -17,12 +17,15 @@ Iterable, Mapping, Sequence, + overload, ) import numpy as np from . import dtypes, duck_array_ops, utils from .alignment import align, deep_align +from .common import zeros_like +from .duck_array_ops import datetime_to_numeric from .indexes import Index, filter_indexes_from_coords from .merge import merge_attrs, merge_coordinates_without_align from .options import OPTIONS, _get_keep_attrs @@ -1843,36 +1846,111 @@ def where(cond, x, y, keep_attrs=None): ) -def polyval(coord, coeffs, degree_dim="degree"): +@overload +def polyval(coord: DataArray, coeffs: DataArray, degree_dim: Hashable) -> DataArray: + ... + + +@overload +def polyval(coord: DataArray, coeffs: Dataset, degree_dim: Hashable) -> Dataset: + ... + + +@overload +def polyval(coord: Dataset, coeffs: DataArray, degree_dim: Hashable) -> Dataset: + ... + + +@overload +def polyval(coord: Dataset, coeffs: Dataset, degree_dim: Hashable) -> Dataset: + ... + + +def polyval( + coord: Dataset | DataArray, + coeffs: Dataset | DataArray, + degree_dim: Hashable = "degree", +) -> Dataset | DataArray: """Evaluate a polynomial at specific values Parameters ---------- - coord : DataArray - The 1D coordinate along which to evaluate the polynomial. - coeffs : DataArray - Coefficients of the polynomials. - degree_dim : str, default: "degree" + coord : DataArray or Dataset + Values at which to evaluate the polynomial. + coeffs : DataArray or Dataset + Coefficients of the polynomial. + degree_dim : Hashable, default: "degree" Name of the polynomial degree dimension in `coeffs`. + Returns + ------- + DataArray or Dataset + Evaluated polynomial. + See Also -------- xarray.DataArray.polyfit - numpy.polyval + numpy.polynomial.polynomial.polyval """ - from .dataarray import DataArray - from .missing import get_clean_interp_index - x = get_clean_interp_index(coord, coord.name, strict=False) + if degree_dim not in coeffs._indexes: + raise ValueError( + f"Dimension `{degree_dim}` should be a coordinate variable with labels." + ) + if not np.issubdtype(coeffs[degree_dim].dtype, int): + raise ValueError( + f"Dimension `{degree_dim}` should be of integer dtype. Received {coeffs[degree_dim].dtype} instead." + ) + max_deg = coeffs[degree_dim].max().item() + coeffs = coeffs.reindex( + {degree_dim: np.arange(max_deg + 1)}, fill_value=0, copy=False + ) + coord = _ensure_numeric(coord) # type: ignore # https://github.com/python/mypy/issues/1533 ? - deg_coord = coeffs[degree_dim] + # using Horner's method + # https://en.wikipedia.org/wiki/Horner%27s_method + res = zeros_like(coord) + coeffs.isel({degree_dim: max_deg}, drop=True) + for deg in range(max_deg - 1, -1, -1): + res *= coord + res += coeffs.isel({degree_dim: deg}, drop=True) - lhs = DataArray( - np.vander(x, int(deg_coord.max()) + 1), - dims=(coord.name, degree_dim), - coords={coord.name: coord, degree_dim: np.arange(deg_coord.max() + 1)[::-1]}, - ) - return (lhs * coeffs).sum(degree_dim) + return res + + +def _ensure_numeric(data: T_Xarray) -> T_Xarray: + """Converts all datetime64 variables to float64 + + Parameters + ---------- + data : DataArray or Dataset + Variables with possible datetime dtypes. + + Returns + ------- + DataArray or Dataset + Variables with datetime64 dtypes converted to float64. + """ + from .dataset import Dataset + + def to_floatable(x: DataArray) -> DataArray: + if x.dtype.kind == "M": + # datetimes + return x.copy( + data=datetime_to_numeric( + x.data, + offset=np.datetime64("1970-01-01"), + datetime_unit="ns", + ), + ) + elif x.dtype.kind == "m": + # timedeltas + return x.astype(float) + return x + + if isinstance(data, Dataset): + return data.map(to_floatable) + else: + return to_floatable(data) def _calc_idxminmax( diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py index 0e0229cc3ca..e114c238b72 100644 --- a/xarray/core/dask_array_compat.py +++ b/xarray/core/dask_array_compat.py @@ -1,14 +1,11 @@ import warnings import numpy as np -from packaging.version import Version - -from .pycompat import dask_version try: import dask.array as da except ImportError: - da = None + da = None # type: ignore def _validate_pad_output_shape(input_shape, pad_width, output_shape): @@ -57,127 +54,7 @@ def pad(array, pad_width, mode="constant", **kwargs): return padded -if dask_version > Version("2.30.0"): - ensure_minimum_chunksize = da.overlap.ensure_minimum_chunksize -else: - - # copied from dask - def ensure_minimum_chunksize(size, chunks): - """Determine new chunks to ensure that every chunk >= size - - Parameters - ---------- - size : int - The maximum size of any chunk. - chunks : tuple - Chunks along one axis, e.g. ``(3, 3, 2)`` - - Examples - -------- - >>> ensure_minimum_chunksize(10, (20, 20, 1)) - (20, 11, 10) - >>> ensure_minimum_chunksize(3, (1, 1, 3)) - (5,) - - See Also - -------- - overlap - """ - if size <= min(chunks): - return chunks - - # add too-small chunks to chunks before them - output = [] - new = 0 - for c in chunks: - if c < size: - if new > size + (size - c): - output.append(new - (size - c)) - new = size - else: - new += c - if new >= size: - output.append(new) - new = 0 - if c >= size: - new += c - if new >= size: - output.append(new) - elif len(output) >= 1: - output[-1] += new - else: - raise ValueError( - f"The overlapping depth {size} is larger than your " - f"array {sum(chunks)}." - ) - - return tuple(output) - - -if dask_version > Version("2021.03.0"): +if da is not None: sliding_window_view = da.lib.stride_tricks.sliding_window_view else: - - def sliding_window_view(x, window_shape, axis=None): - from dask.array.overlap import map_overlap - from numpy.core.numeric import normalize_axis_tuple - - from .npcompat import sliding_window_view as _np_sliding_window_view - - window_shape = ( - tuple(window_shape) if np.iterable(window_shape) else (window_shape,) - ) - - window_shape_array = np.array(window_shape) - if np.any(window_shape_array <= 0): - raise ValueError("`window_shape` must contain positive values") - - if axis is None: - axis = tuple(range(x.ndim)) - if len(window_shape) != len(axis): - raise ValueError( - f"Since axis is `None`, must provide " - f"window_shape for all dimensions of `x`; " - f"got {len(window_shape)} window_shape elements " - f"and `x.ndim` is {x.ndim}." - ) - else: - axis = normalize_axis_tuple(axis, x.ndim, allow_duplicate=True) - if len(window_shape) != len(axis): - raise ValueError( - f"Must provide matching length window_shape and " - f"axis; got {len(window_shape)} window_shape " - f"elements and {len(axis)} axes elements." - ) - - depths = [0] * x.ndim - for ax, window in zip(axis, window_shape): - depths[ax] += window - 1 - - # Ensure that each chunk is big enough to leave at least a size-1 chunk - # after windowing (this is only really necessary for the last chunk). - safe_chunks = tuple( - ensure_minimum_chunksize(d + 1, c) for d, c in zip(depths, x.chunks) - ) - x = x.rechunk(safe_chunks) - - # result.shape = x_shape_trimmed + window_shape, - # where x_shape_trimmed is x.shape with every entry - # reduced by one less than the corresponding window size. - # trim chunks to match x_shape_trimmed - newchunks = tuple( - c[:-1] + (c[-1] - d,) for d, c in zip(depths, x.chunks) - ) + tuple((window,) for window in window_shape) - - kwargs = dict( - depth=tuple((0, d) for d in depths), # Overlap on +ve side only - boundary="none", - meta=x._meta, - new_axis=range(x.ndim, x.ndim + len(axis)), - chunks=newchunks, - trim=False, - window_shape=window_shape, - axis=axis, - ) - - return map_overlap(_np_sliding_window_view, x, align_arrays=False, **kwargs) + sliding_window_view = None diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index d15cbd00c0d..64c4e419788 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -17,6 +17,7 @@ import numpy as np import pandas as pd +from ..backends.common import AbstractDataStore, ArrayWriter from ..coding.calendar_ops import convert_calendar, interp_calendar from ..coding.cftimeindex import CFTimeIndex from ..plot.plot import _PlotMethods @@ -67,7 +68,7 @@ try: from dask.delayed import Delayed except ImportError: - Delayed = None + Delayed = None # type: ignore try: from cdms2 import Variable as cdms2_Variable except ImportError: @@ -77,7 +78,7 @@ except ImportError: iris_Cube = None - from .types import T_DataArray, T_Xarray + from .types import ErrorChoice, ErrorChoiceWithWarn, T_DataArray, T_Xarray def _infer_coords_and_dims( @@ -1113,6 +1114,7 @@ def chunk( name_prefix: str = "xarray-", token: str = None, lock: bool = False, + inline_array: bool = False, **chunks_kwargs: Any, ) -> DataArray: """Coerce this array's data into a dask arrays with the given chunks. @@ -1137,6 +1139,9 @@ def chunk( lock : optional Passed on to :py:func:`dask.array.from_array`, if the array is not already as dask array. + inline_array: optional + Passed on to :py:func:`dask.array.from_array`, if the array is not + already as dask array. **chunks_kwargs : {dim: chunks, ...}, optional The keyword arguments form of ``chunks``. One of chunks or chunks_kwargs must be provided. @@ -1144,6 +1149,13 @@ def chunk( Returns ------- chunked : xarray.DataArray + + See Also + -------- + DataArray.chunks + DataArray.chunksizes + xarray.unify_chunks + dask.array.from_array """ if chunks is None: warnings.warn( @@ -1162,7 +1174,11 @@ def chunk( chunks = either_dict_or_kwargs(chunks, chunks_kwargs, "chunk") ds = self._to_temp_dataset().chunk( - chunks, name_prefix=name_prefix, token=token, lock=lock + chunks, + name_prefix=name_prefix, + token=token, + lock=lock, + inline_array=inline_array, ) return self._from_temp_dataset(ds) @@ -1170,7 +1186,7 @@ def isel( self, indexers: Mapping[Any, Any] = None, drop: bool = False, - missing_dims: str = "raise", + missing_dims: ErrorChoiceWithWarn = "raise", **indexers_kwargs: Any, ) -> DataArray: """Return a new DataArray whose data is given by integer indexing @@ -1185,7 +1201,7 @@ def isel( If DataArrays are passed as indexers, xarray-style indexing will be carried out. See :ref:`indexing` for the details. One of indexers or indexers_kwargs must be provided. - drop : bool, optional + drop : bool, default: False If ``drop=True``, drop coordinates variables indexed by integers instead of making them scalar. missing_dims : {"raise", "warn", "ignore"}, default: "raise" @@ -2334,7 +2350,7 @@ def transpose( self, *dims: Hashable, transpose_coords: bool = True, - missing_dims: str = "raise", + missing_dims: ErrorChoiceWithWarn = "raise", ) -> DataArray: """Return a new DataArray object with transposed dimensions. @@ -2385,7 +2401,7 @@ def T(self) -> DataArray: return self.transpose() def drop_vars( - self, names: Hashable | Iterable[Hashable], *, errors: str = "raise" + self, names: Hashable | Iterable[Hashable], *, errors: ErrorChoice = "raise" ) -> DataArray: """Returns an array with dropped variables. @@ -2393,8 +2409,8 @@ def drop_vars( ---------- names : hashable or iterable of hashable Name(s) of variables to drop. - errors : {"raise", "ignore"}, optional - If 'raise' (default), raises a ValueError error if any of the variable + errors : {"raise", "ignore"}, default: "raise" + If 'raise', raises a ValueError error if any of the variable passed are not in the dataset. If 'ignore', any given names that are in the DataArray are dropped and no error is raised. @@ -2411,7 +2427,7 @@ def drop( labels: Mapping = None, dim: Hashable = None, *, - errors: str = "raise", + errors: ErrorChoice = "raise", **labels_kwargs, ) -> DataArray: """Backward compatible method based on `drop_vars` and `drop_sel` @@ -2430,7 +2446,7 @@ def drop_sel( self, labels: Mapping[Any, Any] = None, *, - errors: str = "raise", + errors: ErrorChoice = "raise", **labels_kwargs, ) -> DataArray: """Drop index labels from this DataArray. @@ -2439,8 +2455,8 @@ def drop_sel( ---------- labels : mapping of hashable to Any Index labels to drop - errors : {"raise", "ignore"}, optional - If 'raise' (default), raises a ValueError error if + errors : {"raise", "ignore"}, default: "raise" + If 'raise', raises a ValueError error if any of the index labels passed are not in the dataset. If 'ignore', any given labels that are in the dataset are dropped and no error is raised. @@ -2875,7 +2891,9 @@ def to_masked_array(self, copy: bool = True) -> np.ma.MaskedArray: isnull = pd.isnull(values) return np.ma.MaskedArray(data=values, mask=isnull, copy=copy) - def to_netcdf(self, *args, **kwargs) -> bytes | Delayed | None: + def to_netcdf( + self, *args, **kwargs + ) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: """Write DataArray contents to a netCDF file. All parameters are passed directly to :py:meth:`xarray.Dataset.to_netcdf`. @@ -4586,7 +4604,7 @@ def query( queries: Mapping[Any, Any] = None, parser: str = "pandas", engine: str = None, - missing_dims: str = "raise", + missing_dims: ErrorChoiceWithWarn = "raise", **queries_kwargs: Any, ) -> DataArray: """Return a new data array indexed along the specified diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 76776b4bc44..b73cd797a8f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -32,6 +32,7 @@ import xarray as xr +from ..backends.common import ArrayWriter from ..coding.calendar_ops import convert_calendar, interp_calendar from ..coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings from ..plot.dataset_plot import _Dataset_PlotMethods @@ -105,12 +106,12 @@ from ..backends import AbstractDataStore, ZarrStore from .dataarray import DataArray from .merge import CoercibleMapping - from .types import T_Xarray + from .types import ErrorChoice, ErrorChoiceWithWarn, T_Xarray try: from dask.delayed import Delayed except ImportError: - Delayed = None + Delayed = None # type: ignore # list of attributes of pd.DatetimeIndex that are ndarrays of time info @@ -239,6 +240,7 @@ def _maybe_chunk( lock=None, name_prefix="xarray-", overwrite_encoded_chunks=False, + inline_array=False, ): from dask.base import tokenize @@ -250,7 +252,7 @@ def _maybe_chunk( # subtle bugs result otherwise. see GH3350 token2 = tokenize(name, token if token else var._data, chunks) name2 = f"{name_prefix}{name}-{token2}" - var = var.chunk(chunks, name=name2, lock=lock) + var = var.chunk(chunks, name=name2, lock=lock, inline_array=inline_array) if overwrite_encoded_chunks and var.chunks is not None: var.encoding["chunks"] = tuple(x[0] for x in var.chunks) @@ -1686,7 +1688,7 @@ def to_netcdf( unlimited_dims: Iterable[Hashable] = None, compute: bool = True, invalid_netcdf: bool = False, - ) -> bytes | Delayed | None: + ) -> tuple[ArrayWriter, AbstractDataStore] | bytes | Delayed | None: """Write dataset contents to a netCDF file. Parameters @@ -1994,6 +1996,7 @@ def chunk( name_prefix: str = "xarray-", token: str = None, lock: bool = False, + inline_array: bool = False, **chunks_kwargs: Any, ) -> Dataset: """Coerce all arrays in this dataset into dask arrays with the given @@ -2018,6 +2021,9 @@ def chunk( lock : optional Passed on to :py:func:`dask.array.from_array`, if the array is not already as dask array. + inline_array: optional + Passed on to :py:func:`dask.array.from_array`, if the array is not + already as dask array. **chunks_kwargs : {dim: chunks, ...}, optional The keyword arguments form of ``chunks``. One of chunks or chunks_kwargs must be provided @@ -2031,6 +2037,7 @@ def chunk( Dataset.chunks Dataset.chunksizes xarray.unify_chunks + dask.array.from_array """ if chunks is None and chunks_kwargs is None: warnings.warn( @@ -2058,7 +2065,7 @@ def chunk( return self._replace(variables) def _validate_indexers( - self, indexers: Mapping[Any, Any], missing_dims: str = "raise" + self, indexers: Mapping[Any, Any], missing_dims: ErrorChoiceWithWarn = "raise" ) -> Iterator[tuple[Hashable, int | slice | np.ndarray | Variable]]: """Here we make sure + indexer has a valid keys @@ -2163,7 +2170,7 @@ def isel( self, indexers: Mapping[Any, Any] = None, drop: bool = False, - missing_dims: str = "raise", + missing_dims: ErrorChoiceWithWarn = "raise", **indexers_kwargs: Any, ) -> Dataset: """Returns a new dataset with each array indexed along the specified @@ -2182,14 +2189,14 @@ def isel( If DataArrays are passed as indexers, xarray-style indexing will be carried out. See :ref:`indexing` for the details. One of indexers or indexers_kwargs must be provided. - drop : bool, optional + drop : bool, default: False If ``drop=True``, drop coordinates variables indexed by integers instead of making them scalar. missing_dims : {"raise", "warn", "ignore"}, default: "raise" What to do if dimensions that should be selected from are not present in the Dataset: - "raise": raise an exception - - "warning": raise a warning, and ignore the missing dimensions + - "warn": raise a warning, and ignore the missing dimensions - "ignore": ignore the missing dimensions **indexers_kwargs : {dim: indexer, ...}, optional The keyword arguments form of ``indexers``. @@ -2254,7 +2261,7 @@ def _isel_fancy( indexers: Mapping[Any, Any], *, drop: bool, - missing_dims: str = "raise", + missing_dims: ErrorChoiceWithWarn = "raise", ) -> Dataset: valid_indexers = dict(self._validate_indexers(indexers, missing_dims)) @@ -2270,6 +2277,10 @@ def _isel_fancy( } if var_indexers: new_var = var.isel(indexers=var_indexers) + # drop scalar coordinates + # https://github.com/pydata/xarray/issues/6554 + if name in self.coords and drop and new_var.ndim == 0: + continue else: new_var = var.copy(deep=False) if name not in indexes: @@ -4520,7 +4531,7 @@ def _assert_all_in_dataset( ) def drop_vars( - self, names: Hashable | Iterable[Hashable], *, errors: str = "raise" + self, names: Hashable | Iterable[Hashable], *, errors: ErrorChoice = "raise" ) -> Dataset: """Drop variables from this dataset. @@ -4528,8 +4539,8 @@ def drop_vars( ---------- names : hashable or iterable of hashable Name(s) of variables to drop. - errors : {"raise", "ignore"}, optional - If 'raise' (default), raises a ValueError error if any of the variable + errors : {"raise", "ignore"}, default: "raise" + If 'raise', raises a ValueError error if any of the variable passed are not in the dataset. If 'ignore', any given names that are in the dataset are dropped and no error is raised. @@ -4546,6 +4557,23 @@ def drop_vars( if errors == "raise": self._assert_all_in_dataset(names) + # GH6505 + other_names = set() + for var in names: + maybe_midx = self._indexes.get(var, None) + if isinstance(maybe_midx, PandasMultiIndex): + idx_coord_names = set(maybe_midx.index.names + [maybe_midx.dim]) + idx_other_names = idx_coord_names - set(names) + other_names.update(idx_other_names) + if other_names: + names |= set(other_names) + warnings.warn( + f"Deleting a single level of a MultiIndex is deprecated. Previously, this deleted all levels of a MultiIndex. " + f"Please also drop the following variables: {other_names!r} to avoid an error in the future.", + DeprecationWarning, + stacklevel=2, + ) + assert_no_index_corrupted(self.xindexes, names) variables = {k: v for k, v in self._variables.items() if k not in names} @@ -4555,7 +4583,9 @@ def drop_vars( variables, coord_names=coord_names, indexes=indexes ) - def drop(self, labels=None, dim=None, *, errors="raise", **labels_kwargs): + def drop( + self, labels=None, dim=None, *, errors: ErrorChoice = "raise", **labels_kwargs + ): """Backward compatible method based on `drop_vars` and `drop_sel` Using either `drop_vars` or `drop_sel` is encouraged @@ -4604,15 +4634,15 @@ def drop(self, labels=None, dim=None, *, errors="raise", **labels_kwargs): ) return self.drop_sel(labels, errors=errors) - def drop_sel(self, labels=None, *, errors="raise", **labels_kwargs): + def drop_sel(self, labels=None, *, errors: ErrorChoice = "raise", **labels_kwargs): """Drop index labels from this dataset. Parameters ---------- labels : mapping of hashable to Any Index labels to drop - errors : {"raise", "ignore"}, optional - If 'raise' (default), raises a ValueError error if + errors : {"raise", "ignore"}, default: "raise" + If 'raise', raises a ValueError error if any of the index labels passed are not in the dataset. If 'ignore', any given labels that are in the dataset are dropped and no error is raised. @@ -4739,7 +4769,7 @@ def drop_isel(self, indexers=None, **indexers_kwargs): return ds def drop_dims( - self, drop_dims: Hashable | Iterable[Hashable], *, errors: str = "raise" + self, drop_dims: Hashable | Iterable[Hashable], *, errors: ErrorChoice = "raise" ) -> Dataset: """Drop dimensions and associated variables from this dataset. @@ -4779,7 +4809,7 @@ def drop_dims( def transpose( self, *dims: Hashable, - missing_dims: str = "raise", + missing_dims: ErrorChoiceWithWarn = "raise", ) -> Dataset: """Return a new Dataset object with all array dimensions transposed. @@ -5278,7 +5308,7 @@ def map( args: Iterable[Any] = (), **kwargs: Any, ) -> Dataset: - """Apply a function to each variable in this dataset + """Apply a function to each data variable in this dataset Parameters ---------- @@ -7713,7 +7743,7 @@ def query( queries: Mapping[Any, Any] = None, parser: str = "pandas", engine: str = None, - missing_dims: str = "raise", + missing_dims: ErrorChoiceWithWarn = "raise", **queries_kwargs: Any, ) -> Dataset: """Return a new dataset with each array indexed along the specified @@ -7746,7 +7776,7 @@ def query( Dataset: - "raise": raise an exception - - "warning": raise a warning, and ignore the missing dimensions + - "warn": raise a warning, and ignore the missing dimensions - "ignore": ignore the missing dimensions **queries_kwargs : {dim: query, ...}, optional diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 455ff96d38c..a8d7476f1ab 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -30,7 +30,7 @@ import dask.array as dask_array from dask.base import tokenize except ImportError: - dask_array = None + dask_array = None # type: ignore def _dask_or_eager_func( diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index e97499f06b4..fec8954c9e2 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -264,6 +264,10 @@ class GroupBy: "_stacked_dim", "_unique_coord", "_dims", + "_squeeze", + # Save unstacked object for flox + "_original_obj", + "_unstacked_group", "_bins", ) @@ -312,7 +316,7 @@ def __init__( if not hashable(group): raise TypeError( "`group` must be an xarray.DataArray or the " - "name of an xarray variable or dimension." + "name of an xarray variable or dimension. " f"Received {group!r} instead." ) group = obj[group] @@ -326,6 +330,10 @@ def __init__( if getattr(group, "name", None) is None: group.name = "group" + self._original_obj = obj + self._unstacked_group = group + self._bins = bins + group, obj, stacked_dim, inserted_dims = _ensure_1d(group, obj) (group_dim,) = group.dims @@ -342,7 +350,7 @@ def __init__( if bins is not None: if duck_array_ops.isnull(bins).all(): raise ValueError("All bin edges are NaN.") - binned = pd.cut(group.values, bins, **cut_kwargs) + binned, bins = pd.cut(group.values, bins, **cut_kwargs, retbins=True) new_dim_name = group.name + "_bins" group = DataArray(binned, group.coords, name=new_dim_name) full_index = binned.categories @@ -403,6 +411,7 @@ def __init__( self._full_index = full_index self._restore_coord_dims = restore_coord_dims self._bins = bins + self._squeeze = squeeze # cached attributes self._groups = None @@ -570,6 +579,121 @@ def _maybe_unstack(self, obj): obj._indexes = filter_indexes_from_coords(obj._indexes, set(obj.coords)) return obj + def _flox_reduce(self, dim, **kwargs): + """Adaptor function that translates our groupby API to that of flox.""" + from flox.xarray import xarray_reduce + + from .dataset import Dataset + + obj = self._original_obj + + # preserve current strategy (approximately) for dask groupby. + # We want to control the default anyway to prevent surprises + # if flox decides to change its default + kwargs.setdefault("method", "split-reduce") + + numeric_only = kwargs.pop("numeric_only", None) + if numeric_only: + non_numeric = { + name: var + for name, var in obj.data_vars.items() + if not (np.issubdtype(var.dtype, np.number) or (var.dtype == np.bool_)) + } + else: + non_numeric = {} + + # weird backcompat + # reducing along a unique indexed dimension with squeeze=True + # should raise an error + if ( + dim is None or dim == self._group.name + ) and self._group.name in obj.xindexes: + index = obj.indexes[self._group.name] + if index.is_unique and self._squeeze: + raise ValueError(f"cannot reduce over dimensions {self._group.name!r}") + + # group is only passed by resample + group = kwargs.pop("group", None) + if group is None: + if isinstance(self._unstacked_group, _DummyGroup): + group = self._unstacked_group.name + else: + group = self._unstacked_group + + unindexed_dims = tuple() + if isinstance(group, str): + if group in obj.dims and group not in obj._indexes and self._bins is None: + unindexed_dims = (group,) + group = self._original_obj[group] + + if isinstance(dim, str): + dim = (dim,) + elif dim is None: + dim = group.dims + elif dim is Ellipsis: + dim = tuple(self._original_obj.dims) + + # Do this so we raise the same error message whether flox is present or not. + # Better to control it here than in flox. + if any(d not in group.dims and d not in self._original_obj.dims for d in dim): + raise ValueError(f"cannot reduce over dimensions {dim}.") + + if self._bins is not None: + # TODO: fix this; When binning by time, self._bins is a DatetimeIndex + expected_groups = (np.array(self._bins),) + isbin = (True,) + # This is an annoying hack. Xarray returns np.nan + # when there are no observations in a bin, instead of 0. + # We can fake that here by forcing min_count=1. + if kwargs["func"] == "count": + if "fill_value" not in kwargs or kwargs["fill_value"] is None: + kwargs["fill_value"] = np.nan + # note min_count makes no sense in the xarray world + # as a kwarg for count, so this should be OK + kwargs["min_count"] = 1 + # empty bins have np.nan regardless of dtype + # flox's default would not set np.nan for integer dtypes + kwargs.setdefault("fill_value", np.nan) + else: + expected_groups = (self._unique_coord.values,) + isbin = False + + result = xarray_reduce( + self._original_obj.drop_vars(non_numeric), + group, + dim=dim, + expected_groups=expected_groups, + isbin=isbin, + **kwargs, + ) + + # Ignore error when the groupby reduction is effectively + # a reduction of the underlying dataset + result = result.drop_vars(unindexed_dims, errors="ignore") + + # broadcast and restore non-numeric data variables (backcompat) + for name, var in non_numeric.items(): + if all(d not in var.dims for d in dim): + result[name] = var.variable.set_dims( + (group.name,) + var.dims, (result.sizes[group.name],) + var.shape + ) + + if self._bins is not None: + # bins provided to flox are at full precision + # the bin edge labels have a default precision of 3 + # reassign to fix that. + new_coord = [ + pd.Interval(inter.left, inter.right) for inter in self._full_index + ] + result[self._group.name] = new_coord + # Fix dimension order when binning a dimension coordinate + # Needed as long as we do a separate code path for pint; + # For some reason Datasets and DataArrays behave differently! + if isinstance(self._obj, Dataset) and self._group_dim in self._obj.dims: + result = result.transpose(self._group.name, ...) + + return result + def fillna(self, value): """Fill missing values in this object by group. diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index e02e1f569b2..9884a756fe6 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -22,10 +22,10 @@ from . import formatting, nputils, utils from .indexing import IndexSelResult, PandasIndexingAdapter, PandasMultiIndexingAdapter -from .types import T_Index from .utils import Frozen, get_valid_numpy_dtype, is_dict_like, is_scalar if TYPE_CHECKING: + from .types import ErrorChoice, T_Index from .variable import Variable IndexVars = Dict[Any, "Variable"] @@ -1098,7 +1098,7 @@ def is_multi(self, key: Hashable) -> bool: return len(self._id_coord_names[self._coord_name_id[key]]) > 1 def get_all_coords( - self, key: Hashable, errors: str = "raise" + self, key: Hashable, errors: ErrorChoice = "raise" ) -> dict[Hashable, Variable]: """Return all coordinates having the same index. @@ -1106,7 +1106,7 @@ def get_all_coords( ---------- key : hashable Index key. - errors : {"raise", "ignore"}, optional + errors : {"raise", "ignore"}, default: "raise" If "raise", raises a ValueError if `key` is not in indexes. If "ignore", an empty tuple is returned instead. @@ -1129,7 +1129,7 @@ def get_all_coords( return {k: self._variables[k] for k in all_coord_names} def get_all_dims( - self, key: Hashable, errors: str = "raise" + self, key: Hashable, errors: ErrorChoice = "raise" ) -> Mapping[Hashable, int]: """Return all dimensions shared by an index. @@ -1137,7 +1137,7 @@ def get_all_dims( ---------- key : hashable Index key. - errors : {"raise", "ignore"}, optional + errors : {"raise", "ignore"}, default: "raise" If "raise", raises a ValueError if `key` is not in indexes. If "ignore", an empty tuple is returned instead. diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index c1a4d629f97..fa96bd6e150 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -11,7 +11,7 @@ from . import dask_array_compat except ImportError: - dask_array = None + dask_array = None # type: ignore[assignment] dask_array_compat = None # type: ignore[assignment] diff --git a/xarray/core/options.py b/xarray/core/options.py index 399afe90b66..d31f2577601 100644 --- a/xarray/core/options.py +++ b/xarray/core/options.py @@ -27,6 +27,7 @@ class T_Options(TypedDict): keep_attrs: Literal["default", True, False] warn_for_unclosed_files: bool use_bottleneck: bool + use_flox: bool OPTIONS: T_Options = { @@ -45,6 +46,7 @@ class T_Options(TypedDict): "file_cache_maxsize": 128, "keep_attrs": "default", "use_bottleneck": True, + "use_flox": True, "warn_for_unclosed_files": False, } @@ -70,6 +72,7 @@ def _positive_integer(value): "file_cache_maxsize": _positive_integer, "keep_attrs": lambda choice: choice in [True, False, "default"], "use_bottleneck": lambda value: isinstance(value, bool), + "use_flox": lambda value: isinstance(value, bool), "warn_for_unclosed_files": lambda value: isinstance(value, bool), } @@ -180,6 +183,9 @@ class set_options: use_bottleneck : bool, default: True Whether to use ``bottleneck`` to accelerate 1D reductions and 1D rolling reduction operations. + use_flox : bool, default: True + Whether to use ``numpy_groupies`` and `flox`` to + accelerate groupby and resampling reductions. warn_for_unclosed_files : bool, default: False Whether or not to issue a warning when unclosed files are deallocated. This is mostly useful for debugging. diff --git a/xarray/core/resample.py b/xarray/core/resample.py index ed665ad4048..bcc4bfb90cd 100644 --- a/xarray/core/resample.py +++ b/xarray/core/resample.py @@ -1,13 +1,15 @@ import warnings from typing import Any, Callable, Hashable, Sequence, Union +import numpy as np + from ._reductions import DataArrayResampleReductions, DatasetResampleReductions -from .groupby import DataArrayGroupByBase, DatasetGroupByBase +from .groupby import DataArrayGroupByBase, DatasetGroupByBase, GroupBy RESAMPLE_DIM = "__resample_dim__" -class Resample: +class Resample(GroupBy): """An object that extends the `GroupBy` object with additional logic for handling specialized re-sampling operations. @@ -21,6 +23,29 @@ class Resample: """ + def _flox_reduce(self, dim, **kwargs): + + from .dataarray import DataArray + + kwargs.setdefault("method", "cohorts") + + # now create a label DataArray since resample doesn't do that somehow + repeats = [] + for slicer in self._group_indices: + stop = ( + slicer.stop + if slicer.stop is not None + else self._obj.sizes[self._group_dim] + ) + repeats.append(stop - slicer.start) + labels = np.repeat(self._unique_coord.data, repeats) + group = DataArray(labels, dims=(self._group_dim,), name=self._unique_coord.name) + + result = super()._flox_reduce(dim=dim, group=group, **kwargs) + result = self._maybe_restore_empty_groups(result) + result = result.rename({RESAMPLE_DIM: self._group_dim}) + return result + def _upsample(self, method, *args, **kwargs): """Dispatch function to call appropriate up-sampling methods on data. @@ -158,7 +183,7 @@ def _interpolate(self, kind="linear"): ) -class DataArrayResample(DataArrayGroupByBase, DataArrayResampleReductions, Resample): +class DataArrayResample(Resample, DataArrayGroupByBase, DataArrayResampleReductions): """DataArrayGroupBy object specialized to time resampling operations over a specified dimension """ @@ -249,7 +274,7 @@ def apply(self, func, args=(), shortcut=None, **kwargs): return self.map(func=func, shortcut=shortcut, args=args, **kwargs) -class DatasetResample(DatasetGroupByBase, DatasetResampleReductions, Resample): +class DatasetResample(Resample, DatasetGroupByBase, DatasetResampleReductions): """DatasetGroupBy object specialized to resampling a specified dimension""" def __init__(self, *args, dim=None, resample_dim=None, **kwargs): diff --git a/xarray/core/types.py b/xarray/core/types.py index 3f368501b25..6dbc57ce797 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, TypeVar, Union +from typing import TYPE_CHECKING, Literal, TypeVar, Union import numpy as np @@ -16,7 +16,7 @@ try: from dask.array import Array as DaskArray except ImportError: - DaskArray = np.ndarray + DaskArray = np.ndarray # type: ignore T_Dataset = TypeVar("T_Dataset", bound="Dataset") @@ -33,3 +33,6 @@ DaCompatible = Union["DataArray", "Variable", "DataArrayGroupBy", "ScalarOrArray"] VarCompatible = Union["Variable", "ScalarOrArray"] GroupByIncompatible = Union["Variable", "GroupBy"] + +ErrorChoice = Literal["raise", "ignore"] +ErrorChoiceWithWarn = Literal["raise", "warn", "ignore"] diff --git a/xarray/core/utils.py b/xarray/core/utils.py index bab6476e734..eda08becc20 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -29,6 +29,9 @@ import numpy as np import pandas as pd +if TYPE_CHECKING: + from .types import ErrorChoiceWithWarn + K = TypeVar("K") V = TypeVar("V") T = TypeVar("T") @@ -756,7 +759,9 @@ def __len__(self) -> int: def infix_dims( - dims_supplied: Collection, dims_all: Collection, missing_dims: str = "raise" + dims_supplied: Collection, + dims_all: Collection, + missing_dims: ErrorChoiceWithWarn = "raise", ) -> Iterator: """ Resolves a supplied list containing an ellipsis representing other items, to @@ -804,7 +809,7 @@ def get_temp_dimname(dims: Container[Hashable], new_dim: Hashable) -> Hashable: def drop_dims_from_indexers( indexers: Mapping[Any, Any], dims: list | Mapping[Any, int], - missing_dims: str, + missing_dims: ErrorChoiceWithWarn, ) -> Mapping[Hashable, Any]: """Depending on the setting of missing_dims, drop any dimensions from indexers that are not present in dims. @@ -850,7 +855,7 @@ def drop_dims_from_indexers( def drop_missing_dims( - supplied_dims: Collection, dims: Collection, missing_dims: str + supplied_dims: Collection, dims: Collection, missing_dims: ErrorChoiceWithWarn ) -> Collection: """Depending on the setting of missing_dims, drop any dimensions from supplied_dims that are not present in dims. @@ -923,3 +928,21 @@ def iterate_nested(nested_list): yield from iterate_nested(item) else: yield item + + +def contains_only_dask_or_numpy(obj) -> bool: + """Returns True if xarray object contains only numpy or dask arrays. + + Expects obj to be Dataset or DataArray""" + from .dataarray import DataArray + from .pycompat import is_duck_dask_array + + if isinstance(obj, DataArray): + obj = obj._to_temp_dataset() + + return all( + [ + isinstance(var.data, np.ndarray) or is_duck_dask_array(var.data) + for var in obj.variables.values() + ] + ) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 05c70390b46..1e684a72984 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -59,7 +59,7 @@ BASIC_INDEXING_TYPES = integer_types + (slice,) if TYPE_CHECKING: - from .types import T_Variable + from .types import ErrorChoiceWithWarn, T_Variable class MissingDimensionsError(ValueError): @@ -1023,6 +1023,7 @@ def chunk( ) = {}, name: str = None, lock: bool = False, + inline_array: bool = False, **chunks_kwargs: Any, ) -> Variable: """Coerce this array's data into a dask array with the given chunks. @@ -1046,6 +1047,9 @@ def chunk( lock : optional Passed on to :py:func:`dask.array.from_array`, if the array is not already as dask array. + inline_array: optional + Passed on to :py:func:`dask.array.from_array`, if the array is not + already as dask array. **chunks_kwargs : {dim: chunks, ...}, optional The keyword arguments form of ``chunks``. One of chunks or chunks_kwargs must be provided. @@ -1053,6 +1057,13 @@ def chunk( Returns ------- chunked : xarray.Variable + + See Also + -------- + Variable.chunks + Variable.chunksizes + xarray.unify_chunks + dask.array.from_array """ import dask.array as da @@ -1098,7 +1109,9 @@ def chunk( if utils.is_dict_like(chunks): chunks = tuple(chunks.get(n, s) for n, s in enumerate(self.shape)) - data = da.from_array(data, chunks, name=name, lock=lock, **kwargs) + data = da.from_array( + data, chunks, name=name, lock=lock, inline_array=inline_array, **kwargs + ) return self._replace(data=data) @@ -1159,7 +1172,7 @@ def _to_dense(self): def isel( self: T_Variable, indexers: Mapping[Any, Any] = None, - missing_dims: str = "raise", + missing_dims: ErrorChoiceWithWarn = "raise", **indexers_kwargs: Any, ) -> T_Variable: """Return a new array indexed along the specified dimension(s). @@ -1173,7 +1186,7 @@ def isel( What to do if dimensions that should be selected from are not present in the DataArray: - "raise": raise an exception - - "warning": raise a warning, and ignore the missing dimensions + - "warn": raise a warning, and ignore the missing dimensions - "ignore": ignore the missing dimensions Returns @@ -1436,7 +1449,7 @@ def roll(self, shifts=None, **shifts_kwargs): def transpose( self, *dims, - missing_dims: str = "raise", + missing_dims: ErrorChoiceWithWarn = "raise", ) -> Variable: """Return a new Variable object with transposed dimensions. @@ -2710,7 +2723,7 @@ def values(self, values): f"Please use DataArray.assign_coords, Dataset.assign_coords or Dataset.assign as appropriate." ) - def chunk(self, chunks={}, name=None, lock=False): + def chunk(self, chunks={}, name=None, lock=False, inline_array=False): # Dummy - do not chunk. This method is invoked e.g. by Dataset.chunk() return self.copy(deep=False) diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 7872fec2e62..65f0bc08261 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -78,6 +78,8 @@ def _importorskip(modname, minversion=None): has_cartopy, requires_cartopy = _importorskip("cartopy") has_pint, requires_pint = _importorskip("pint") has_numexpr, requires_numexpr = _importorskip("numexpr") +has_flox, requires_flox = _importorskip("flox") + # some special cases has_scipy_or_netCDF4 = has_scipy or has_netCDF4 diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 81bfeb11a1e..6b6f6e462bd 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -82,7 +82,11 @@ _NON_STANDARD_CALENDARS, _STANDARD_CALENDARS, ) -from .test_dataset import create_append_test_data, create_test_data +from .test_dataset import ( + create_append_string_length_mismatch_test_data, + create_append_test_data, + create_test_data, +) try: import netCDF4 as nc4 @@ -2112,6 +2116,17 @@ def test_append_with_existing_encoding_raises(self): encoding={"da": {"compressor": None}}, ) + @pytest.mark.parametrize("dtype", ["U", "S"]) + def test_append_string_length_mismatch_raises(self, dtype): + ds, ds_to_append = create_append_string_length_mismatch_test_data(dtype) + with self.create_zarr_target() as store_target: + ds.to_zarr(store_target, mode="w") + with pytest.raises(ValueError, match="Mismatched dtypes for variable"): + ds_to_append.to_zarr( + store_target, + append_dim="time", + ) + def test_check_encoding_is_consistent_after_append(self): ds, ds_to_append, _ = create_append_test_data() @@ -3825,6 +3840,27 @@ def test_load_dataarray(self): # load_dataarray ds.to_netcdf(tmp) + @pytest.mark.skipif( + ON_WINDOWS, + reason="counting number of tasks in graph fails on windows for some reason", + ) + def test_inline_array(self): + with create_tmp_file() as tmp: + original = Dataset({"foo": ("x", np.random.randn(10))}) + original.to_netcdf(tmp) + chunks = {"time": 10} + + def num_graph_nodes(obj): + return len(obj.__dask_graph__()) + + not_inlined = open_dataset(tmp, inline_array=False, chunks=chunks) + inlined = open_dataset(tmp, inline_array=True, chunks=chunks) + assert num_graph_nodes(inlined) < num_graph_nodes(not_inlined) + + not_inlined = open_dataarray(tmp, inline_array=False, chunks=chunks) + inlined = open_dataarray(tmp, inline_array=True, chunks=chunks) + assert num_graph_nodes(inlined) < num_graph_nodes(not_inlined) + @requires_scipy_or_netCDF4 @requires_pydap diff --git a/xarray/tests/test_coding_times.py b/xarray/tests/test_coding_times.py index f10523aecbb..a5344fe4c85 100644 --- a/xarray/tests/test_coding_times.py +++ b/xarray/tests/test_coding_times.py @@ -1121,3 +1121,30 @@ def test_should_cftime_be_used_target_not_npable(): ValueError, match="Calendar 'noleap' is only valid with cftime." ): _should_cftime_be_used(src, "noleap", False) + + +@pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.uint32, np.uint64]) +def test_decode_cf_datetime_uint(dtype): + units = "seconds since 2018-08-22T03:23:03Z" + num_dates = dtype(50) + result = decode_cf_datetime(num_dates, units) + expected = np.asarray(np.datetime64("2018-08-22T03:23:53", "ns")) + np.testing.assert_equal(result, expected) + + +@requires_cftime +def test_decode_cf_datetime_uint64_with_cftime(): + units = "days since 1700-01-01" + num_dates = np.uint64(182621) + result = decode_cf_datetime(num_dates, units) + expected = np.asarray(np.datetime64("2200-01-01", "ns")) + np.testing.assert_equal(result, expected) + + +@requires_cftime +def test_decode_cf_datetime_uint64_with_cftime_overflow_error(): + units = "microseconds since 1700-01-01" + calendar = "360_day" + num_dates = np.uint64(1_000_000 * 86_400 * 360 * 500_000) + with pytest.raises(OverflowError): + decode_cf_datetime(num_dates, units, calendar) diff --git a/xarray/tests/test_computation.py b/xarray/tests/test_computation.py index 7a397428ba3..ec8b5a5bc7c 100644 --- a/xarray/tests/test_computation.py +++ b/xarray/tests/test_computation.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import functools import operator import pickle @@ -22,6 +24,7 @@ unified_dim_sizes, ) from xarray.core.pycompat import dask_version +from xarray.core.types import T_Xarray from . import has_dask, raise_if_dask_computes, requires_dask @@ -1933,37 +1936,113 @@ def test_where_attrs() -> None: assert actual.attrs == {} -@pytest.mark.parametrize("use_dask", [True, False]) -@pytest.mark.parametrize("use_datetime", [True, False]) -def test_polyval(use_dask, use_datetime) -> None: - if use_dask and not has_dask: - pytest.skip("requires dask") - - if use_datetime: - xcoord = xr.DataArray( - pd.date_range("2000-01-01", freq="D", periods=10), dims=("x",), name="x" - ) - x = xr.core.missing.get_clean_interp_index(xcoord, "x") - else: - x = np.arange(10) - xcoord = xr.DataArray(x, dims=("x",), name="x") - - da = xr.DataArray( - np.stack((1.0 + x + 2.0 * x**2, 1.0 + 2.0 * x + 3.0 * x**2)), - dims=("d", "x"), - coords={"x": xcoord, "d": [0, 1]}, - ) - coeffs = xr.DataArray( - [[2, 1, 1], [3, 2, 1]], - dims=("d", "degree"), - coords={"d": [0, 1], "degree": [2, 1, 0]}, - ) +@pytest.mark.parametrize("use_dask", [False, True]) +@pytest.mark.parametrize( + ["x", "coeffs", "expected"], + [ + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.DataArray([2, 3, 4], dims="degree", coords={"degree": [0, 1, 2]}), + xr.DataArray([9, 2 + 6 + 16, 2 + 9 + 36], dims="x"), + id="simple", + ), + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.DataArray( + [[0, 1], [0, 1]], dims=("y", "degree"), coords={"degree": [0, 1]} + ), + xr.DataArray([[1, 1], [2, 2], [3, 3]], dims=("x", "y")), + id="broadcast-x", + ), + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.DataArray( + [[0, 1], [1, 0], [1, 1]], + dims=("x", "degree"), + coords={"degree": [0, 1]}, + ), + xr.DataArray([1, 1, 1 + 3], dims="x"), + id="shared-dim", + ), + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.DataArray([1, 0, 0], dims="degree", coords={"degree": [2, 1, 0]}), + xr.DataArray([1, 2**2, 3**2], dims="x"), + id="reordered-index", + ), + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.DataArray([5], dims="degree", coords={"degree": [3]}), + xr.DataArray([5, 5 * 2**3, 5 * 3**3], dims="x"), + id="sparse-index", + ), + pytest.param( + xr.DataArray([1, 2, 3], dims="x"), + xr.Dataset( + {"a": ("degree", [0, 1]), "b": ("degree", [1, 0])}, + coords={"degree": [0, 1]}, + ), + xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("x", [1, 1, 1])}), + id="array-dataset", + ), + pytest.param( + xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("x", [2, 3, 4])}), + xr.DataArray([1, 1], dims="degree", coords={"degree": [0, 1]}), + xr.Dataset({"a": ("x", [2, 3, 4]), "b": ("x", [3, 4, 5])}), + id="dataset-array", + ), + pytest.param( + xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("y", [2, 3, 4])}), + xr.Dataset( + {"a": ("degree", [0, 1]), "b": ("degree", [1, 1])}, + coords={"degree": [0, 1]}, + ), + xr.Dataset({"a": ("x", [1, 2, 3]), "b": ("y", [3, 4, 5])}), + id="dataset-dataset", + ), + pytest.param( + xr.DataArray(pd.date_range("1970-01-01", freq="s", periods=3), dims="x"), + xr.DataArray([0, 1], dims="degree", coords={"degree": [0, 1]}), + xr.DataArray( + [0, 1e9, 2e9], + dims="x", + coords={"x": pd.date_range("1970-01-01", freq="s", periods=3)}, + ), + id="datetime", + ), + pytest.param( + xr.DataArray( + np.array([1000, 2000, 3000], dtype="timedelta64[ns]"), dims="x" + ), + xr.DataArray([0, 1], dims="degree", coords={"degree": [0, 1]}), + xr.DataArray([1000.0, 2000.0, 3000.0], dims="x"), + id="timedelta", + ), + ], +) +def test_polyval( + use_dask: bool, + x: xr.DataArray | xr.Dataset, + coeffs: xr.DataArray | xr.Dataset, + expected: xr.DataArray | xr.Dataset, +) -> None: if use_dask: - coeffs = coeffs.chunk({"d": 2}) + if not has_dask: + pytest.skip("requires dask") + coeffs = coeffs.chunk({"degree": 2}) + x = x.chunk({"x": 2}) + with raise_if_dask_computes(): + actual = xr.polyval(coord=x, coeffs=coeffs) # type: ignore + xr.testing.assert_allclose(actual, expected) - da_pv = xr.polyval(da.x, coeffs) - xr.testing.assert_allclose(da, da_pv.T) +def test_polyval_degree_dim_checks(): + x = (xr.DataArray([1, 2, 3], dims="x"),) + coeffs = xr.DataArray([2, 3, 4], dims="degree", coords={"degree": [0, 1, 2]}) + with pytest.raises(ValueError): + xr.polyval(x, coeffs.drop_vars("degree")) + with pytest.raises(ValueError): + xr.polyval(x, coeffs.assign_coords(degree=coeffs.degree.astype(float))) @pytest.mark.parametrize("use_dask", [False, True]) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index b8c9edd7258..8e1099b7e33 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -2360,10 +2360,11 @@ def test_drop_coordinates(self): assert_identical(actual, renamed) def test_drop_multiindex_level(self): - with pytest.raises( - ValueError, match=r"cannot remove coordinate.*corrupt.*index " - ): - self.mda.drop_vars("level_1") + # GH6505 + expected = self.mda.drop_vars(["x", "level_1", "level_2"]) + with pytest.warns(DeprecationWarning): + actual = self.mda.drop_vars("level_1") + assert_identical(expected, actual) def test_drop_all_multiindex_levels(self): dim_levels = ["x", "level_1", "level_2"] diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index c0f7f09ff61..263237d9d30 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -76,6 +76,8 @@ def create_append_test_data(seed=None): time2 = pd.date_range("2000-02-01", periods=nt2) string_var = np.array(["ae", "bc", "df"], dtype=object) string_var_to_append = np.array(["asdf", "asdfg"], dtype=object) + string_var_fixed_length = np.array(["aa", "bb", "cc"], dtype="|S2") + string_var_fixed_length_to_append = np.array(["dd", "ee"], dtype="|S2") unicode_var = ["áó", "áó", "áó"] datetime_var = np.array( ["2019-01-01", "2019-01-02", "2019-01-03"], dtype="datetime64[s]" @@ -94,6 +96,9 @@ def create_append_test_data(seed=None): dims=["lat", "lon", "time"], ), "string_var": xr.DataArray(string_var, coords=[time1], dims=["time"]), + "string_var_fixed_length": xr.DataArray( + string_var_fixed_length, coords=[time1], dims=["time"] + ), "unicode_var": xr.DataArray( unicode_var, coords=[time1], dims=["time"] ).astype(np.unicode_), @@ -112,6 +117,9 @@ def create_append_test_data(seed=None): "string_var": xr.DataArray( string_var_to_append, coords=[time2], dims=["time"] ), + "string_var_fixed_length": xr.DataArray( + string_var_fixed_length_to_append, coords=[time2], dims=["time"] + ), "unicode_var": xr.DataArray( unicode_var[:nt2], coords=[time2], dims=["time"] ).astype(np.unicode_), @@ -137,6 +145,33 @@ def create_append_test_data(seed=None): return ds, ds_to_append, ds_with_new_var +def create_append_string_length_mismatch_test_data(dtype): + def make_datasets(data, data_to_append): + ds = xr.Dataset( + {"temperature": (["time"], data)}, + coords={"time": [0, 1, 2]}, + ) + ds_to_append = xr.Dataset( + {"temperature": (["time"], data_to_append)}, coords={"time": [0, 1, 2]} + ) + assert all(objp.data.flags.writeable for objp in ds.variables.values()) + assert all( + objp.data.flags.writeable for objp in ds_to_append.variables.values() + ) + return ds, ds_to_append + + u2_strings = ["ab", "cd", "ef"] + u5_strings = ["abc", "def", "ghijk"] + + s2_strings = np.array(["aa", "bb", "cc"], dtype="|S2") + s3_strings = np.array(["aaa", "bbb", "ccc"], dtype="|S3") + + if dtype == "U": + return make_datasets(u2_strings, u5_strings) + elif dtype == "S": + return make_datasets(s2_strings, s3_strings) + + def create_test_multiindex(): mindex = pd.MultiIndex.from_product( [["a", "b"], [1, 2]], names=("level_1", "level_2") @@ -1175,6 +1210,25 @@ def test_isel_fancy(self): assert_array_equal(actual["var2"], expected_var2) assert_array_equal(actual["var3"], expected_var3) + # test that drop works + ds = xr.Dataset({"a": (("x",), [1, 2, 3])}, coords={"b": (("x",), [5, 6, 7])}) + + actual = ds.isel({"x": 1}, drop=False) + expected = xr.Dataset({"a": 2}, coords={"b": 6}) + assert_identical(actual, expected) + + actual = ds.isel({"x": 1}, drop=True) + expected = xr.Dataset({"a": 2}) + assert_identical(actual, expected) + + actual = ds.isel({"x": DataArray(1)}, drop=False) + expected = xr.Dataset({"a": 2}, coords={"b": 6}) + assert_identical(actual, expected) + + actual = ds.isel({"x": DataArray(1)}, drop=True) + expected = xr.Dataset({"a": 2}) + assert_identical(actual, expected) + def test_isel_dataarray(self): """Test for indexing by DataArray""" data = create_test_data() @@ -2399,11 +2453,10 @@ def test_drop_variables(self): def test_drop_multiindex_level(self): data = create_test_multiindex() - - with pytest.raises( - ValueError, match=r"cannot remove coordinate.*corrupt.*index " - ): - data.drop_vars("level_1") + expected = data.drop_vars(["x", "level_1", "level_2"]) + with pytest.warns(DeprecationWarning): + actual = data.drop_vars("level_1") + assert_identical(expected, actual) def test_drop_index_labels(self): data = Dataset({"A": (["x", "y"], np.random.randn(2, 3)), "x": ["a", "b"]}) diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 4bbf41c7b38..a5c044d8ea7 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -9,7 +9,7 @@ import xarray as xr from xarray.core import formatting -from . import requires_netCDF4 +from . import requires_dask, requires_netCDF4 class TestFormatting: @@ -418,6 +418,26 @@ def test_array_repr_variable(self) -> None: with xr.set_options(display_expand_data=False): formatting.array_repr(var) + @requires_dask + def test_array_scalar_format(self) -> None: + var = xr.DataArray(0) + assert var.__format__("") == "0" + assert var.__format__("d") == "0" + assert var.__format__(".2f") == "0.00" + + var = xr.DataArray([0.1, 0.2]) + assert var.__format__("") == "[0.1 0.2]" + with pytest.raises(TypeError) as excinfo: + var.__format__(".2f") + assert "unsupported format string passed to" in str(excinfo.value) + + # also check for dask + var = var.chunk(chunks={"dim_0": 1}) + assert var.__format__("") == "[0.1 0.2]" + with pytest.raises(TypeError) as excinfo: + var.__format__(".2f") + assert "unsupported format string passed to" in str(excinfo.value) + def test_inline_variable_array_repr_custom_repr() -> None: class CustomArray: diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index b4b93d1dba3..8c745dc640d 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -17,6 +17,7 @@ assert_identical, create_test_data, requires_dask, + requires_flox, requires_scipy, ) @@ -24,7 +25,10 @@ @pytest.fixture def dataset(): ds = xr.Dataset( - {"foo": (("x", "y", "z"), np.random.randn(3, 4, 2))}, + { + "foo": (("x", "y", "z"), np.random.randn(3, 4, 2)), + "baz": ("x", ["e", "f", "g"]), + }, {"x": ["a", "b", "c"], "y": [1, 2, 3, 4], "z": [1, 2]}, ) ds["boo"] = (("z", "y"), [["f", "g", "h", "j"]] * 2) @@ -71,6 +75,15 @@ def test_multi_index_groupby_map(dataset) -> None: assert_equal(expected, actual) +def test_reduce_numeric_only(dataset) -> None: + gb = dataset.groupby("x", squeeze=False) + with xr.set_options(use_flox=False): + expected = gb.sum() + with xr.set_options(use_flox=True): + actual = gb.sum() + assert_identical(expected, actual) + + def test_multi_index_groupby_sum() -> None: # regression test for GH873 ds = xr.Dataset( @@ -961,6 +974,17 @@ def test_groupby_dataarray_map_dataset_func(): assert_identical(actual, expected) +@requires_flox +@pytest.mark.parametrize("kwargs", [{"method": "map-reduce"}, {"engine": "numpy"}]) +def test_groupby_flox_kwargs(kwargs): + ds = Dataset({"a": ("x", range(5))}, {"c": ("x", [0, 0, 1, 1, 1])}) + with xr.set_options(use_flox=False): + expected = ds.groupby("c").mean() + with xr.set_options(use_flox=True): + actual = ds.groupby("c").mean(**kwargs) + assert_identical(expected, actual) + + class TestDataArrayGroupBy: @pytest.fixture(autouse=True) def setup(self): @@ -1016,19 +1040,22 @@ def test_groupby_properties(self): assert_array_equal(expected_groups[key], grouped.groups[key]) assert 3 == len(grouped) - def test_groupby_map_identity(self): + @pytest.mark.parametrize( + "by, use_da", [("x", False), ("y", False), ("y", True), ("abc", False)] + ) + @pytest.mark.parametrize("shortcut", [True, False]) + @pytest.mark.parametrize("squeeze", [True, False]) + def test_groupby_map_identity(self, by, use_da, shortcut, squeeze) -> None: expected = self.da - idx = expected.coords["y"] + if use_da: + by = expected.coords[by] def identity(x): return x - for g in ["x", "y", "abc", idx]: - for shortcut in [False, True]: - for squeeze in [False, True]: - grouped = expected.groupby(g, squeeze=squeeze) - actual = grouped.map(identity, shortcut=shortcut) - assert_identical(expected, actual) + grouped = expected.groupby(by, squeeze=squeeze) + actual = grouped.map(identity, shortcut=shortcut) + assert_identical(expected, actual) def test_groupby_sum(self): array = self.da @@ -1083,19 +1110,21 @@ def test_groupby_sum(self): assert_allclose(expected_sum_axis1, grouped.reduce(np.sum, "y")) assert_allclose(expected_sum_axis1, grouped.sum("y")) - def test_groupby_sum_default(self): + @pytest.mark.parametrize("method", ["sum", "mean", "median"]) + def test_groupby_reductions(self, method): array = self.da grouped = array.groupby("abc") - expected_sum_all = Dataset( + reduction = getattr(np, method) + expected = Dataset( { "foo": Variable( ["x", "abc"], np.array( [ - self.x[:, :9].sum(axis=-1), - self.x[:, 10:].sum(axis=-1), - self.x[:, 9:10].sum(axis=-1), + reduction(self.x[:, :9], axis=-1), + reduction(self.x[:, 10:], axis=-1), + reduction(self.x[:, 9:10], axis=-1), ] ).T, ), @@ -1103,7 +1132,14 @@ def test_groupby_sum_default(self): } )["foo"] - assert_allclose(expected_sum_all, grouped.sum(dim="y")) + with xr.set_options(use_flox=False): + actual_legacy = getattr(grouped, method)(dim="y") + + with xr.set_options(use_flox=True): + actual_npg = getattr(grouped, method)(dim="y") + + assert_allclose(expected, actual_legacy) + assert_allclose(expected, actual_npg) def test_groupby_count(self): array = DataArray( @@ -1318,13 +1354,23 @@ def test_groupby_bins(self): expected = DataArray( [1, 5], dims="dim_0_bins", coords={"dim_0_bins": bin_coords} ) - # the problem with this is that it overwrites the dimensions of array! - # actual = array.groupby('dim_0', bins=bins).sum() - actual = array.groupby_bins("dim_0", bins).map(lambda x: x.sum()) + actual = array.groupby_bins("dim_0", bins=bins).sum() + assert_identical(expected, actual) + + actual = array.groupby_bins("dim_0", bins=bins).map(lambda x: x.sum()) assert_identical(expected, actual) + # make sure original array dims are unchanged assert len(array.dim_0) == 4 + da = xr.DataArray(np.ones((2, 3, 4))) + bins = [-1, 0, 1, 2] + with xr.set_options(use_flox=False): + actual = da.groupby_bins("dim_0", bins).mean(...) + with xr.set_options(use_flox=True): + expected = da.groupby_bins("dim_0", bins).mean(...) + assert_allclose(actual, expected) + def test_groupby_bins_empty(self): array = DataArray(np.arange(4), [("x", range(4))]) # one of these bins will be empty @@ -1350,6 +1396,27 @@ def test_groupby_bins_multidim(self): actual = array.groupby_bins("lat", bins).map(lambda x: x.sum()) assert_identical(expected, actual) + bins = [-2, -1, 0, 1, 2] + field = DataArray(np.ones((5, 3)), dims=("x", "y")) + by = DataArray( + np.array([[-1.5, -1.5, 0.5, 1.5, 1.5] * 3]).reshape(5, 3), dims=("x", "y") + ) + actual = field.groupby_bins(by, bins=bins).count() + + bincoord = np.array( + [ + pd.Interval(left, right, closed="right") + for left, right in zip(bins[:-1], bins[1:]) + ], + dtype=object, + ) + expected = DataArray( + np.array([6, np.nan, 3, 6]), + dims="group_bins", + coords={"group_bins": bincoord}, + ) + assert_identical(actual, expected) + def test_groupby_bins_sort(self): data = xr.DataArray( np.arange(100), dims="x", coords={"x": np.linspace(-100, 100, num=100)} @@ -1357,6 +1424,12 @@ def test_groupby_bins_sort(self): binned_mean = data.groupby_bins("x", bins=11).mean() assert binned_mean.to_index().is_monotonic_increasing + with xr.set_options(use_flox=True): + actual = data.groupby_bins("x", bins=11).count() + with xr.set_options(use_flox=False): + expected = data.groupby_bins("x", bins=11).count() + assert_identical(actual, expected) + def test_groupby_assign_coords(self): array = DataArray([1, 2, 3, 4], {"c": ("x", [0, 0, 1, 1])}, dims="x") @@ -1769,7 +1842,7 @@ def test_resample_min_count(self): ], dim=actual["time"], ) - assert_equal(expected, actual) + assert_allclose(expected, actual) def test_resample_by_mean_with_keep_attrs(self): times = pd.date_range("2000-01-01", freq="6H", periods=10) @@ -1903,7 +1976,7 @@ def test_resample_ds_da_are_the_same(self): "x": np.arange(5), } ) - assert_identical( + assert_allclose( ds.resample(time="M").mean()["foo"], ds.foo.resample(time="M").mean() ) @@ -1916,6 +1989,3 @@ def func(arg1, arg2, arg3=0.0): expected = xr.Dataset({"foo": ("time", [3.0, 3.0, 3.0]), "time": times}) actual = ds.resample(time="D").map(func, args=(1.0,), arg3=1.0) assert_identical(expected, actual) - - -# TODO: move other groupby tests from test_dataset and test_dataarray over here diff --git a/xarray/tests/test_testing.py b/xarray/tests/test_testing.py index 2bde7529d1e..1470706d0eb 100644 --- a/xarray/tests/test_testing.py +++ b/xarray/tests/test_testing.py @@ -10,7 +10,7 @@ try: from dask.array import from_array as dask_from_array except ImportError: - dask_from_array = lambda x: x + dask_from_array = lambda x: x # type: ignore try: import pint diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index c18b7d18c04..679733e1ecf 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -5344,8 +5344,12 @@ def test_computation_objects(self, func, variant, dtype): units = extract_units(ds) args = [] if func.name != "groupby" else ["y"] - expected = attach_units(func(strip_units(ds)).mean(*args), units) - actual = func(ds).mean(*args) + # Doesn't work with flox because pint doesn't implement + # ufunc.reduceat or np.bincount + # kwargs = {"engine": "numpy"} if "groupby" in func.name else {} + kwargs = {} + expected = attach_units(func(strip_units(ds)).mean(*args, **kwargs), units) + actual = func(ds).mean(*args, **kwargs) assert_units_equal(expected, actual) assert_allclose(expected, actual) diff --git a/xarray/util/generate_ops.py b/xarray/util/generate_ops.py index f1fd6cbfeb2..0a382642708 100644 --- a/xarray/util/generate_ops.py +++ b/xarray/util/generate_ops.py @@ -210,7 +210,7 @@ def inplace(): try: from dask.array import Array as DaskArray except ImportError: - DaskArray = np.ndarray + DaskArray = np.ndarray # type: ignore # DatasetOpsMixin etc. are parent classes of Dataset etc. # Because of https://github.com/pydata/xarray/issues/5755, we redefine these. Generally diff --git a/xarray/util/generate_reductions.py b/xarray/util/generate_reductions.py index e79c94e8907..96b91c16906 100644 --- a/xarray/util/generate_reductions.py +++ b/xarray/util/generate_reductions.py @@ -23,13 +23,19 @@ from typing import TYPE_CHECKING, Any, Callable, Hashable, Optional, Sequence, Union from . import duck_array_ops +from .options import OPTIONS +from .utils import contains_only_dask_or_numpy if TYPE_CHECKING: from .dataarray import DataArray - from .dataset import Dataset''' + from .dataset import Dataset +try: + import flox +except ImportError: + flox = None # type: ignore''' -CLASS_PREAMBLE = """ +DEFAULT_PREAMBLE = """ class {obj}{cls}Reductions: __slots__ = () @@ -46,6 +52,54 @@ def reduce( ) -> "{obj}": raise NotImplementedError()""" +GROUPBY_PREAMBLE = """ + +class {obj}{cls}Reductions: + _obj: "{obj}" + + def reduce( + self, + func: Callable[..., Any], + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + axis: Union[None, int, Sequence[int]] = None, + keep_attrs: bool = None, + keepdims: bool = False, + **kwargs: Any, + ) -> "{obj}": + raise NotImplementedError() + + def _flox_reduce( + self, + dim: Union[None, Hashable, Sequence[Hashable]], + **kwargs, + ) -> "{obj}": + raise NotImplementedError()""" + +RESAMPLE_PREAMBLE = """ + +class {obj}{cls}Reductions: + _obj: "{obj}" + + def reduce( + self, + func: Callable[..., Any], + dim: Union[None, Hashable, Sequence[Hashable]] = None, + *, + axis: Union[None, int, Sequence[int]] = None, + keep_attrs: bool = None, + keepdims: bool = False, + **kwargs: Any, + ) -> "{obj}": + raise NotImplementedError() + + def _flox_reduce( + self, + dim: Union[None, Hashable, Sequence[Hashable]], + **kwargs, + ) -> "{obj}": + raise NotImplementedError()""" + TEMPLATE_REDUCTION_SIGNATURE = ''' def {method}( self, @@ -113,11 +167,7 @@ def {method}( These could include dask-specific kwargs like ``split_every``.""" NAN_CUM_METHODS = ["cumsum", "cumprod"] - -NUMERIC_ONLY_METHODS = [ - "cumsum", - "cumprod", -] +NUMERIC_ONLY_METHODS = ["cumsum", "cumprod"] _NUMERIC_ONLY_NOTES = "Non-numeric variables will be removed prior to reducing." ExtraKwarg = collections.namedtuple("ExtraKwarg", "docs kwarg call example") @@ -182,6 +232,7 @@ def __init__( docref, docref_description, example_call_preamble, + definition_preamble, see_also_obj=None, ): self.datastructure = datastructure @@ -190,7 +241,7 @@ def __init__( self.docref = docref self.docref_description = docref_description self.example_call_preamble = example_call_preamble - self.preamble = CLASS_PREAMBLE.format(obj=datastructure.name, cls=cls) + self.preamble = definition_preamble.format(obj=datastructure.name, cls=cls) if not see_also_obj: self.see_also_obj = self.datastructure.name else: @@ -268,6 +319,53 @@ def generate_example(self, method): >>> {calculation}(){extra_examples}""" +class GroupByReductionGenerator(ReductionGenerator): + def generate_code(self, method): + extra_kwargs = [kwarg.call for kwarg in method.extra_kwargs if kwarg.call] + + if self.datastructure.numeric_only: + extra_kwargs.append(f"numeric_only={method.numeric_only},") + + # numpy_groupies & flox do not support median + # https://github.com/ml31415/numpy-groupies/issues/43 + if method.name == "median": + indent = 12 + else: + indent = 16 + + if extra_kwargs: + extra_kwargs = textwrap.indent("\n" + "\n".join(extra_kwargs), indent * " ") + else: + extra_kwargs = "" + + if method.name == "median": + return f"""\ + return self.reduce( + duck_array_ops.{method.array_method}, + dim=dim,{extra_kwargs} + keep_attrs=keep_attrs, + **kwargs, + )""" + + else: + return f"""\ + if flox and OPTIONS["use_flox"] and contains_only_dask_or_numpy(self._obj): + return self._flox_reduce( + func="{method.name}", + dim=dim,{extra_kwargs} + # fill_value=fill_value, + keep_attrs=keep_attrs, + **kwargs, + ) + else: + return self.reduce( + duck_array_ops.{method.array_method}, + dim=dim,{extra_kwargs} + keep_attrs=keep_attrs, + **kwargs, + )""" + + class GenericReductionGenerator(ReductionGenerator): def generate_code(self, method): extra_kwargs = [kwarg.call for kwarg in method.extra_kwargs if kwarg.call] @@ -335,6 +433,7 @@ class DataStructure: docref_description="reduction or aggregation operations", example_call_preamble="", see_also_obj="DataArray", + definition_preamble=DEFAULT_PREAMBLE, ) DATAARRAY_GENERATOR = GenericReductionGenerator( cls="", @@ -344,39 +443,43 @@ class DataStructure: docref_description="reduction or aggregation operations", example_call_preamble="", see_also_obj="Dataset", + definition_preamble=DEFAULT_PREAMBLE, ) - -DATAARRAY_GROUPBY_GENERATOR = GenericReductionGenerator( +DATAARRAY_GROUPBY_GENERATOR = GroupByReductionGenerator( cls="GroupBy", datastructure=DATAARRAY_OBJECT, methods=REDUCTION_METHODS, docref="groupby", docref_description="groupby operations", example_call_preamble='.groupby("labels")', + definition_preamble=GROUPBY_PREAMBLE, ) -DATAARRAY_RESAMPLE_GENERATOR = GenericReductionGenerator( +DATAARRAY_RESAMPLE_GENERATOR = GroupByReductionGenerator( cls="Resample", datastructure=DATAARRAY_OBJECT, methods=REDUCTION_METHODS, docref="resampling", docref_description="resampling operations", example_call_preamble='.resample(time="3M")', + definition_preamble=RESAMPLE_PREAMBLE, ) -DATASET_GROUPBY_GENERATOR = GenericReductionGenerator( +DATASET_GROUPBY_GENERATOR = GroupByReductionGenerator( cls="GroupBy", datastructure=DATASET_OBJECT, methods=REDUCTION_METHODS, docref="groupby", docref_description="groupby operations", example_call_preamble='.groupby("labels")', + definition_preamble=GROUPBY_PREAMBLE, ) -DATASET_RESAMPLE_GENERATOR = GenericReductionGenerator( +DATASET_RESAMPLE_GENERATOR = GroupByReductionGenerator( cls="Resample", datastructure=DATASET_OBJECT, methods=REDUCTION_METHODS, docref="resampling", docref_description="resampling operations", example_call_preamble='.resample(time="3M")', + definition_preamble=RESAMPLE_PREAMBLE, ) @@ -386,6 +489,7 @@ class DataStructure: p = Path(os.getcwd()) filepath = p.parent / "xarray" / "xarray" / "core" / "_reductions.py" + # filepath = p.parent / "core" / "_reductions.py" # Run from script location with open(filepath, mode="w", encoding="utf-8") as f: f.write(MODULE_PREAMBLE + "\n") for gen in [ diff --git a/xarray/util/print_versions.py b/xarray/util/print_versions.py index 561126ea05f..b8689e3a18f 100755 --- a/xarray/util/print_versions.py +++ b/xarray/util/print_versions.py @@ -122,6 +122,8 @@ def show_versions(file=sys.stdout): ("cupy", lambda mod: mod.__version__), ("pint", lambda mod: mod.__version__), ("sparse", lambda mod: mod.__version__), + ("flox", lambda mod: mod.__version__), + ("numpy_groupies", lambda mod: mod.__version__), # xarray setup/test ("setuptools", lambda mod: mod.__version__), ("pip", lambda mod: mod.__version__),