Source code for megnet.utils.data

"""
Data utitlities
"""
from typing import Tuple

import numpy as np
from pymatgen.optimization.neighbors import find_points_in_spheres
from pymatgen import Structure, Molecule

from megnet.utils.typing import StructureOrMolecule


[docs]def get_graphs_within_cutoff(structure: StructureOrMolecule, cutoff: float = 5.0, numerical_tol: float = 1e-8) \ -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Get graph representations from structure within cutoff Args: structure (pymatgen Structure or molecule) cutoff (float): cutoff radius numerical_tol (float): numerical tolerance Returns: center_indices, neighbor_indices, images, distances """ if isinstance(structure, Structure): lattice_matrix = np.ascontiguousarray( np.array(structure.lattice.matrix), dtype=float) pbc = np.array([1, 1, 1], dtype=int) elif isinstance(structure, Molecule): lattice_matrix = np.array( [[1000.0, 0., 0.], [0., 1000., 0.], [0., 0., 1000.]], dtype=float) pbc = np.array([0, 0, 0], dtype=int) else: raise ValueError('structure type not supported') r = float(cutoff) cart_coords = np.ascontiguousarray( np.array(structure.cart_coords), dtype=float) center_indices, neighbor_indices, images, distances = \ find_points_in_spheres(cart_coords, cart_coords, r=r, pbc=pbc, lattice=lattice_matrix, tol=numerical_tol) exclude_self = (center_indices != neighbor_indices) | (distances > numerical_tol) return center_indices[exclude_self], neighbor_indices[exclude_self], \ images[exclude_self], distances[exclude_self]