google / objax

Apache License 2.0
769 stars 77 forks source link

RNN redesign #211

Open aterzis-google opened 3 years ago

AlexeyKurakin commented 3 years ago

Requriements

While choosing design I would recommend conditions to be satisfied:

  1. User can easily customize RNN cell or add new types of RNN cells, without need to think how to write efficient RNN code.
  2. Our design allows to represent relatively complex RNN architecture (let's say deep LSTM)

Discussion of proposed designs

Three possible designs of RNN are proposed here:

FactorizedRNN pushes most of the work on RNN cell and requires RNN cell to essentially implement for-loop or scan, thus I would say does not really satisfy requirement 1.

I would say both RNN and `VectorizedRNN satisfy these requirements (simple implementation of cell and can implement more complex RNNs). The question is - whether we want to adopt one which accepts batch or not.

I read @ebrevdo comments in other PRs and I would agree that RNN Cell should operate on single example (instead of a batch). Now I'm not entirely sure yet whether RNN as a whole should be operating on batches or single examples. I had impression that our layers always operate on batch of examples in this case RNN also should operate on a batch. But I might be wrong.

Other details

get_initial_state

After reading some examples of RNN code I think get_initial_state actually is a good idea. However I would use name create_initial_state because it's more descriptive.

One way or another initial_state has to be created somewhere and putting this code into the definition of the cell is probably the most clear and least error prone approach.

For comparison TensorFlow expects that RNN cell has get_initial_state method which is used by RNN implementation to create initial state. For Pytorch approach check out example of RNN in https://d2l.ai/chapter_recurrent-neural-networks/rnn-concise.html (see PyTorch code in section 8.6.1), which seems to me essentially pushes creation of initial state into user code.

output_layer

I would probably agree with David and Eugene's comments in other pull request that we should drop output_layer and user can decide whether they want to just use RNN state or do any additional processing on top of it.