Source code for megnet.utils.general

"""
Operation utilities on lists and arrays
"""
from collections import Iterable
from typing import Union, List

import numpy as np


[docs]def to_list(x: Union[Iterable, np.ndarray]) -> List: """ If x is not a list, convert it to list """ if isinstance(x, Iterable): return list(x) elif isinstance(x, np.ndarray): return x.tolist() # noqa else: return [x]
[docs]def expand_1st(x: np.ndarray) -> np.ndarray: """ Adding an extra first dimension Args: x: (np.array) Returns: (np.array) """ return np.expand_dims(x, axis=0)
[docs]def fast_label_binarize(value: List, labels: List) -> List[int]: """Faster version of label binarize `label_binarize` from scikit-learn is slow when run 1 label at a time. `label_binarize` also is efficient for large numbers of classes, which is not common in `megnet` Args: value: Value to encode labels (list): Possible class values Returns: ([int]): List of integers """ if len(labels) == 2: return [int(value == labels[0])] else: output = [0] * len(labels) if value in labels: output[labels.index(value)] = 1 return output