-
Notifications
You must be signed in to change notification settings - Fork 0
/
_util.py
104 lines (87 loc) · 3.27 KB
/
_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import numpy as np
def _normalize_axis(
axis: None | int | tuple[int, ...],
ndim: int,
) -> tuple[int, ...]:
if axis is None:
axis = tuple(range(ndim))
axis = np.core.numeric.normalize_axis_tuple(axis, ndim=ndim)
axis = tuple(~(~np.array(axis) % ndim))
return axis
def _normalize_input_output_coordinates(
coordinates_input: tuple[np.ndarray, ...],
coordinates_output: tuple[np.ndarray, ...],
axis_input: None | int | tuple[int, ...] = None,
axis_output: None | int | tuple[int, ...] = None,
) -> tuple[
tuple[np.ndarray, ...],
tuple[np.ndarray, ...],
tuple[int, ...],
tuple[int, ...],
tuple[int, ...],
tuple[int, ...],
tuple[int, ...],
]:
shape_coordinates_input = np.broadcast(*coordinates_input).shape
shape_coordinates_output = np.broadcast(*coordinates_output).shape
ndim_input = len(shape_coordinates_input)
ndim_output = len(shape_coordinates_output)
axis_input = _normalize_axis(axis_input, ndim=ndim_input)
axis_output = _normalize_axis(axis_output, ndim=ndim_output)
axis_input = sorted(axis_input, reverse=True)
axis_output = sorted(axis_output, reverse=True)
if len(axis_output) != len(axis_input):
raise ValueError(
f"The number of axes in `axis_output`, {axis_output}, "
f"must match the number of axes in `axis_input`, {axis_input}"
)
if len(coordinates_input) != len(axis_input):
raise ValueError(
f"The number of elements in `coordinates_input`, {len(coordinates_input)}, "
f"should match the number of axes in `axis_input`, {axis_input}"
)
if len(coordinates_output) != len(coordinates_input):
raise ValueError(
f"The number of elements in `coordinates_output`, {len(coordinates_output)}, "
f"should match the number of elements in `coordinates_input`, {len(coordinates_input)}"
)
axis_input_orthogonal = tuple(
ax for ax in _normalize_axis(None, ndim_input) if ax not in axis_input
)
axis_output_orthogonal = tuple(
ax for ax in _normalize_axis(None, ndim_output) if ax not in axis_output
)
shape_input_orthogonal = tuple(
shape_coordinates_input[ax] for ax in axis_input_orthogonal
)
shape_output_orthogonal = tuple(
shape_coordinates_output[ax] for ax in axis_output_orthogonal
)
shape_orthogonal = np.broadcast_shapes(
shape_input_orthogonal, shape_output_orthogonal
)
shape_input = list(shape_orthogonal)
for ax in reversed(axis_input):
ax = ax % ndim_input
shape_input.insert(ax, shape_coordinates_input[ax])
shape_input = tuple(shape_input)
shape_output = list(shape_orthogonal)
for ax in reversed(axis_output):
ax = ax % ndim_input
shape_output.insert(ax, shape_coordinates_output[ax])
shape_output = tuple(shape_output)
coordinates_input = tuple(
np.broadcast_to(coord, shape_input) for coord in coordinates_input
)
coordinates_output = tuple(
np.broadcast_to(coord, shape_output) for coord in coordinates_output
)
return (
coordinates_input,
coordinates_output,
axis_input,
axis_output,
shape_input,
shape_output,
shape_orthogonal,
)