-
Notifications
You must be signed in to change notification settings - Fork 42
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge remote-tracking branch 'origin/streamlined-backend' into stream…
…lined-backend
- Loading branch information
Showing
37 changed files
with
547 additions
and
179 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
|
||
from .lstnet import LSTNet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
|
||
import keras | ||
from keras import layers, Sequential | ||
from keras.saving import register_keras_serializable | ||
|
||
from bayesflow.experimental.types import Tensor | ||
from bayesflow.experimental.utils import keras_kwargs | ||
|
||
from .skip_recurrent import SkipRecurrentNet | ||
from ...networks import MLP | ||
|
||
|
||
@register_keras_serializable(package="bayesflow.networks.lstnet") | ||
class LSTNet(keras.Model): | ||
""" | ||
Implements a LSTNet Architecture as described in [1] | ||
[1] Y. Zhang and L. Mikelsons, Solving Stochastic Inverse Problems with Stochastic BayesFlow, | ||
2023 IEEE/ASME International Conference on Advanced Intelligent Mechatronics (AIM), | ||
Seattle, WA, USA, 2023, pp. 966-972, doi: 10.1109/AIM46323.2023.10196190. | ||
TODO: Add proper docstring | ||
""" | ||
|
||
def __init__( | ||
self, | ||
summary_dim: int = 16, | ||
filters: int | list | tuple = 32, | ||
kernel_sizes: int | list | tuple = 3, | ||
strides: int | list | tuple = 1, | ||
activation: str = "relu", | ||
kernel_initializer: str = "glorot_uniform", | ||
groups: int = 8, | ||
recurrent_type: str | keras.Layer = "gru", | ||
recurrent_dim: int = 128, | ||
bidirectional: bool = True, | ||
dropout: float = 0.05, | ||
skip_steps: int = 4, | ||
**kwargs | ||
): | ||
|
||
super().__init__(**keras_kwargs(kwargs)) | ||
|
||
# Convolutional backbone -> can be extended with inception-like structure | ||
if not isinstance(filters, (list, tuple)): | ||
filters = (filters, ) | ||
if not isinstance(kernel_sizes, (list, tuple)): | ||
kernel_sizes = (kernel_sizes, ) | ||
if not isinstance(strides, (list, tuple)): | ||
strides = (strides, ) | ||
self.conv = Sequential() | ||
for f, k, s in zip(filters, kernel_sizes, strides): | ||
self.conv.add( | ||
layers.Conv1D( | ||
filters=f, | ||
kernel_size=k, | ||
strides=s, | ||
activation=activation, | ||
kernel_initializer=kernel_initializer, | ||
) | ||
) | ||
self.conv.add( | ||
layers.GroupNormalization(groups=groups) | ||
) | ||
|
||
# Recurrent and feedforward backbones | ||
self.recurrent = SkipRecurrentNet( | ||
hidden_dim=recurrent_dim, | ||
recurrent_type=recurrent_type, | ||
bidirectional=bidirectional, | ||
input_channels=filters[-1], | ||
skip_steps=skip_steps, | ||
dropout=dropout | ||
) | ||
self.feedforward = MLP(**kwargs.get("mlp_kwargs", {})) | ||
|
||
self.output_projector = layers.Dense(summary_dim) | ||
|
||
def call(self, time_series: Tensor, **kwargs) -> Tensor: | ||
summary = self.conv(time_series, **kwargs) | ||
summary = self.recurrent(summary, **kwargs) | ||
summary = self.feedforward(summary, **kwargs) | ||
summary = self.output_projector(summary) | ||
return summary | ||
|
||
def build(self, input_shape): | ||
self.call(keras.ops.zeros(input_shape)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
|
||
import keras | ||
from keras.saving import register_keras_serializable | ||
|
||
from bayesflow.experimental.types import Tensor | ||
from bayesflow.experimental.utils import keras_kwargs, find_recurrent_net | ||
|
||
@register_keras_serializable(package="bayesflow.networks") | ||
class SkipRecurrentNet(keras.Model): | ||
""" | ||
Implements a Skip recurrent layer as described in [1], but allowing a more flexible | ||
recurrent backbone and a more flexible implementation. | ||
[1] Y. Zhang and L. Mikelsons, Solving Stochastic Inverse Problems with Stochastic BayesFlow, | ||
2023 IEEE/ASME International Conference on Advanced Intelligent Mechatronics (AIM), | ||
Seattle, WA, USA, 2023, pp. 966-972, doi: 10.1109/AIM46323.2023.10196190. | ||
TODO: Add proper docstring | ||
""" | ||
def __init__( | ||
self, | ||
hidden_dim: int = 256, | ||
recurrent_type: str | keras.Layer = "gru", | ||
bidirectional: bool = True, | ||
input_channels: int = 64, | ||
skip_steps: int = 4, | ||
dropout: float = 0.05, | ||
**kwargs | ||
): | ||
super().__init__(**keras_kwargs(kwargs)) | ||
|
||
recurrent_constructor = find_recurrent_net(recurrent_type) | ||
|
||
self.recurrent = recurrent_constructor( | ||
units=hidden_dim // 2 if bidirectional else hidden_dim, | ||
dropout=dropout, | ||
recurrent_dropout=dropout | ||
) | ||
self.skip_conv = keras.layers.Conv1D( | ||
filters=input_channels*skip_steps, | ||
kernel_size=skip_steps, | ||
strides=skip_steps | ||
) | ||
self.skip_recurrent = recurrent_constructor( | ||
units=hidden_dim // 2 if bidirectional else hidden_dim, | ||
dropout=dropout, | ||
recurrent_dropout=dropout | ||
) | ||
self.input_channels = input_channels | ||
|
||
def call(self, time_series: Tensor, **kwargs) -> Tensor: | ||
direct_summary = self.recurrent(time_series, **kwargs) | ||
skip_summary = self.skip_recurrent(self.skip_conv(time_series), **kwargs) | ||
return keras.ops.concatenate((direct_summary, skip_summary), axis=-1) | ||
|
||
def build(self, input_shape): | ||
self.call(keras.ops.zeros(input_shape)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,2 @@ | ||
|
||
|
||
class SetTransformer: | ||
pass | ||
from .set_transformer import SetTransformer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.