Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

{full,zeros,ones}_like typing #6611

Merged
merged 9 commits into from
May 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 104 additions & 18 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
Iterator,
Mapping,
TypeVar,
Union,
overload,
)

import numpy as np
import pandas as pd

from . import dtypes, duck_array_ops, formatting, formatting_html, ops
from .npcompat import DTypeLike
from .npcompat import DTypeLike, DTypeLikeSave
from .options import OPTIONS, _get_keep_attrs
from .pycompat import is_duck_dask_array
from .rolling_exp import RollingExp
Expand Down Expand Up @@ -1577,26 +1578,45 @@ def __getitem__(self, value):
raise NotImplementedError()


DTypeMaybeMapping = Union[DTypeLikeSave, Mapping[Any, DTypeLikeSave]]


@overload
def full_like(
other: Dataset,
fill_value,
dtype: DTypeLike | Mapping[Any, DTypeLike] = None,
) -> Dataset:
def full_like(other: DataArray, fill_value: Any, dtype: DTypeLikeSave) -> DataArray:
...


@overload
def full_like(other: Dataset, fill_value: Any, dtype: DTypeMaybeMapping) -> Dataset:
...


@overload
def full_like(other: DataArray, fill_value, dtype: DTypeLike = None) -> DataArray:
def full_like(other: Variable, fill_value: Any, dtype: DTypeLikeSave) -> Variable:
...


@overload
def full_like(other: Variable, fill_value, dtype: DTypeLike = None) -> Variable:
def full_like(
other: Dataset | DataArray, fill_value: Any, dtype: DTypeMaybeMapping = None
) -> Dataset | DataArray:
...


def full_like(other, fill_value, dtype=None):
@overload
def full_like(
other: Dataset | DataArray | Variable,
fill_value: Any,
dtype: DTypeMaybeMapping = None,
) -> Dataset | DataArray | Variable:
...


def full_like(
other: Dataset | DataArray | Variable,
fill_value: Any,
dtype: DTypeMaybeMapping = None,
) -> Dataset | DataArray | Variable:
"""Return a new object with the same shape and type as a given object.

Parameters
Expand Down Expand Up @@ -1711,26 +1731,26 @@ def full_like(other, fill_value, dtype=None):
f"fill_value must be scalar or, for datasets, a dict-like. Received {fill_value} instead."
)

if not isinstance(other, Dataset) and isinstance(dtype, Mapping):
raise ValueError(
"'dtype' cannot be dict-like when passing a DataArray or Variable"
)

if isinstance(other, Dataset):
if not isinstance(fill_value, dict):
fill_value = {k: fill_value for k in other.data_vars.keys()}

dtype_: Mapping[Any, DTypeLikeSave]
if not isinstance(dtype, Mapping):
dtype_ = {k: dtype for k in other.data_vars.keys()}
else:
dtype_ = dtype

data_vars = {
k: _full_like_variable(v, fill_value.get(k, dtypes.NA), dtype_.get(k, None))
k: _full_like_variable(
v.variable, fill_value.get(k, dtypes.NA), dtype_.get(k, None)
)
for k, v in other.data_vars.items()
}
return Dataset(data_vars, coords=other.coords, attrs=other.attrs)
elif isinstance(other, DataArray):
if isinstance(dtype, Mapping):
raise ValueError("'dtype' cannot be dict-like when passing a DataArray")
headtr1ck marked this conversation as resolved.
Show resolved Hide resolved
return DataArray(
_full_like_variable(other.variable, fill_value, dtype),
dims=other.dims,
Expand All @@ -1739,12 +1759,16 @@ def full_like(other, fill_value, dtype=None):
name=other.name,
)
elif isinstance(other, Variable):
if isinstance(dtype, Mapping):
raise ValueError("'dtype' cannot be dict-like when passing a Variable")
return _full_like_variable(other, fill_value, dtype)
else:
raise TypeError("Expected DataArray, Dataset, or Variable")


def _full_like_variable(other, fill_value, dtype: DTypeLike = None):
def _full_like_variable(
other: Variable, fill_value: Any, dtype: DTypeLike = None
) -> Variable:
"""Inner function of full_like, where other must be a variable"""
from .variable import Variable

Expand All @@ -1765,7 +1789,38 @@ def _full_like_variable(other, fill_value, dtype: DTypeLike = None):
return Variable(dims=other.dims, data=data, attrs=other.attrs)


def zeros_like(other, dtype: DTypeLike = None):
@overload
def zeros_like(other: DataArray, dtype: DTypeLikeSave) -> DataArray:
...


@overload
def zeros_like(other: Dataset, dtype: DTypeMaybeMapping) -> Dataset:
...


@overload
def zeros_like(other: Variable, dtype: DTypeLikeSave) -> Variable:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For these ones which are generic over DataArray and Variable, we could define a TypeVar for them (but no need in this PR, and I'm not even sure whether it's a broader issue than just the *_like.

...


@overload
def zeros_like(
other: Dataset | DataArray, dtype: DTypeMaybeMapping = None
) -> Dataset | DataArray:
Comment on lines +1807 to +1810
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why this overload is required? I had thought that DataArray couldn't take a DTypeMaybeMapping? (and same principle for the next two)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to adjust these locally. An issue I hit is that so much of our typing is on Dataset rather than T_Dataset and changing one seems to require changing its dependencies, which makes it difficult to change gradually. 20 minutes into me trying to change this, I had 31 errors in 6 files!

To the extent you know which parts are Perfect vs Not-perfect-but-required-to-pass: If you want to add a comment for the ones the latter, that will make it easier for future travelers to know why things are as they are and hopefully change them.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was struggling a lot with TypeVar definitions.
Some other functions (like polyval) are calling zeros_like with DataArray | Dataset.

This means that a "simple" TypeVar["T", DataArray, Dataset] will not work.

I tried using TypeVar["T", bound=DataArray|Dataset] (recommended by some mypy people) but then the if isinstance(x Dataset) were causing problems (still not sure if that is a mypy bug or intended, my TypeVar knowledge is not good enough for that)...

So this was the only solution I could get to work.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, the code will actually create a plain e.g. DataArray, so typevars with bounds are actually wrong here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I think the proposed code is a big upgrade, and we can refine towards perfection in the future...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW I found this SO answer helpful in clarifying the difference — i saw that I had upvoted it before — but I'm still not confident in how we should design these methods.

...


@overload
def zeros_like(
other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None
) -> Dataset | DataArray | Variable:
...


def zeros_like(
other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None
) -> Dataset | DataArray | Variable:
"""Return a new object of zeros with the same shape and
type as a given dataarray or dataset.

