megnet.models.base module¶
Implements basic GraphModels.
-
class
GraphModel
(model: tensorflow.python.keras.engine.training.Model, graph_converter: megnet.data.graph.StructureGraph, target_scaler: megnet.utils.preprocessing.Scaler = <megnet.utils.preprocessing.DummyScaler object>, metadata: Dict = None, **kwargs)[source]¶ Bases:
object
Composition of keras model and converter class for transfering structure object to input tensors. We add methods to train the model from (structures, targets) pairs
- Parameters
model – (keras model)
graph_converter – (object) a object that turns a structure to a graph, check megnet.data.crystal
target_scaler – (object) a scaler object for converting targets, check megnet.utils.preprocessing
metadata – (dict) An optional dict of metadata associated with the model. Recommended to incorporate some basic information such as units, MAE performance, etc.
-
check_dimension
(graph: Dict) → bool[source]¶ Check the model dimension against the graph converter dimension :param graph: structure graph
Returns:
-
classmethod
from_file
(filename: str) → megnet.models.base.GraphModel[source]¶ - Class method to load model from
filename for keras model filename.json for additional converters
- Parameters
filename – (str) model file name
- Returns
GraphModel
-
get_all_graphs_targets
(structures: List[pymatgen.core.structure.Structure], targets: List[float], scrub_failed_structures: bool = False) → tuple[source]¶ Compute the graphs from structures and spit out (graphs, targets) with options to automatically remove structures with failed graph computations
- Parameters
structures – (list) pymatgen structure list
targets – (list) target property list
scrub_failed_structures – (bool) whether to scrub those failed structures
- Returns
graphs, targets
-
predict_graph
(graph: Dict) → numpy.ndarray[source]¶ Predict property from graph
- Parameters
graph – a graph dictionary, see megnet.data.graph
- Returns
predicted target value
-
predict_structure
(structure: pymatgen.core.structure.Structure) → numpy.ndarray[source]¶ Predict property from structure
- Parameters
structure – pymatgen structure or molecule
- Returns
predicted target value
-
save_model
(filename: str) → None[source]¶ Save the model to a keras model hdf5 and a json config for additional converters
- Parameters
filename – (str) output file name
- Returns
None
-
train
(train_structures: List[pymatgen.core.structure.Structure], train_targets: List[float], validation_structures: List[pymatgen.core.structure.Structure] = None, validation_targets: List[float] = None, epochs: int = 1000, batch_size: int = 128, verbose: int = 1, callbacks: List[tensorflow.python.keras.callbacks.Callback] = None, scrub_failed_structures: bool = False, prev_model: str = None, save_checkpoint: bool = True, automatic_correction: bool = True, lr_scaling_factor: float = 0.5, patience: int = 500, **kwargs) → None[source]¶ - Parameters
train_structures – (list) list of pymatgen structures
train_targets – (list) list of target values
validation_structures – (list) list of pymatgen structures as validation
validation_targets – (list) list of validation targets
epochs – (int) number of epochs
batch_size – (int) training batch size
verbose – (int) keras fit verbose, 0 no progress bar, 1 only at the epoch end and 2 every batch
callbacks – (list) megnet or keras callback functions for training
scrub_failed_structures – (bool) whether to scrub structures with failed graph computation
prev_model – (str) file name for previously saved model
save_checkpoint – (bool) whether to save checkpoint
automatic_correction – (bool) correct nan errors
lr_scaling_factor – (float, less than 1) scale the learning rate down when nan loss encountered
patience – (int) patience for early stopping
**kwargs –
-
train_from_graphs
(train_graphs: List[Dict], train_targets: List[float], validation_graphs: List[Dict] = None, validation_targets: List[float] = None, epochs: int = 1000, batch_size: int = 128, verbose: int = 1, callbacks: List[tensorflow.python.keras.callbacks.Callback] = None, prev_model: str = None, lr_scaling_factor: float = 0.5, patience: int = 500, save_checkpoint: bool = True, automatic_correction: bool = True, **kwargs) → None[source]¶ - Parameters
train_graphs – (list) list of graph dictionaries
train_targets – (list) list of target values
validation_graphs – (list) list of graphs as validation
validation_targets – (list) list of validation targets
epochs – (int) number of epochs
batch_size – (int) training batch size
verbose – (int) keras fit verbose, 0 no progress bar, 1 only at the epoch end and 2 every batch
callbacks – (list) megnet or keras callback functions for training
prev_model – (str) file name for previously saved model
lr_scaling_factor – (float, less than 1) scale the learning rate down when nan loss encountered
patience – (int) patience for early stopping
save_checkpoint – (bool) whether to save checkpoint
automatic_correction – (bool) correct nan errors
**kwargs –