Skip to content

Commit

Permalink
added check for input.key() to be layer inastance
Browse files Browse the repository at this point in the history
  • Loading branch information
ReyhaneAskari committed Jul 31, 2017
1 parent 34af5fb commit eb8b3f1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
9 changes: 8 additions & 1 deletion lasagne/layers/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,14 @@ def get_output(layer_or_layers, inputs=None, **kwargs):
both, while the latter will use two different dropout masks.
"""
from .input import InputLayer
from .base import MergeLayer
from .base import MergeLayer, Layer
# check if the keys of the dictionary are valid
if isinstance(inputs, dict):
for input_key in inputs.keys():
if (input_key is not None) and (not isinstance(input_key, Layer)):
raise TypeError("The inputs dictionary keys must be"
" lasagne layers not %s." %
type(input_key))
# track accepted kwargs used by get_output_for
accepted_kwargs = {'deterministic'}
# obtain topological ordering of all layers the output layer(s) depend on
Expand Down
8 changes: 7 additions & 1 deletion lasagne/tests/layers/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,14 @@ def layer_from_shape(self):

def test_layer_from_shape_invalid_get_output(self, layer_from_shape,
get_output):
from lasagne.layers.base import Layer
layer = layer_from_shape
with pytest.raises(ValueError):
get_output(layer)
with pytest.raises(ValueError):
get_output(layer, [1, 2])
with pytest.raises(ValueError):
get_output(layer, {Mock(): [1, 2]})
get_output(layer, {Mock(spec=Layer): [1, 2]})

def test_layer_from_shape_valid_get_output(self, layer_from_shape,
get_output):
Expand Down Expand Up @@ -454,6 +455,11 @@ def test_layer_from_shape_valid_get_output(self, layer_from_shape,
layer.get_output_for.assert_called_with(
[inputs[None], layer.input_layers[1].input_var])

def test_invalid_input_key(self, layer_from_shape, get_output):
layer = layer_from_shape
with pytest.raises(TypeError):
get_output(layer, {Mock(): [1, 2]})


class TestGetOutputShape_InputLayer:
@pytest.fixture
Expand Down

0 comments on commit eb8b3f1

Please sign in to comment.