Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register LMUCell with Keras #34

Open
bmorcos opened this issue Feb 19, 2021 · 0 comments
Open

Register LMUCell with Keras #34

bmorcos opened this issue Feb 19, 2021 · 0 comments

Comments

@bmorcos
Copy link

bmorcos commented Feb 19, 2021

If the LMUCell is wrapped in another layer (e.g. RNN) then it cannot be serialized since LMUCell is a custom object unknown to Keras. For example:

# Build an LMU layer
dt = 1e-3
activation = "tanh"
dropout=0.2

lmu_layer = RNN(
    keras_lmu.LMUCell(
        memory_d=10,
        order=8,
        theta=10 / dt,
        hidden_cell=Dense(1024, activation),
        hidden_to_memory=False,
        memory_to_memory=False,
        input_to_hidden=False,
        dropout=dropout,
    ),
    return_sequences=True,
)

# Test serialization
lmu_layer.from_config(
    lmu_layer.get_config(),
)

This fails with ValueError: Unknown layer: LMUCell.

The quick fix is to tell Keras about the LMUCell via custom_objects:

# Test serialization
lmu_layer.from_config(
    lmu_layer.get_config(),
    custom_objects={"LMUCell":keras_lmu.LMUCell},  # <-- This is key
)

Although this allows the LMUCell to be properly (de)serialized, this requires direct access and may be challenging if using additional scripts on top of the RNN.

It seems like there is a way to register custom objects with Keras and that may be the proper general solution, just don't have time to test that out right now!


aside
For completeness/reference, using theLMU layer (instead of the LMUCell wrapped in an RNN, for example) serializes fine:

lmu_layer_builtin = keras_lmu.LMU(
    memory_d=10,
    order=8,
    theta=10 / dt,
    hidden_cell=Dense(1024, activation),
    hidden_to_memory=False,
    memory_to_memory=False,
    input_to_hidden=False,
    dropout=dropout,
    return_sequences=True,
)
lmu_layer_builtin.from_config(
    lmu_layer_builtin.get_config(),
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

No branches or pull requests

1 participant