Expand Down Expand Up @@ -1821,7 +1876,38 @@ def zeros_like(other, dtype: DTypeLike = None):
return full_like(other, 0, dtype)


def ones_like(other, dtype: DTypeLike = None):
@overload
def ones_like(other: DataArray, dtype: DTypeLikeSave) -> DataArray:
...


@overload
def ones_like(other: Dataset, dtype: DTypeMaybeMapping) -> Dataset:
...


@overload
def ones_like(other: Variable, dtype: DTypeLikeSave) -> Variable:
...


@overload
def ones_like(
other: Dataset | DataArray, dtype: DTypeMaybeMapping = None
) -> Dataset | DataArray:
...


@overload
def ones_like(
other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None
) -> Dataset | DataArray | Variable:
...


def ones_like(
other: Dataset | DataArray | Variable, dtype: DTypeMaybeMapping = None
) -> Dataset | DataArray | Variable:
"""Return a new object of ones with the same shape and
type as a given dataarray or dataset.

Expand Down
4 changes: 2 additions & 2 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1905,7 +1905,7 @@ def polyval(
coeffs = coeffs.reindex(
{degree_dim: np.arange(max_deg + 1)}, fill_value=0, copy=False
)
coord = _ensure_numeric(coord) # type: ignore # https://github.com/python/mypy/issues/1533 ?
coord = _ensure_numeric(coord)

# using Horner's method
# https://en.wikipedia.org/wiki/Horner%27s_method
Expand All @@ -1917,7 +1917,7 @@ def polyval(
return res


def _ensure_numeric(data: T_Xarray) -> T_Xarray:
def _ensure_numeric(data: Dataset | DataArray) -> Dataset | DataArray:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the kind of func that would be nice at some point to make generic; with the proposed code we lose whether it's a Dataset vs. DataArray. (fine to add as a comment / TODO tho)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I failed to make it work with Typevars since it is called with DataArray|Dataset :(

"""Converts all datetime64 variables to float64

Parameters
Expand Down
22 changes: 15 additions & 7 deletions xarray/core/missing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import datetime as dt
import warnings
from functools import partial
from numbers import Number
from typing import Any, Callable, Dict, Hashable, Sequence, Union
from typing import TYPE_CHECKING, Any, Callable, Hashable, Sequence

import numpy as np
import pandas as pd
Expand All @@ -17,8 +19,14 @@
from .utils import OrderedSet, is_scalar
from .variable import Variable, broadcast_variables

