Source code for megnet.utils.general
"""
Operation utilities on lists and arrays
"""
from collections import Iterable
from typing import Union, List
import numpy as np
[docs]def to_list(x: Union[Iterable, np.ndarray]) -> List:
"""
If x is not a list, convert it to list
"""
if isinstance(x, Iterable):
return list(x)
elif isinstance(x, np.ndarray):
return x.tolist() # noqa
else:
return [x]
[docs]def expand_1st(x: np.ndarray) -> np.ndarray:
"""
Adding an extra first dimension
Args:
x: (np.array)
Returns:
(np.array)
"""
return np.expand_dims(x, axis=0)
[docs]def fast_label_binarize(value: List, labels: List) -> List[int]:
"""Faster version of label binarize
`label_binarize` from scikit-learn is slow when run 1 label at a time.
`label_binarize` also is efficient for large numbers of classes, which is not
common in `megnet`
Args:
value: Value to encode
labels (list): Possible class values
Returns:
([int]): List of integers
"""
if len(labels) == 2:
return [int(value == labels[0])]
else:
output = [0] * len(labels)
if value in labels:
output[labels.index(value)] = 1
return output