megnet.utils.layer module

Tensorflow layer utilities

repeat_with_index(x: tensorflow.python.framework.ops.Tensor, index: tensorflow.python.framework.ops.Tensor, axis: int = 1)[source]

Given an tensor x (N*M*K), repeat the middle axis (axis=1) according to the index tensor index (G, ) for example, if axis=1 and n = Tensor([0, 0, 0, 1, 2, 2]) then M = 3 (3 unique values), and the final tensor would have the shape (N*6*3) with the first one in M repeated 3 times, second 1 time and third 2 times.

Args:

x: (3d Tensor) tensor to be augmented index: (1d Tensor) repetition tensor axis: (int) axis for repetition

Returns

(3d Tensor) tensor after repetition