nengo / keras-lmu

Keras implementation of Legendre Memory Units
https://www.nengo.ai/keras-lmu/
Other
207 stars 35 forks source link

Register LMUCell with Keras #34

Open bmorcos opened 3 years ago

bmorcos commented 3 years ago

If the LMUCell is wrapped in another layer (e.g. RNN) then it cannot be serialized since LMUCell is a custom object unknown to Keras. For example:

# Build an LMU layer
dt = 1e-3
activation = "tanh"
dropout=0.2

lmu_layer = RNN(
    keras_lmu.LMUCell(
        memory_d=10,
        order=8,
        theta=10 / dt,
        hidden_cell=Dense(1024, activation),
        hidden_to_memory=False,
        memory_to_memory=False,
        input_to_hidden=False,
        dropout=dropout,
    ),
    return_sequences=True,
)

# Test serialization
lmu_layer.from_config(
    lmu_layer.get_config(),
)

This fails with ValueError: Unknown layer: LMUCell.

The quick fix is to tell Keras about the LMUCell via custom_objects:

# Test serialization
lmu_layer.from_config(
    lmu_layer.get_config(),
    custom_objects={"LMUCell":keras_lmu.LMUCell},  # <-- This is key
)

Although this allows the LMUCell to be properly (de)serialized, this requires direct access and may be challenging if using additional scripts on top of the RNN.

It seems like there is a way to register custom objects with Keras and that may be the proper general solution, just don't have time to test that out right now!


aside For completeness/reference, using theLMU layer (instead of the LMUCell wrapped in an RNN, for example) serializes fine:

lmu_layer_builtin = keras_lmu.LMU(
    memory_d=10,
    order=8,
    theta=10 / dt,
    hidden_cell=Dense(1024, activation),
    hidden_to_memory=False,
    memory_to_memory=False,
    input_to_hidden=False,
    dropout=dropout,
    return_sequences=True,
)
lmu_layer_builtin.from_config(
    lmu_layer_builtin.get_config(),
)