This is kind of a meta-issue for optimizing for a use-case I saw in Keras that is not easy to express in Axon at the moment. Consider this model:
class FeedBack(tf.keras.Model):
def __init__(self, units, out_steps):
super().__init__()
self.out_steps = out_steps
self.units = units
self.lstm_cell = tf.keras.layers.LSTMCell(units)
# Also wrap the LSTMCell in an RNN to simplify the `warmup` method.
self.lstm_rnn = tf.keras.layers.RNN(self.lstm_cell, return_state=True)
self.dense = tf.keras.layers.Dense(num_features)
def warmup(self, inputs):
# inputs.shape => (batch, time, features)
# x.shape => (batch, lstm_units)
x, *state = self.lstm_rnn(inputs)
# predictions.shape => (batch, features)
prediction = self.dense(x)
return prediction, state
def call(self, inputs, training=None):
# Use a TensorArray to capture dynamically unrolled outputs.
predictions = []
# Initialize the LSTM state.
prediction, state = self.warmup(inputs)
# Insert the first prediction.
predictions.append(prediction)
# Run the rest of the prediction steps.
for n in range(1, self.out_steps):
# Use the last prediction as input.
x = prediction
# Execute one lstm step.
x, state = self.lstm_cell(x, states=state,
training=training)
# Convert the lstm output to a prediction.
prediction = self.dense(x)
# Add the prediction to the output.
predictions.append(prediction)
# predictions.shape => (time, batch, features)
predictions = tf.stack(predictions)
# predictions.shape => (batch, time, features)
predictions = tf.transpose(predictions, [1, 0, 2])
return predictions
The Axon implementation is close to being as succinct; however, we have to declare parameters for use inside the implementation, and we duplicate params generated form Axon.lstm and Axon.dense. One option is to introduce an Axon.scan primitive that would take away the need for the custom layer at all, but I can still see it being useful to add some ability to say "I want dense parameters" or similar.
Working with RNNs in general is not great right now in Axon, so this is another issue for consideration when optimizing for RNN-based implementations
This is kind of a meta-issue for optimizing for a use-case I saw in Keras that is not easy to express in Axon at the moment. Consider this model:
This is the equivalent in Elixir:
The Axon implementation is close to being as succinct; however, we have to declare parameters for use inside the implementation, and we duplicate params generated form
Axon.lstm
andAxon.dense
. One option is to introduce anAxon.scan
primitive that would take away the need for the custom layer at all, but I can still see it being useful to add some ability to say "I want dense parameters" or similar.Working with RNNs in general is not great right now in Axon, so this is another issue for consideration when optimizing for RNN-based implementations