elixir-nx / axon

Nx-powered Neural Networks
Apache License 2.0
1.54k stars 102 forks source link

Add mechanism for a custom layer "inheriting" parameters from built-in layers #459

Closed seanmor5 closed 11 months ago

seanmor5 commented 1 year ago

This is kind of a meta-issue for optimizing for a use-case I saw in Keras that is not easy to express in Axon at the moment. Consider this model:

class FeedBack(tf.keras.Model):
  def __init__(self, units, out_steps):
    super().__init__()
    self.out_steps = out_steps
    self.units = units
    self.lstm_cell = tf.keras.layers.LSTMCell(units)
    # Also wrap the LSTMCell in an RNN to simplify the `warmup` method.
    self.lstm_rnn = tf.keras.layers.RNN(self.lstm_cell, return_state=True)
    self.dense = tf.keras.layers.Dense(num_features)

  def warmup(self, inputs):
    # inputs.shape => (batch, time, features)
    # x.shape => (batch, lstm_units)
    x, *state = self.lstm_rnn(inputs)

    # predictions.shape => (batch, features)
    prediction = self.dense(x)
    return prediction, state

  def call(self, inputs, training=None):
    # Use a TensorArray to capture dynamically unrolled outputs.
    predictions = []
    # Initialize the LSTM state.
    prediction, state = self.warmup(inputs)

    # Insert the first prediction.
    predictions.append(prediction)

    # Run the rest of the prediction steps.
    for n in range(1, self.out_steps):
      # Use the last prediction as input.
      x = prediction
      # Execute one lstm step.
      x, state = self.lstm_cell(x, states=state,
                              training=training)
      # Convert the lstm output to a prediction.
      prediction = self.dense(x)
      # Add the prediction to the output.
      predictions.append(prediction)

    # predictions.shape => (time, batch, features)
    predictions = tf.stack(predictions)
    # predictions.shape => (batch, time, features)
    predictions = tf.transpose(predictions, [1, 0, 2])
    return predictions

This is the equivalent in Elixir:

defmodule Autoregressive do
  import Nx.Defn

  def feedback_layer(input, units, num_features, out_steps) do
    kernel_shape = fn _, _ -> {units, 1} end
    bias_shape = fn _, _ -> {1} end

    input_kernel_shape = fn inp, _ -> Axon.Shape.rnn_input_kernel(inp, units, :lstm) end
    hidden_kernel_shape = fn inp, _ -> Axon.Shape.rnn_hidden_kernel(inp, units, :lstm) end
    lstm_bias_shape = fn inp, _ -> Axon.Shape.rnn_bias(inp, units, :lstm) end

    dense_kernel = Axon.param("dense_kernel", kernel_shape, initializer: :glorot_uniform)
    dense_bias = Axon.param("dense_bias", bias_shape, initializer: :zeros)
    lstm_input_kernel =
      Axon.param("input_kernel", {:tuple, List.duplicate(input_kernel_shape, 4)},
        initializer: :glorot_uniform
      )

    lstm_hidden_kernel =
      Axon.param("hidden_kernel", {:tuple, List.duplicate(hidden_kernel_shape, 4)},
        initializer: :glorot_uniform
      )

    lstm_bias =
      Axon.param("bias", {:tuple, List.duplicate(lstm_bias_shape, 4)}, initializer: :zeros)

    {prediction, state} = Axon.lstm(input, units, name: "feedback")
    prediction = Axon.layer(
      fn inp, _state, kernel, bias, _opts ->
        last_seq = inp[[0..-1//1, -1, 0..-1//1]] |> Nx.new_axis(1)
        Axon.Layers.dense(last_seq, kernel, bias)
      end, 
      [prediction, Axon.container(state), dense_kernel, dense_bias],
      name: "feedback"
    )

    Axon.layer(&feedback_impl/8, [
      prediction,
      Axon.container(state),
      dense_kernel,
      dense_bias,
      lstm_input_kernel,
      lstm_hidden_kernel,
      lstm_bias
    ], name: "feedback", output_length: out_steps)
  end

  defnp feedback_impl(
    input,
    state,
    dense_kernel,
    dense_bias,
    lstm_input_kernel,
    lstm_hidden_kernel,
    lstm_bias,
    opts \\ []
  ) do
    opts = keyword!(opts, [:output_length, mode: :train])

    batch_size = Nx.axis_size(input, 0)
    seq_length = opts[:output_length]

    predictions = Nx.broadcast(0.0, {batch_size, seq_length, 1})

    acc =
      while {input, state, predictions, dense_kernel, dense_bias, lstm_input_kernel, lstm_hidden_kernel, lstm_bias}, index <- Nx.iota({seq_length}) do
        {seq, state} = Axon.Layers.lstm_cell(
          input, state, lstm_input_kernel, lstm_hidden_kernel, lstm_bias
        )
        last_seq = seq[[0..-1//1, -1, 0..-1//1]] |> Nx.new_axis(1)
        prediction = Axon.Layers.dense(last_seq, dense_kernel, dense_bias)

        {prediction, state, Nx.put_slice(predictions, [0, index, 0], prediction), dense_kernel, dense_bias, lstm_input_kernel, lstm_hidden_kernel, lstm_bias}
      end

    elem(acc, 2)
  end
end

The Axon implementation is close to being as succinct; however, we have to declare parameters for use inside the implementation, and we duplicate params generated form Axon.lstm and Axon.dense. One option is to introduce an Axon.scan primitive that would take away the need for the custom layer at all, but I can still see it being useful to add some ability to say "I want dense parameters" or similar.

Working with RNNs in general is not great right now in Axon, so this is another issue for consideration when optimizing for RNN-based implementations

seanmor5 commented 11 months ago

This is possible with blocks now