FluxML / FluxTraining.jl

A flexible neural net training library inspired by fast.ai
https://fluxml.ai/FluxTraining.jl
MIT License
117 stars 25 forks source link

recurrent example for docs #144

Open ExpandingMan opened 1 year ago

ExpandingMan commented 1 year ago

Motivation and description

Dealing with recurrent networks presents a lot of questions because it works rather differently from the stateless case.

I think it would be extremely helpful to have explicit examples: one for sequence-to-sequence and one for sequence-to-one.

Possible Implementation

I might come back and contribute this, but as I'm posting this I still don't think I'm doing this the intended way...

lorenzoh commented 1 year ago

I haven’t used the library for recurrent nets, so would be interested to see how this works and am open to changes of API if necessary 👍

ExpandingMan commented 1 year ago

I've messed around with it more since writing this. Recurrent nets seem to require a fair amount of dedicated code, I'm not sure if FluxTraining.jl would be the place for all of it. In particular, I've found myself needing to write functions to:

Additionally I wonder if the way sequences are stored is uniform in the ecosystem. The flux documentation itself strongly suggests that sequences should be nested arrays rather than rank-3 arrays.

darsnack commented 1 year ago

I have used Flux + FluxTraining quite a bit for recurrent models in the past. In general, you shouldn't need to do anything special. Most of the work is related primarily to Flux and how it expects the data. Here is the situation I almost always end up in, and it might be useful for you.

  1. There is a function generate(T) that creates a D x T matrix of where T is time and D is the feature dimension.
  2. I can generate a vector of samples as samples = [generate(T) for _ in 1:nsamples].
  3. I use Flux.batchseq(samples) to turn this into a sequence of batches (from a batch of sequences). This is the key step.
  4. You can generate many batches repeating the above steps with nsamples = batch_size for nbatches iterations.

In general as long as you think of a single sample in your dataset as a single sequence, then you can adapt the steps above to get them into the sequence of batches (samples) that Flux wants.

From there, achieving the different tasks is all in the loss function.

# seq to seq prediction
function seq2seq_loss(loss_fn)
    function _loss(m, xs, ys)
         yhats = [m(xi) for xi in xs]
         return mean(loss_fn(yhat, yi) for (yhat, yi) in zip(yhats, yi))
    end

    return _loss
end

# seq to one prediction
function seq2one_loss(loss_fn)
    function _loss(m, xs, ys)
         yhats = [m(xi) for xi in xs]
         return loss_fn(yhats[end], ys[end])
    end

    return _loss
end

# samplers for mapping the previous token to the next token
# used below in sample_model
sample_softmax(y::AbstractVector) =
    Flux.onehot(rand(Categorical(softmax(y))), 1:length(y))
function sample_softmax(y::AbstractMatrix)
    ŷs = [rand(Categorical(y)) for y in eachcol(softmax(ys))]

    return Flux.onehotbatch(ŷs, 1:size(y, 1))
end

sample_best(ys::AbstractVecOrMat) = Flux.onehot(argmax(ys; dims = 1), 1:size(ys, 1))

# recurrently predict a sequence given a primer input sequence
function sample_model(model, nseq, primer = [], sampler = identity)
    Flux.reset!(model)
    tokens = [model(x) for x in primer]
    ncurrent = length(tokens)
    while ncurrent < nseq
        nexttoken = model(sampler(last(tokens)))
        push!(tokens, nexttoken)
        ncurrent += 1
    end

    return tokens
end

Note that batching does not affect any of the functions above. As long as you get the "sequence of batches" format right, you should be good.

If you still want to express all this using FluxTraining, then the following is something I've used in the past.

get_inout_seq(xs::AbstractVector) = xs[1:(end - 1)], xs[2:end]
get_inout_seq(xs::NTuple{2}) = xs[1], xs[2]

struct BPTTTrainingPhase <: AbstractTrainingPhase end

function FluxTraining.step!(learner, phase::BPTTTrainingPhase, batch)
    xs, ys = get_inout_seq(batch)
    FluxTraining.runstep(learner, phase, (xs = xs, ys = ys)) do handle, state
        Flux.reset!(learner.model)
        state.grads = gradient(learner.params) do
            state.ŷs = [learner.model(xi) for xi in state.xs]
            handle(FluxTraining.LossBegin())
            state.loss = learner.lossfn(state.ŷs, state.ys)

            handle(FluxTraining.BackwardBegin())
            return state.loss
        end
        handle(FluxTraining.BackwardEnd())
        Flux.update!(learner.optimizer, learner.params, state.grads)
    end
