FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.51k stars 608 forks source link

Zygote errors building bidirectional RNN #962

Closed rana closed 4 years ago

rana commented 4 years ago

Hi,

How would you build a bidirectional RNN with Flux? Drawing from a few examples

  1. Bidirectional LSTM example A
  2. Bidirectional LSTM example B
  3. Knet bidirectional RNN source

I've written

using Pkg; for p in ["Flux"] Pkg.add(p) end
using Flux

# Bidirectional RNN
struct BRNN{L,D}
  forward  :: L
  backward :: L
  dense    :: D
end

Flux.@functor BRNN

function BRNN(in::Integer, hidden::Integer, out::Integer, σ = relu)
  return BRNN(
    RNN(in, hidden, σ), # forward
    RNN(in, hidden, σ), # backward
    Dense(2hidden, out, σ)
  )
end

function (m::BRNN)(xs)
  m.dense(vcat(m.forward(xs), reverse(m.backward(reverse(xs)))))
end

inSize = 5
hiddenSize = 3
outSize = 1

trn = [(rand(inSize), rand(outSize)) for _ in 1:8]
@info "trn", summary(trn)

m = BRNN(inSize, hiddenSize, outSize)
loss(x, y) = Flux.mse(m(x), y)
ps = Flux.params(m)
opt = ADAM()

Flux.train!(loss, ps, trn, opt)

errors with

ERROR: LoadError: Mutating arrays is not supported

Flux does not like reversing the input data.

reverse(m.backward(reverse(xs)))

Other examples (A, B) use

Flux.flip(m.backward, xs)

but does not produce a compatible output shape for this example.

How would you approach it?

tanhevg commented 4 years ago

If you look closely at the Example B, both forward and backward RNNs are invoked via broadcasting (the backward broadcasting is hidden in flip). The underlying convention for RNNs in Flux is that the data should come in as a vector of vectors, where each inner vector is of size in, and the outer vector represents the sequence. It follows, that the dense layer and vcat also have to be applied element-wise (via broadcasting), and the data needs to be transformed accordingly.

rana commented 4 years ago

@tanhevg Thanks for your thoughtful comments. You helped me move this forward.

Modifying (m::BRNN)(xs) to

function (m::BRNN)(xs)
  @info "xs", xs
  fwd = m.forward(xs)
  rev = reverse(xs)
  @info "reverse(xs)", rev
  bak = m.backward(rev)
  @info "m.backward(reverse(xs))", bak
  revBak = reverse(bak)
  @info "reverse(m.backward(reverse(xs)))", revBak
  fwdBak = vcat(fwd, revBak)
  @info "vcat(m.forward(xs), reverse(m.backward(reverse(xs))))", fwdBak
  ŷ = m.dense(fwdBak)
  @info "m.dense(vcat(m.forward(xs), reverse(m.backward(reverse(xs)))))", ŷ
  ŷ
end

outputs

[ Info: ("xs", [0.4024241926911283, 0.0031736700674680485, 0.5874856115265599, 0.6248590811960677, 0.6107094805087905])
[ Info: ("reverse(xs)", [0.6107094805087905, 0.6248590811960677, 0.5874856115265599, 0.0031736700674680485, 0.4024241926911283])
[ Info: ("m.backward(reverse(xs))", [0.0, 0.0, 0.0])
[ Info: ("reverse(m.backward(reverse(xs)))", [0.0, 0.0, 0.0])
[ Info: ("vcat(m.forward(xs), reverse(m.backward(reverse(xs))))", [0.920278546440162, 0.8759832558156826, 1.3535489921199844, 0.0, 0.0, 0.0])
[ Info: ("m.dense(vcat(m.forward(xs), reverse(m.backward(reverse(xs)))))", Float32[0.0])

making the data shapes readable.

The data is in the correct shape.

The change in code produces a different error

ERROR: LoadError: Compiling Tuple{BRNN{Flux.Recur{Flux.RNNCell{typeof(relu),Array{Float32,2},Array{Float32,1}}},Dense{typeof(relu),Array{Float32,2},Array{Float32,1}}},Array{Float64,1}}: try/catch is not supported.

From what I can tell this indicates a defect in Zygote.

rana commented 4 years ago

Maybe the multi input layer, parallel layer, Add Parallel and Bi layer pull requests would address?

tanhevg commented 4 years ago

In the snippet both forward and backward are invoked on a single data point, which is a vector of 5 elements. This does not really make sense for RNNs; their recursive feature comes into play when they are invoked multiple times between resets. So either change the dimension of your RNNs to 1 and broadcast them on the existing data, or change the data to be a vector of vectors, as I outlined above. Either way, broadcasting is needed for the inner state of RNN to accumulate, or some other way to repeatedly invoke it, like a for loop. Broadcasting BRNN will not do, because forward and backward must run in opposite directions.

I am not sure why there are zeros in the output; could be because of the relu activation; could try it with identity and see if it works.

Re this error:

ERROR: LoadError: Mutating arrays is not supported

Indeed, Zygote does not currently support mutating arrays, there is some work in progress that has not yet been merged. One possible workaround is to roll your own flip, that does not involve reverse (where mutation takes place):

rev(x) = x[end:-1:1]
flip(f, x) = rev(f.(rev(x)))

Disclaimer: performance not tested. I imagine there would be a penalty. Maybe try using views.

Otherwise, RNNs seem to work ok with Zygote, just tested with speech-blstm

AzamatB commented 4 years ago

This seems fixed on master. So should probably be closed