Source code for megnet.models.base

"""
Implements basic GraphModels.
"""

import os
from warnings import warn
from typing import Dict, List, Union

import numpy as np
from monty.serialization import dumpfn, loadfn

from tensorflow.keras.backend import int_shape
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.models import Model

from megnet.callbacks import ModelCheckpointMAE, ManualStop, ReduceLRUponNan
from megnet.data.graph import GraphBatchDistanceConvert, GraphBatchGenerator, StructureGraph
from megnet.utils.preprocessing import DummyScaler, Scaler

from pymatgen import Structure


[docs]class GraphModel: """ Composition of keras model and converter class for transfering structure object to input tensors. We add methods to train the model from (structures, targets) pairs """ def __init__(self, model: Model, graph_converter: StructureGraph, target_scaler: Scaler = DummyScaler(), metadata: Dict = None, **kwargs): """ Args: model: (keras model) graph_converter: (object) a object that turns a structure to a graph, check `megnet.data.crystal` target_scaler: (object) a scaler object for converting targets, check `megnet.utils.preprocessing` metadata: (dict) An optional dict of metadata associated with the model. Recommended to incorporate some basic information such as units, MAE performance, etc. """ self.model = model self.graph_converter = graph_converter self.target_scaler = target_scaler self.metadata = metadata or {} def __getattr__(self, p): return getattr(self.model, p)
[docs] def train(self, train_structures: List[Structure], train_targets: List[float], validation_structures: List[Structure] = None, validation_targets: List[float] = None, epochs: int = 1000, batch_size: int = 128, verbose: int = 1, callbacks: List[Callback] = None, scrub_failed_structures: bool = False, prev_model: str = None, save_checkpoint: bool = True, automatic_correction: bool = True, lr_scaling_factor: float = 0.5, patience: int = 500, **kwargs) -> None: """ Args: train_structures: (list) list of pymatgen structures train_targets: (list) list of target values validation_structures: (list) list of pymatgen structures as validation validation_targets: (list) list of validation targets epochs: (int) number of epochs batch_size: (int) training batch size verbose: (int) keras fit verbose, 0 no progress bar, 1 only at the epoch end and 2 every batch callbacks: (list) megnet or keras callback functions for training scrub_failed_structures: (bool) whether to scrub structures with failed graph computation prev_model: (str) file name for previously saved model save_checkpoint: (bool) whether to save checkpoint automatic_correction: (bool) correct nan errors lr_scaling_factor: (float, less than 1) scale the learning rate down when nan loss encountered patience: (int) patience for early stopping **kwargs: """ train_graphs, train_targets = self.get_all_graphs_targets(train_structures, train_targets, scrub_failed_structures=scrub_failed_structures) if validation_structures is not None: val_graphs, validation_targets = self.get_all_graphs_targets( validation_structures, validation_targets, scrub_failed_structures=scrub_failed_structures) else: val_graphs = None self.train_from_graphs(train_graphs, train_targets, validation_graphs=val_graphs, validation_targets=validation_targets, epochs=epochs, batch_size=batch_size, verbose=verbose, callbacks=callbacks, prev_model=prev_model, lr_scaling_factor=lr_scaling_factor, patience=patience, save_checkpoint=save_checkpoint, automatic_correction=automatic_correction, **kwargs )
[docs] def train_from_graphs(self, train_graphs: List[Dict], train_targets: List[float], validation_graphs: List[Dict] = None, validation_targets: List[float] = None, epochs: int = 1000, batch_size: int = 128, verbose: int = 1, callbacks: List[Callback] = None, prev_model: str = None, lr_scaling_factor: float = 0.5, patience: int = 500, save_checkpoint: bool = True, automatic_correction: bool = True, **kwargs ) -> None: """ Args: train_graphs: (list) list of graph dictionaries train_targets: (list) list of target values validation_graphs: (list) list of graphs as validation validation_targets: (list) list of validation targets epochs: (int) number of epochs batch_size: (int) training batch size verbose: (int) keras fit verbose, 0 no progress bar, 1 only at the epoch end and 2 every batch callbacks: (list) megnet or keras callback functions for training prev_model: (str) file name for previously saved model lr_scaling_factor: (float, less than 1) scale the learning rate down when nan loss encountered patience: (int) patience for early stopping save_checkpoint: (bool) whether to save checkpoint automatic_correction: (bool) correct nan errors **kwargs: """ # load from saved model if prev_model: self.load_weights(prev_model) is_classification = 'entropy' in self.model.loss monitor = 'val_acc' if is_classification else 'val_mae' mode = 'max' if is_classification else 'min' dirname = kwargs.pop('dirname', 'callback') if not os.path.isdir(dirname): os.makedirs(dirname) if callbacks is None: # with this call back you can stop the model training by `touch STOP` callbacks = [ManualStop()] train_nb_atoms = [len(i['atom']) for i in train_graphs] train_targets = [self.target_scaler.transform(i, j) for i, j in zip(train_targets, train_nb_atoms)] if validation_graphs is not None: filepath = os.path.join(dirname, '%s_{epoch:05d}_{%s:.6f}.hdf5' % (monitor, monitor)) val_nb_atoms = [len(i['atom']) for i in validation_graphs] validation_targets = [self.target_scaler.transform(i, j) for i, j in zip(validation_targets, val_nb_atoms)] val_inputs = self.graph_converter.get_flat_data(validation_graphs, validation_targets) val_generator = self._create_generator(*val_inputs, batch_size=batch_size) steps_per_val = int(np.ceil(len(validation_graphs) / batch_size)) if automatic_correction: callbacks.extend([ReduceLRUponNan(filepath=filepath, monitor=monitor, mode=mode, factor=lr_scaling_factor, patience=patience, )]) if save_checkpoint: callbacks.extend([ModelCheckpointMAE(filepath=filepath, monitor=monitor, mode=mode, save_best_only=True, save_weights_only=False, val_gen=val_generator, steps_per_val=steps_per_val, target_scaler=self.target_scaler)]) # avoid running validation twice in an epoch val_generator = None steps_per_val = None else: val_generator = None steps_per_val = None train_inputs = self.graph_converter.get_flat_data(train_graphs, train_targets) # check dimension match self.check_dimension(train_graphs[0]) train_generator = self._create_generator(*train_inputs, batch_size=batch_size) steps_per_train = int(np.ceil(len(train_graphs) / batch_size)) self.fit(train_generator, steps_per_epoch=steps_per_train, validation_data=val_generator, validation_steps=steps_per_val, epochs=epochs, verbose=verbose, callbacks=callbacks, **kwargs)
[docs] def check_dimension(self, graph: Dict) -> bool: """ Check the model dimension against the graph converter dimension Args: graph: structure graph Returns: """ test_inp = self.graph_converter.graph_to_input(graph) input_shapes = [i.shape for i in test_inp] model_input_shapes = [int_shape(i) for i in self.model.inputs] def _check_match(real_shape, tensor_shape): if len(real_shape) != len(tensor_shape): return False matched = True for i, j in zip(real_shape, tensor_shape): if j is None: continue else: if i == j: continue else: matched = False return matched for i, j, k in zip(['atom features', 'bond features', 'state features'], input_shapes[:3], model_input_shapes[:3]): matched = _check_match(j, k) if not matched: raise ValueError("The data dimension for %s is %s and does not match model " "required shape of %s" % (i, str(j), str(k)))
[docs] def get_all_graphs_targets(self, structures: List[Structure], targets: List[float], scrub_failed_structures: bool = False) -> tuple: """ Compute the graphs from structures and spit out (graphs, targets) with options to automatically remove structures with failed graph computations Args: structures: (list) pymatgen structure list targets: (list) target property list scrub_failed_structures: (bool) whether to scrub those failed structures Returns: graphs, targets """ graphs_valid = [] targets_valid = [] for i, (s, t) in enumerate(zip(structures, targets)): try: graph = self.graph_converter.convert(s) graphs_valid.append(graph) targets_valid.append(t) except Exception as e: if scrub_failed_structures: warn("structure with index %d failed the graph computations" % i, UserWarning) continue else: raise RuntimeError(str(e)) return graphs_valid, targets_valid
[docs] def predict_structure(self, structure: Structure) -> np.ndarray: """ Predict property from structure Args: structure: pymatgen structure or molecule Returns: predicted target value """ graph = self.graph_converter.convert(structure) return self.predict_graph(graph)
[docs] def predict_graph(self, graph: Dict) -> np.ndarray: """ Predict property from graph Args: graph: a graph dictionary, see megnet.data.graph Returns: predicted target value """ inp = self.graph_converter.graph_to_input(graph) return self.target_scaler.inverse_transform(self.predict(inp).ravel(), len(graph['atom']))
def _create_generator(self, *args, **kwargs) -> \ Union[GraphBatchDistanceConvert, GraphBatchGenerator]: if hasattr(self.graph_converter, 'bond_converter'): kwargs.update({'distance_converter': self.graph_converter.bond_converter}) return GraphBatchDistanceConvert(*args, **kwargs) else: return GraphBatchGenerator(*args, **kwargs)
[docs] def save_model(self, filename: str) -> None: """ Save the model to a keras model hdf5 and a json config for additional converters Args: filename: (str) output file name Returns: None """ self.model.save(filename) dumpfn( { 'graph_converter': self.graph_converter, 'target_scaler': self.target_scaler, 'metadata': self.metadata }, filename + '.json' )
[docs] @classmethod def from_file(cls, filename: str) -> 'GraphModel': """ Class method to load model from filename for keras model filename.json for additional converters Args: filename: (str) model file name Returns GraphModel """ configs = loadfn(filename + '.json') from tensorflow.keras.models import load_model from megnet.layers import _CUSTOM_OBJECTS model = load_model(filename, custom_objects=_CUSTOM_OBJECTS) configs.update({'model': model}) return GraphModel(**configs)