if TYPE_CHECKING:
from .dataarray import DataArray
from .dataset import Dataset


def _get_nan_block_lengths(obj, dim: Hashable, index: Variable):
def _get_nan_block_lengths(
obj: Dataset | DataArray | Variable, dim: Hashable, index: Variable
):
"""
Return an object where each NaN element in 'obj' is replaced by the
length of the gap the element is in.
Expand Down Expand Up @@ -48,8 +56,8 @@ def _get_nan_block_lengths(obj, dim: Hashable, index: Variable):
class BaseInterpolator:
"""Generic interpolator class for normalizing interpolation methods"""

cons_kwargs: Dict[str, Any]
call_kwargs: Dict[str, Any]
cons_kwargs: dict[str, Any]
call_kwargs: dict[str, Any]
f: Callable
method: str

Expand Down Expand Up @@ -213,7 +221,7 @@ def _apply_over_vars_with_dim(func, self, dim=None, **kwargs):


def get_clean_interp_index(
arr, dim: Hashable, use_coordinate: Union[str, bool] = True, strict: bool = True
arr, dim: Hashable, use_coordinate: str | bool = True, strict: bool = True
):
"""Return index to use for x values in interpolation or curve fitting.

Expand Down Expand Up @@ -300,10 +308,10 @@ def get_clean_interp_index(
def interp_na(
self,
dim: Hashable = None,
use_coordinate: Union[bool, str] = True,
use_coordinate: bool | str = True,
method: str = "linear",
limit: int = None,
max_gap: Union[int, float, str, pd.Timedelta, np.timedelta64, dt.timedelta] = None,
max_gap: int | float | str | pd.Timedelta | np.timedelta64 | dt.timedelta = None,
keep_attrs: bool = None,
**kwargs,
):
Expand Down
54 changes: 52 additions & 2 deletions xarray/core/npcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,49 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from typing import TYPE_CHECKING, Any, Literal, Sequence, TypeVar, Union
from typing import (
TYPE_CHECKING,
Any,
List,
Literal,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)

import numpy as np
from packaging.version import Version

# Type annotations stubs
try:
from numpy.typing import ArrayLike, DTypeLike
from numpy.typing._dtype_like import _DTypeLikeNested, _ShapeLike, _SupportsDType

# Xarray requires a Mapping[Hashable, dtype] in many places which
# conflics with numpys own DTypeLike (with dtypes for fields).
# https://numpy.org/devdocs/reference/typing.html#numpy.typing.DTypeLike
# This is a copy of this DTypeLike that allows only non-Mapping dtypes.
headtr1ck marked this conversation as resolved.
Show resolved Hide resolved
DTypeLikeSave = Union[
np.dtype,
# default data type (float64)
None,
# array-scalar types and generic types
Type[Any],
# character codes, type strings or comma-separated fields, e.g., 'float64'
str,
# (flexible_dtype, itemsize)
Tuple[_DTypeLikeNested, int],
# (fixed_dtype, shape)
Tuple[_DTypeLikeNested, _ShapeLike],
# (base_dtype, new_dtype)
Tuple[_DTypeLikeNested, _DTypeLikeNested],
# because numpy does the same?
List[Any],
# anything with a dtype attribute
_SupportsDType[np.dtype],
]
except ImportError:
# fall back for numpy < 1.20, ArrayLike adapted from numpy.typing._array_like
from typing import Protocol
Expand All @@ -46,8 +81,14 @@ class _SupportsArray(Protocol):
def __array__(self) -> np.ndarray:
...

class _SupportsDTypeFallback(Protocol):
@property
def dtype(self) -> np.dtype:
...

else:
_SupportsArray = Any
_SupportsDTypeFallback = Any

_T = TypeVar("_T")
_NestedSequence = Union[
Expand All @@ -72,7 +113,16 @@ def __array__(self) -> np.ndarray:
# with the same name (ArrayLike and DTypeLike from the try block)
ArrayLike = _ArrayLikeFallback # type: ignore
# fall back for numpy < 1.20
DTypeLike = Union[np.dtype, str] # type: ignore[misc]
DTypeLikeSave = Union[ # type: ignore[misc]
np.dtype,
str,
None,
Type[Any],
Tuple[Any, Any],
List[Any],
_SupportsDTypeFallback,
]
DTypeLike = DTypeLikeSave # type: ignore[misc]


if Version(np.__version__) >= Version("1.20.0"):
Expand Down
Loading