Closed rana closed 4 years ago
If you look closely at the Example B, both forward and backward RNNs are invoked via broadcasting (the backward broadcasting is hidden in flip
). The underlying convention for RNNs in Flux is that the data should come in as a vector of vectors, where each inner vector is of size in
, and the outer vector represents the sequence. It follows, that the dense
layer and vcat
also have to be applied element-wise (via broadcasting), and the data needs to be transformed accordingly.
@tanhevg Thanks for your thoughtful comments. You helped me move this forward.
Modifying (m::BRNN)(xs)
to
function (m::BRNN)(xs)
@info "xs", xs
fwd = m.forward(xs)
rev = reverse(xs)
@info "reverse(xs)", rev
bak = m.backward(rev)
@info "m.backward(reverse(xs))", bak
revBak = reverse(bak)
@info "reverse(m.backward(reverse(xs)))", revBak
fwdBak = vcat(fwd, revBak)
@info "vcat(m.forward(xs), reverse(m.backward(reverse(xs))))", fwdBak
ŷ = m.dense(fwdBak)
@info "m.dense(vcat(m.forward(xs), reverse(m.backward(reverse(xs)))))", ŷ
ŷ
end
outputs
[ Info: ("xs", [0.4024241926911283, 0.0031736700674680485, 0.5874856115265599, 0.6248590811960677, 0.6107094805087905])
[ Info: ("reverse(xs)", [0.6107094805087905, 0.6248590811960677, 0.5874856115265599, 0.0031736700674680485, 0.4024241926911283])
[ Info: ("m.backward(reverse(xs))", [0.0, 0.0, 0.0])
[ Info: ("reverse(m.backward(reverse(xs)))", [0.0, 0.0, 0.0])
[ Info: ("vcat(m.forward(xs), reverse(m.backward(reverse(xs))))", [0.920278546440162, 0.8759832558156826, 1.3535489921199844, 0.0, 0.0, 0.0])
[ Info: ("m.dense(vcat(m.forward(xs), reverse(m.backward(reverse(xs)))))", Float32[0.0])
making the data shapes readable.
The data is in the correct shape.
The change in code produces a different error
ERROR: LoadError: Compiling Tuple{BRNN{Flux.Recur{Flux.RNNCell{typeof(relu),Array{Float32,2},Array{Float32,1}}},Dense{typeof(relu),Array{Float32,2},Array{Float32,1}}},Array{Float64,1}}: try/catch is not supported.
From what I can tell this indicates a defect in Zygote.
Maybe the multi input layer, parallel layer, Add Parallel and Bi layer pull requests would address?
In the snippet both forward
and backward
are invoked on a single data point, which is a vector of 5 elements. This does not really make sense for RNNs; their recursive feature comes into play when they are invoked multiple times between resets
. So either change the dimension of your RNNs to 1 and broadcast them on the existing data, or change the data to be a vector of vectors, as I outlined above. Either way, broadcasting is needed for the inner state of RNN to accumulate, or some other way to repeatedly invoke it, like a for
loop. Broadcasting BRNN
will not do, because forward
and backward
must run in opposite directions.
I am not sure why there are zeros in the output; could be because of the relu
activation; could try it with identity
and see if it works.
Re this error:
ERROR: LoadError: Mutating arrays is not supported
Indeed, Zygote does not currently support mutating arrays, there is some work in progress that has not yet been merged. One possible workaround is to roll your own flip
, that does not involve reverse
(where mutation takes place):
rev(x) = x[end:-1:1]
flip(f, x) = rev(f.(rev(x)))
Disclaimer: performance not tested. I imagine there would be a penalty. Maybe try using views.
Otherwise, RNNs seem to work ok with Zygote, just tested with speech-blstm
This seems fixed on master. So should probably be closed
Hi,
How would you build a bidirectional RNN with Flux? Drawing from a few examples
I've written
errors with
Flux does not like reversing the input data.
Other examples (A, B) use
but does not produce a compatible output shape for this example.
How would you approach it?