Source code for megnet.layers.readout.linear
from tensorflow.keras.layers import Layer
import tensorflow as tf
[docs]class LinearWithIndex(Layer):
"""
Sum or average the node/edge attributes to get a structure-level vector
Args:
mode: (str) 'mean' or 'sum'
"""
def __init__(self, mode='mean', **kwargs):
super(LinearWithIndex, self).__init__(**kwargs)
self.mode = mode
if self.mode == 'mean':
self.reduce_method = tf.math.segment_mean
elif self.mode == 'sum':
self.reduce_method = tf.math.segment_sum
else:
raise ValueError('Only sum and mean are supported at the moment!')
[docs] def build(self, input_shape):
self.built = True
[docs] def call(self, inputs, mask=None):
prop, index = inputs
index = tf.reshape(index, (-1,))
prop = tf.transpose(a=prop, perm=[1, 0, 2])
out = self.reduce_method(prop, index)
out = tf.transpose(a=out, perm=[1, 0, 2])
return out
[docs] def compute_output_shape(self, input_shape):
prop_shape = input_shape[0]
return prop_shape[0], None, prop_shape[-1]
[docs] def get_config(self):
config = {'mode': self.mode}
base_config = super(LinearWithIndex, self).get_config()
return dict(list(base_config.items()) + list(config.items()))