megnet.models package

Module contents

Models package, this package contains various graph-based models

class GraphModel(model: tensorflow.python.keras.engine.training.Model, graph_converter: megnet.data.graph.StructureGraph, target_scaler: megnet.utils.preprocessing.Scaler = <megnet.utils.preprocessing.DummyScaler object>, metadata: Dict = None, **kwargs)[source]

Bases: object

Composition of keras model and converter class for transfering structure object to input tensors. We add methods to train the model from (structures, targets) pairs

Parameters
  • model – (keras model)

  • graph_converter – (object) a object that turns a structure to a graph, check megnet.data.crystal

  • target_scaler – (object) a scaler object for converting targets, check megnet.utils.preprocessing

  • metadata – (dict) An optional dict of metadata associated with the model. Recommended to incorporate some basic information such as units, MAE performance, etc.

check_dimension(graph: Dict) → bool[source]

Check the model dimension against the graph converter dimension :param graph: structure graph

Returns:

classmethod from_file(filename: str)megnet.models.base.GraphModel[source]
Class method to load model from

filename for keras model filename.json for additional converters

Parameters

filename – (str) model file name

Returns

GraphModel

get_all_graphs_targets(structures: List[pymatgen.core.structure.Structure], targets: List[float], scrub_failed_structures: bool = False) → tuple[source]

Compute the graphs from structures and spit out (graphs, targets) with options to automatically remove structures with failed graph computations

Parameters
  • structures – (list) pymatgen structure list

  • targets – (list) target property list

  • scrub_failed_structures – (bool) whether to scrub those failed structures

Returns

graphs, targets

predict_graph(graph: Dict) → numpy.ndarray[source]

Predict property from graph

Parameters

graph – a graph dictionary, see megnet.data.graph

Returns

predicted target value

predict_structure(structure: pymatgen.core.structure.Structure) → numpy.ndarray[source]

Predict property from structure

Parameters

structure – pymatgen structure or molecule

Returns

predicted target value

save_model(filename: str) → None[source]

Save the model to a keras model hdf5 and a json config for additional converters

Parameters

filename – (str) output file name

Returns

None

train(train_structures: List[pymatgen.core.structure.Structure], train_targets: List[float], validation_structures: List[pymatgen.core.structure.Structure] = None, validation_targets: List[float] = None, epochs: int = 1000, batch_size: int = 128, verbose: int = 1, callbacks: List[tensorflow.python.keras.callbacks.Callback] = None, scrub_failed_structures: bool = False, prev_model: str = None, save_checkpoint: bool = True, automatic_correction: bool = True, lr_scaling_factor: float = 0.5, patience: int = 500, **kwargs) → None[source]
Parameters
  • train_structures – (list) list of pymatgen structures

  • train_targets – (list) list of target values

  • validation_structures – (list) list of pymatgen structures as validation

  • validation_targets – (list) list of validation targets

  • epochs – (int) number of epochs

  • batch_size – (int) training batch size

  • verbose – (int) keras fit verbose, 0 no progress bar, 1 only at the epoch end and 2 every batch

  • callbacks – (list) megnet or keras callback functions for training

  • scrub_failed_structures – (bool) whether to scrub structures with failed graph computation

  • prev_model – (str) file name for previously saved model

  • save_checkpoint – (bool) whether to save checkpoint

  • automatic_correction – (bool) correct nan errors

  • lr_scaling_factor – (float, less than 1) scale the learning rate down when nan loss encountered

  • patience – (int) patience for early stopping

  • **kwargs

train_from_graphs(train_graphs: List[Dict], train_targets: List[float], validation_graphs: List[Dict] = None, validation_targets: List[float] = None, epochs: int = 1000, batch_size: int = 128, verbose: int = 1, callbacks: List[tensorflow.python.keras.callbacks.Callback] = None, prev_model: str = None, lr_scaling_factor: float = 0.5, patience: int = 500, save_checkpoint: bool = True, automatic_correction: bool = True, **kwargs) → None[source]
Parameters
  • train_graphs – (list) list of graph dictionaries

  • train_targets – (list) list of target values

  • validation_graphs – (list) list of graphs as validation

  • validation_targets – (list) list of validation targets

  • epochs – (int) number of epochs

  • batch_size – (int) training batch size

  • verbose – (int) keras fit verbose, 0 no progress bar, 1 only at the epoch end and 2 every batch

  • callbacks – (list) megnet or keras callback functions for training

  • prev_model – (str) file name for previously saved model

  • lr_scaling_factor – (float, less than 1) scale the learning rate down when nan loss encountered

  • patience – (int) patience for early stopping

  • save_checkpoint – (bool) whether to save checkpoint

  • automatic_correction – (bool) correct nan errors

  • **kwargs

class MEGNetModel(nfeat_edge: int = None, nfeat_global: int = None, nfeat_node: int = None, nblocks: int = 3, lr: float = 0.001, n1: int = 64, n2: int = 32, n3: int = 16, nvocal: int = 95, embedding_dim: int = 16, nbvocal: int = None, bond_embedding_dim: int = None, ngvocal: int = None, global_embedding_dim: int = None, npass: int = 3, ntarget: int = 1, act: Callable = <function softplus2>, is_classification: bool = False, loss: str = 'mse', metrics: List[str] = None, l2_coef: float = None, dropout: float = None, graph_converter: megnet.data.graph.StructureGraph = None, target_scaler: megnet.utils.preprocessing.Scaler = <megnet.utils.preprocessing.DummyScaler object>, optimizer_kwargs: Dict = None, dropout_on_predict: bool = False)[source]

Bases: megnet.models.base.GraphModel

Construct a graph network model with or without explicit atom features if n_feature is specified then a general graph model is assumed, otherwise a crystal graph model with z number as atom feature is assumed.

Parameters
  • nfeat_edge – (int) number of bond features

  • nfeat_global – (int) number of state features

  • nfeat_node – (int) number of atom features

  • nblocks – (int) number of MEGNetLayer blocks

  • lr – (float) learning rate

  • n1 – (int) number of hidden units in layer 1 in MEGNetLayer

  • n2 – (int) number of hidden units in layer 2 in MEGNetLayer

  • n3 – (int) number of hidden units in layer 3 in MEGNetLayer

  • nvocal – (int) number of total element

  • embedding_dim – (int) number of embedding dimension

  • nbvocal – (int) number of bond types if bond attributes are types

  • bond_embedding_dim – (int) number of bond embedding dimension

  • ngvocal – (int) number of global types if global attributes are types

  • global_embedding_dim – (int) number of global embedding dimension

  • npass – (int) number of recurrent steps in Set2Set layer

  • ntarget – (int) number of output targets

  • act – (object) activation function

  • l2_coef – (float or None) l2 regularization parameter

  • is_classification – (bool) whether it is a classification task

  • loss – (object or str) loss function

  • metrics – (list or dict) List or dictionary of Keras metrics to be evaluated by the model during training and testing

  • dropout – (float) dropout rate

  • graph_converter – (object) object that exposes a “convert” method for structure to graph conversion

  • target_scaler – (object) object that exposes a “transform” and “inverse_transform” methods for transforming the target values

  • optimizer_kwargs (dict) – extra keywords for optimizer, for example clipnorm and clipvalue

classmethod from_mvl_models(name: str)megnet.models.megnet.MEGNetModel[source]
classmethod from_url(url: str)megnet.models.megnet.MEGNetModel[source]

Download and load a model from a URL. E.g. https://github.com/materialsvirtuallab/megnet/blob/master/mvl_models/mp-2019.4.1/formation_energy.hdf5

Parameters

url – (str) url link of the model

Returns

GraphModel