Feature request: Bidirectional for RNN layer. #687

Open NeroBlackstone opened 4 weeks ago

NeroBlackstone commented 4 weeks ago
Bidirectional(Recurrence(LSTMCell(in_dims => hidden_dims)))
# or
Recurrence(LSTMCell(in_dims => hidden_dims), bidirectional=true)

Which one is better and easier to implement in Lux.jl? I'm willing to try to implement it and open a pr, but may need some guidance.

avik-pal commented 4 weeks ago

I think Bidirectional RNNs can be implemented as:

   <fuse_op>, # User provided
   Recurrence(<rnn cell>; return_sequence=true),
       ReverseSequence(), # Needs to be implemented
       Recurrence(<rnn cell>; return_sequence=true)

It would be nice to have a Bidirectional(<fuse_op>, <rnn cell>) constructor which does the above.

I will have to check what other frameworks do but I am not sure if the two direction parameters are shared between the layers.

I'm willing to try to implement it and open a pr, but may need some guidance.

For sure!

NeroBlackstone commented 3 weeks ago
   <fuse_op>, # User provided
   Recurrence(<rnn cell>; return_sequence=true),
       ReverseSequence(), # Needs to be implemented
       Recurrence(<rnn cell>; return_sequence=true)

Hi! I would like to ask what role does <fuse_op> play here. If it is for concatenating hidden states, then it should be implemented by Lux rather than provided by the user?

NeroBlackstone commented 3 weeks ago

So I think ReverseSequence() is a Helper Layer that can reverse the specified dimension of the input array?

So it should like:

x = [1 2; 3 4]

model = ReverseSequence(1)
y,st_new = model(x,ps,st)
# y =
 [3, 4]
 [1, 2]

model = ReverseSequence(2)
y,st_new = model(x,ps,st)
# y =
 [2, 4]
 [1, 3]

Is my understanding correct? Are there any suggestions for high-performance implementation?

NeroBlackstone commented 3 weeks ago

I think ReverseSequence should be added to the output of the backward RNN layer, so that the output of the first time step of the forward RNN corresponds to the output of the last time step of the backward RNN:

   <fuse_op>, # User provided
   Recurrence(<rnn cell>; return_sequence=true),
       ReverseSequence(), # Needs to be implemented
       Recurrence(<rnn cell>; return_sequence=true),

And what <fuse_op> should we use?... The output size of RNN is seq_len * (hidden_state_len * batch_size), <fuse_op> should actually be a broadcast vcat operation, but Parallel's connection does not support broadcast.

avik-pal commented 3 weeks ago

Tensorflow allows you to choose fuse_ops that's why I want it to be user choice. For broadcasting just wrap the user input (default to vcat) with BroadcastFunciton(<op>). BroadcastFunction is in base.

NeroBlackstone commented 2 weeks ago

keras provides some default implementations ({"sum", "mul", "concat", "ave", None}), should we provide them too?

Or provide a default implementation of concat and a custom <fuse_op> option

NeroBlackstone commented 2 weeks ago

I can't find Base.BroadcastFunction in Julia documentation, could you please give me a detialed code about how to use broadcast function in Parallel Layer? I still think this is impossible

NeroBlackstone commented 2 weeks ago

Ok.. I figure out..

bvcat(a,b) = vcat.(a,b)
model = Parallel(bvcat,
    Recurrence(GRUCell(3=>2),return_sequence = true),
    Recurrence(GRUCell(3=>2),return_sequence = true)
NeroBlackstone commented 2 weeks ago

I have roughly figured out how to build a bidirectional RNN using Parallel layers. I will start implementing it once #698 is merged.

avik-pal commented 2 weeks ago
help?> Broadcast.BroadcastFunction
  BroadcastFunction{F} <: Function

  Represents the "dotted" version of an operator, which broadcasts the operator over its arguments, so BroadcastFunction(op) is functionally equivalent to (x...) -> (op).(x...).

  Can be created by just passing an operator preceded by a dot to a higher-order function.


  julia> a = [[1 3; 2 4], [5 7; 6 8]];

  julia> b = [[9 11; 10 12], [13 15; 14 16]];

  julia> map(.*, a, b)
  2-element Vector{Matrix{Int64}}:
   [9 33; 20 48]
   [65 105; 84 128]

  julia> Base.BroadcastFunction(+)(a, b) == a .+ b

  │ Julia 1.6
  │  BroadcastFunction and the standalone .op syntax are available as of Julia 1.6.

This is basically what you wrote, just do Broadcast.BroadcastedFunction(vcat)