"""Abstract classes and utility operations for building graph representations and
data loaders (known as Sequence objects in Keras).
Most users will not need to interact with this module."""
from abc import abstractmethod
from operator import itemgetter
from tensorflow.keras.utils import Sequence
import numpy as np
from megnet.utils.general import expand_1st, to_list
from megnet.utils.data import get_graphs_within_cutoff
from monty.json import MSONable
from megnet.data import local_env
from inspect import signature
from pymatgen.analysis.local_env import NearNeighbors
from pymatgen import Structure
from typing import Union, Dict, List, Any
[docs]class Converter(MSONable):
"""
Base class for atom or bond converter
"""
[docs] def convert(self, d: Any) -> Any:
raise NotImplementedError
[docs]class StructureGraph(MSONable):
"""
This is a base class for converting converting structure into graphs or model inputs
Methods to be implemented are follows:
1. convert(self, structure)
This is to convert a structure into a graph dictionary
2. get_input(self, structure)
This method convert a structure directly to a model input
3. get_flat_data(self, graphs, targets)
This method process graphs and targets pairs and output model input list.
"""
# TODO (wardlt): Consider making "num_*_features" funcs to simplify making a MEGNet model
def __init__(self,
nn_strategy: Union[str, NearNeighbors] = None,
atom_converter: Converter = None,
bond_converter: Converter = None,
**kwargs):
if isinstance(nn_strategy, str):
strategy = local_env.get(nn_strategy)
parameters = signature(strategy).parameters
param_dict = {i: j.default for i, j in parameters.items()}
for i, j in kwargs.items():
if i in param_dict:
setattr(self, i, j)
param_dict.update({i: j})
self.nn_strategy = strategy(**param_dict)
elif isinstance(nn_strategy, NearNeighbors):
self.nn_strategy = nn_strategy
elif nn_strategy is None:
self.nn_strategy = None
else:
raise RuntimeError("Strategy not valid")
self.atom_converter = atom_converter
self.bond_converter = bond_converter
if self.atom_converter is None:
self.atom_converter = self._get_dummy_converter()
if self.bond_converter is None:
self.bond_converter = self._get_dummy_converter()
[docs] def convert(self, structure: Structure, state_attributes: List = None) -> Dict:
"""
Take a pymatgen structure and convert it to a index-type graph representation
The graph will have node, distance, index1, index2, where node is a vector of Z number
of atoms in the structure, index1 and index2 mark the atom indices forming the bond and separated by
distance.
For state attributes, you can set structure.state = [[xx, xx]] beforehand or the algorithm would
take default [[0, 0]]
Args:
state_attributes: (list) state attributes
structure: (pymatgen structure)
(dictionary)
"""
state_attributes = state_attributes or getattr(structure, 'state', None) or [[0, 0]]
index1 = []
index2 = []
bonds = []
if self.nn_strategy is None:
raise RuntimeError("NearNeighbor strategy is not provided!")
for n, neighbors in enumerate(self.nn_strategy.get_all_nn_info(structure)):
index1.extend([n] * len(neighbors))
for neighbor in neighbors:
index2.append(neighbor['site_index'])
bonds.append(neighbor['weight'])
atoms = self.get_atom_features(structure)
if np.size(np.unique(index1)) < len(atoms):
raise RuntimeError("Isolated atoms found in the structure")
else:
return {'atom': atoms,
'bond': bonds,
'state': state_attributes,
'index1': index1,
'index2': index2
}
[docs] @staticmethod
def get_atom_features(structure) -> List[int]:
"""
Get atom features from structure, may be overwritten
Args:
structure: (Pymatgen.Structure) pymatgen structure
Returns:
List of atomic numbers
"""
return np.array([i.specie.Z for i in structure],
dtype='int32').tolist()
def __call__(self, structure: Structure) -> Dict:
return self.convert(structure)
[docs] def get_flat_data(self, graphs: List[Dict], targets: List = None) -> tuple:
"""
Expand the graph dictionary to form a list of features and targets tensors.
This is useful when the model is trained on assembled graphs on the fly.
Args:
graphs: (list of dictionary) list of graph dictionary for each structure
targets: (list of float or list) Optional: corresponding target
values for each structure
Returns:
tuple(node_features, edges_features, global_values, index1, index2, targets)
"""
output = [] # Will be a list of arrays
# Convert the graphs to matrices
for feature in ['atom', 'bond', 'state', 'index1', 'index2']:
output.append([np.array(x[feature]) for x in graphs])
# If needed, add the targets
if targets is not None:
output.append([to_list(t) for t in targets])
return tuple(output)
@staticmethod
def _get_dummy_converter() -> 'DummyConverter':
return DummyConverter()
[docs] def as_dict(self) -> Dict:
all_dict = super().as_dict()
if 'nn_strategy' in all_dict:
nn_strategy = all_dict.pop('nn_strategy')
all_dict.update({'nn_strategy': local_env.serialize(nn_strategy)})
return all_dict
[docs] @classmethod
def from_dict(cls, d: Dict) -> 'StructureGraph':
if 'nn_strategy' in d:
nn_strategy = d.pop('nn_strategy')
nn_strategy_obj = local_env.deserialize(nn_strategy)
d.update({'nn_strategy': nn_strategy_obj})
return super().from_dict(d)
return super().from_dict(d)
[docs]class StructureGraphFixedRadius(StructureGraph):
"""
This one uses a short cut to call find_points_in_spheres cython function in
pymatgen. It is orders of magnitude faster than previous implementations
"""
[docs] def convert(self, structure: Structure, state_attributes: List = None) -> Dict:
"""
Take a pymatgen structure and convert it to a index-type graph representation
The graph will have node, distance, index1, index2, where node is a vector of Z number
of atoms in the structure, index1 and index2 mark the atom indices forming the bond and separated by
distance.
For state attributes, you can set structure.state = [[xx, xx]] beforehand or the algorithm would
take default [[0, 0]]
Args:
state_attributes: (list) state attributes
structure: (pymatgen structure)
(dictionary)
"""
state_attributes = state_attributes or getattr(structure, 'state', None) or [[0, 0]]
atoms = self.get_atom_features(structure)
index1, index2, _, bonds = get_graphs_within_cutoff(structure, self.nn_strategy.cutoff)
if np.size(np.unique(index1)) < len(atoms):
raise RuntimeError("Isolated atoms found in the structure")
else:
return {'atom': atoms,
'bond': bonds,
'state': state_attributes,
'index1': index1,
'index2': index2
}
[docs] @classmethod
def from_structure_graph(cls, structure_graph: StructureGraph) -> 'StructureGraphFixedRadius':
return cls(nn_strategy=structure_graph.nn_strategy,
atom_converter=structure_graph.atom_converter,
bond_converter=structure_graph.bond_converter)
[docs]class DummyConverter(Converter):
"""
Dummy converter as a placeholder
"""
[docs] def convert(self, d: Any) -> Any:
return d
[docs]class EmbeddingMap(Converter):
"""
Convert an integer to a row vector in a feature matrix
Args:
feature_matrix: (np.ndarray) A matrix of shape (N, M)
"""
def __init__(self, feature_matrix: np.ndarray):
self.feature_matrix = np.array(feature_matrix)
[docs] def convert(self, int_array: np.ndarray) -> np.ndarray:
"""
convert atomic number to row vectors in the feature_matrix
Args:
int_array: (1d array) number array of length L
Returns
(matrix) L*M matrix with N the length of d and M the length of centers
"""
return self.feature_matrix[int_array]
[docs]class GaussianDistance(Converter):
"""
Expand distance with Gaussian basis sit at centers and with width 0.5.
Args:
centers: (np.array)
width: (float)
"""
def __init__(self, centers: np.ndarray = np.linspace(0, 5, 100), width=0.5):
self.centers = centers
self.width = width
[docs] def convert(self, d: np.ndarray) -> np.ndarray:
"""
expand distance vector d with given parameters
Args:
d: (1d array) distance array
Returns
(matrix) N*M matrix with N the length of d and M the length of centers
"""
d = np.array(d)
return np.exp(-(d[:, None] - self.centers[None, :]) ** 2 / self.width ** 2)
[docs]class BaseGraphBatchGenerator(Sequence):
"""Base class for classes that generate batches of training data for MEGNet.
Based on the Sequence class, which is the data loader equivalent for Keras.
Implementations of this base class must implement the :meth:`_generate_inputs`,
which generates the lists of graph descriptions for a batch.
The :meth:`process_atom_features` function and related functions are used to modify
the features for each atom, bond, and global features when creating a batch.
"""
def __init__(self, dataset_size: int, targets: np.ndarray,
batch_size: int = 128, shuffle: bool = True):
"""
Args:
dataset_size (int): Number of entries in dataset
targets (ndarray): Feature to be predicted for each network
batch_size (int): Maximum batch size
shuffle (bool): Whether to shuffle the data after each step
"""
if targets is not None:
self.targets = np.array(targets)
else:
self.targets = None
self.batch_size = batch_size
self.total_n = dataset_size
self.is_shuffle = shuffle
self.max_step = int(np.ceil(self.total_n / batch_size))
self.mol_index = np.arange(self.total_n)
if self.is_shuffle:
self.mol_index = np.random.permutation(self.mol_index)
def __len__(self) -> int:
return self.max_step
def _combine_graph_data(self,
feature_list_temp: List[np.ndarray],
connection_list_temp: List[np.ndarray],
global_list_temp: List[np.ndarray],
index1_temp: List[np.ndarray],
index2_temp: List[np.ndarray]) -> List:
"""Compile the matrices describing each graph into single matrices for the entire graph
Beyond concatenating the graph descriptions, this operation updates the indices of each
node to be sequential across all graphs so they are not duplicated between graphs
Args:
feature_list_temp ([ndarray]): List of features for each node
connection_list_temp ([ndarray]): List of features for each connection
global_list_temp ([ndarray]): List of global state for each graph
index1_temp ([ndarray]): List of indices for the start of each bond
index2_temp ([ndarray]): List of indices for the end of each bond
Returns:
(tuple): Input arrays describing the entire batch of networks:
- ndarray: Features for each node
- ndarray: Features for each connection
- ndarray: Global state for each graph
- ndarray: Indices for the start of each bond
- ndarray: Indices for the end of each bond
- ndarray: Index of graph associated with each node
- ndarray: Index of graph associated with each connection
"""
# get atom's structure id
gnode = []
for i, j in enumerate(feature_list_temp):
gnode += [i] * len(j)
# get bond features from a batch of structures
# get bond's structure id
gbond = []
for i, j in enumerate(connection_list_temp):
gbond += [i] * len(j)
# assemble atom features together
feature_list_temp = np.concatenate(feature_list_temp, axis=0)
feature_list_temp = self.process_atom_feature(feature_list_temp)
# assemble bond feature together
connection_list_temp = np.concatenate(connection_list_temp, axis=0)
connection_list_temp = self.process_bond_feature(connection_list_temp)
# assemble state feature together
global_list_temp = np.concatenate(global_list_temp, axis=0)
global_list_temp = self.process_state_feature(global_list_temp)
# assemble bond indices
index1 = []
index2 = []
offset_ind = 0
for ind1, ind2 in zip(index1_temp, index2_temp):
index1 += [i + offset_ind for i in ind1]
index2 += [i + offset_ind for i in ind2]
offset_ind += (max(ind1) + 1)
# Compile the inputs in needed order
inputs = [expand_1st(feature_list_temp),
expand_1st(connection_list_temp),
expand_1st(global_list_temp),
expand_1st(np.array(index1, dtype=np.int32)),
expand_1st(np.array(index2, dtype=np.int32)),
expand_1st(np.array(gnode, dtype=np.int32)),
expand_1st(np.array(gbond, dtype=np.int32))]
return inputs
[docs] def on_epoch_end(self):
if self.is_shuffle:
self.mol_index = np.random.permutation(self.mol_index)
[docs] def process_atom_feature(self, x: np.ndarray) -> np.ndarray:
return x
[docs] def process_bond_feature(self, x: np.ndarray) -> np.ndarray:
return x
[docs] def process_state_feature(self, x: np.ndarray) -> np.ndarray:
return x
def __getitem__(self, index: int) -> tuple:
# Get the indices for this batch
batch_index = self.mol_index[index * self.batch_size:(index + 1) * self.batch_size]
# Get the inputs for each batch
inputs = self._generate_inputs(batch_index)
# Make the graph data
inputs = self._combine_graph_data(*inputs)
# Return the batch
if self.targets is None:
return inputs
else:
# get targets
it = itemgetter(*batch_index)
target_temp = itemgetter_list(self.targets, batch_index)
target_temp = np.atleast_2d(target_temp)
return inputs, expand_1st(target_temp)
@abstractmethod
def _generate_inputs(self, batch_index: int) -> tuple:
"""Get the graph descriptions for each batch
Args:
batch_index ([int]): List of indices for training batch
Returns:
(tuple): Input arrays describing each network:
- [ndarray]: List of features for each node
- [ndarray]: List of features for each connection
- [ndarray]: List of global state for each graph
- [ndarray]: List of indices for the start of each bond
- [ndarray]: List of indices for the end of each bond
"""
pass
[docs]class GraphBatchGenerator(BaseGraphBatchGenerator):
"""
A generator class that assembles several structures (indicated by
batch_size) and form (x, y) pairs for model training.
Args:
atom_features: (list of np.array) list of atom feature matrix,
bond_features: (list of np.array) list of bond features matrix
state_features: (list of np.array) list of [1, G] state features,
where G is the global state feature dimension
index1_list: (list of integer) list of (M, ) one side atomic index of the bond,
M is different for different structures
index2_list: (list of integer) list of (M, ) the other side atomic
index of the bond, M is different for different structures,
but it has to be the same as the corresponding index1.
targets: (numpy array), N*1, where N is the number of structures
batch_size: (int) number of samples in a batch
"""
def __init__(self,
atom_features: List[np.ndarray],
bond_features: List[np.ndarray],
state_features: List[np.ndarray],
index1_list: List[int],
index2_list: List[int],
targets: np.ndarray = None,
batch_size: int = 128,
is_shuffle: bool = True):
super().__init__(len(atom_features), targets, batch_size, is_shuffle)
self.atom_features = atom_features
self.bond_features = bond_features
self.state_features = state_features
self.index1_list = index1_list
self.index2_list = index2_list
def _generate_inputs(self, batch_index: int) -> tuple:
"""Get the graph descriptions for each batch
Args:
batch_index ([int]): List of indices for training batch
Returns:
(tuple): Input arrays describe each network:
- [ndarray]: List of features for each nodes
- [ndarray]: List of features for each connection
- [ndarray]: List of global state for each graph
- [ndarray]: List of indices for the start of each bond
- [ndarray]: List of indices for the end of each bond
"""
# Get the features and connectivity lists for this batch
feature_list_temp = itemgetter_list(self.atom_features, batch_index)
connection_list_temp = itemgetter_list(self.bond_features, batch_index)
global_list_temp = itemgetter_list(self.state_features, batch_index)
index1_temp = itemgetter_list(self.index1_list, batch_index)
index2_temp = itemgetter_list(self.index2_list, batch_index)
return feature_list_temp, connection_list_temp, global_list_temp, index1_temp, index2_temp
[docs]class GraphBatchDistanceConvert(GraphBatchGenerator):
"""
Generate batch of structures with bond distance being expanded using a Expansor
Args:
atom_features: (list of np.array) list of atom feature matrix,
bond_features: (list of np.array) list of bond features matrix
state_features: (list of np.array) list of [1, G] state features, where G is the global state feature dimension
index1_list: (list of integer) list of (M, ) one side atomic index of the bond, M is different for different
structures
index2_list: (list of integer) list of (M, ) the other side atomic index of the bond, M is different for
different structures, but it has to be the same as the correponding index1.
targets: (numpy array), N*1, where N is the number of structures
batch_size: (int) number of samples in a batch
is_shuffle: (bool) whether to shuffle the structure, default to True
distance_converter: (bool) converter for processing the distances
"""
def __init__(self,
atom_features: List[np.ndarray],
bond_features: List[np.ndarray],
state_features: List[np.ndarray],
index1_list: List[int],
index2_list: List[int],
targets: np.ndarray = None,
batch_size: int = 128,
is_shuffle: bool = True,
distance_converter: Converter = None):
super().__init__(atom_features=atom_features,
bond_features=bond_features,
state_features=state_features,
index1_list=index1_list,
index2_list=index2_list,
targets=targets,
batch_size=batch_size,
is_shuffle=is_shuffle)
self.distance_converter = distance_converter
[docs] def process_bond_feature(self, x) -> np.ndarray:
return self.distance_converter.convert(x)
[docs]def itemgetter_list(l, indices: List) -> tuple:
"""
Get indices of l and return a tuple
Args:
l: (list)
indices: (list) indices
Returns:
(tuple)
"""
it = itemgetter(*indices)
if np.size(indices) == 1:
return it(l),
else:
return it(l)