diff --git a/xarray/core/common.py b/xarray/core/common.py index 75518716870..626114a1f0f 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -13,6 +13,7 @@ Iterator, Mapping, TypeVar, + Union, overload, ) @@ -20,7 +21,7 @@ import pandas as pd from . import dtypes, duck_array_ops, formatting, formatting_html, ops -from .npcompat import DTypeLike +from .npcompat import DTypeLike, DTypeLikeSave from .options import OPTIONS, _get_keep_attrs from .pycompat import is_duck_dask_array from .rolling_exp import RollingExp @@ -1577,26 +1578,45 @@ def __getitem__(self, value): raise NotImplementedError() +DTypeMaybeMapping = Union[DTypeLikeSave, Mapping[Any, DTypeLikeSave]] + + @overload -def full_like( - other: Dataset, - fill_value, - dtype: DTypeLike | Mapping[Any, DTypeLike] = None, -) -> Dataset: +def full_like(other: DataArray, fill_value: Any, dtype: DTypeLikeSave) -> DataArray: + ... + + +@overload +def full_like(other: Dataset, fill_value: Any, dtype: DTypeMaybeMapping) -> Dataset: ... @overload -def full_like(other: DataArray, fill_value, dtype: DTypeLike = None) -> DataArray: +def full_like(other: Variable, fill_value: Any, dtype: DTypeLikeSave) -> Variable: ... @overload -def full_like(other: Variable, fill_value, dtype: DTypeLike = None) -> Variable: +def full_like( + other: Dataset | DataArray, fill_value: Any, dtype: DTypeMaybeMapping = None +) -> Dataset | DataArray: ... -def full_like(other, fill_value, dtype=None): +@overload +def full_like( + other: Dataset | DataArray | Variable, + fill_value: Any, + dtype: DTypeMaybeMapping = None, +) -> Dataset | DataArray | Variable: + ... + + +def full_like( + other: Dataset | DataArray | Variable, + fill_value: Any, + dtype: DTypeMaybeMapping = None, +) -> Dataset | DataArray | Variable: """Return a new object with the same shape and type as a given object. Parameters @@ -1711,26 +1731,26 @@ def full_like(other, fill_value, dtype=None): f"fill_value must be scalar or, for datasets, a dict-like. Received {fill_value} instead." ) - if not isinstance(other, Dataset) and isinstance(dtype, Mapping): - raise ValueError( - "'dtype' cannot be dict-like when passing a DataArray or Variable" - ) - if isinstance(other, Dataset): if not isinstance(fill_value, dict): fill_value = {k: fill_value for k in other.data_vars.keys()} + dtype_: Mapping[Any, DTypeLikeSave] if not isinstance(dtype, Mapping): dtype_ = {k: dtype for k in other.data_vars.keys()} else: dtype_ = dtype data_vars = { - k: _full_like_variable(v, fill_value.get(k, dtypes.NA), dtype_.get(k, None)) + k: _full_like_variable( + v.variable, fill_value.get(k, dtypes.NA), dtype_.get(k, None) + ) for k, v in other.data_vars.items() } return Dataset(data_vars, coords=other.coords, attrs=other.attrs) elif isinstance(other, DataArray): + if isinstance(dtype, Mapping): + raise ValueError("'dtype' cannot be dict-like when passing a DataArray") return DataArray( _full_like_variable(other.variable, fill_value, dtype), dims=other.dims, @@ -1739,12 +1759,16 @@ def full_like(other, fill_value, dtype=None): name=other.name, ) elif isinstance(other, Variable): + if isinstance(dtype, Mapping): + raise ValueError("'dtype' cannot be dict-like when passing a Variable") return _full_like_variable(other, fill_value, dtype) else: raise TypeError("Expected DataArray, Dataset, or Variable") -def _full_like_variable(other, fill_value, dtype: DTypeLike = None): +def _full_like_variable( + other: Variable, fill_value: Any, dtype: DTypeLike = None +) -> Variable: """Inner function of full_like, where other must be a variable""" from .variable import Variable @@ -1765,7 +1789,38 @@ def _full_like_variable(other, fill_value, dtype: DTypeLike = None): return Variable(dims=other.dims, data=data, attrs=other.attrs) -def zeros_like(other, dtype: DTypeLike = None): +@overload +def zeros_like(other: DataArray, dtype: DTypeLikeSave) -> DataArray: + ... + + +@overload +def zeros_like(other: Dataset, dtype: DTypeMaybeMapping) -> Dataset: + ... + + +@overload +def zeros_like(other: Variable, dtype: DTypeLikeSave) -> Variable: + ... + + +@overload +def zeros_like( + other: Dataset | DataArray, dtype: DTypeMaybeMapping = None +) -> Dataset | DataArray: + ... + + +@overload +def zeros_like( + other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None +) -> Dataset | DataArray | Variable: + ... + + +def zeros_like( + other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None +) -> Dataset | DataArray | Variable: """Return a new object of zeros with the same shape and type as a given dataarray or dataset. @@ -1821,7 +1876,38 @@ def zeros_like(other, dtype: DTypeLike = None): return full_like(other, 0, dtype) -def ones_like(other, dtype: DTypeLike = None): +@overload +def ones_like(other: DataArray, dtype: DTypeLikeSave) -> DataArray: + ... + + +@overload +def ones_like(other: Dataset, dtype: DTypeMaybeMapping) -> Dataset: + ... + + +@overload +def ones_like(other: Variable, dtype: DTypeLikeSave) -> Variable: + ... + + +@overload +def ones_like( + other: Dataset | DataArray, dtype: DTypeMaybeMapping = None +) -> Dataset | DataArray: + ... + + +@overload +def ones_like( + other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None +) -> Dataset | DataArray | Variable: + ... + + +def ones_like( + other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None +) -> Dataset | DataArray | Variable: """Return a new object of ones with the same shape and type as a given dataarray or dataset. diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 81b5e3fd915..8bd103af558 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1905,7 +1905,7 @@ def polyval( coeffs = coeffs.reindex( {degree_dim: np.arange(max_deg + 1)}, fill_value=0, copy=False ) - coord = _ensure_numeric(coord) # type: ignore # https://github.com/python/mypy/issues/1533 ? + coord = _ensure_numeric(coord) # using Horner's method # https://en.wikipedia.org/wiki/Horner%27s_method @@ -1917,7 +1917,7 @@ def polyval( return res -def _ensure_numeric(data: T_Xarray) -> T_Xarray: +def _ensure_numeric(data: Dataset | DataArray) -> Dataset | DataArray: """Converts all datetime64 variables to float64 Parameters diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 3d33631bebd..2e869dbe675 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import datetime as dt import warnings from functools import partial from numbers import Number -from typing import Any, Callable, Dict, Hashable, Sequence, Union +from typing import TYPE_CHECKING, Any, Callable, Hashable, Sequence import numpy as np import pandas as pd @@ -17,8 +19,14 @@ from .utils import OrderedSet, is_scalar from .variable import Variable, broadcast_variables +if TYPE_CHECKING: + from .dataarray import DataArray + from .dataset import Dataset + -def _get_nan_block_lengths(obj, dim: Hashable, index: Variable): +def _get_nan_block_lengths( + obj: Dataset | DataArray | Variable, dim: Hashable, index: Variable +): """ Return an object where each NaN element in 'obj' is replaced by the length of the gap the element is in. @@ -48,8 +56,8 @@ def _get_nan_block_lengths(obj, dim: Hashable, index: Variable): class BaseInterpolator: """Generic interpolator class for normalizing interpolation methods""" - cons_kwargs: Dict[str, Any] - call_kwargs: Dict[str, Any] + cons_kwargs: dict[str, Any] + call_kwargs: dict[str, Any] f: Callable method: str @@ -213,7 +221,7 @@ def _apply_over_vars_with_dim(func, self, dim=None, **kwargs): def get_clean_interp_index( - arr, dim: Hashable, use_coordinate: Union[str, bool] = True, strict: bool = True + arr, dim: Hashable, use_coordinate: str | bool = True, strict: bool = True ): """Return index to use for x values in interpolation or curve fitting. @@ -300,10 +308,10 @@ def get_clean_interp_index( def interp_na( self, dim: Hashable = None, - use_coordinate: Union[bool, str] = True, + use_coordinate: bool | str = True, method: str = "linear", limit: int = None, - max_gap: Union[int, float, str, pd.Timedelta, np.timedelta64, dt.timedelta] = None, + max_gap: int | float | str | pd.Timedelta | np.timedelta64 | dt.timedelta = None, keep_attrs: bool = None, **kwargs, ): diff --git a/xarray/core/npcompat.py b/xarray/core/npcompat.py index b5b98052fe9..85a8f88aba6 100644 --- a/xarray/core/npcompat.py +++ b/xarray/core/npcompat.py @@ -28,7 +28,17 @@ # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + List, + Literal, + Sequence, + Tuple, + Type, + TypeVar, + Union, +) import numpy as np from packaging.version import Version @@ -36,6 +46,31 @@ # Type annotations stubs try: from numpy.typing import ArrayLike, DTypeLike + from numpy.typing._dtype_like import _DTypeLikeNested, _ShapeLike, _SupportsDType + + # Xarray requires a Mapping[Hashable, dtype] in many places which + # conflics with numpys own DTypeLike (with dtypes for fields). + # https://numpy.org/devdocs/reference/typing.html#numpy.typing.DTypeLike + # This is a copy of this DTypeLike that allows only non-Mapping dtypes. + DTypeLikeSave = Union[ + np.dtype, + # default data type (float64) + None, + # array-scalar types and generic types + Type[Any], + # character codes, type strings or comma-separated fields, e.g., 'float64' + str, + # (flexible_dtype, itemsize) + Tuple[_DTypeLikeNested, int], + # (fixed_dtype, shape) + Tuple[_DTypeLikeNested, _ShapeLike], + # (base_dtype, new_dtype) + Tuple[_DTypeLikeNested, _DTypeLikeNested], + # because numpy does the same? + List[Any], + # anything with a dtype attribute + _SupportsDType[np.dtype], + ] except ImportError: # fall back for numpy < 1.20, ArrayLike adapted from numpy.typing._array_like from typing import Protocol @@ -46,8 +81,14 @@ class _SupportsArray(Protocol): def __array__(self) -> np.ndarray: ... + class _SupportsDTypeFallback(Protocol): + @property + def dtype(self) -> np.dtype: + ... + else: _SupportsArray = Any + _SupportsDTypeFallback = Any _T = TypeVar("_T") _NestedSequence = Union[ @@ -72,7 +113,16 @@ def __array__(self) -> np.ndarray: # with the same name (ArrayLike and DTypeLike from the try block) ArrayLike = _ArrayLikeFallback # type: ignore # fall back for numpy < 1.20 - DTypeLike = Union[np.dtype, str] # type: ignore[misc] + DTypeLikeSave = Union[ # type: ignore[misc] + np.dtype, + str, + None, + Type[Any], + Tuple[Any, Any], + List[Any], + _SupportsDTypeFallback, + ] + DTypeLike = DTypeLikeSave # type: ignore[misc] if Version(np.__version__) >= Version("1.20.0"): diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 263237d9d30..950f15e91df 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -5559,7 +5559,7 @@ def test_binary_op_join_setting(self): actual = ds1 + ds2 assert_equal(actual, expected) - def test_full_like(self): + def test_full_like(self) -> None: # For more thorough tests, see test_variable.py # Note: testing data_vars with mismatched dtypes ds = Dataset( @@ -5572,8 +5572,9 @@ def test_full_like(self): actual = full_like(ds, 2) expected = ds.copy(deep=True) - expected["d1"].values = [2, 2, 2] - expected["d2"].values = [2.0, 2.0, 2.0] + # https://github.com/python/mypy/issues/3004 + expected["d1"].values = [2, 2, 2] # type: ignore + expected["d2"].values = [2.0, 2.0, 2.0] # type: ignore assert expected["d1"].dtype == int assert expected["d2"].dtype == float assert_identical(expected, actual) @@ -5581,8 +5582,8 @@ def test_full_like(self): # override dtype actual = full_like(ds, fill_value=True, dtype=bool) expected = ds.copy(deep=True) - expected["d1"].values = [True, True, True] - expected["d2"].values = [True, True, True] + expected["d1"].values = [True, True, True] # type: ignore + expected["d2"].values = [True, True, True] # type: ignore assert expected["d1"].dtype == bool assert expected["d2"].dtype == bool assert_identical(expected, actual) @@ -5788,7 +5789,7 @@ def test_ipython_key_completion(self): ds.data_vars[item] # should not raise assert sorted(actual) == sorted(expected) - def test_polyfit_output(self): + def test_polyfit_output(self) -> None: ds = create_test_data(seed=1) out = ds.polyfit("dim2", 2, full=False) @@ -5801,7 +5802,7 @@ def test_polyfit_output(self): out = ds.polyfit("time", 2) assert len(out.data_vars) == 0 - def test_polyfit_warnings(self): + def test_polyfit_warnings(self) -> None: ds = create_test_data(seed=1) with warnings.catch_warnings(record=True) as ws: diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 0168f19b921..886b0360c04 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -2480,7 +2480,7 @@ def test_datetime(self): assert np.ndarray == type(actual) assert np.dtype("datetime64[ns]") == actual.dtype - def test_full_like(self): + def test_full_like(self) -> None: # For more thorough tests, see test_variable.py orig = Variable( dims=("x", "y"), data=[[1.5, 2.0], [3.1, 4.3]], attrs={"foo": "bar"} @@ -2503,7 +2503,7 @@ def test_full_like(self): full_like(orig, True, dtype={"x": bool}) @requires_dask - def test_full_like_dask(self): + def test_full_like_dask(self) -> None: orig = Variable( dims=("x", "y"), data=[[1.5, 2.0], [3.1, 4.3]], attrs={"foo": "bar"} ).chunk(((1, 1), (2,))) @@ -2534,14 +2534,14 @@ def check(actual, expect_dtype, expect_values): else: assert not isinstance(v, np.ndarray) - def test_zeros_like(self): + def test_zeros_like(self) -> None: orig = Variable( dims=("x", "y"), data=[[1.5, 2.0], [3.1, 4.3]], attrs={"foo": "bar"} ) assert_identical(zeros_like(orig), full_like(orig, 0)) assert_identical(zeros_like(orig, dtype=int), full_like(orig, 0, dtype=int)) - def test_ones_like(self): + def test_ones_like(self) -> None: orig = Variable( dims=("x", "y"), data=[[1.5, 2.0], [3.1, 4.3]], attrs={"foo": "bar"} )