Skip to content

Commit

Permalink
Get rid of Sequential
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed Jun 28, 2024
1 parent 56b8ca8 commit 80695e1
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
2 changes: 1 addition & 1 deletion bayesflow/networks/mlp/hidden_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
self.dense = layers.SpectralNormalization(self.dense)
self.dropout = keras.layers.Dropout(dropout)

def call(self, inputs: Tensor, training=False):
def call(self, inputs: Tensor, training=False) -> Tensor:
x = self.dense(inputs, training=training)
x = self.dropout(x, training=training)
if self.residual:
Expand Down
16 changes: 11 additions & 5 deletions bayesflow/networks/mlp/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,18 @@ def __init__(

super().__init__(**keras_kwargs(kwargs))

self.res_blocks = keras.Sequential()
self.res_blocks = []
projector = layers.Dense(
units=width,
kernel_initializer=kernel_initializer,
)
if spectral_normalization:
projector = layers.SpectralNormalization(projector)
self.res_blocks.add(projector)
self.res_blocks.add(layers.Dropout(dropout))
self.res_blocks.append(projector)
self.res_blocks.append(layers.Dropout(dropout))

for _ in range(depth):
self.res_blocks.add(
self.res_blocks.append(
ConfigurableHiddenBlock(
units=width,
activation=activation,
Expand All @@ -77,4 +77,10 @@ def build(self, input_shape):
self.call(keras.ops.zeros(input_shape))

def call(self, inputs: Tensor, **kwargs) -> Tensor:
return self.res_blocks(inputs, training=kwargs.get("training", False))
for layer in self.res_blocks:
_kwargs = {}
if layer._call_has_training_arg:
_kwargs["training"] = kwargs.get("training", False)
outputs = layer(inputs, **kwargs)
inputs = outputs
return outputs

0 comments on commit 80695e1

Please sign in to comment.