FluxML / Flux.jl

Relax! Flux is the ML library that doesn't make you tensor
https://fluxml.ai/
Other
4.53k stars 609 forks source link

Flux new explicit API does not work but old implicit API works for a simple RNN #2341

Open liuyxpp opened 1 year ago

liuyxpp commented 1 year ago

I am trying to reproduce the tutorial A Basic RNN using Flux.jl v0.14.6. Using the old Flux API as in the tutorial, the model can be successfully trained. The code is

using Flux

num_samples = 1000
num_epochs = 50

function generate_data(num_samples)
    train_data = [reshape(rand(Float32.(1.0:10.0), rand(2:2:10)), 2, :) for i in 1:num_samples]
    train_labels = (v -> sum(v)).(train_data)

    test_data = 2 .* train_data
    test_labels = 2 .* train_labels

    train_data, train_labels, test_data, test_labels
end

train_data, train_labels, test_data, test_labels = generate_data(num_samples)

model = Flux.RNN(2, 1, (x -> x))

function eval_model(x)
    Flux.reset!(model)
    out = [model(view(x, :, t)) for t in axes(x, 2)]
    out[end] |> first
end

loss(x, y) = abs(sum(eval_model(x) .- y))

evalcb() = @show(sum(loss.(test_data, test_labels)))

ps = Flux.params(model)

opt = Flux.ADAM(0.1)

for epoch in 1:num_epochs
    Flux.train!(loss, ps, zip(train_data, train_labels), opt, cb = Flux.throttle(evalcb, 1))
end

However, refractor the above code to use the new explicit API, Zygote complains:

ERROR: LoadError: MethodError: no method matching +(::@NamedTuple{cell::@NamedTuple{σ::Nothing, Wi::Matrix{Float32}, Wh::Matrix{Float32}, b::Vector{Float32}, state0::Nothing}, state::Matrix{Float32}}, ::Base.RefValue{Any})

The code is as follows:

using Flux
using Statistics

num_samples = 1000
num_epochs = 50

function generate_data(num_samples)
    train_data = [reshape(rand(Float32.(1.0:10.0), rand(2:2:10)), 2, :) for i in 1:num_samples]
    train_labels = (v -> sum(v)).(train_data)

    test_data = 2 .* train_data
    test_labels = 2 .* train_labels

    train_data, train_labels, test_data, test_labels
end

train_data, train_labels, test_data, test_labels = generate_data(num_samples)

model = Flux.RNN(2, 1, (x -> x))

function eval_model(model, x)
    # Comment following line to make it run.
    # However, in the Flux doc, the following line is required.
    Flux.reset!(model)
    out = [model(view(x, :, t)) for t in axes(x, 2)]
    out[end] |> first
end

loss(model, x, y) = abs(sum(eval_model(model, x) .- y))

opt_state = Flux.setup(Flux.ADAM(0.1), model)

for epoch in 1:num_epochs
    for (x, y) in zip(train_data, train_labels)
        train_loss, grads = Flux.withgradient(model) do m
            loss(m, x, y)
        end
        Flux.update!(opt_state, model, grads[1])
    end
    test_loss = mean(loss.(Ref(model), test_data, test_labels))
    println("Epoch $epoch, loss = $test_loss")
end

# Following codes also failed to run.
# for epoch in 1:num_epochs
#     Flux.train!(model, zip(train_data, train_labels), opt_state) do m, x, y
#         loss(m, x, y)
#     end
# end

Julia version info

julia> versioninfo()
Julia Version 1.10.0-beta2
Commit a468aa198d0 (2023-08-17 06:27 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 128 × Intel(R) Xeon(R) Platinum 8362 CPU @ 2.80GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-15.0.7 (ORCJIT, icelake-server)
  Threads: 1 on 128 virtual cores
ToucheSir commented 1 year ago

If you're ok with the initial state being non-trainable, then using one of the functions under https://juliadiff.org/ChainRulesCore.jl/stable/api.html#Ignoring-gradients on the reset! line should work. e.g. @ignore_derivatives Flux.reset!(model). Moving the call to reset! outside of the loss function would also do the trick.

liuyxpp commented 1 year ago

Ah, thanks! Can you explain more why does this fail for explicit mode but not implicit mode?

BTW, if I have extra data to train the initial state for each time sequence, how should I do that?

ToucheSir commented 1 year ago

I'm not sure why it fails. The RNN API is a weird one because it uses some of the implicit mode machinery even when you use explicit mode.

if I have extra data to train the initial state for each time sequence, how should I do that?

If you want to have separate initial states for each sample like you mentioned in https://github.com/FluxML/Flux.jl/issues/2185#issuecomment-1736563421, the best bet would be to use the underlying RNN cell API (e.g. RNN -> RNNCell) and write your own loop over the timesteps. It'll be more manual work than using the Recur-based API, but it should just work and also avoid the MethodError shown above.

liuyxpp commented 1 year ago

Got that and I will report back once I figure it out. Many thanks!