end

struct BPTTValidationPhase <: AbstractValidationPhase
    nfeedback::Int
    sampler
end
BPTTValidationPhase() = BPTTValidationPhase(0, identity)
BPTTValidationPhase(nfeedback) = BPTTValidationPhase(nfeedback, identity)

function FluxTraining.step!(learner, phase::BPTTValidationPhase, batch)
    xs, ys = get_inout_seq(batch)
    FluxTraining.runstep(learner, phase, (xs = xs, ys = ys)) do _, state
        Flux.reset!(learner.model)
        n = length(state.xs) - phase.nfeedback
        # n steps where input drives model
        state.ŷs = [learner.model(state.xs[i]) for i in 1:n]
        # nfeedback steps where the model drives itself
        for _ in (n + 1):length(state.xs)
            ŷ = phase.sampler(state.ŷs[end])
            push!(state.ŷs, learner.model(ŷ))
        end
        state.loss = learner.lossfn(state.ŷs, state.ys)
    end
end

I don't need to do this for training recurrent models, but I found it nice for a particular project where BPTT was the thing I was comparing against. Specifically, BPTTValidationPhase is nice for allowing evaluating models in the recurrently driven mode where they feed their own input.

darsnack commented 1 year ago

If your data is already in a big rank-3 array, then you can make your axis order as feature x samples x time, and use Base.Iterators or MLUtils.jl to partition this along second axis into a vector of feature x batch x time chunks. A Recur model in Flux should consume these chunks correctly.

Otherwise, I find the approach of treating each sample as a self-contained time series is the most intuitive and compatible with existing data wrangling/loading packages like MLUtils.jl. Just remember to batchseq before passing to the Flux model.

ToucheSir commented 1 year ago

Additionally I wonder if the way sequences are stored is uniform in the ecosystem. The flux documentation itself strongly suggests that sequences should be nested arrays rather than rank-3 arrays.

Note that we actually do support 3D arrays of shape (features, batch, timesteps) as inputs to RNN layers. The reason it's not documented/advertised is we're not sure whether the API makes sense. For example, how do you differentiate between a batched sequence input to a normal RNN and one timestep of input to a conv-based RNN? The current implementation also does the same partitioning by timesteps you'd do by hand internally, so it should be slower than Kyle's suggestion above.

darsnack commented 1 year ago

Note I edited my comments from the original to correct a mistake in the order of the axis dimensions. Clearly, the time I've been spending with Jax recently is leaking...

fujiehuang commented 1 year ago

i'm trying to understand how Zygote does the gradient accumulation, in case of a RNN. In the following I'm comparing the result with a manual gradient accumulation, and the result is different. What could be the reason here? The code is self-contained and runnable.

using Flux 
using Random
Random.seed!(149)

# x in format (feature, samples, timesteps)
x = reshape([0.84147096, 0.9092974, 0.14112], 1, 1, 3)
y = -0.7568025

layer1 = Flux.Recur(Flux.RNNCell(1 => 5, tanh))
layer2 = Flux.Dense( 5 => 1 )
model = Flux.Chain(layer1, layer2)

Flux.reset!(model)
e, g = Flux.withgradient(model, x, y) do m, xi, yi
    yhat = [m(xi[:,:,i]) for i in 1:3]    # timesteps = 3
    return Flux.mse(yhat[3], yi)
end
println("flux gradient dWx: ", g[1][1][1].cell.Wi)

#-------- get individual gradients at each step -----------------
c1 = deepcopy(layer1.cell)
c2 = deepcopy(c1)
c3 = deepcopy(c2)

h0 = zeros(5, 1)  # initial state zero 
e3, f = Flux.withgradient(c1.Wi, c2.Wi, c3.Wi, 
    c1.Wh, c2.Wh, c3.Wh, 
    c1.b, c2.b, c3.b) do Wi1, Wi2, Wi3,   Wh1, Wh2, Wh3,  b1, b2, b3

    h1 = tanh.( Wi1 * x[:,:,1] + Wh1 * h0 + b1);  y1 = layer2(h1)
    h2 = tanh.( Wi2 * x[:,:,2] + Wh2 * h1 + b2);  y2 = layer2(h2)
    h3 = tanh.( Wi3 * x[:,:,3] + Wh3 * h2 + b3);  y3 = layer2(h3)

    Flux.mse(y3, y)
end
println("accumulated dWx:   ", f[1]+f[2]+f[3])